83 lines
2.1 KiB
Go
83 lines
2.1 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
// InitDB initializes the database connection
|
|
func InitDB() error {
|
|
// Get database URL from environment variable
|
|
dbURL := os.Getenv("DATABASE_URL")
|
|
if dbURL == "" {
|
|
// Default for local development
|
|
dbURL = "postgres://postgres:password@localhost:5432/counter_db?sslmode=disable"
|
|
}
|
|
|
|
var err error
|
|
db, err = sql.Open("postgres", dbURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// Test the connection
|
|
if err = db.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
log.Println("Database connection established")
|
|
return nil
|
|
}
|
|
|
|
// CreateTables creates the necessary database tables
|
|
func CreateTables() error {
|
|
queries := []string{
|
|
`CREATE TABLE IF NOT EXISTS users (
|
|
id SERIAL PRIMARY KEY,
|
|
username VARCHAR(50) UNIQUE NOT NULL,
|
|
email VARCHAR(255) UNIQUE NOT NULL,
|
|
password VARCHAR(255) NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
`CREATE TABLE IF NOT EXISTS counters (
|
|
id SERIAL PRIMARY KEY,
|
|
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
|
|
name VARCHAR(100) NOT NULL,
|
|
description TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
`CREATE TABLE IF NOT EXISTS counter_entries (
|
|
id SERIAL PRIMARY KEY,
|
|
counter_id INTEGER REFERENCES counters(id) ON DELETE CASCADE,
|
|
value INTEGER NOT NULL,
|
|
date DATE NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_counters_user_id ON counters(user_id)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_counter_entries_counter_id ON counter_entries(counter_id)`,
|
|
`CREATE INDEX IF NOT EXISTS idx_counter_entries_date ON counter_entries(date)`,
|
|
}
|
|
|
|
for _, query := range queries {
|
|
if _, err := db.Exec(query); err != nil {
|
|
return fmt.Errorf("failed to execute query: %w", err)
|
|
}
|
|
}
|
|
|
|
log.Println("Database tables created successfully")
|
|
return nil
|
|
}
|
|
|
|
// GetDB returns the database connection
|
|
func GetDB() *sql.DB {
|
|
return db
|
|
}
|