Stack of changes to get gin, scs, nosurf running.

This commit is contained in:
2025-10-28 11:56:42 +00:00
parent 07117ba35e
commit 86be6479f1
65 changed files with 1890 additions and 1503 deletions

View File

@@ -1,57 +1,84 @@
package storage
import (
"context"
"database/sql"
"log"
"net/http"
"time"
securityHelpers "synlotto-website/internal/helpers/security"
templateHelpers "synlotto-website/internal/helpers/template"
"synlotto-website/internal/logging"
"synlotto-website/internal/http/middleware"
"synlotto-website/internal/logging"
"synlotto-website/internal/platform/bootstrap"
"github.com/gin-gonic/gin"
)
func AdminOnly(db *sql.DB, next http.HandlerFunc) http.HandlerFunc {
return middleware.Auth(true)(func(w http.ResponseWriter, r *http.Request) {
userID, ok := securityHelpers.GetCurrentUserID(r)
if !ok || !securityHelpers.IsAdmin(db, userID) {
log.Printf("⛔️ Unauthorized admin attempt: user_id=%v, IP=%s, Path=%s", userID, r.RemoteAddr, r.URL.Path)
templateHelpers.RenderError(w, r, http.StatusForbidden)
const insertRegistrationSQL = `
INSERT INTO audit_registration
(user_id, username, email, ip, user_agent, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
`
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
app := c.MustGet("app").(*bootstrap.App)
sm := app.SessionManager
ctx := c.Request.Context()
// Require logged in (assumes RequireAuth already ran; this is a safety net)
v := sm.Get(ctx, "user_id")
var uid int64
switch t := v.(type) {
case int64:
uid = t
case int:
uid = int64(t)
default:
c.Redirect(http.StatusSeeOther, "/account/login")
c.Abort()
return
}
ip := r.RemoteAddr
ua := r.UserAgent()
path := r.URL.Path
_, err := db.Exec(`
INSERT INTO admin_access_log (user_id, path, ip, user_agent)
VALUES (?, ?, ?, ?)`,
userID, path, ip, ua,
)
if err != nil {
log.Printf("⚠️ Failed to log admin access: %v", err)
// Check admin
if !securityHelpers.IsAdmin(app.DB, int(uid)) {
// Optional: log access attempt here or in a helper
c.String(http.StatusForbidden, "Forbidden")
c.Abort()
return
}
log.Printf("🛡️ Admin access: user_id=%d IP=%s Path=%s", userID, ip, path)
// Optionally record access (moved here from storage)
_, _ = app.DB.Exec(`
INSERT INTO admin_access_log (user_id, path, ip, user_agent, accessed_at)
VALUES ($1, $2, $3, $4, $5)
`, uid, c.Request.URL.Path, c.ClientIP(), c.Request.UserAgent(), time.Now().UTC())
next(w, r)
})
c.Next()
}
}
// Todo has to add in - db *sql.DB to make this work should this not be an import as all functions use it, more importantly no functions in storage just sql?
func LogLoginAttempt(db *sql.DB, r *http.Request, username string, success bool) {
ip := r.RemoteAddr
userAgent := r.UserAgent()
// Handler Call - auditlogStorage.LogLoginAttempt(db, r.RemoteAddr, r.UserAgent(), username, ok)
func LogLoginAttempt(db *sql.DB, rIP, rUA, username string, success bool) {
_, err := db.Exec(
`INSERT INTO audit_login (username, success, ip, user_agent, timestamp)
VALUES (?, ?, ?, ?, ?)`,
username, success, ip, userAgent, time.Now().UTC(),
VALUES ($1, $2, $3, $4, $5)`,
username, success, rIP, rUA, time.Now().UTC(),
)
if err != nil {
logging.Info("❌ Failed to log login:", err)
}
}
func LogSignup(db *sql.DB, userID int64, username, email, ip, userAgent string) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := db.ExecContext(ctx, insertRegistrationSQL,
userID, username, email, ip, userAgent, time.Now().UTC(),
)
if err != nil {
logging.Info("❌ Failed to log registration: %v", err)
}
}

View File

@@ -1,78 +0,0 @@
package storage
import (
"database/sql"
"embed"
"fmt"
"log"
"os"
_ "github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/mysql"
iofs "github.com/golang-migrate/migrate/v4/source/iofs"
)
//go:embed migrations/*.sql
var migrationFiles embed.FS
var DB *sql.DB
// InitDB connects to MySQL, runs migrations, and returns the DB handle.
func InitDB() *sql.DB {
cfg := getDSNFromEnv()
db, err := sql.Open("mysql", cfg)
if err != nil {
log.Fatalf("❌ Failed to connect to MySQL: %v", err)
}
if err := db.Ping(); err != nil {
log.Fatalf("❌ MySQL not reachable: %v", err)
}
if err := runMigrations(db); err != nil {
log.Fatalf("❌ Migration failed: %v", err)
}
DB = db
return db
}
// runMigrations applies any pending .sql files in migrations/
func runMigrations(db *sql.DB) error {
driver, err := mysql.WithInstance(db, &mysql.Config{})
if err != nil {
return err
}
src, err := iofs.New(migrationFiles, "migrations")
if err != nil {
return err
}
m, err := migrate.NewWithInstance("iofs", src, "mysql", driver)
if err != nil {
return err
}
err = m.Up()
if err == migrate.ErrNoChange {
log.Println("✅ Database schema up to date.")
return nil
}
return err
}
func getDSNFromEnv() string {
user := os.Getenv("DB_USER")
pass := os.Getenv("DB_PASS")
host := os.Getenv("DB_HOST") // e.g. localhost or 127.0.0.1
port := os.Getenv("DB_PORT") // e.g. 3306
name := os.Getenv("DB_NAME") // e.g. synlotto
params := "parseTime=true&multiStatements=true"
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s",
user, pass, host, port, name, params)
return dsn
}

View File

@@ -4,6 +4,20 @@
-- - utf8mb4 for full Unicode
-- Booleans are TINYINT(1). Dates use DATE/DATETIME/TIMESTAMP as appropriate.
CREATE TABLE audit_registration (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
user_id BIGINT UNSIGNED NOT NULL,
username VARCHAR(255) NOT NULL,
email VARCHAR(255) NOT NULL,
ip VARCHAR(45) NOT NULL,
user_agent VARCHAR(500),
timestamp DATETIME NOT NULL,
INDEX (user_id),
CONSTRAINT fk_audit_registration_users
FOREIGN KEY (user_id) REFERENCES users(id)
ON DELETE CASCADE
);
-- USERS
CREATE TABLE IF NOT EXISTS users (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,

View File

@@ -0,0 +1,6 @@
package migrations
import _ "embed"
//go:embed 0001_initial_create.up.sql
var InitialSchema string

View File

@@ -0,0 +1,6 @@
package migrations
const ProbeUsersTable = `
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_name = 'users'`

View File

@@ -1,238 +0,0 @@
package storage
const SchemaUsers = `
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
is_admin BOOLEAN
);`
const SchemaThunderballResults = `
CREATE TABLE IF NOT EXISTS results_thunderball (
id INTEGER PRIMARY KEY AUTOINCREMENT,
draw_date TEXT NOT NULL UNIQUE,
draw_id INTEGER NOT NULL UNIQUE,
machine TEXT,
ballset TEXT,
ball1 INTEGER,
ball2 INTEGER,
ball3 INTEGER,
ball4 INTEGER,
ball5 INTEGER,
thunderball INTEGER
);`
const SchemaThunderballPrizes = `
CREATE TABLE IF NOT EXISTS prizes_thunderball (
id INTEGER PRIMARY KEY AUTOINCREMENT,
draw_id INTEGER NOT NULL,
draw_date TEXT,
prize1 TEXT,
prize1_winners INTEGER,
prize1_per_winner INTEGER,
prize1_fund INTEGER,
prize2 TEXT,
prize2_winners INTEGER,
prize2_per_winner INTEGER,
prize2_fund INTEGER,
prize3 TEXT,
prize3_winners INTEGER,
prize3_per_winner INTEGER,
prize3_fund INTEGER,
prize4 TEXT,
prize4_winners INTEGER,
prize4_per_winner INTEGER,
prize4_fund INTEGER,
prize5 TEXT,
prize5_winners INTEGER,
prize5_per_winner INTEGER,
prize5_fund INTEGER,
prize6 TEXT,
prize6_winners INTEGER,
prize6_per_winner INTEGER,
prize6_fund INTEGER,
prize7 TEXT,
prize7_winners INTEGER,
prize7_per_winner INTEGER,
prize7_fund INTEGER,
prize8 TEXT,
prize8_winners INTEGER,
prize8_per_winner INTEGER,
prize8_fund INTEGER,
prize9 TEXT,
prize9_winners INTEGER,
prize9_per_winner INTEGER,
prize9_fund INTEGER,
total_winners INTEGER,
total_prize_fund INTEGER,
FOREIGN KEY (draw_date) REFERENCES results_thunderball(draw_date)
);`
const SchemaLottoResults = `
CREATE TABLE IF NOT EXISTS results_lotto (
id INTEGER PRIMARY KEY AUTOINCREMENT,
draw_date TEXT NOT NULL UNIQUE,
draw_id INTEGER NOT NULL UNIQUE,
machine TEXT,
ballset TEXT,
ball1 INTEGER,
ball2 INTEGER,
ball3 INTEGER,
ball4 INTEGER,
ball5 INTEGER,
ball6 INTEGER,
bonusball INTEGER
);`
const SchemaMyTickets = `
CREATE TABLE IF NOT EXISTS my_tickets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
userId INTEGER NOT NULL,
game_type TEXT NOT NULL,
draw_date TEXT NOT NULL,
ball1 INTEGER,
ball2 INTEGER,
ball3 INTEGER,
ball4 INTEGER,
ball5 INTEGER,
ball6 INTEGER,
bonus1 INTEGER,
bonus2 INTEGER,
duplicate BOOLEAN DEFAULT 0,
purchase_date TEXT,
purchase_method TEXT,
image_path TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
matched_main INTEGER,
matched_bonus INTEGER,
prize_tier TEXT,
is_winner BOOLEAN,
prize_amount INTEGER,
prize_label TEXT,
syndicate_id INTEGER,
FOREIGN KEY (userId) REFERENCES users(id)
);`
const SchemaUsersMessages = `
CREATE TABLE IF NOT EXISTS users_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
senderId INTEGER NOT NULL REFERENCES users(id),
recipientId INTEGER NOT NULL REFERENCES users(id),
subject TEXT NOT NULL,
message TEXT,
is_read BOOLEAN DEFAULT FALSE,
is_archived BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
archived_at TIMESTAMP
);`
const SchemaUsersNotifications = `
CREATE TABLE IF NOT EXISTS users_notification (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
subject TEXT,
body TEXT,
is_read BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);`
const SchemaAuditLog = `
CREATE TABLE IF NOT EXISTS auditlog (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT,
success INTEGER,
timestamp TEXT
);`
const SchemaLogTicketMatching = `
CREATE TABLE IF NOT EXISTS log_ticket_matching (
id INTEGER PRIMARY KEY AUTOINCREMENT,
triggered_by TEXT,
run_at DATETIME DEFAULT CURRENT_TIMESTAMP,
tickets_matched INTEGER,
winners_found INTEGER,
notes TEXT
);`
const SchemaAdminAccessLog = `
CREATE TABLE IF NOT EXISTS admin_access_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
path TEXT,
ip TEXT,
user_agent TEXT
);`
const SchemaNewAuditLog = `
CREATE TABLE IF NOT EXISTS audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
username TEXT,
action TEXT,
path TEXT,
ip TEXT,
user_agent TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
);`
const SchemaAuditLogin = `
CREATE TABLE IF NOT EXISTS audit_login (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT,
success BOOLEAN,
ip TEXT,
user_agent TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
);`
const SchemaSyndicates = `
CREATE TABLE IF NOT EXISTS syndicates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT,
owner_id INTEGER NOT NULL,
join_code TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (owner_id) REFERENCES users(id)
);`
const SchemaSyndicateMembers = `
CREATE TABLE IF NOT EXISTS syndicate_members (
id INTEGER PRIMARY KEY AUTOINCREMENT,
syndicate_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
role TEXT DEFAULT 'member', -- owner, manager, member
status TEXT DEFAULT 'active',
joined_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (syndicate_id) REFERENCES syndicates(id),
FOREIGN KEY (user_id) REFERENCES users(id)
);`
const SchemaSyndicateInvites = `
CREATE TABLE IF NOT EXISTS syndicate_invites (
id INTEGER PRIMARY KEY AUTOINCREMENT,
syndicate_id INTEGER NOT NULL,
invited_user_id INTEGER NOT NULL,
sent_by_user_id INTEGER NOT NULL,
status TEXT DEFAULT 'pending',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(syndicate_id) REFERENCES syndicates(id),
FOREIGN KEY(invited_user_id) REFERENCES users(id)
);`
const SchemaSyndicateInviteTokens = `
CREATE TABLE IF NOT EXISTS syndicate_invite_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
syndicate_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
invited_by_user_id INTEGER NOT NULL,
accepted_by_user_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
accepted_at TIMESTAMP,
expires_at TIMESTAMP,
FOREIGN KEY (syndicate_id) REFERENCES syndicates(id),
FOREIGN KEY (invited_by_user_id) REFERENCES users(id),
FOREIGN KEY (accepted_by_user_id) REFERENCES users(id)
);`

View File

@@ -130,6 +130,35 @@ func GetSyndicateMembers(db *sql.DB, syndicateID int) []models.SyndicateMember {
return members
}
func GetSyndicateTickets(db *sql.DB, syndicateID int) []models.Ticket {
rows, err := db.Query(`
SELECT id, userId, syndicateId, game_type, draw_date, ball1, ball2, ball3, ball4, ball5, ball6,
bonus1, bonus2, matched_main, matched_bonus, prize_tier, prize_amount, prize_label, is_winner
FROM my_tickets
WHERE syndicateId = ?
ORDER BY draw_date DESC
`, syndicateID)
if err != nil {
return nil
}
defer rows.Close()
var tickets []models.Ticket
for rows.Next() {
var t models.Ticket
err := rows.Scan(
&t.Id, &t.UserId, &t.SyndicateId, &t.GameType, &t.DrawDate,
&t.Ball1, &t.Ball2, &t.Ball3, &t.Ball4, &t.Ball5, &t.Ball6,
&t.Bonus1, &t.Bonus2, &t.MatchedMain, &t.MatchedBonus,
&t.PrizeTier, &t.PrizeAmount, &t.PrizeLabel, &t.IsWinner,
)
if err == nil {
tickets = append(tickets, t)
}
}
return tickets
}
func IsSyndicateManager(db *sql.DB, syndicateID, userID int) bool {
var count int
err := db.QueryRow(`

