Refactor and remove sqlite and replace with MySQL
This commit is contained in:
78
internal/storage/mysql/db.go
Normal file
78
internal/storage/mysql/db.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
iofs "github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationFiles embed.FS
|
||||
|
||||
var DB *sql.DB
|
||||
|
||||
// InitDB connects to MySQL, runs migrations, and returns the DB handle.
|
||||
func InitDB() *sql.DB {
|
||||
cfg := getDSNFromEnv()
|
||||
|
||||
db, err := sql.Open("mysql", cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ Failed to connect to MySQL: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Fatalf("❌ MySQL not reachable: %v", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
log.Fatalf("❌ Migration failed: %v", err)
|
||||
}
|
||||
|
||||
DB = db
|
||||
return db
|
||||
}
|
||||
|
||||
// runMigrations applies any pending .sql files in migrations/
|
||||
func runMigrations(db *sql.DB) error {
|
||||
driver, err := mysql.WithInstance(db, &mysql.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
src, err := iofs.New(migrationFiles, "migrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", src, "mysql", driver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.Up()
|
||||
if err == migrate.ErrNoChange {
|
||||
log.Println("✅ Database schema up to date.")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getDSNFromEnv() string {
|
||||
user := os.Getenv("DB_USER")
|
||||
pass := os.Getenv("DB_PASS")
|
||||
host := os.Getenv("DB_HOST") // e.g. localhost or 127.0.0.1
|
||||
port := os.Getenv("DB_PORT") // e.g. 3306
|
||||
name := os.Getenv("DB_NAME") // e.g. synlotto
|
||||
params := "parseTime=true&multiStatements=true"
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s",
|
||||
user, pass, host, port, name, params)
|
||||
return dsn
|
||||
}
|
||||
Reference in New Issue
Block a user