Refactor and remove sqlite and replace with MySQL
This commit is contained in:
50
internal/helpers/ballslice.go
Normal file
50
internal/helpers/ballslice.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"synlotto-website/models"
|
||||
)
|
||||
|
||||
func BuildBallsSlice(t models.Ticket) []int {
|
||||
balls := []int{t.Ball1, t.Ball2, t.Ball3, t.Ball4, t.Ball5}
|
||||
if t.GameType == "Lotto" && t.Ball6 > 0 {
|
||||
balls = append(balls, t.Ball6)
|
||||
}
|
||||
|
||||
return balls
|
||||
}
|
||||
|
||||
func BuildBonusSlice(t models.Ticket) []int {
|
||||
var bonuses []int
|
||||
if t.Bonus1 != nil {
|
||||
bonuses = append(bonuses, *t.Bonus1)
|
||||
}
|
||||
if t.Bonus2 != nil {
|
||||
bonuses = append(bonuses, *t.Bonus2)
|
||||
}
|
||||
|
||||
return bonuses
|
||||
}
|
||||
|
||||
// BuildBallsFromNulls builds main balls from sql.NullInt64 values
|
||||
func BuildBallsFromNulls(vals ...sql.NullInt64) []int {
|
||||
var result []int
|
||||
for _, v := range vals {
|
||||
if v.Valid {
|
||||
result = append(result, int(v.Int64))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildBonusFromNulls builds bonus balls from two sql.NullInt64 values
|
||||
func BuildBonusFromNulls(b1, b2 sql.NullInt64) []int {
|
||||
var result []int
|
||||
if b1.Valid {
|
||||
result = append(result, int(b1.Int64))
|
||||
}
|
||||
if b2.Valid {
|
||||
result = append(result, int(b2.Int64))
|
||||
}
|
||||
return result
|
||||
}
|
||||
21
internal/helpers/distinctresults.go
Normal file
21
internal/helpers/distinctresults.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package helpers
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func GetDistinctValues(db *sql.DB, column string) ([]string, error) {
|
||||
query := "SELECT DISTINCT " + column + " FROM results_thunderball ORDER BY " + column
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var values []string
|
||||
for rows.Next() {
|
||||
var val string
|
||||
if err := rows.Scan(&val); err == nil {
|
||||
values = append(values, val)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
51
internal/helpers/http/session.go
Normal file
51
internal/helpers/http/session.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
session "synlotto-website/handlers/session"
|
||||
|
||||
"synlotto-website/constants"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func GetSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
|
||||
return session.GetSession(w, r)
|
||||
}
|
||||
|
||||
func IsSessionExpired(session *sessions.Session) bool {
|
||||
last, ok := session.Values["last_activity"].(time.Time)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Since(last) > constants.SessionDuration
|
||||
}
|
||||
|
||||
func UpdateSessionActivity(session *sessions.Session, r *http.Request, w http.ResponseWriter) {
|
||||
session.Values["last_activity"] = time.Now().UTC()
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session, _ := GetSession(w, r)
|
||||
|
||||
if IsSessionExpired(session) {
|
||||
session.Options.MaxAge = -1
|
||||
session.Save(r, w)
|
||||
|
||||
newSession, _ := GetSession(w, r)
|
||||
newSession.Values["flash"] = "Your session has timed out."
|
||||
newSession.Save(r, w)
|
||||
|
||||
http.Redirect(w, r, "/account/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
UpdateSessionActivity(session, r, w)
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
13
internal/helpers/intptr.go
Normal file
13
internal/helpers/intptr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func IntPtrIfValid(val sql.NullInt64) *int {
|
||||
if val.Valid {
|
||||
n := int(val.Int64)
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
internal/helpers/match.go
Normal file
16
internal/helpers/match.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package helpers
|
||||
|
||||
func CountMatches(a, b []int) int {
|
||||
m := make(map[int]bool)
|
||||
for _, n := range b {
|
||||
m[n] = true
|
||||
}
|
||||
match := 0
|
||||
for _, n := range a {
|
||||
if m[n] {
|
||||
match++
|
||||
}
|
||||
}
|
||||
|
||||
return match
|
||||
}
|
||||
8
internal/helpers/nullable.go
Normal file
8
internal/helpers/nullable.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package helpers
|
||||
|
||||
func Nullable(val int) *int {
|
||||
if val == 0 {
|
||||
return nil
|
||||
}
|
||||
return &val
|
||||
}
|
||||
14
internal/helpers/parse.go
Normal file
14
internal/helpers/parse.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package helpers
|
||||
|
||||
import "strconv"
|
||||
|
||||
func ParseIntSlice(input []string) []int {
|
||||
var out []int
|
||||
for _, s := range input {
|
||||
n, err := strconv.Atoi(s)
|
||||
if err == nil {
|
||||
out = append(out, n)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
16
internal/helpers/security/admin.go
Normal file
16
internal/helpers/security/admin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
)
|
||||
|
||||
func IsAdmin(db *sql.DB, userID int) bool {
|
||||
var isAdmin bool
|
||||
err := db.QueryRow(`SELECT is_admin FROM users WHERE id = ?`, userID).Scan(&isAdmin)
|
||||
if err != nil {
|
||||
log.Printf("⚠️ Failed to check is_admin for user %d: %v", userID, err)
|
||||
return false
|
||||
}
|
||||
return isAdmin
|
||||
}
|
||||
13
internal/helpers/security/password.go
Normal file
13
internal/helpers/security/password.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package security
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func CheckPasswordHash(hash, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
15
internal/helpers/security/token.go
Normal file
15
internal/helpers/security/token.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func GenerateSecureToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
17
internal/helpers/security/users.go
Normal file
17
internal/helpers/security/users.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
httpHelpers "synlotto-website/helpers/http"
|
||||
)
|
||||
|
||||
func GetCurrentUserID(r *http.Request) (int, bool) {
|
||||
session, err := httpHelpers.GetSession(nil, r)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
id, ok := session.Values["user_id"].(int)
|
||||
return id, ok
|
||||
}
|
||||
7
internal/helpers/session/encoding.go
Normal file
7
internal/helpers/session/encoding.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package helpers
|
||||
|
||||
import "encoding/base64"
|
||||
|
||||
func EncodeKey(b []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
20
internal/helpers/session/loader.go
Normal file
20
internal/helpers/session/loader.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
func LoadKeyFromFile(path string) ([]byte, error) {
|
||||
key, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.TrimSpace(key), nil
|
||||
}
|
||||
|
||||
func ZeroBytes(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
8
internal/helpers/strconv.go
Normal file
8
internal/helpers/strconv.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package helpers
|
||||
|
||||
import "strconv"
|
||||
|
||||
func Atoi(s string) int {
|
||||
n, _ := strconv.Atoi(s)
|
||||
return n
|
||||
}
|
||||
151
internal/helpers/template/build.go
Normal file
151
internal/helpers/template/build.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"synlotto-website/config"
|
||||
helpers "synlotto-website/helpers/http"
|
||||
"synlotto-website/models"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
)
|
||||
|
||||
func TemplateContext(w http.ResponseWriter, r *http.Request, data models.TemplateData) map[string]interface{} {
|
||||
cfg := config.Get()
|
||||
if cfg == nil {
|
||||
log.Println("⚠️ Config not initialized!")
|
||||
}
|
||||
session, _ := helpers.GetSession(w, r)
|
||||
|
||||
var flash string
|
||||
if f, ok := session.Values["flash"].(string); ok {
|
||||
flash = f
|
||||
delete(session.Values, "flash")
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"CSRFField": csrf.TemplateField(r),
|
||||
"Flash": flash,
|
||||
"User": data.User,
|
||||
"IsAdmin": data.IsAdmin,
|
||||
"NotificationCount": data.NotificationCount,
|
||||
"Notifications": data.Notifications,
|
||||
"MessageCount": data.MessageCount,
|
||||
"Messages": data.Messages,
|
||||
"SiteName": cfg.Site.SiteName,
|
||||
"CopyrightYearStart": cfg.Site.CopyrightYearStart,
|
||||
}
|
||||
}
|
||||
|
||||
func TemplateFuncs() template.FuncMap {
|
||||
return template.FuncMap{
|
||||
"plus1": func(i int) int { return i + 1 },
|
||||
"minus1": func(i int) int {
|
||||
if i > 1 {
|
||||
return i - 1
|
||||
}
|
||||
return 0
|
||||
},
|
||||
"mul": func(a, b int) int { return a * b },
|
||||
"add": func(a, b int) int { return a + b },
|
||||
"sub": func(a, b int) int { return a - b },
|
||||
"min": func(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
},
|
||||
"intVal": func(p *int) int {
|
||||
if p == nil {
|
||||
return 0
|
||||
}
|
||||
return *p
|
||||
},
|
||||
"inSlice": InSlice,
|
||||
"lower": lower,
|
||||
"truncate": func(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "..."
|
||||
},
|
||||
"PageRange": PageRange,
|
||||
"now": time.Now,
|
||||
"humanTime": func(v interface{}) string {
|
||||
switch t := v.(type) {
|
||||
case time.Time:
|
||||
return t.Local().Format("02 Jan 2006 15:04")
|
||||
case string:
|
||||
parsed, err := time.Parse(time.RFC3339, t)
|
||||
if err == nil {
|
||||
return parsed.Local().Format("02 Jan 2006 15:04")
|
||||
}
|
||||
return t
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
},
|
||||
"rangeClass": rangeClass,
|
||||
}
|
||||
}
|
||||
|
||||
func LoadTemplateFiles(name string, files ...string) *template.Template {
|
||||
shared := []string{
|
||||
"templates/main/layout.html",
|
||||
"templates/main/topbar.html",
|
||||
"templates/main/footer.html",
|
||||
}
|
||||
all := append(shared, files...)
|
||||
|
||||
return template.Must(template.New(name).Funcs(TemplateFuncs()).ParseFiles(all...))
|
||||
}
|
||||
|
||||
func SetFlash(w http.ResponseWriter, r *http.Request, message string) {
|
||||
session, _ := helpers.GetSession(w, r)
|
||||
session.Values["flash"] = message
|
||||
session.Save(r, w)
|
||||
}
|
||||
|
||||
func InSlice(n int, list []int) bool {
|
||||
for _, v := range list {
|
||||
if v == n {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func lower(input string) string {
|
||||
return strings.ToLower(input)
|
||||
}
|
||||
|
||||
func PageRange(current, total int) []int {
|
||||
var pages []int
|
||||
for i := 1; i <= total; i++ {
|
||||
pages = append(pages, i)
|
||||
}
|
||||
return pages
|
||||
}
|
||||
|
||||
func rangeClass(n int) string {
|
||||
switch {
|
||||
case n >= 1 && n <= 9:
|
||||
return "01-09"
|
||||
case n >= 10 && n <= 19:
|
||||
return "10-19"
|
||||
case n >= 20 && n <= 29:
|
||||
return "20-29"
|
||||
case n >= 30 && n <= 39:
|
||||
return "30-39"
|
||||
case n >= 40 && n <= 49:
|
||||
return "40-49"
|
||||
default:
|
||||
return "50-plus"
|
||||
}
|
||||
}
|
||||
39
internal/helpers/template/error.go
Normal file
39
internal/helpers/template/error.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"synlotto-website/models"
|
||||
)
|
||||
|
||||
func RenderError(w http.ResponseWriter, r *http.Request, statusCode int) {
|
||||
log.Printf("⚙️ RenderError called with status: %d", statusCode)
|
||||
|
||||
context := TemplateContext(w, r, models.TemplateData{})
|
||||
|
||||
pagePath := fmt.Sprintf("templates/error/%d.html", statusCode)
|
||||
log.Printf("📄 Checking for template file: %s", pagePath)
|
||||
|
||||
if _, err := os.Stat(pagePath); err != nil {
|
||||
log.Printf("🚫 Template file missing: %s", err)
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("✅ Template file found, loading...")
|
||||
|
||||
tmpl := LoadTemplateFiles(fmt.Sprintf("%d.html", statusCode), pagePath)
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
err := tmpl.ExecuteTemplate(w, "layout", context)
|
||||
if err != nil {
|
||||
log.Printf("❌ Failed to render error page layout: %v", err)
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("✅ Successfully rendered error page") // ToDo: log these to database
|
||||
}
|
||||
26
internal/helpers/template/pagination.go
Normal file
26
internal/helpers/template/pagination.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func GetTotalPages(db *sql.DB, tableName, whereClause string, args []interface{}, pageSize int) (totalPages, totalCount int) {
|
||||
query := "SELECT COUNT(*) FROM " + tableName + " " + whereClause
|
||||
row := db.QueryRow(query, args...)
|
||||
if err := row.Scan(&totalCount); err != nil {
|
||||
return 1, 0
|
||||
}
|
||||
totalPages = (totalCount + pageSize - 1) / pageSize
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
return totalPages, totalCount
|
||||
}
|
||||
|
||||
func MakePageRange(current, total int) []int {
|
||||
var pages []int
|
||||
for i := 1; i <= total; i++ {
|
||||
pages = append(pages, i)
|
||||
}
|
||||
return pages
|
||||
}
|
||||
Reference in New Issue
Block a user