View File

@@ -1,125 +0,0 @@
package storage
import (
"database/sql"
"fmt"
"synlotto-website/internal/models"
"time"
)
// todo should be a ticket function?
func GetSyndicateTickets(db *sql.DB, syndicateID int) []models.Ticket {
rows, err := db.Query(`
SELECT id, userId, syndicateId, game_type, draw_date, ball1, ball2, ball3, ball4, ball5, ball6,
bonus1, bonus2, matched_main, matched_bonus, prize_tier, prize_amount, prize_label, is_winner
FROM my_tickets
WHERE syndicateId = ?
ORDER BY draw_date DESC
`, syndicateID)
if err != nil {
return nil
}
defer rows.Close()
var tickets []models.Ticket
for rows.Next() {
var t models.Ticket
err := rows.Scan(
&t.Id, &t.UserId, &t.SyndicateId, &t.GameType, &t.DrawDate,
&t.Ball1, &t.Ball2, &t.Ball3, &t.Ball4, &t.Ball5, &t.Ball6,
&t.Bonus1, &t.Bonus2, &t.MatchedMain, &t.MatchedBonus,
&t.PrizeTier, &t.PrizeAmount, &t.PrizeLabel, &t.IsWinner,
)
if err == nil {
tickets = append(tickets, t)
}
}
return tickets
}
// both a read and inset break up
func AcceptInvite(db *sql.DB, inviteID, userID int) error {
var syndicateID int
err := db.QueryRow(`
SELECT syndicate_id FROM syndicate_invites
WHERE id = ? AND invited_user_id = ? AND status = 'pending'
`, inviteID, userID).Scan(&syndicateID)
if err != nil {
return err
}
if err := UpdateInviteStatus(db, inviteID, "accepted"); err != nil {
return err
}
_, err = db.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, joined_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
`, syndicateID, userID)
return err
}
func CreateSyndicate(db *sql.DB, ownerID int, name, description string) (int64, error) {
tx, err := db.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
result, err := tx.Exec(`
INSERT INTO syndicates (name, description, owner_id, created_at)
VALUES (?, ?, ?, ?)
`, name, description, ownerID, time.Now())
if err != nil {
return 0, fmt.Errorf("failed to create syndicate: %w", err)
}
syndicateID, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("failed to get syndicate ID: %w", err)
}
_, err = tx.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, role, joined_at)
VALUES (?, ?, 'manager', CURRENT_TIMESTAMP)
`, syndicateID, ownerID)
if err != nil {
return 0, fmt.Errorf("failed to add owner as member: %w", err)
}
if err := tx.Commit(); err != nil {
return 0, fmt.Errorf("commit failed: %w", err)
}
return syndicateID, nil
}
func InviteToSyndicate(db *sql.DB, inviterID, syndicateID int, username string) error {
var inviteeID int
err := db.QueryRow(`
SELECT id FROM users WHERE username = ?
`, username).Scan(&inviteeID)
if err == sql.ErrNoRows {
return fmt.Errorf("user not found")
} else if err != nil {
return err
}
var count int
err = db.QueryRow(`
SELECT COUNT(*) FROM syndicate_members
WHERE syndicate_id = ? AND user_id = ?
`, syndicateID, inviteeID).Scan(&count)
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("user already a member or invited")
}
_, err = db.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, is_manager, status)
VALUES (?, ?, 0, 'invited')
`, syndicateID, inviteeID)
return err
}

View File

@@ -2,6 +2,8 @@ package storage
import (
"database/sql"
"fmt"
"time"
)
func UpdateInviteStatus(db *sql.DB, inviteID int, status string) error {
@@ -12,3 +14,90 @@ func UpdateInviteStatus(db *sql.DB, inviteID int, status string) error {
`, status, inviteID)
return err
}
// ToDo: both a read and inset break up
func AcceptInvite(db *sql.DB, inviteID, userID int) error {
var syndicateID int
err := db.QueryRow(`
SELECT syndicate_id FROM syndicate_invites
WHERE id = ? AND invited_user_id = ? AND status = 'pending'
`, inviteID, userID).Scan(&syndicateID)
if err != nil {
return err
}
if err := UpdateInviteStatus(db, inviteID, "accepted"); err != nil {
return err
}
_, err = db.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, joined_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
`, syndicateID, userID)
return err
}
func CreateSyndicate(db *sql.DB, ownerID int, name, description string) (int64, error) {
tx, err := db.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
result, err := tx.Exec(`
INSERT INTO syndicates (name, description, owner_id, created_at)
VALUES (?, ?, ?, ?)
`, name, description, ownerID, time.Now())
if err != nil {
return 0, fmt.Errorf("failed to create syndicate: %w", err)
}
syndicateID, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("failed to get syndicate ID: %w", err)
}
_, err = tx.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, role, joined_at)
VALUES (?, ?, 'manager', CURRENT_TIMESTAMP)
`, syndicateID, ownerID)
if err != nil {
return 0, fmt.Errorf("failed to add owner as member: %w", err)
}
if err := tx.Commit(); err != nil {
return 0, fmt.Errorf("commit failed: %w", err)
}
return syndicateID, nil
}
func InviteToSyndicate(db *sql.DB, inviterID, syndicateID int, username string) error {
var inviteeID int
err := db.QueryRow(`
SELECT id FROM users WHERE username = ?
`, username).Scan(&inviteeID)
if err == sql.ErrNoRows {
return fmt.Errorf("user not found")
} else if err != nil {
return err
}
var count int
err = db.QueryRow(`
SELECT COUNT(*) FROM syndicate_members
WHERE syndicate_id = ? AND user_id = ?
`, syndicateID, inviteeID).Scan(&count)
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("user already a member or invited")
}
_, err = db.Exec(`
INSERT INTO syndicate_members (syndicate_id, user_id, is_manager, status)
VALUES (?, ?, 0, 'invited')
`, syndicateID, inviteeID)
return err
}

