Circular dependancy issue when working on hardening
This commit is contained in:
@@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"synlotto-website/helpers"
|
||||||
"synlotto-website/models"
|
"synlotto-website/models"
|
||||||
|
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
@@ -15,8 +16,18 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
"templates/account/login.html",
|
"templates/account/login.html",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
session, _ := helpers.GetSession(w, r)
|
||||||
|
|
||||||
|
var flash string
|
||||||
|
if f, ok := session.Values["flash"].(string); ok {
|
||||||
|
flash = f
|
||||||
|
delete(session.Values, "flash")
|
||||||
|
session.Save(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
tmpl.ExecuteTemplate(w, "layout", map[string]interface{}{
|
tmpl.ExecuteTemplate(w, "layout", map[string]interface{}{
|
||||||
"csrfField": csrf.TemplateField(r),
|
"csrfField": csrf.TemplateField(r),
|
||||||
|
"Flash": flash,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -30,19 +41,30 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, _ := GetSession(w, r)
|
session, err := helpers.GetSession(w, r)
|
||||||
session.Values["user_id"] = user.Id
|
if err != nil {
|
||||||
|
http.Error(w, "Session error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session.Options.MaxAge = -1
|
||||||
session.Save(r, w)
|
session.Save(r, w)
|
||||||
|
|
||||||
|
newSession, _ := helpers.GetSession(w, r)
|
||||||
|
newSession.Values["user_id"] = user.Id
|
||||||
|
newSession.Save(r, w)
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Logout(w http.ResponseWriter, r *http.Request) {
|
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||||
session, _ := GetSession(w, r)
|
session, _ := helpers.GetSession(w, r)
|
||||||
|
|
||||||
session.Options.MaxAge = -1
|
session.Options.MaxAge = -1
|
||||||
session.Save(r, w)
|
session.Save(r, w)
|
||||||
|
|
||||||
|
newSession, _ := helpers.GetSession(w, r)
|
||||||
|
newSession.Values["flash"] = "You’ve been logged out"
|
||||||
|
newSession.Save(r, w)
|
||||||
|
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ func Home(db *sql.DB) http.HandlerFunc {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// ✅ Add sorted ball list
|
|
||||||
res.SortedBalls = []int{
|
res.SortedBalls = []int{
|
||||||
res.Ball1, res.Ball2, res.Ball3, res.Ball4, res.Ball5,
|
res.Ball1, res.Ball2, res.Ball3, res.Ball4, res.Ball5,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,27 +3,11 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"synlotto-website/helpers"
|
||||||
)
|
)
|
||||||
|
|
||||||
var store = sessions.NewCookieStore([]byte("super-secret-key")) // ToDo: Make global
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
store.Options = &sessions.Options{
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: 86400 * 1,
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
|
|
||||||
return store.Get(r, "session-name")
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetCurrentUserID(r *http.Request) (int, bool) {
|
func GetCurrentUserID(r *http.Request) (int, bool) {
|
||||||
session, err := GetSession(nil, r)
|
session, err := helpers.GetSession(nil, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|||||||
23
helpers/session.go
Normal file
23
helpers/session.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var store = sessions.NewCookieStore([]byte("super-secret-key")) // move this here
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
store.Options = &sessions.Options{
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 86400 * 1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
|
||||||
|
return store.Get(r, "session-name")
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@ package helpers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"synlotto-website/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TemplateFuncs() template.FuncMap {
|
func TemplateFuncs() template.FuncMap {
|
||||||
@@ -23,3 +25,24 @@ func TemplateFuncs() template.FuncMap {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TemplateContext(w http.ResponseWriter, r *http.Request) map[string]interface{} {
|
||||||
|
session, _ := GetSession(w, r)
|
||||||
|
|
||||||
|
var flash string
|
||||||
|
if f, ok := session.Values["flash"].(string); ok {
|
||||||
|
flash = f
|
||||||
|
delete(session.Values, "flash")
|
||||||
|
session.Save(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentUser *models.User
|
||||||
|
if userId, ok := session.Values["user_id"].(int); ok {
|
||||||
|
currentUser = models.GetUserByID(userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"Flash": flash,
|
||||||
|
"User": currentUser,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,3 +35,17 @@ func GetUserByUsername(username string) *User {
|
|||||||
}
|
}
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUserByID(id int) *User {
|
||||||
|
row := db.QueryRow("SELECT id, username, password_hash FROM users WHERE id = ?", id)
|
||||||
|
|
||||||
|
var user User
|
||||||
|
err := row.Scan(&user.Id, &user.Username, &user.PasswordHash)
|
||||||
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
log.Println("DB error:", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &user
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
{{ define "content" }}
|
{{ define "content" }}
|
||||||
|
{{ if .Flash }}
|
||||||
|
<p style="color: green;">{{ .Flash }}</p>
|
||||||
|
{{ end }}
|
||||||
<h2>Login</h2>
|
<h2>Login</h2>
|
||||||
<form method="POST" action="/login">
|
<form method="POST" action="/login">
|
||||||
{{ .csrfField }}
|
{{ .csrfField }}
|
||||||
|
|||||||
@@ -3,6 +3,11 @@
|
|||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
|
{{ if .User }}
|
||||||
|
<p>Hello, {{ .User.Username }} | <a href="/logout">Logout</a></p>
|
||||||
|
{{ else }}
|
||||||
|
<p><a href="/login">Login</a></p>
|
||||||
|
{{ end }}
|
||||||
<title>Lotto Tracker</title>
|
<title>Lotto Tracker</title>
|
||||||
<style>
|
<style>
|
||||||
body { font-family: Arial, sans-serif; margin: 40px; }
|
body { font-family: Arial, sans-serif; margin: 40px; }
|
||||||
|
|||||||
Reference in New Issue
Block a user