Rewiring CSRF protection and movign some functionality to the bootstrapping stage.
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
package security
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,14 +1,25 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
securityhandlers "synlotto-website/handlers/security"
|
|
||||||
|
|
||||||
helpers "synlotto-website/helpers/session"
|
helpers "synlotto-website/helpers/session"
|
||||||
"synlotto-website/logging"
|
"synlotto-website/logging"
|
||||||
"synlotto-website/models"
|
"synlotto-website/models"
|
||||||
|
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sessionStore *sessions.CookieStore
|
||||||
|
sessionName string
|
||||||
|
authKey []byte
|
||||||
|
encryptKey []byte
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitSession(cfg *models.Config) error {
|
func InitSession(cfg *models.Config) error {
|
||||||
@@ -41,7 +52,7 @@ func InitSession(cfg *models.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return securityhandlers.LoadSessionKeys(
|
return loadSessionKeys(
|
||||||
authPath,
|
authPath,
|
||||||
encPath,
|
encPath,
|
||||||
cfg.Session.Name,
|
cfg.Session.Name,
|
||||||
@@ -59,3 +70,41 @@ func generateRandomBytes(length int) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
rawAuth, err := os.ReadFile(authPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reading auth key: %w", err)
|
||||||
|
}
|
||||||
|
authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error decoding auth key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawEnc, err := os.ReadFile(encryptionPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reading encryption key: %w", err)
|
||||||
|
}
|
||||||
|
encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error decoding encryption key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(authKey) != 32 || len(encryptKey) != 32 {
|
||||||
|
return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionStore = sessions.NewCookieStore(authKey, encryptKey)
|
||||||
|
sessionStore.Options = &sessions.Options{
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 86400 * 1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: isProduction,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionName = name
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
package security
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/gob"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/securecookie"
|
|
||||||
"github.com/gorilla/sessions"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
sessionStore *sessions.CookieStore
|
|
||||||
sessionName string
|
|
||||||
authKey []byte
|
|
||||||
encryptKey []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
gob.Register(time.Time{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadSessionKeys(authPath, encryptionPath, name string, isProduction bool) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
rawAuth, err := os.ReadFile(authPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error reading auth key: %w", err)
|
|
||||||
}
|
|
||||||
authKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawAuth)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error decoding auth key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rawEnc, err := os.ReadFile(encryptionPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error reading encryption key: %w", err)
|
|
||||||
}
|
|
||||||
encryptKey, err = base64.StdEncoding.DecodeString(string(bytes.TrimSpace(rawEnc)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error decoding encryption key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(authKey) != 32 || len(encryptKey) != 32 {
|
|
||||||
return fmt.Errorf("auth and encryption keys must be 32 bytes each (got auth=%d, enc=%d)", len(authKey), len(encryptKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionStore = sessions.NewCookieStore(authKey, encryptKey)
|
|
||||||
sessionStore.Options = &sessions.Options{
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: 86400 * 1,
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: isProduction,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionName = name
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
|
|
||||||
return sessionStore.Get(r, sessionName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error {
|
|
||||||
s := securecookie.New(authKey, encryptKey)
|
|
||||||
|
|
||||||
encoded, err := s.Encode(name, value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: encoded,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: isProduction,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadKeyFromFile(path string) ([]byte, error) {
|
|
||||||
key, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return bytes.TrimSpace(key), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ZeroBytes(b []byte) {
|
|
||||||
for i := range b {
|
|
||||||
b[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
16
handlers/session/account.go
Normal file
16
handlers/session/account.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sessionStore *sessions.CookieStore
|
||||||
|
sessionName string
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
|
||||||
|
return sessionStore.Get(r, sessionName)
|
||||||
|
}
|
||||||
32
handlers/session/auth.go
Normal file
32
handlers/session/auth.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
authKey []byte
|
||||||
|
encryptKey []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func SecureCookie(w http.ResponseWriter, name, value string, isProduction bool) error {
|
||||||
|
s := securecookie.New(authKey, encryptKey)
|
||||||
|
|
||||||
|
encoded, err := s.Encode(name, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: encoded,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: isProduction,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
20
helpers/session/loader.go
Normal file
20
helpers/session/loader.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadKeyFromFile(path string) ([]byte, error) {
|
||||||
|
key, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.TrimSpace(key), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ZeroBytes(b []byte) {
|
||||||
|
for i := range b {
|
||||||
|
b[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
6
main.go
6
main.go
@@ -4,8 +4,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
securityhandlers "synlotto-website/handlers/security"
|
|
||||||
|
|
||||||
"synlotto-website/bootstrap"
|
"synlotto-website/bootstrap"
|
||||||
"synlotto-website/config"
|
"synlotto-website/config"
|
||||||
"synlotto-website/handlers"
|
"synlotto-website/handlers"
|
||||||
@@ -32,7 +30,7 @@ func main() {
|
|||||||
logging.Error("❌ Failed to init session: %v", err)
|
logging.Error("❌ Failed to init session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = securityhandlers.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode)
|
err = bootstrap.InitCSRFProtection([]byte(appState.Config.CSRF.CSRFKey), appState.Config.HttpServer.ProductionMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error("Failed to init CSRF: %v", err)
|
logging.Error("Failed to init CSRF: %v", err)
|
||||||
}
|
}
|
||||||
@@ -46,7 +44,7 @@ func main() {
|
|||||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
|
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
|
||||||
mux.HandleFunc("/", handlers.Home(db))
|
mux.HandleFunc("/", handlers.Home(db))
|
||||||
|
|
||||||
wrapped := securityhandlers.CSRFMiddleware(mux)
|
wrapped := bootstrap.CSRFMiddleware(mux)
|
||||||
wrapped = middleware.RateLimit(wrapped)
|
wrapped = middleware.RateLimit(wrapped)
|
||||||
wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode)
|
wrapped = middleware.EnforceHTTPS(wrapped, appState.Config.HttpServer.ProductionMode)
|
||||||
wrapped = middleware.SecureHeaders(wrapped)
|
wrapped = middleware.SecureHeaders(wrapped)
|
||||||
|
|||||||
Reference in New Issue
Block a user