Stack of changes to get gin, scs, nosurf running.
This commit is contained in:
68
internal/helpers/database/statements.go
Normal file
68
internal/helpers/database/statements.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package databaseHelpers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"database/sql"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExecScript executes a multi-statement SQL script.
|
||||
// It only requires that statements end with ';' and ignores '--' comments.
|
||||
// (Good for simple DDL/DML. If you add routines/triggers, upgrade later.)
|
||||
func ExecScript(tx *sql.Tx, script string) error {
|
||||
sc := bufio.NewScanner(strings.NewReader(script))
|
||||
sc.Split(splitStatements)
|
||||
|
||||
for sc.Scan() {
|
||||
stmt := strings.TrimSpace(sc.Text())
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := tx.Exec(stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return sc.Err()
|
||||
}
|
||||
|
||||
// splitStatements separates statements at ';'
|
||||
// and strips whitespace and '--' comments.
|
||||
func splitStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// skip whitespace and comments
|
||||
start := 0
|
||||
for {
|
||||
// whitespace
|
||||
for start < len(data) {
|
||||
switch data[start] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
start++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
// '-- comment'
|
||||
if start+1 < len(data) && data[start] == '-' && data[start+1] == '-' {
|
||||
i := start + 2
|
||||
for i < len(data) && data[i] != '\n' {
|
||||
i++
|
||||
}
|
||||
if i >= len(data) {
|
||||
return len(data), nil, nil
|
||||
}
|
||||
start = i + 1
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// detect semicolon termination
|
||||
for i := start; i < len(data); i++ {
|
||||
if data[i] == ';' {
|
||||
return i + 1, data[start:i], nil
|
||||
}
|
||||
}
|
||||
if atEOF && start < len(data) {
|
||||
return len(data), data[start:], nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
}
|
||||
19
internal/helpers/http/request.go
Normal file
19
internal/helpers/http/request.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package httpHelpers
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ClientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -3,15 +3,12 @@ package security
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
httpHelpers "synlotto-website/internal/helpers/http"
|
||||
"synlotto-website/internal/platform/sessionkeys"
|
||||
|
||||
"github.com/alexedwards/scs/v2"
|
||||
)
|
||||
|
||||
func GetCurrentUserID(r *http.Request) (int, bool) {
|
||||
session, err := httpHelpers.GetSession(nil, r)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
id, ok := session.Values["user_id"].(int)
|
||||
return id, ok
|
||||
func GetCurrentUserID(sm *scs.SessionManager, r *http.Request) (int, bool) {
|
||||
userID := sm.GetInt(r.Context(), sessionkeys.UserID)
|
||||
return userID, userID != 0
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
func LoadKeyFromFile(path string) ([]byte, error) {
|
||||
key, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.TrimSpace(key), nil
|
||||
}
|
||||
|
||||
func ZeroBytes(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
69
internal/helpers/session/remember.go
Normal file
69
internal/helpers/session/remember.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
)
|
||||
|
||||
func randomBase64(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func HashVerifier(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// StoreToken inserts a new token row
|
||||
func StoreToken(db *sql.DB, userID int64, selector, verifierHash string, expiresAt time.Time) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO remember_tokens (user_id, selector, verifier_hash, issued_at, expires_at)
|
||||
VALUES ($1,$2,$3,NOW(),$4)`, userID, selector, verifierHash, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// FindToken fetches selector+hash
|
||||
func FindToken(db *sql.DB, selector string) (userID int64, verifierHash string, expiresAt time.Time, revokedAt *time.Time, err error) {
|
||||
err = db.QueryRow(`SELECT user_id, verifier_hash, expires_at, revoked_at FROM remember_tokens WHERE selector=$1`, selector).
|
||||
Scan(&userID, &verifierHash, &expiresAt, &revokedAt)
|
||||
return
|
||||
}
|
||||
|
||||
// RevokeToken marks token as revoked
|
||||
func RevokeToken(db *sql.DB, selector string) error {
|
||||
_, err := db.Exec(`UPDATE remember_tokens SET revoked_at=NOW() WHERE selector=$1`, selector)
|
||||
return err
|
||||
}
|
||||
|
||||
// GenerateAndStore creates a new remember-me token, stores it server-side,
|
||||
// and returns the cookie-safe plaintext value to set on the client
|
||||
func GenerateAndStore(db *sql.DB, userID int64, duration time.Duration) (string, time.Time, error) {
|
||||
selector, err := randomBase64(16)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
verifier, err := randomBase64(32)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
hash := HashVerifier(verifier)
|
||||
expires := time.Now().Add(duration)
|
||||
|
||||
if err := StoreToken(db, userID, selector, hash, expires); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
// The client cookie value contains selector + verifier
|
||||
cookieVal := selector + ":" + verifier
|
||||
|
||||
return cookieVal, expires, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package helpers
|
||||
package templateHelper
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
httpHelpers "synlotto-website/internal/helpers/http"
|
||||
"synlotto-website/internal/models"
|
||||
|
||||
"github.com/alexedwards/scs/v2"
|
||||
"github.com/justinas/nosurf"
|
||||
)
|
||||
|
||||
@@ -27,19 +27,15 @@ func InitSiteMeta(name string, yearStart, yearEnd int) {
|
||||
}
|
||||
}
|
||||
|
||||
var sm *scs.SessionManager
|
||||
|
||||
func InitSessionManager(manager *scs.SessionManager) {
|
||||
sm = manager
|
||||
}
|
||||
|
||||
func TemplateContext(w http.ResponseWriter, r *http.Request, data models.TemplateData) map[string]interface{} {
|
||||
session, _ := httpHelpers.GetSession(w, r)
|
||||
|
||||
var flash string
|
||||
if f, ok := session.Values["flash"].(string); ok {
|
||||
flash = f
|
||||
delete(session.Values, "flash")
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"CSRFToken": nosurf.Token(r),
|
||||
"Flash": flash,
|
||||
"User": data.User,
|
||||
"IsAdmin": data.IsAdmin,
|
||||
"NotificationCount": data.NotificationCount,
|
||||
@@ -105,18 +101,18 @@ func TemplateFuncs() template.FuncMap {
|
||||
|
||||
func LoadTemplateFiles(name string, files ...string) *template.Template {
|
||||
shared := []string{
|
||||
"templates/main/layout.html",
|
||||
"templates/main/topbar.html",
|
||||
"templates/main/footer.html",
|
||||
"web/templates/main/layout.html",
|
||||
"web/templates/main/topbar.html",
|
||||
"web/templates/main/footer.html",
|
||||
}
|
||||
all := append(shared, files...)
|
||||
return template.Must(template.New(name).Funcs(TemplateFuncs()).ParseFiles(all...))
|
||||
}
|
||||
|
||||
func SetFlash(w http.ResponseWriter, r *http.Request, message string) {
|
||||
session, _ := httpHelpers.GetSession(w, r)
|
||||
session.Values["flash"] = message
|
||||
session.Save(r, w)
|
||||
func SetFlash(r *http.Request, message string) {
|
||||
if sm != nil {
|
||||
sm.Put(r.Context(), "flash", message)
|
||||
}
|
||||
}
|
||||
|
||||
func InSlice(n int, list []int) bool {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package helpers
|
||||
package templateHelper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package helpers
|
||||
package templateHelper
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
Reference in New Issue
Block a user