diff --git a/handlers/security/csrf.go b/bootstrap/csrf.go similarity index 96% rename from handlers/security/csrf.go rename to bootstrap/csrf.go index c44b08c..a576df6 100644 --- a/handlers/security/csrf.go +++ b/bootstrap/csrf.go @@ -1,4 +1,4 @@ -package security +package bootstrap import ( "fmt" diff --git a/bootstrap/session.go b/bootstrap/session.go index 6edaed1..326c2c6 100644 --- a/bootstrap/session.go +++ b/bootstrap/session.go @@ -1,14 +1,25 @@ package bootstrap import ( + "bytes" "crypto/rand" + "encoding/base64" + "fmt" + "net/http" "os" - securityhandlers "synlotto-website/handlers/security" - helpers "synlotto-website/helpers/session" "synlotto-website/logging" "synlotto-website/models" + + "github.com/gorilla/sessions" +) + +var ( + sessionStore *sessions.CookieStore + sessionName string + authKey []byte + encryptKey []byte ) func InitSession(cfg *models.Config) error { @@ -41,7 +52,7 @@ func InitSession(cfg *models.Config) error { } } - return securityhandlers.LoadSessionKeys( + return loadSessionKeys( authPath, encPath, cfg.Session.Name, @@ -59,3 +70,41 @@ func generateRandomBytes(length int) ([]byte, error) { } return b, nil } + +func loadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error { + var err error + + rawAuth, err := os.ReadFile(authPath) + if err != nil { + return fmt.Errorf("error reading auth key: %w", err) + } + authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth))) + if err != nil { + return fmt.Errorf("error decoding auth key: %w", err) + } + + rawEnc, err := os.ReadFile(encryptionPath) + if err != nil { + return fmt.Errorf("error reading encryption key: %w", err) + } + encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc))) + if err != nil { + return fmt.Errorf("error decoding encryption key: %w", err) + } + + if len(authKey) != 32 || len(encryptKey) != 32 { + return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey)) + } + + sessionStore = sessions.NewCookieStore(authKey, encryptKey) + sessionStore.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 1, + HttpOnly: true, + Secure: isProduction, + SameSite: http.SameSiteLaxMode, + } + + sessionName = name + return nil +} diff --git a/handlers/security/session.go b/handlers/security/session.go deleted file mode 100644 index ccd1d51..0000000 --- a/handlers/security/session.go +++ /dev/null @@ -1,101 +0,0 @@ -package security - -import ( - "bytes" - "encoding/base64" - "encoding/gob" - "fmt" - "net/http" - "os" - "time" - - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" -) - -var ( - sessionStore *sessions.CookieStore - sessionName string - authKey []byte - encryptKey []byte -) - -func init() { - gob.Register(time.Time{}) -} - -func LoadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error { - var err error - - rawAuth, err := os.ReadFile(authPath) - if err != nil { - return fmt.Errorf("error reading auth key: %w", err) - } - authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth))) - if err != nil { - return fmt.Errorf("error decoding auth key: %w", err) - } - - rawEnc, err := os.ReadFile(encryptionPath) - if err != nil { - return fmt.Errorf("error reading encryption key: %w", err) - } - encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc))) - if err != nil { - return fmt.Errorf("error decoding encryption key: %w", err) - } - - if len(authKey) != 32 || len(encryptKey) != 32 { - return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey)) - } - - sessionStore = sessions.NewCookieStore(authKey, encryptKey) - sessionStore.Options = &sessions.Options{ - Path: "/", - MaxAge: 86400 * 1, - HttpOnly: true, - Secure: isProduction, - SameSite: http.SameSiteLaxMode, - } - - sessionName = name - return nil -} - -func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { - return sessionStore.Get(r, sessionName) -} - -func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error { - s := securecookie.New(authKey, encryptKey) - - encoded, err := s.Encode(name, value) - if err != nil { - return err - } - - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: encoded, - Path: "/", - HttpOnly: true, - Secure: isProduction, - SameSite: http.SameSiteStrictMode, - }) - - return nil -} - -func LoadKeyFromFile(path string) ([]byte, error) { - key, err := os.ReadFile(path) - if err != nil { - return nil, err - } - return bytes.TrimSpace(key), nil -} - -func ZeroBytes(b []byte) { - for i := range b { - b[i] = 0 - } -} diff --git a/handlers/session/account.go b/handlers/session/account.go new file mode 100644 index 0000000..a6743ad --- /dev/null +++ b/handlers/session/account.go @@ -0,0 +1,16 @@ +package handlers + +import ( + "net/http" + + "github.com/gorilla/sessions" +) + +var ( + sessionStore *sessions.CookieStore + sessionName string +) + +func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { + return sessionStore.Get(r, sessionName) +} diff --git a/handlers/session/auth.go b/handlers/session/auth.go new file mode 100644 index 0000000..6da5493 --- /dev/null +++ b/handlers/session/auth.go @@ -0,0 +1,32 @@ +package handlers + +import ( + "net/http" + + "github.com/gorilla/securecookie" +) + +var ( + authKey []byte + encryptKey []byte +) + +func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error { + s := securecookie.New(authKey, encryptKey) + + encoded, err := s.Encode(name, value) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: encoded, + Path: "/", + HttpOnly: true, + Secure: isProduction, + SameSite: http.SameSiteStrictMode, + }) + + return nil +} diff --git a/helpers/session/loader.go b/helpers/session/loader.go new file mode 100644 index 0000000..d1153f0 --- /dev/null +++ b/helpers/session/loader.go @@ -0,0 +1,20 @@ +package helpers + +import ( + "bytes" + "os" +) + +func LoadKeyFromFile(path string) ([]byte, error) { + key, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return bytes.TrimSpace(key), nil +} + +func ZeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} diff --git a/main.go b/main.go index 7cf1d57..8c48285 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,6 @@ import ( "log" "net/http" - securityhandlers "synlotto-website/handlers/security" - "synlotto-website/bootstrap" "synlotto-website/config" "synlotto-website/handlers" @@ -32,7 +30,7 @@ func main() { logging.Error("❌ Failed to init session: %v", err) } - err = securityhandlers.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode) + err = bootstrap.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode) if err != nil { logging.Error("Failed to init CSRF: %v", err) } @@ -46,7 +44,7 @@ func main() { mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) mux.HandleFunc("/", handlers.Home(db)) - wrapped := securityhandlers.CSRFMiddleware(mux) + wrapped := bootstrap.CSRFMiddleware(mux) wrapped = middleware.RateLimit(wrapped) wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode) wrapped = middleware.SecureHeaders(wrapped)