diff --git a/handlers/account.go b/handlers/account.go index 2d53a6d..32d3933 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -5,30 +5,28 @@ import ( "net/http" "synlotto-website/helpers" "synlotto-website/models" + "time" "github.com/gorilla/csrf" ) func Login(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { + session, _ := helpers.GetSession(w, r) + if _, ok := session.Values["user_id"].(int); ok { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + tmpl := template.Must(template.ParseFiles( "templates/layout.html", "templates/account/login.html", )) - session, _ := helpers.GetSession(w, r) + context := helpers.TemplateContext(w, r) + context["csrfField"] = csrf.TemplateField(r) - var flash string - if f, ok := session.Values["flash"].(string); ok { - flash = f - delete(session.Values, "flash") - session.Save(r, w) - } - - tmpl.ExecuteTemplate(w, "layout", map[string]interface{}{ - "csrfField": csrf.TemplateField(r), - "Flash": flash, - }) + tmpl.ExecuteTemplate(w, "layout", context) return } @@ -36,21 +34,27 @@ func Login(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") user := models.GetUserByUsername(username) - if user == nil || !CheckPasswordHash(user.PasswordHash, password) { + if user == nil || !helpers.CheckPasswordHash(user.PasswordHash, password) { http.Error(w, "Invalid credentials", http.StatusUnauthorized) return } - session, err := helpers.GetSession(w, r) - if err != nil { - http.Error(w, "Session error", http.StatusInternalServerError) - return - } + session, _ := helpers.GetSession(w, r) session.Options.MaxAge = -1 session.Save(r, w) + remember := r.FormValue("remember") == "on" + newSession, _ := helpers.GetSession(w, r) newSession.Values["user_id"] = user.Id + newSession.Values["last_activity"] = time.Now() + + if remember { + newSession.Options.MaxAge = 60 * 60 * 24 * 30 // 30 days + } else { + newSession.Options.MaxAge = 0 + } + newSession.Save(r, w) http.Redirect(w, r, "/", http.StatusSeeOther) @@ -84,7 +88,7 @@ func Signup(w http.ResponseWriter, r *http.Request) { username := r.FormValue("username") password := r.FormValue("password") - hashed, err := HashPassword(password) + hashed, err := helpers.HashPassword(password) if err != nil { http.Error(w, "Server error", http.StatusInternalServerError) return diff --git a/handlers/session.go b/handlers/session.go deleted file mode 100644 index 7f72a21..0000000 --- a/handlers/session.go +++ /dev/null @@ -1,28 +0,0 @@ -package handlers - -import ( - "net/http" - - "synlotto-website/helpers" -) - -func GetCurrentUserID(r *http.Request) (int, bool) { - session, err := helpers.GetSession(nil, r) - if err != nil { - return 0, false - } - - id, ok := session.Values["user_id"].(int) - return id, ok -} - -func RequireAuth(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, ok := GetCurrentUserID(r) - if !ok { - http.Redirect(w, r, "/login", http.StatusSeeOther) - return - } - next(w, r) - } -} diff --git a/handlers/ticket_handler.go b/handlers/ticket_handler.go index 6970e13..0c847db 100644 --- a/handlers/ticket_handler.go +++ b/handlers/ticket_handler.go @@ -35,7 +35,7 @@ func NewTicket(db *sql.DB) http.HandlerFunc { func SubmitTicket(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if _, ok := GetCurrentUserID(r); !ok { + if _, ok := helpers.GetCurrentUserID(r); !ok { http.Redirect(w, r, "/login", http.StatusSeeOther) return } diff --git a/handlers/auth.go b/helpers/auth.go similarity index 95% rename from handlers/auth.go rename to helpers/auth.go index 160b364..8bc848f 100644 --- a/handlers/auth.go +++ b/helpers/auth.go @@ -1,4 +1,4 @@ -package handlers +package helpers import "golang.org/x/crypto/bcrypt" diff --git a/helpers/session.go b/helpers/session.go index bc1d4e0..5c86c67 100644 --- a/helpers/session.go +++ b/helpers/session.go @@ -2,11 +2,13 @@ package helpers import ( "net/http" + "time" "github.com/gorilla/sessions" ) -var store = sessions.NewCookieStore([]byte("super-secret-key")) // move this here +var store = sessions.NewCookieStore([]byte("super-secret-key")) // //ToDo make key global +const SessionTimeout = 30 * time.Minute func init() { store.Options = &sessions.Options{ @@ -21,3 +23,48 @@ func init() { func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { return store.Get(r, "session-name") } + +func IsSessionExpired(session *sessions.Session) bool { + last, ok := session.Values["last_activity"].(time.Time) + if !ok { + return false + } + return time.Since(last) > SessionTimeout +} + +func UpdateSessionActivity(session *sessions.Session, r *http.Request, w http.ResponseWriter) { + session.Values["last_activity"] = time.Now() + session.Save(r, w) +} + +func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + session, _ := GetSession(w, r) + + if IsSessionExpired(session) { + session.Options.MaxAge = -1 + session.Save(r, w) + + newSession, _ := GetSession(w, r) + newSession.Values["flash"] = "Your session has timed out." + newSession.Save(r, w) + + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + UpdateSessionActivity(session, r, w) + + next(w, r) + } +} + +func GetCurrentUserID(r *http.Request) (int, bool) { + session, err := GetSession(nil, r) + if err != nil { + return 0, false + } + + id, ok := session.Values["user_id"].(int) + return id, ok +} diff --git a/main.go b/main.go index 4797d94..f90325f 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "net/http" "synlotto-website/handlers" "synlotto-website/helpers" + "synlotto-website/middleware" "synlotto-website/models" "synlotto-website/storage" @@ -17,20 +18,21 @@ func main() { csrfMiddleware := csrf.Protect( []byte("32-byte-long-auth-key-here"), // TodO: Make Global - csrf.Secure(false), + csrf.Secure(true), + csrf.Path("/"), ) mux := http.NewServeMux() mux.HandleFunc("/", handlers.Home(db)) - mux.HandleFunc("/new", handlers.NewDraw) + mux.HandleFunc("/new", handlers.NewDraw) // ToDo: needs to be wrapped in admin auth mux.HandleFunc("/submit", handlers.Submit) mux.HandleFunc("/ticket", handlers.NewTicket(db)) - mux.HandleFunc("/tickets", handlers.ListTickets(db)) - mux.HandleFunc("/submit-ticket", handlers.RequireAuth(handlers.SubmitTicket(db))) - mux.HandleFunc("/login", handlers.Login) + mux.HandleFunc("/tickets", middleware.Auth(true)(handlers.ListTickets(db))) + mux.HandleFunc("/submit-ticket", helpers.AuthMiddleware(handlers.SubmitTicket(db))) + mux.HandleFunc("/login", middleware.Auth(false)(handlers.Login)) mux.HandleFunc("/logout", handlers.Logout) - mux.HandleFunc("/signup", handlers.Signup) + mux.HandleFunc("/signup", middleware.Auth(false)(handlers.Signup)) // Result pages mux.HandleFunc("/results/thunderball", handlers.ResultsThunderball(db)) diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..40c4bb2 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "net/http" + "time" + + "synlotto-website/helpers" +) + +const SessionTimeout = 30 * time.Minute + +func Auth(required bool) func(http.HandlerFunc) http.HandlerFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + session, _ := helpers.GetSession(w, r) + + _, ok := session.Values["user_id"].(int) + + if required && !ok { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + if ok { + last, hasLast := session.Values["last_activity"].(time.Time) + if hasLast && time.Since(last) > SessionTimeout { + session.Options.MaxAge = -1 + session.Save(r, w) + + newSession, _ := helpers.GetSession(w, r) + newSession.Values["flash"] = "Your session has timed out." + newSession.Save(r, w) + + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + session.Values["last_activity"] = time.Now() + session.Save(r, w) + } + + next(w, r) + } + } +} diff --git a/templates/account/login.html b/templates/account/login.html index fb58e38..c918d5f 100644 --- a/templates/account/login.html +++ b/templates/account/login.html @@ -7,6 +7,7 @@ {{ .csrfField }}

+ {{ end }} \ No newline at end of file