Rewiring CSRF protection and movign some functionality to the bootstrapping stage.

This commit is contained in:
2025-04-16 09:50:58 +01:00
parent 4bb3b58ddb
commit 2440b3a668
7 changed files with 123 additions and 109 deletions

View File

@@ -1,4 +1,4 @@
package security
package bootstrap
import (
"fmt"

View File

@@ -1,14 +1,25 @@
package bootstrap
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"os"
securityhandlers "synlotto-website/handlers/security"
helpers "synlotto-website/helpers/session"
"synlotto-website/logging"
"synlotto-website/models"
"github.com/gorilla/sessions"
)
var (
sessionStore *sessions.CookieStore
sessionName string
authKey []byte
encryptKey []byte
)
func InitSession(cfg *models.Config) error {
@@ -41,7 +52,7 @@ func InitSession(cfg *models.Config) error {
}
}
return securityhandlers.LoadSessionKeys(
return loadSessionKeys(
authPath,
encPath,
cfg.Session.Name,
@@ -59,3 +70,41 @@ func generateRandomBytes(length int) ([]byte, error) {
}
return b, nil
}
func loadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error {
var err error
rawAuth, err := os.ReadFile(authPath)
if err != nil {
return fmt.Errorf("error reading auth key: %w", err)
}
authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth)))
if err != nil {
return fmt.Errorf("error decoding auth key: %w", err)
}
rawEnc, err := os.ReadFile(encryptionPath)
if err != nil {
return fmt.Errorf("error reading encryption key: %w", err)
}
encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc)))
if err != nil {
return fmt.Errorf("error decoding encryption key: %w", err)
}
if len(authKey) != 32 || len(encryptKey) != 32 {
return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey))
}
sessionStore = sessions.NewCookieStore(authKey, encryptKey)
sessionStore.Options = &sessions.Options{
Path: "/",
MaxAge: 86400 * 1,
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteLaxMode,
}
sessionName = name
return nil
}

View File

@@ -1,101 +0,0 @@
package security
import (
"bytes"
"encoding/base64"
"encoding/gob"
"fmt"
"net/http"
"os"
"time"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
var (
sessionStore *sessions.CookieStore
sessionName string
authKey []byte
encryptKey []byte
)
func init() {
gob.Register(time.Time{})
}
func LoadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error {
var err error
rawAuth, err := os.ReadFile(authPath)
if err != nil {
return fmt.Errorf("error reading auth key: %w", err)
}
authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth)))
if err != nil {
return fmt.Errorf("error decoding auth key: %w", err)
}
rawEnc, err := os.ReadFile(encryptionPath)
if err != nil {
return fmt.Errorf("error reading encryption key: %w", err)
}
encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc)))
if err != nil {
return fmt.Errorf("error decoding encryption key: %w", err)
}
if len(authKey) != 32 || len(encryptKey) != 32 {
return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey))
}
sessionStore = sessions.NewCookieStore(authKey, encryptKey)
sessionStore.Options = &sessions.Options{
Path: "/",
MaxAge: 86400 * 1,
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteLaxMode,
}
sessionName = name
return nil
}
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
return sessionStore.Get(r, sessionName)
}
func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error {
s := securecookie.New(authKey, encryptKey)
encoded, err := s.Encode(name, value)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encoded,
Path: "/",
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteStrictMode,
})
return nil
}
func LoadKeyFromFile(path string) ([]byte, error) {
key, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return bytes.TrimSpace(key), nil
}
func ZeroBytes(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -0,0 +1,16 @@
package handlers
import (
"net/http"
"github.com/gorilla/sessions"
)
var (
sessionStore *sessions.CookieStore
sessionName string
)
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
return sessionStore.Get(r, sessionName)
}

32
handlers/session/auth.go Normal file
View File

@@ -0,0 +1,32 @@
package handlers
import (
"net/http"
"github.com/gorilla/securecookie"
)
var (
authKey []byte
encryptKey []byte
)
func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error {
s := securecookie.New(authKey, encryptKey)
encoded, err := s.Encode(name, value)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encoded,
Path: "/",
HttpOnly: true,
Secure: isProduction,
SameSite: http.SameSiteStrictMode,
})
return nil
}

20
helpers/session/loader.go Normal file
View File

@@ -0,0 +1,20 @@
package helpers
import (
"bytes"
"os"
)
func LoadKeyFromFile(path string) ([]byte, error) {
key, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return bytes.TrimSpace(key), nil
}
func ZeroBytes(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -4,8 +4,6 @@ import (
"log"
"net/http"
securityhandlers "synlotto-website/handlers/security"
"synlotto-website/bootstrap"
"synlotto-website/config"
"synlotto-website/handlers"
@@ -32,7 +30,7 @@ func main() {
logging.Error("❌ Failed to init session: %v", err)
}
err = securityhandlers.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode)
err = bootstrap.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode)
if err != nil {
logging.Error("Failed to init CSRF: %v", err)
}
@@ -46,7 +44,7 @@ func main() {
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
mux.HandleFunc("/", handlers.Home(db))
wrapped := securityhandlers.CSRFMiddleware(mux)
wrapped := bootstrap.CSRFMiddleware(mux)
wrapped = middleware.RateLimit(wrapped)
wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode)
wrapped = middleware.SecureHeaders(wrapped)