View File

@@ -1,22 +1,27 @@
package storage
package usersStorage
// ToDo.. "errors" should this not be using my custom log wrapper
import (
"context"
"database/sql"
"errors"
"time"
)
type UsersRepo struct{ db *sql.DB }
const CreateUserSQL = `
INSERT INTO users (username, email, password_hash, created_at, updated_at)
VALUES (?, ?, ?, UTC_TIMESTAMP(), UTC_TIMESTAMP())`
func NewUsersRepo(db *sql.DB) *UsersRepo {
return &UsersRepo{db: db}
}
func CreateUser(db *sql.DB, username, email, passwordHash string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// ToDo: should the function be in sql?
func (r *UsersRepo) Create(ctx context.Context, username, passwordHash string, isAdmin bool) error {
_, err := r.db.ExecContext(ctx,
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
username, passwordHash, isAdmin,
)
return err
res, err := db.ExecContext(ctx, CreateUserSQL, username, email, passwordHash)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil || id == 0 {
return 0, errors.New("could not get insert id")
}
return id, nil
}

View File

@@ -1,34 +1,72 @@
package storage
package usersStorage
import (
"context"
"database/sql"
"synlotto-website/internal/logging"
"synlotto-website/internal/models"
"time"
)
func GetUserByID(db *sql.DB, id int) *models.User {
row := db.QueryRow("SELECT id, username, password_hash, is_admin FROM users WHERE id = ?", id)
const (
UsernameExistsSQL = `
SELECT EXISTS(SELECT 1 FROM users WHERE username = ? LIMIT 1)`
var user models.User
err := row.Scan(&user.Id, &user.Username, &user.PasswordHash, &user.IsAdmin)
if err != nil {
if err != sql.ErrNoRows {
logging.Error("DB error:", err)
}
return nil
}
EmailExistsSQL = `
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? LIMIT 1)`
return &user
GetByUsernameSQL = `
SELECT id, username, email, password_hash, created_at, updated_at
FROM users
WHERE username = ?
LIMIT 1`
GetByIDSQL = `
SELECT id, username, email, password_hash, is_admin, created_at, updated_at
FROM users
WHERE id = ?
LIMIT 1`
)
func UsernameExists(db *sql.DB, username string) bool {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var exists bool
_ = db.QueryRowContext(ctx, UsernameExistsSQL, username).Scan(&exists)
return exists
}
func GetUserByUsername(db *sql.DB, username string) *models.User {
row := db.QueryRow(`SELECT id, username, password_hash, is_admin FROM users WHERE username = ?`, username)
func EmailExists(db *sql.DB, email string) bool {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var exists bool
_ = db.QueryRowContext(ctx, EmailExistsSQL, email).Scan(&exists)
return exists
}
var u models.User
err := row.Scan(&u.Id, &u.Username, &u.PasswordHash, &u.IsAdmin)
func GetUserByUsername(db *sql.DB, username string) *User {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var u User
err := db.QueryRowContext(ctx, GetByUsernameSQL, username).
Scan(&u.Id, &u.Username, &u.Email, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt)
if err != nil {
return nil
}
return &u
}
func GetUserByID(db *sql.DB, id int) *User {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var u User
err := db.QueryRowContext(ctx, GetByIDSQL, id).
Scan(&u.Id, &u.Username, &u.Email, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
if err != nil {
if err != sql.ErrNoRows {
logging.Error("GetUserByID: %v", err)
}
return nil
}
return &u
}

View File

@@ -0,0 +1,5 @@
package usersStorage
import "synlotto-website/internal/models"
type User = models.User