Stack of changes to get gin, scs, nosurf running.
This commit is contained in:
@@ -1,57 +1,84 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
securityHelpers "synlotto-website/internal/helpers/security"
|
||||
templateHelpers "synlotto-website/internal/helpers/template"
|
||||
"synlotto-website/internal/logging"
|
||||
|
||||
"synlotto-website/internal/http/middleware"
|
||||
"synlotto-website/internal/logging"
|
||||
"synlotto-website/internal/platform/bootstrap"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AdminOnly(db *sql.DB, next http.HandlerFunc) http.HandlerFunc {
|
||||
return middleware.Auth(true)(func(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := securityHelpers.GetCurrentUserID(r)
|
||||
if !ok || !securityHelpers.IsAdmin(db, userID) {
|
||||
log.Printf("⛔️ Unauthorized admin attempt: user_id=%v, IP=%s, Path=%s", userID, r.RemoteAddr, r.URL.Path)
|
||||
templateHelpers.RenderError(w, r, http.StatusForbidden)
|
||||
const insertRegistrationSQL = `
|
||||
INSERT INTO audit_registration
|
||||
(user_id, username, email, ip, user_agent, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
app := c.MustGet("app").(*bootstrap.App)
|
||||
sm := app.SessionManager
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Require logged in (assumes RequireAuth already ran; this is a safety net)
|
||||
v := sm.Get(ctx, "user_id")
|
||||
var uid int64
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
uid = t
|
||||
case int:
|
||||
uid = int64(t)
|
||||
default:
|
||||
c.Redirect(http.StatusSeeOther, "/account/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
ip := r.RemoteAddr
|
||||
ua := r.UserAgent()
|
||||
path := r.URL.Path
|
||||
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO admin_access_log (user_id, path, ip, user_agent)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
userID, path, ip, ua,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ Failed to log admin access: %v", err)
|
||||
// Check admin
|
||||
if !securityHelpers.IsAdmin(app.DB, int(uid)) {
|
||||
// Optional: log access attempt here or in a helper
|
||||
c.String(http.StatusForbidden, "Forbidden")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("🛡️ Admin access: user_id=%d IP=%s Path=%s", userID, ip, path)
|
||||
// Optionally record access (moved here from storage)
|
||||
_, _ = app.DB.Exec(`
|
||||
INSERT INTO admin_access_log (user_id, path, ip, user_agent, accessed_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, uid, c.Request.URL.Path, c.ClientIP(), c.Request.UserAgent(), time.Now().UTC())
|
||||
|
||||
next(w, r)
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Todo has to add in - db *sql.DB to make this work should this not be an import as all functions use it, more importantly no functions in storage just sql?
|
||||
func LogLoginAttempt(db *sql.DB, r *http.Request, username string, success bool) {
|
||||
ip := r.RemoteAddr
|
||||
userAgent := r.UserAgent()
|
||||
|
||||
// Handler Call - auditlogStorage.LogLoginAttempt(db, r.RemoteAddr, r.UserAgent(), username, ok)
|
||||
func LogLoginAttempt(db *sql.DB, rIP, rUA, username string, success bool) {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO audit_login (username, success, ip, user_agent, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
username, success, ip, userAgent, time.Now().UTC(),
|
||||
VALUES ($1, $2, $3, $4, $5)`,
|
||||
username, success, rIP, rUA, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
logging.Info("❌ Failed to log login:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func LogSignup(db *sql.DB, userID int64, username, email, ip, userAgent string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.ExecContext(ctx, insertRegistrationSQL,
|
||||
userID, username, email, ip, userAgent, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
logging.Info("❌ Failed to log registration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
iofs "github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationFiles embed.FS
|
||||
|
||||
var DB *sql.DB
|
||||
|
||||
// InitDB connects to MySQL, runs migrations, and returns the DB handle.
|
||||
func InitDB() *sql.DB {
|
||||
cfg := getDSNFromEnv()
|
||||
|
||||
db, err := sql.Open("mysql", cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ Failed to connect to MySQL: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Fatalf("❌ MySQL not reachable: %v", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
log.Fatalf("❌ Migration failed: %v", err)
|
||||
}
|
||||
|
||||
DB = db
|
||||
return db
|
||||
}
|
||||
|
||||
// runMigrations applies any pending .sql files in migrations/
|
||||
func runMigrations(db *sql.DB) error {
|
||||
driver, err := mysql.WithInstance(db, &mysql.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
src, err := iofs.New(migrationFiles, "migrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", src, "mysql", driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.Up()
|
||||
if err == migrate.ErrNoChange {
|
||||
log.Println("✅ Database schema up to date.")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getDSNFromEnv() string {
|
||||
user := os.Getenv("DB_USER")
|
||||
pass := os.Getenv("DB_PASS")
|
||||
host := os.Getenv("DB_HOST") // e.g. localhost or 127.0.0.1
|
||||
port := os.Getenv("DB_PORT") // e.g. 3306
|
||||
name := os.Getenv("DB_NAME") // e.g. synlotto
|
||||
params := "parseTime=true&multiStatements=true"
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s",
|
||||
user, pass, host, port, name, params)
|
||||
return dsn
|
||||
}
|
||||
@@ -4,6 +4,20 @@
|
||||
-- - utf8mb4 for full Unicode
|
||||
-- Booleans are TINYINT(1). Dates use DATE/DATETIME/TIMESTAMP as appropriate.
|
||||
|
||||
CREATE TABLE audit_registration (
|
||||
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
|
||||
user_id BIGINT UNSIGNED NOT NULL,
|
||||
username VARCHAR(255) NOT NULL,
|
||||
email VARCHAR(255) NOT NULL,
|
||||
ip VARCHAR(45) NOT NULL,
|
||||
user_agent VARCHAR(500),
|
||||
timestamp DATETIME NOT NULL,
|
||||
INDEX (user_id),
|
||||
CONSTRAINT fk_audit_registration_users
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- USERS
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
|
||||
6
internal/storage/migrations/embed.go
Normal file
6
internal/storage/migrations/embed.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package migrations
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed 0001_initial_create.up.sql
|
||||
var InitialSchema string
|
||||
6
internal/storage/migrations/read.go
Normal file
6
internal/storage/migrations/read.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package migrations
|
||||
|
||||
const ProbeUsersTable = `
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_name = 'users'`
|
||||
@@ -1,238 +0,0 @@
|
||||
package storage
|
||||
|
||||
const SchemaUsers = `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin BOOLEAN
|
||||
);`
|
||||
|
||||
const SchemaThunderballResults = `
|
||||
CREATE TABLE IF NOT EXISTS results_thunderball (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
draw_date TEXT NOT NULL UNIQUE,
|
||||
draw_id INTEGER NOT NULL UNIQUE,
|
||||
machine TEXT,
|
||||
ballset TEXT,
|
||||
ball1 INTEGER,
|
||||
ball2 INTEGER,
|
||||
ball3 INTEGER,
|
||||
ball4 INTEGER,
|
||||
ball5 INTEGER,
|
||||
thunderball INTEGER
|
||||
);`
|
||||
|
||||
const SchemaThunderballPrizes = `
|
||||
CREATE TABLE IF NOT EXISTS prizes_thunderball (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
draw_id INTEGER NOT NULL,
|
||||
draw_date TEXT,
|
||||
prize1 TEXT,
|
||||
prize1_winners INTEGER,
|
||||
prize1_per_winner INTEGER,
|
||||
prize1_fund INTEGER,
|
||||
prize2 TEXT,
|
||||
prize2_winners INTEGER,
|
||||
prize2_per_winner INTEGER,
|
||||
prize2_fund INTEGER,
|
||||
prize3 TEXT,
|
||||
prize3_winners INTEGER,
|
||||
prize3_per_winner INTEGER,
|
||||
prize3_fund INTEGER,
|
||||
prize4 TEXT,
|
||||
prize4_winners INTEGER,
|
||||
prize4_per_winner INTEGER,
|
||||
prize4_fund INTEGER,
|
||||
prize5 TEXT,
|
||||
prize5_winners INTEGER,
|
||||
prize5_per_winner INTEGER,
|
||||
prize5_fund INTEGER,
|
||||
prize6 TEXT,
|
||||
prize6_winners INTEGER,
|
||||
prize6_per_winner INTEGER,
|
||||
prize6_fund INTEGER,
|
||||
prize7 TEXT,
|
||||
prize7_winners INTEGER,
|
||||
prize7_per_winner INTEGER,
|
||||
prize7_fund INTEGER,
|
||||
prize8 TEXT,
|
||||
prize8_winners INTEGER,
|
||||
prize8_per_winner INTEGER,
|
||||
prize8_fund INTEGER,
|
||||
prize9 TEXT,
|
||||
prize9_winners INTEGER,
|
||||
prize9_per_winner INTEGER,
|
||||
prize9_fund INTEGER,
|
||||
total_winners INTEGER,
|
||||
total_prize_fund INTEGER,
|
||||
FOREIGN KEY (draw_date) REFERENCES results_thunderball(draw_date)
|
||||
);`
|
||||
|
||||
const SchemaLottoResults = `
|
||||
CREATE TABLE IF NOT EXISTS results_lotto (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
draw_date TEXT NOT NULL UNIQUE,
|
||||
draw_id INTEGER NOT NULL UNIQUE,
|
||||
machine TEXT,
|
||||
ballset TEXT,
|
||||
ball1 INTEGER,
|
||||
ball2 INTEGER,
|
||||
ball3 INTEGER,
|
||||
ball4 INTEGER,
|
||||
ball5 INTEGER,
|
||||
ball6 INTEGER,
|
||||
bonusball INTEGER
|
||||
);`
|
||||
|
||||
const SchemaMyTickets = `
|
||||
CREATE TABLE IF NOT EXISTS my_tickets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
userId INTEGER NOT NULL,
|
||||
game_type TEXT NOT NULL,
|
||||
draw_date TEXT NOT NULL,
|
||||
ball1 INTEGER,
|
||||
ball2 INTEGER,
|
||||
ball3 INTEGER,
|
||||
ball4 INTEGER,
|
||||
ball5 INTEGER,
|
||||
ball6 INTEGER,
|
||||
bonus1 INTEGER,
|
||||
bonus2 INTEGER,
|
||||
duplicate BOOLEAN DEFAULT 0,
|
||||
purchase_date TEXT,
|
||||
purchase_method TEXT,
|
||||
image_path TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
matched_main INTEGER,
|
||||
matched_bonus INTEGER,
|
||||
prize_tier TEXT,
|
||||
is_winner BOOLEAN,
|
||||
prize_amount INTEGER,
|
||||
prize_label TEXT,
|
||||
syndicate_id INTEGER,
|
||||
FOREIGN KEY (userId) REFERENCES users(id)
|
||||
);`
|
||||
|
||||
const SchemaUsersMessages = `
|
||||
CREATE TABLE IF NOT EXISTS users_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
senderId INTEGER NOT NULL REFERENCES users(id),
|
||||
recipientId INTEGER NOT NULL REFERENCES users(id),
|
||||
subject TEXT NOT NULL,
|
||||
message TEXT,
|
||||
is_read BOOLEAN DEFAULT FALSE,
|
||||
is_archived BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
archived_at TIMESTAMP
|
||||
);`
|
||||
|
||||
const SchemaUsersNotifications = `
|
||||
CREATE TABLE IF NOT EXISTS users_notification (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
subject TEXT,
|
||||
body TEXT,
|
||||
is_read BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
const SchemaAuditLog = `
|
||||
CREATE TABLE IF NOT EXISTS auditlog (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT,
|
||||
success INTEGER,
|
||||
timestamp TEXT
|
||||
);`
|
||||
|
||||
const SchemaLogTicketMatching = `
|
||||
CREATE TABLE IF NOT EXISTS log_ticket_matching (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
triggered_by TEXT,
|
||||
run_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
tickets_matched INTEGER,
|
||||
winners_found INTEGER,
|
||||
notes TEXT
|
||||
);`
|
||||
|
||||
const SchemaAdminAccessLog = `
|
||||
CREATE TABLE IF NOT EXISTS admin_access_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
path TEXT,
|
||||
ip TEXT,
|
||||
user_agent TEXT
|
||||
);`
|
||||
|
||||
const SchemaNewAuditLog = `
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
username TEXT,
|
||||
action TEXT,
|
||||
path TEXT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
const SchemaAuditLogin = `
|
||||
CREATE TABLE IF NOT EXISTS audit_login (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT,
|
||||
success BOOLEAN,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
const SchemaSyndicates = `
|
||||
CREATE TABLE IF NOT EXISTS syndicates (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
owner_id INTEGER NOT NULL,
|
||||
join_code TEXT UNIQUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (owner_id) REFERENCES users(id)
|
||||
);`
|
||||
|
||||
const SchemaSyndicateMembers = `
|
||||
CREATE TABLE IF NOT EXISTS syndicate_members (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
syndicate_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
role TEXT DEFAULT 'member', -- owner, manager, member
|
||||
status TEXT DEFAULT 'active',
|
||||
joined_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (syndicate_id) REFERENCES syndicates(id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);`
|
||||
|
||||
const SchemaSyndicateInvites = `
|
||||
CREATE TABLE IF NOT EXISTS syndicate_invites (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
syndicate_id INTEGER NOT NULL,
|
||||
invited_user_id INTEGER NOT NULL,
|
||||
sent_by_user_id INTEGER NOT NULL,
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(syndicate_id) REFERENCES syndicates(id),
|
||||
FOREIGN KEY(invited_user_id) REFERENCES users(id)
|
||||
);`
|
||||
|
||||
const SchemaSyndicateInviteTokens = `
|
||||
CREATE TABLE IF NOT EXISTS syndicate_invite_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
syndicate_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
invited_by_user_id INTEGER NOT NULL,
|
||||
accepted_by_user_id INTEGER,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
accepted_at TIMESTAMP,
|
||||
expires_at TIMESTAMP,
|
||||
FOREIGN KEY (syndicate_id) REFERENCES syndicates(id),
|
||||
FOREIGN KEY (invited_by_user_id) REFERENCES users(id),
|
||||
FOREIGN KEY (accepted_by_user_id) REFERENCES users(id)
|
||||
);`
|
||||
@@ -130,6 +130,35 @@ func GetSyndicateMembers(db *sql.DB, syndicateID int) []models.SyndicateMember {
|
||||
return members
|
||||
}
|
||||
|
||||
func GetSyndicateTickets(db *sql.DB, syndicateID int) []models.Ticket {
|
||||
rows, err := db.Query(`
|
||||
SELECT id, userId, syndicateId, game_type, draw_date, ball1, ball2, ball3, ball4, ball5, ball6,
|
||||
bonus1, bonus2, matched_main, matched_bonus, prize_tier, prize_amount, prize_label, is_winner
|
||||
FROM my_tickets
|
||||
WHERE syndicateId = ?
|
||||
ORDER BY draw_date DESC
|
||||
`, syndicateID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tickets []models.Ticket
|
||||
for rows.Next() {
|
||||
var t models.Ticket
|
||||
err := rows.Scan(
|
||||
&t.Id, &t.UserId, &t.SyndicateId, &t.GameType, &t.DrawDate,
|
||||
&t.Ball1, &t.Ball2, &t.Ball3, &t.Ball4, &t.Ball5, &t.Ball6,
|
||||
&t.Bonus1, &t.Bonus2, &t.MatchedMain, &t.MatchedBonus,
|
||||
&t.PrizeTier, &t.PrizeAmount, &t.PrizeLabel, &t.IsWinner,
|
||||
)
|
||||
if err == nil {
|
||||
tickets = append(tickets, t)
|
||||
}
|
||||
}
|
||||
return tickets
|
||||
}
|
||||
|
||||
func IsSyndicateManager(db *sql.DB, syndicateID, userID int) bool {
|
||||
var count int
|
||||
err := db.QueryRow(`
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"synlotto-website/internal/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
// todo should be a ticket function?
|
||||
func GetSyndicateTickets(db *sql.DB, syndicateID int) []models.Ticket {
|
||||
rows, err := db.Query(`
|
||||
SELECT id, userId, syndicateId, game_type, draw_date, ball1, ball2, ball3, ball4, ball5, ball6,
|
||||
bonus1, bonus2, matched_main, matched_bonus, prize_tier, prize_amount, prize_label, is_winner
|
||||
FROM my_tickets
|
||||
WHERE syndicateId = ?
|
||||
ORDER BY draw_date DESC
|
||||
`, syndicateID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tickets []models.Ticket
|
||||
for rows.Next() {
|
||||
var t models.Ticket
|
||||
err := rows.Scan(
|
||||
&t.Id, &t.UserId, &t.SyndicateId, &t.GameType, &t.DrawDate,
|
||||
&t.Ball1, &t.Ball2, &t.Ball3, &t.Ball4, &t.Ball5, &t.Ball6,
|
||||
&t.Bonus1, &t.Bonus2, &t.MatchedMain, &t.MatchedBonus,
|
||||
&t.PrizeTier, &t.PrizeAmount, &t.PrizeLabel, &t.IsWinner,
|
||||
)
|
||||
if err == nil {
|
||||
tickets = append(tickets, t)
|
||||
}
|
||||
}
|
||||
return tickets
|
||||
}
|
||||
|
||||
// both a read and inset break up
|
||||
func AcceptInvite(db *sql.DB, inviteID, userID int) error {
|
||||
var syndicateID int
|
||||
err := db.QueryRow(`
|
||||
SELECT syndicate_id FROM syndicate_invites
|
||||
WHERE id = ? AND invited_user_id = ? AND status = 'pending'
|
||||
`, inviteID, userID).Scan(&syndicateID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := UpdateInviteStatus(db, inviteID, "accepted"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, joined_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
`, syndicateID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func CreateSyndicate(db *sql.DB, ownerID int, name, description string) (int64, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO syndicates (name, description, owner_id, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, name, description, ownerID, time.Now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create syndicate: %w", err)
|
||||
}
|
||||
|
||||
syndicateID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syndicate ID: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, role, joined_at)
|
||||
VALUES (?, ?, 'manager', CURRENT_TIMESTAMP)
|
||||
`, syndicateID, ownerID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to add owner as member: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
return syndicateID, nil
|
||||
}
|
||||
|
||||
func InviteToSyndicate(db *sql.DB, inviterID, syndicateID int, username string) error {
|
||||
var inviteeID int
|
||||
err := db.QueryRow(`
|
||||
SELECT id FROM users WHERE username = ?
|
||||
`, username).Scan(&inviteeID)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found")
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var count int
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(*) FROM syndicate_members
|
||||
WHERE syndicate_id = ? AND user_id = ?
|
||||
`, syndicateID, inviteeID).Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return fmt.Errorf("user already a member or invited")
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, is_manager, status)
|
||||
VALUES (?, ?, 0, 'invited')
|
||||
`, syndicateID, inviteeID)
|
||||
return err
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateInviteStatus(db *sql.DB, inviteID int, status string) error {
|
||||
@@ -12,3 +14,90 @@ func UpdateInviteStatus(db *sql.DB, inviteID int, status string) error {
|
||||
`, status, inviteID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ToDo: both a read and inset break up
|
||||
func AcceptInvite(db *sql.DB, inviteID, userID int) error {
|
||||
var syndicateID int
|
||||
err := db.QueryRow(`
|
||||
SELECT syndicate_id FROM syndicate_invites
|
||||
WHERE id = ? AND invited_user_id = ? AND status = 'pending'
|
||||
`, inviteID, userID).Scan(&syndicateID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := UpdateInviteStatus(db, inviteID, "accepted"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, joined_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
`, syndicateID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func CreateSyndicate(db *sql.DB, ownerID int, name, description string) (int64, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO syndicates (name, description, owner_id, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, name, description, ownerID, time.Now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create syndicate: %w", err)
|
||||
}
|
||||
|
||||
syndicateID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get syndicate ID: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, role, joined_at)
|
||||
VALUES (?, ?, 'manager', CURRENT_TIMESTAMP)
|
||||
`, syndicateID, ownerID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to add owner as member: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
return syndicateID, nil
|
||||
}
|
||||
|
||||
func InviteToSyndicate(db *sql.DB, inviterID, syndicateID int, username string) error {
|
||||
var inviteeID int
|
||||
err := db.QueryRow(`
|
||||
SELECT id FROM users WHERE username = ?
|
||||
`, username).Scan(&inviteeID)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("user not found")
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var count int
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(*) FROM syndicate_members
|
||||
WHERE syndicate_id = ? AND user_id = ?
|
||||
`, syndicateID, inviteeID).Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return fmt.Errorf("user already a member or invited")
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO syndicate_members (syndicate_id, user_id, is_manager, status)
|
||||
VALUES (?, ?, 0, 'invited')
|
||||
`, syndicateID, inviteeID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
package storage
|
||||
package usersStorage
|
||||
|
||||
// ToDo.. "errors" should this not be using my custom log wrapper
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UsersRepo struct{ db *sql.DB }
|
||||
const CreateUserSQL = `
|
||||
INSERT INTO users (username, email, password_hash, created_at, updated_at)
|
||||
VALUES (?, ?, ?, UTC_TIMESTAMP(), UTC_TIMESTAMP())`
|
||||
|
||||
func NewUsersRepo(db *sql.DB) *UsersRepo {
|
||||
return &UsersRepo{db: db}
|
||||
}
|
||||
func CreateUser(db *sql.DB, username, email, passwordHash string) (int64, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// ToDo: should the function be in sql?
|
||||
func (r *UsersRepo) Create(ctx context.Context, username, passwordHash string, isAdmin bool) error {
|
||||
_, err := r.db.ExecContext(ctx,
|
||||
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
|
||||
username, passwordHash, isAdmin,
|
||||
)
|
||||
return err
|
||||
res, err := db.ExecContext(ctx, CreateUserSQL, username, email, passwordHash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil || id == 0 {
|
||||
return 0, errors.New("could not get insert id")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
@@ -1,34 +1,72 @@
|
||||
package storage
|
||||
package usersStorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"synlotto-website/internal/logging"
|
||||
"synlotto-website/internal/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetUserByID(db *sql.DB, id int) *models.User {
|
||||
row := db.QueryRow("SELECT id, username, password_hash, is_admin FROM users WHERE id = ?", id)
|
||||
const (
|
||||
UsernameExistsSQL = `
|
||||
SELECT EXISTS(SELECT 1 FROM users WHERE username = ? LIMIT 1)`
|
||||
|
||||
var user models.User
|
||||
err := row.Scan(&user.Id, &user.Username, &user.PasswordHash, &user.IsAdmin)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
logging.Error("DB error:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
EmailExistsSQL = `
|
||||
SELECT EXISTS(SELECT 1 FROM users WHERE email = ? LIMIT 1)`
|
||||
|
||||
return &user
|
||||
GetByUsernameSQL = `
|
||||
SELECT id, username, email, password_hash, created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
LIMIT 1`
|
||||
|
||||
GetByIDSQL = `
|
||||
SELECT id, username, email, password_hash, is_admin, created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
LIMIT 1`
|
||||
)
|
||||
|
||||
func UsernameExists(db *sql.DB, username string) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
var exists bool
|
||||
_ = db.QueryRowContext(ctx, UsernameExistsSQL, username).Scan(&exists)
|
||||
return exists
|
||||
}
|
||||
|
||||
func GetUserByUsername(db *sql.DB, username string) *models.User {
|
||||
row := db.QueryRow(`SELECT id, username, password_hash, is_admin FROM users WHERE username = ?`, username)
|
||||
func EmailExists(db *sql.DB, email string) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
var exists bool
|
||||
_ = db.QueryRowContext(ctx, EmailExistsSQL, email).Scan(&exists)
|
||||
return exists
|
||||
}
|
||||
|
||||
var u models.User
|
||||
err := row.Scan(&u.Id, &u.Username, &u.PasswordHash, &u.IsAdmin)
|
||||
func GetUserByUsername(db *sql.DB, username string) *User {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
var u User
|
||||
err := db.QueryRowContext(ctx, GetByUsernameSQL, username).
|
||||
Scan(&u.Id, &u.Username, &u.Email, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
func GetUserByID(db *sql.DB, id int) *User {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var u User
|
||||
err := db.QueryRowContext(ctx, GetByIDSQL, id).
|
||||
Scan(&u.Id, &u.Username, &u.Email, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
logging.Error("GetUserByID: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
5
internal/storage/users/types.go
Normal file
5
internal/storage/users/types.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package usersStorage
|
||||
|
||||
import "synlotto-website/internal/models"
|
||||
|
||||
type User = models.User
|
||||
Reference in New Issue
Block a user