rework to lighten the main, refactor wrappers. Rehandle csrf and pull config items.

This commit is contained in:
2025-04-15 22:19:55 +01:00
parent 0a5d61ea1e
commit 0a21973237
7 changed files with 142 additions and 11 deletions

View File

@@ -8,14 +8,16 @@ import (
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"synlotto-website/helpers" "synlotto-website/helpers"
"synlotto-website/middleware"
"synlotto-website/models" "synlotto-website/models"
) )
func ResultsThunderball(db *sql.DB) http.HandlerFunc { func ResultsThunderball(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ip, _, _ := net.SplitHostPort(r.RemoteAddr) ip, _, _ := net.SplitHostPort(r.RemoteAddr)
limiter := helpers.GetVisitorLimiter(ip) limiter := middleware.GetVisitorLimiter(ip)
if !limiter.Allow() { if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)

26
handlers/security/csrf.go Normal file
View File

@@ -0,0 +1,26 @@
package security
import (
"fmt"
"net/http"
"github.com/gorilla/csrf"
)
var CSRFMiddleware func(http.Handler) http.Handler
func InitCSRFProtection(csrfKey []byte, isProduction bool) error {
if len(csrfKey) != 32 {
return fmt.Errorf("csrf key must be 32 bytes, got %d", len(csrfKey))
}
CSRFMiddleware = csrf.Protect(
csrfKey,
csrf.Secure(isProduction),
csrf.SameSite(csrf.SameSiteStrictMode),
csrf.Path("/"),
csrf.HttpOnly(true),
)
return nil
}

View File

@@ -0,0 +1,94 @@
package security
import (
"bytes"
"encoding/gob"
"fmt"
"net/http"
"os"
"time"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
var (
sessionStore *sessions.CookieStore
sessionName string
authKey []byte
encryptKey []byte
)
func init() {
gob.Register(time.Time{})
}
func LoadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error {
var err error
authKey, err = os.ReadFile(authPath)
if err != nil {
return fmt.Errorf("error loading auth key: %w", err)
}
encryptKey, err = os.ReadFile(encryptionPath)
if err != nil {
return fmt.Errorf("error loading encryption key: %w", err)
}
authKey = bytes.TrimSpace(authKey)
encryptKey = bytes.TrimSpace(encryptKey)
if len(authKey) != 32 || len(encryptKey) != 32 {
return fmt.Errorf("auth and encryption keys must be 32 bytes each")
}
sessionStore = sessions.NewCookieStore(authKey, encryptKey)
sessionStore.Options = &sessions.Options{
Path: "/",
MaxAge: 86400 * 1,
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteLaxMode,
}
sessionName = name
return nil
}
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
return sessionStore.Get(r, sessionName)
}
func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error {
s := securecookie.New(authKey, encryptKey)
encoded, err := s.Encode(name, value)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encoded,
Path: "/",
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteStrictMode,
})
return nil
}
func LoadKeyFromFile(path string) ([]byte, error) {
key, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return bytes.TrimSpace(key), nil
}
func ZeroBytes(b []byte) {
for i := range b {
b[i] = 0
}
}

18
main.go
View File

@@ -3,17 +3,17 @@ package main
import ( import (
"log" "log"
"net/http" "net/http"
securityhandlers "synlotto-website/handlers/security"
"synlotto-website/bootstrap" "synlotto-website/bootstrap"
"synlotto-website/config" "synlotto-website/config"
"synlotto-website/handlers" "synlotto-website/handlers"
"synlotto-website/helpers"
"synlotto-website/logging" "synlotto-website/logging"
"synlotto-website/middleware" "synlotto-website/middleware"
"synlotto-website/models" "synlotto-website/models"
"synlotto-website/routes" "synlotto-website/routes"
"synlotto-website/storage" "synlotto-website/storage"
"github.com/gorilla/csrf"
) )
func main() { func main() {
@@ -27,11 +27,10 @@ func main() {
db := storage.InitDB("synlotto.db") db := storage.InitDB("synlotto.db")
models.SetDB(db) // Should be in storage not models. models.SetDB(db) // Should be in storage not models.
csrfMiddleware := csrf.Protect( err = securityhandlers.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode)
[]byte("abcdefghijklmnopqrstuvwx12345678"), // TodO: Make Global if err != nil {
csrf.Secure(true), logging.Error("Failed to init CSRF: %v", err)
csrf.Path("/"), }
)
mux := http.NewServeMux() mux := http.NewServeMux()
routes.SetupAdminRoutes(mux, db) routes.SetupAdminRoutes(mux, db)
@@ -42,7 +41,8 @@ func main() {
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
mux.HandleFunc("/", handlers.Home(db)) mux.HandleFunc("/", handlers.Home(db))
wrapped := helpers.RateLimit(csrfMiddleware(mux)) wrapped := securityhandlers.CSRFMiddleware(mux)
wrapped = middleware.RateLimit(wrapped)
wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode) wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode)
wrapped = middleware.SecureHeaders(wrapped) wrapped = middleware.SecureHeaders(wrapped)
wrapped = middleware.Recover(wrapped) wrapped = middleware.Recover(wrapped)

View File

@@ -1,4 +1,4 @@
package helpers package middleware
import ( import (
"net" "net"

View File

@@ -4,10 +4,15 @@ import (
"database/sql" "database/sql"
"log" "log"
"synlotto-website/config"
"synlotto-website/logging"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func InitDB(filepath string) *sql.DB { func InitDB(filepath string) *sql.DB {
var err error
cfg := config.Get()
db, err := sql.Open("sqlite", filepath) db, err := sql.Open("sqlite", filepath)
if err != nil { if err != nil {
log.Fatal("❌ Failed to open DB:", err) log.Fatal("❌ Failed to open DB:", err)
@@ -30,6 +35,10 @@ func InitDB(filepath string) *sql.DB {
SchemaSyndicateInvites, SchemaSyndicateInvites,
SchemaSyndicateInviteTokens, SchemaSyndicateInviteTokens,
} }
if cfg == nil {
logging.Error("❌ config is nil — did config.Init() run before InitDB?")
panic("config not ready")
}
for _, stmt := range schemas { for _, stmt := range schemas {
if _, err := db.Exec(stmt); err != nil { if _, err := db.Exec(stmt); err != nil {