diff --git a/handlers/results.go b/handlers/results.go index fbc5101..fa18262 100644 --- a/handlers/results.go +++ b/handlers/results.go @@ -8,14 +8,16 @@ import ( "regexp" "sort" "strconv" + "synlotto-website/helpers" + "synlotto-website/middleware" "synlotto-website/models" ) func ResultsThunderball(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip, _, _ := net.SplitHostPort(r.RemoteAddr) - limiter := helpers.GetVisitorLimiter(ip) + limiter := middleware.GetVisitorLimiter(ip) if !limiter.Allow() { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) diff --git a/handlers/security/csrf.go b/handlers/security/csrf.go new file mode 100644 index 0000000..c44b08c --- /dev/null +++ b/handlers/security/csrf.go @@ -0,0 +1,26 @@ +package security + +import ( + "fmt" + "net/http" + + "github.com/gorilla/csrf" +) + +var CSRFMiddleware func(http.Handler) http.Handler + +func InitCSRFProtection(csrfKey []byte, isProduction bool) error { + if len(csrfKey) != 32 { + return fmt.Errorf("csrf key must be 32 bytes, got %d", len(csrfKey)) + } + + CSRFMiddleware = csrf.Protect( + csrfKey, + csrf.Secure(isProduction), + csrf.SameSite(csrf.SameSiteStrictMode), + csrf.Path("/"), + csrf.HttpOnly(true), + ) + + return nil +} diff --git a/handlers/security/session.go b/handlers/security/session.go new file mode 100644 index 0000000..9c1677a --- /dev/null +++ b/handlers/security/session.go @@ -0,0 +1,94 @@ +package security + +import ( + "bytes" + "encoding/gob" + "fmt" + "net/http" + "os" + "time" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +var ( + sessionStore *sessions.CookieStore + sessionName string + authKey []byte + encryptKey []byte +) + +func init() { + gob.Register(time.Time{}) +} + +func LoadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error { + var err error + authKey, err = os.ReadFile(authPath) + if err != nil { + return fmt.Errorf("error loading auth key: %w", err) + } + + encryptKey, err = os.ReadFile(encryptionPath) + if err != nil { + return fmt.Errorf("error loading encryption key: %w", err) + } + + authKey = bytes.TrimSpace(authKey) + encryptKey = bytes.TrimSpace(encryptKey) + + if len(authKey) != 32 || len(encryptKey) != 32 { + return fmt.Errorf("auth and encryption keys must be 32 bytes each") + } + + sessionStore = sessions.NewCookieStore(authKey, encryptKey) + sessionStore.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 1, + HttpOnly: true, + Secure: isProduction, + SameSite: http.SameSiteLaxMode, + } + + sessionName = name + return nil +} + +func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { + return sessionStore.Get(r, sessionName) +} + +func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error { + s := securecookie.New(authKey, encryptKey) + + encoded, err := s.Encode(name, value) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: encoded, + Path: "/", + HttpOnly: true, + Secure: isProduction, + SameSite: http.SameSiteStrictMode, + }) + + return nil +} + +func LoadKeyFromFile(path string) ([]byte, error) { + key, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return bytes.TrimSpace(key), nil +} + +func ZeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} diff --git a/main.go b/main.go index a0a3ed7..e446e74 100644 --- a/main.go +++ b/main.go @@ -3,17 +3,17 @@ package main import ( "log" "net/http" + + securityhandlers "synlotto-website/handlers/security" + "synlotto-website/bootstrap" "synlotto-website/config" "synlotto-website/handlers" - "synlotto-website/helpers" "synlotto-website/logging" "synlotto-website/middleware" "synlotto-website/models" "synlotto-website/routes" "synlotto-website/storage" - - "github.com/gorilla/csrf" ) func main() { @@ -27,11 +27,10 @@ func main() { db := storage.InitDB("synlotto.db") models.SetDB(db) // Should be in storage not models. - csrfMiddleware := csrf.Protect( - []byte("abcdefghijklmnopqrstuvwx12345678"), // TodO: Make Global - csrf.Secure(true), - csrf.Path("/"), - ) + err = securityhandlers.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode) + if err != nil { + logging.Error("Failed to init CSRF: %v", err) + } mux := http.NewServeMux() routes.SetupAdminRoutes(mux, db) @@ -42,7 +41,8 @@ func main() { mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) mux.HandleFunc("/", handlers.Home(db)) - wrapped := helpers.RateLimit(csrfMiddleware(mux)) + wrapped := securityhandlers.CSRFMiddleware(mux) + wrapped = middleware.RateLimit(wrapped) wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode) wrapped = middleware.SecureHeaders(wrapped) wrapped = middleware.Recover(wrapped) diff --git a/middleware/security.go b/middleware/headers.go similarity index 100% rename from middleware/security.go rename to middleware/headers.go diff --git a/helpers/ratelimit.go b/middleware/ratelimit.go similarity index 97% rename from helpers/ratelimit.go rename to middleware/ratelimit.go index d4220ca..b658f16 100644 --- a/helpers/ratelimit.go +++ b/middleware/ratelimit.go @@ -1,4 +1,4 @@ -package helpers +package middleware import ( "net" diff --git a/storage/db.go b/storage/db.go index ffe8334..d99c27c 100644 --- a/storage/db.go +++ b/storage/db.go @@ -4,10 +4,15 @@ import ( "database/sql" "log" + "synlotto-website/config" + "synlotto-website/logging" + _ "modernc.org/sqlite" ) func InitDB(filepath string) *sql.DB { + var err error + cfg := config.Get() db, err := sql.Open("sqlite", filepath) if err != nil { log.Fatal("❌ Failed to open DB:", err) @@ -30,6 +35,10 @@ func InitDB(filepath string) *sql.DB { SchemaSyndicateInvites, SchemaSyndicateInviteTokens, } + if cfg == nil { + logging.Error("❌ config is nil — did config.Init() run before InitDB?") + panic("config not ready") + } for _, stmt := range schemas { if _, err := db.Exec(stmt); err != nil {