diff --git a/handlers/account.go b/handlers/account.go index c1ec735..2d53a6d 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -3,6 +3,7 @@ package handlers import ( "html/template" "net/http" + "synlotto-website/helpers" "synlotto-website/models" "github.com/gorilla/csrf" @@ -15,8 +16,18 @@ func Login(w http.ResponseWriter, r *http.Request) { "templates/account/login.html", )) + session, _ := helpers.GetSession(w, r) + + var flash string + if f, ok := session.Values["flash"].(string); ok { + flash = f + delete(session.Values, "flash") + session.Save(r, w) + } + tmpl.ExecuteTemplate(w, "layout", map[string]interface{}{ "csrfField": csrf.TemplateField(r), + "Flash": flash, }) return } @@ -30,19 +41,30 @@ func Login(w http.ResponseWriter, r *http.Request) { return } - session, _ := GetSession(w, r) - session.Values["user_id"] = user.Id + session, err := helpers.GetSession(w, r) + if err != nil { + http.Error(w, "Session error", http.StatusInternalServerError) + return + } + session.Options.MaxAge = -1 session.Save(r, w) + newSession, _ := helpers.GetSession(w, r) + newSession.Values["user_id"] = user.Id + newSession.Save(r, w) + http.Redirect(w, r, "/", http.StatusSeeOther) } func Logout(w http.ResponseWriter, r *http.Request) { - session, _ := GetSession(w, r) - + session, _ := helpers.GetSession(w, r) session.Options.MaxAge = -1 session.Save(r, w) + newSession, _ := helpers.GetSession(w, r) + newSession.Values["flash"] = "You’ve been logged out" + newSession.Save(r, w) + http.Redirect(w, r, "/login", http.StatusSeeOther) } diff --git a/handlers/draw_handler.go b/handlers/draw_handler.go index e8cf0cd..5126464 100644 --- a/handlers/draw_handler.go +++ b/handlers/draw_handler.go @@ -40,7 +40,6 @@ func Home(db *sql.DB) http.HandlerFunc { continue } - // ✅ Add sorted ball list res.SortedBalls = []int{ res.Ball1, res.Ball2, res.Ball3, res.Ball4, res.Ball5, } diff --git a/handlers/session.go b/handlers/session.go index 2f3e740..7f72a21 100644 --- a/handlers/session.go +++ b/handlers/session.go @@ -3,27 +3,11 @@ package handlers import ( "net/http" - "github.com/gorilla/sessions" + "synlotto-website/helpers" ) -var store = sessions.NewCookieStore([]byte("super-secret-key")) // ToDo: Make global - -func init() { - store.Options = &sessions.Options{ - Path: "/", - MaxAge: 86400 * 1, - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - } -} - -func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { - return store.Get(r, "session-name") -} - func GetCurrentUserID(r *http.Request) (int, bool) { - session, err := GetSession(nil, r) + session, err := helpers.GetSession(nil, r) if err != nil { return 0, false } diff --git a/helpers/session.go b/helpers/session.go new file mode 100644 index 0000000..bc1d4e0 --- /dev/null +++ b/helpers/session.go @@ -0,0 +1,23 @@ +package helpers + +import ( + "net/http" + + "github.com/gorilla/sessions" +) + +var store = sessions.NewCookieStore([]byte("super-secret-key")) // move this here + +func init() { + store.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 1, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + } +} + +func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { + return store.Get(r, "session-name") +} diff --git a/helpers/template.go b/helpers/template.go index ebeb89d..7d92a3a 100644 --- a/helpers/template.go +++ b/helpers/template.go @@ -2,6 +2,8 @@ package helpers import ( "html/template" + "net/http" + "synlotto-website/models" ) func TemplateFuncs() template.FuncMap { @@ -23,3 +25,24 @@ func TemplateFuncs() template.FuncMap { }, } } + +func TemplateContext(w http.ResponseWriter, r *http.Request) map[string]interface{} { + session, _ := GetSession(w, r) + + var flash string + if f, ok := session.Values["flash"].(string); ok { + flash = f + delete(session.Values, "flash") + session.Save(r, w) + } + + var currentUser *models.User + if userId, ok := session.Values["user_id"].(int); ok { + currentUser = models.GetUserByID(userId) + } + + return map[string]interface{}{ + "Flash": flash, + "User": currentUser, + } +} diff --git a/models/user.go b/models/user.go index b399a9c..1744e48 100644 --- a/models/user.go +++ b/models/user.go @@ -35,3 +35,17 @@ func GetUserByUsername(username string) *User { } return &user } + +func GetUserByID(id int) *User { + row := db.QueryRow("SELECT id, username, password_hash FROM users WHERE id = ?", id) + + var user User + err := row.Scan(&user.Id, &user.Username, &user.PasswordHash) + if err != nil { + if err != sql.ErrNoRows { + log.Println("DB error:", err) + } + return nil + } + return &user +} diff --git a/templates/account/login.html b/templates/account/login.html index 720dc1b..fb58e38 100644 --- a/templates/account/login.html +++ b/templates/account/login.html @@ -1,4 +1,7 @@ {{ define "content" }} +{{ if .Flash }} +
{{ .Flash }}
+{{ end }}