diff --git a/handlers/account.go b/handlers/account.go
index 2d53a6d..32d3933 100644
--- a/handlers/account.go
+++ b/handlers/account.go
@@ -5,30 +5,28 @@ import (
"net/http"
"synlotto-website/helpers"
"synlotto-website/models"
+ "time"
"github.com/gorilla/csrf"
)
func Login(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
+ session, _ := helpers.GetSession(w, r)
+ if _, ok := session.Values["user_id"].(int); ok {
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ return
+ }
+
tmpl := template.Must(template.ParseFiles(
"templates/layout.html",
"templates/account/login.html",
))
- session, _ := helpers.GetSession(w, r)
+ context := helpers.TemplateContext(w, r)
+ context["csrfField"] = csrf.TemplateField(r)
- var flash string
- if f, ok := session.Values["flash"].(string); ok {
- flash = f
- delete(session.Values, "flash")
- session.Save(r, w)
- }
-
- tmpl.ExecuteTemplate(w, "layout", map[string]interface{}{
- "csrfField": csrf.TemplateField(r),
- "Flash": flash,
- })
+ tmpl.ExecuteTemplate(w, "layout", context)
return
}
@@ -36,21 +34,27 @@ func Login(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
user := models.GetUserByUsername(username)
- if user == nil || !CheckPasswordHash(user.PasswordHash, password) {
+ if user == nil || !helpers.CheckPasswordHash(user.PasswordHash, password) {
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return
}
- session, err := helpers.GetSession(w, r)
- if err != nil {
- http.Error(w, "Session error", http.StatusInternalServerError)
- return
- }
+ session, _ := helpers.GetSession(w, r)
session.Options.MaxAge = -1
session.Save(r, w)
+ remember := r.FormValue("remember") == "on"
+
newSession, _ := helpers.GetSession(w, r)
newSession.Values["user_id"] = user.Id
+ newSession.Values["last_activity"] = time.Now()
+
+ if remember {
+ newSession.Options.MaxAge = 60 * 60 * 24 * 30 // 30 days
+ } else {
+ newSession.Options.MaxAge = 0
+ }
+
newSession.Save(r, w)
http.Redirect(w, r, "/", http.StatusSeeOther)
@@ -84,7 +88,7 @@ func Signup(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
password := r.FormValue("password")
- hashed, err := HashPassword(password)
+ hashed, err := helpers.HashPassword(password)
if err != nil {
http.Error(w, "Server error", http.StatusInternalServerError)
return
diff --git a/handlers/session.go b/handlers/session.go
deleted file mode 100644
index 7f72a21..0000000
--- a/handlers/session.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package handlers
-
-import (
- "net/http"
-
- "synlotto-website/helpers"
-)
-
-func GetCurrentUserID(r *http.Request) (int, bool) {
- session, err := helpers.GetSession(nil, r)
- if err != nil {
- return 0, false
- }
-
- id, ok := session.Values["user_id"].(int)
- return id, ok
-}
-
-func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- _, ok := GetCurrentUserID(r)
- if !ok {
- http.Redirect(w, r, "/login", http.StatusSeeOther)
- return
- }
- next(w, r)
- }
-}
diff --git a/handlers/ticket_handler.go b/handlers/ticket_handler.go
index 6970e13..0c847db 100644
--- a/handlers/ticket_handler.go
+++ b/handlers/ticket_handler.go
@@ -35,7 +35,7 @@ func NewTicket(db *sql.DB) http.HandlerFunc {
func SubmitTicket(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- if _, ok := GetCurrentUserID(r); !ok {
+ if _, ok := helpers.GetCurrentUserID(r); !ok {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
diff --git a/handlers/auth.go b/helpers/auth.go
similarity index 95%
rename from handlers/auth.go
rename to helpers/auth.go
index 160b364..8bc848f 100644
--- a/handlers/auth.go
+++ b/helpers/auth.go
@@ -1,4 +1,4 @@
-package handlers
+package helpers
import "golang.org/x/crypto/bcrypt"
diff --git a/helpers/session.go b/helpers/session.go
index bc1d4e0..5c86c67 100644
--- a/helpers/session.go
+++ b/helpers/session.go
@@ -2,11 +2,13 @@ package helpers
import (
"net/http"
+ "time"
"github.com/gorilla/sessions"
)
-var store = sessions.NewCookieStore([]byte("super-secret-key")) // move this here
+var store = sessions.NewCookieStore([]byte("super-secret-key")) // //ToDo make key global
+const SessionTimeout = 30 * time.Minute
func init() {
store.Options = &sessions.Options{
@@ -21,3 +23,48 @@ func init() {
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
return store.Get(r, "session-name")
}
+
+func IsSessionExpired(session *sessions.Session) bool {
+ last, ok := session.Values["last_activity"].(time.Time)
+ if !ok {
+ return false
+ }
+ return time.Since(last) > SessionTimeout
+}
+
+func UpdateSessionActivity(session *sessions.Session, r *http.Request, w http.ResponseWriter) {
+ session.Values["last_activity"] = time.Now()
+ session.Save(r, w)
+}
+
+func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ session, _ := GetSession(w, r)
+
+ if IsSessionExpired(session) {
+ session.Options.MaxAge = -1
+ session.Save(r, w)
+
+ newSession, _ := GetSession(w, r)
+ newSession.Values["flash"] = "Your session has timed out."
+ newSession.Save(r, w)
+
+ http.Redirect(w, r, "/login", http.StatusSeeOther)
+ return
+ }
+
+ UpdateSessionActivity(session, r, w)
+
+ next(w, r)
+ }
+}
+
+func GetCurrentUserID(r *http.Request) (int, bool) {
+ session, err := GetSession(nil, r)
+ if err != nil {
+ return 0, false
+ }
+
+ id, ok := session.Values["user_id"].(int)
+ return id, ok
+}
diff --git a/main.go b/main.go
index 4797d94..f90325f 100644
--- a/main.go
+++ b/main.go
@@ -5,6 +5,7 @@ import (
"net/http"
"synlotto-website/handlers"
"synlotto-website/helpers"
+ "synlotto-website/middleware"
"synlotto-website/models"
"synlotto-website/storage"
@@ -17,20 +18,21 @@ func main() {
csrfMiddleware := csrf.Protect(
[]byte("32-byte-long-auth-key-here"), // TodO: Make Global
- csrf.Secure(false),
+ csrf.Secure(true),
+ csrf.Path("/"),
)
mux := http.NewServeMux()
mux.HandleFunc("/", handlers.Home(db))
- mux.HandleFunc("/new", handlers.NewDraw)
+ mux.HandleFunc("/new", handlers.NewDraw) // ToDo: needs to be wrapped in admin auth
mux.HandleFunc("/submit", handlers.Submit)
mux.HandleFunc("/ticket", handlers.NewTicket(db))
- mux.HandleFunc("/tickets", handlers.ListTickets(db))
- mux.HandleFunc("/submit-ticket", handlers.RequireAuth(handlers.SubmitTicket(db)))
- mux.HandleFunc("/login", handlers.Login)
+ mux.HandleFunc("/tickets", middleware.Auth(true)(handlers.ListTickets(db)))
+ mux.HandleFunc("/submit-ticket", helpers.AuthMiddleware(handlers.SubmitTicket(db)))
+ mux.HandleFunc("/login", middleware.Auth(false)(handlers.Login))
mux.HandleFunc("/logout", handlers.Logout)
- mux.HandleFunc("/signup", handlers.Signup)
+ mux.HandleFunc("/signup", middleware.Auth(false)(handlers.Signup))
// Result pages
mux.HandleFunc("/results/thunderball", handlers.ResultsThunderball(db))
diff --git a/middleware/auth.go b/middleware/auth.go
new file mode 100644
index 0000000..40c4bb2
--- /dev/null
+++ b/middleware/auth.go
@@ -0,0 +1,45 @@
+package middleware
+
+import (
+ "net/http"
+ "time"
+
+ "synlotto-website/helpers"
+)
+
+const SessionTimeout = 30 * time.Minute
+
+func Auth(required bool) func(http.HandlerFunc) http.HandlerFunc {
+ return func(next http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ session, _ := helpers.GetSession(w, r)
+
+ _, ok := session.Values["user_id"].(int)
+
+ if required && !ok {
+ http.Redirect(w, r, "/login", http.StatusSeeOther)
+ return
+ }
+
+ if ok {
+ last, hasLast := session.Values["last_activity"].(time.Time)
+ if hasLast && time.Since(last) > SessionTimeout {
+ session.Options.MaxAge = -1
+ session.Save(r, w)
+
+ newSession, _ := helpers.GetSession(w, r)
+ newSession.Values["flash"] = "Your session has timed out."
+ newSession.Save(r, w)
+
+ http.Redirect(w, r, "/login", http.StatusSeeOther)
+ return
+ }
+
+ session.Values["last_activity"] = time.Now()
+ session.Save(r, w)
+ }
+
+ next(w, r)
+ }
+ }
+}
diff --git a/templates/account/login.html b/templates/account/login.html
index fb58e38..c918d5f 100644
--- a/templates/account/login.html
+++ b/templates/account/login.html
@@ -7,6 +7,7 @@
{{ .csrfField }}
+
{{ end }}
\ No newline at end of file