package db import ( "context" "crypto/rand" "database/sql" "encoding/base64" "errors" "fmt" "strings" "time" _ "modernc.org/sqlite" ) const timeFormat = time.RFC3339Nano var ErrNotFound = errors.New("not found") var ( ErrTokenDisabled = errors.New("token disabled") ErrTokenExpired = errors.New("token expired") ErrTokenExhausted = errors.New("token exhausted") ) type Store struct { db *sql.DB } type User struct { ID int64 Username string CreatedAt time.Time } type Token struct { ID int64 UserID int64 Token string CreatedAt time.Time ExpiresAt time.Time Disabled bool DisabledAt *time.Time MaxUses int UsedCount int } type TokenWithUser struct { Token Username string } type UsageEntry struct { RequestIP string UserID sql.NullInt64 TokenID sql.NullInt64 OriginalURL string HTTPStatus int Success bool ErrorReason string OccurredAt time.Time } func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) { db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, err } if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { db.Close() return nil, err } if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil { db.Close() return nil, err } if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil { db.Close() return nil, err } s := &Store{db: db} if err := s.migrate(context.Background()); err != nil { db.Close() return nil, err } return s, nil } func (s *Store) Close() error { return s.db.Close() } func (s *Store) migrate(ctx context.Context) error { stmts := []string{ `CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, created_at TEXT NOT NULL );`, `CREATE TABLE IF NOT EXISTS tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, token TEXT NOT NULL UNIQUE, created_at TEXT NOT NULL, expires_at TEXT NOT NULL, disabled INTEGER NOT NULL DEFAULT 0, max_uses INTEGER NOT NULL DEFAULT 0, used_count INTEGER NOT NULL DEFAULT 0, disabled_at TEXT, FOREIGN KEY(user_id) REFERENCES users(id) );`, `CREATE TABLE IF NOT EXISTS site_stats ( id INTEGER PRIMARY KEY, total_accelerated_count INTEGER NOT NULL DEFAULT 0, updated_at TEXT NOT NULL );`, `CREATE TABLE IF NOT EXISTS token_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TEXT NOT NULL, request_ip TEXT NOT NULL, user_id INTEGER, token_id INTEGER, original_url TEXT NOT NULL, http_status INTEGER NOT NULL, success INTEGER NOT NULL, error_reason TEXT NOT NULL DEFAULT '', FOREIGN KEY(user_id) REFERENCES users(id), FOREIGN KEY(token_id) REFERENCES tokens(id) );`, `CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`, `CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`, `CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`, `CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`, `CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`, } for _, stmt := range stmts { if _, err := s.db.ExecContext(ctx, stmt); err != nil { return err } } if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) { return err } if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) { return err } _, err := s.db.ExecContext( ctx, `INSERT INTO site_stats (id, total_accelerated_count, updated_at) VALUES (1, 0, ?) ON CONFLICT(id) DO NOTHING;`, time.Now().UTC().Format(timeFormat), ) return err } func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) { now := time.Now().UTC() res, err := s.db.ExecContext( ctx, "INSERT INTO users (username, created_at) VALUES (?, ?);", username, now.Format(timeFormat), ) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } return &User{ID: id, Username: username, CreatedAt: now}, nil } func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) { row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username) var u User var createdAt string if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, err } parsed, err := time.Parse(timeFormat, createdAt) if err != nil { return nil, err } u.CreatedAt = parsed return &u, nil } func (s *Store) ListUsers(ctx context.Context) ([]User, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;") if err != nil { return nil, err } defer rows.Close() var out []User for rows.Next() { var u User var createdAt string if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil { return nil, err } parsed, err := time.Parse(timeFormat, createdAt) if err != nil { return nil, err } u.CreatedAt = parsed out = append(out, u) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) { if maxUses < 0 { return nil, fmt.Errorf("maxUses cannot be negative") } tokenValue, err := generateToken() if err != nil { return nil, err } now := time.Now().UTC() res, err := s.db.ExecContext( ctx, `INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count) VALUES (?, ?, ?, ?, 0, ?, 0);`, userID, tokenValue, now.Format(timeFormat), expiresAt.UTC().Format(timeFormat), maxUses, ) if err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil } func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) { row := s.db.QueryRowContext( ctx, `SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count FROM tokens WHERE token = ?;`, tokenValue, ) return scanToken(row) } func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { if err != nil { _ = tx.Rollback() } }() row := tx.QueryRowContext( ctx, `SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count FROM tokens WHERE token = ?;`, tokenValue, ) token, err := scanToken(row) if err != nil { return nil, err } now := time.Now().UTC() if token.Disabled { return nil, ErrTokenDisabled } if now.After(token.ExpiresAt) { return nil, ErrTokenExpired } if token.MaxUses > 0 && token.UsedCount >= token.MaxUses { return nil, ErrTokenExhausted } res, err := tx.ExecContext( ctx, `UPDATE tokens SET used_count = used_count + 1 WHERE id = ? AND (max_uses = 0 OR used_count < max_uses);`, token.ID, ) if err != nil { return nil, err } affected, err := res.RowsAffected() if err != nil { return nil, err } if affected == 0 { return nil, ErrTokenExhausted } if err = tx.Commit(); err != nil { return nil, err } token.UsedCount++ return token, nil } func (s *Store) DisableToken(ctx context.Context, tokenValue string) error { now := time.Now().UTC().Format(timeFormat) res, err := s.db.ExecContext( ctx, "UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;", now, tokenValue, ) if err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } if affected == 0 { return ErrNotFound } return nil } func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) { rows, err := s.db.QueryContext( ctx, `SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username FROM tokens t INNER JOIN users u ON u.id = t.user_id ORDER BY t.id ASC;`, ) if err != nil { return nil, err } defer rows.Close() var out []TokenWithUser for rows.Next() { var t TokenWithUser var createdAt string var expiresAt string var disabledAt sql.NullString var disabledInt int if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil { return nil, err } ct, err := time.Parse(timeFormat, createdAt) if err != nil { return nil, err } et, err := time.Parse(timeFormat, expiresAt) if err != nil { return nil, err } t.CreatedAt = ct t.ExpiresAt = et t.Disabled = disabledInt == 1 if disabledAt.Valid { dt, err := time.Parse(timeFormat, disabledAt.String) if err != nil { return nil, err } t.DisabledAt = &dt } out = append(out, t) } if err := rows.Err(); err != nil { return nil, err } return out, nil } func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error { if entry.OccurredAt.IsZero() { entry.OccurredAt = time.Now().UTC() } if entry.RequestIP == "" { entry.RequestIP = "unknown" } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer func() { if err != nil { _ = tx.Rollback() } }() successInt := 0 if entry.Success { successInt = 1 } _, err = tx.ExecContext( ctx, `INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason) VALUES (?, ?, ?, ?, ?, ?, ?, ?);`, entry.OccurredAt.UTC().Format(timeFormat), entry.RequestIP, entry.UserID, entry.TokenID, entry.OriginalURL, entry.HTTPStatus, successInt, entry.ErrorReason, ) if err != nil { return err } if increment { _, err = tx.ExecContext( ctx, `UPDATE site_stats SET total_accelerated_count = total_accelerated_count + 1, updated_at = ? WHERE id = 1;`, time.Now().UTC().Format(timeFormat), ) if err != nil { return err } } return tx.Commit() } func generateToken() (string, error) { buf := make([]byte, 32) if _, err := rand.Read(buf); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(buf), nil } func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) { var t Token var createdAt string var expiresAt string var disabledAt sql.NullString var disabledInt int if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, err } ct, err := time.Parse(timeFormat, createdAt) if err != nil { return nil, err } et, err := time.Parse(timeFormat, expiresAt) if err != nil { return nil, err } t.CreatedAt = ct t.ExpiresAt = et t.Disabled = disabledInt == 1 if disabledAt.Valid { dt, err := time.Parse(timeFormat, disabledAt.String) if err != nil { return nil, err } t.DisabledAt = &dt } return &t, nil } func isDuplicateColumnErr(err error) bool { return strings.Contains(strings.ToLower(err.Error()), "duplicate column name") }