@@ -0,0 +1,495 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const timeFormat = time.RFC3339Nano
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
var (
|
||||
ErrTokenDisabled = errors.New("token disabled")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrTokenExhausted = errors.New("token exhausted")
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Username string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Disabled bool
|
||||
DisabledAt *time.Time
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
}
|
||||
|
||||
type TokenWithUser struct {
|
||||
Token
|
||||
Username string
|
||||
}
|
||||
|
||||
type UsageEntry struct {
|
||||
RequestIP string
|
||||
UserID sql.NullInt64
|
||||
TokenID sql.NullInt64
|
||||
OriginalURL string
|
||||
HTTPStatus int
|
||||
Success bool
|
||||
ErrorReason string
|
||||
OccurredAt time.Time
|
||||
}
|
||||
|
||||
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
if err := s.migrate(context.Background()); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) migrate(ctx context.Context) error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
disabled INTEGER NOT NULL DEFAULT 0,
|
||||
max_uses INTEGER NOT NULL DEFAULT 0,
|
||||
used_count INTEGER NOT NULL DEFAULT 0,
|
||||
disabled_at TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS site_stats (
|
||||
id INTEGER PRIMARY KEY,
|
||||
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at TEXT NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS token_usage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TEXT NOT NULL,
|
||||
request_ip TEXT NOT NULL,
|
||||
user_id INTEGER,
|
||||
token_id INTEGER,
|
||||
original_url TEXT NOT NULL,
|
||||
http_status INTEGER NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
error_reason TEXT NOT NULL DEFAULT '',
|
||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
||||
FOREIGN KEY(token_id) REFERENCES tokens(id)
|
||||
);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
||||
return err
|
||||
}
|
||||
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
|
||||
VALUES (1, 0, ?)
|
||||
ON CONFLICT(id) DO NOTHING;`,
|
||||
time.Now().UTC().Format(timeFormat),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
|
||||
now := time.Now().UTC()
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"INSERT INTO users (username, created_at) VALUES (?, ?);",
|
||||
username,
|
||||
now.Format(timeFormat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &User{ID: id, Username: username, CreatedAt: now}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
|
||||
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
|
||||
var u User
|
||||
var createdAt string
|
||||
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.CreatedAt = parsed
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []User
|
||||
for rows.Next() {
|
||||
var u User
|
||||
var createdAt string
|
||||
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.CreatedAt = parsed
|
||||
out = append(out, u)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
|
||||
if maxUses < 0 {
|
||||
return nil, fmt.Errorf("maxUses cannot be negative")
|
||||
}
|
||||
tokenValue, err := generateToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
|
||||
VALUES (?, ?, ?, ?, 0, ?, 0);`,
|
||||
userID,
|
||||
tokenValue,
|
||||
now.Format(timeFormat),
|
||||
expiresAt.UTC().Format(timeFormat),
|
||||
maxUses,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
|
||||
row := s.db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
||||
FROM tokens
|
||||
WHERE token = ?;`,
|
||||
tokenValue,
|
||||
)
|
||||
return scanToken(row)
|
||||
}
|
||||
|
||||
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
row := tx.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
||||
FROM tokens
|
||||
WHERE token = ?;`,
|
||||
tokenValue,
|
||||
)
|
||||
token, err := scanToken(row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if token.Disabled {
|
||||
return nil, ErrTokenDisabled
|
||||
}
|
||||
if now.After(token.ExpiresAt) {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
|
||||
return nil, ErrTokenExhausted
|
||||
}
|
||||
|
||||
res, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE tokens
|
||||
SET used_count = used_count + 1
|
||||
WHERE id = ?
|
||||
AND (max_uses = 0 OR used_count < max_uses);`,
|
||||
token.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if affected == 0 {
|
||||
return nil, ErrTokenExhausted
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.UsedCount++
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
|
||||
now := time.Now().UTC().Format(timeFormat)
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
|
||||
now,
|
||||
tokenValue,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
|
||||
FROM tokens t
|
||||
INNER JOIN users u ON u.id = t.user_id
|
||||
ORDER BY t.id ASC;`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []TokenWithUser
|
||||
for rows.Next() {
|
||||
var t TokenWithUser
|
||||
var createdAt string
|
||||
var expiresAt string
|
||||
var disabledAt sql.NullString
|
||||
var disabledInt int
|
||||
if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ct, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
et, err := time.Parse(timeFormat, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt = ct
|
||||
t.ExpiresAt = et
|
||||
t.Disabled = disabledInt == 1
|
||||
if disabledAt.Valid {
|
||||
dt, err := time.Parse(timeFormat, disabledAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.DisabledAt = &dt
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
|
||||
if entry.OccurredAt.IsZero() {
|
||||
entry.OccurredAt = time.Now().UTC()
|
||||
}
|
||||
if entry.RequestIP == "" {
|
||||
entry.RequestIP = "unknown"
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
successInt := 0
|
||||
if entry.Success {
|
||||
successInt = 1
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
|
||||
entry.OccurredAt.UTC().Format(timeFormat),
|
||||
entry.RequestIP,
|
||||
entry.UserID,
|
||||
entry.TokenID,
|
||||
entry.OriginalURL,
|
||||
entry.HTTPStatus,
|
||||
successInt,
|
||||
entry.ErrorReason,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if increment {
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE site_stats
|
||||
SET total_accelerated_count = total_accelerated_count + 1,
|
||||
updated_at = ?
|
||||
WHERE id = 1;`,
|
||||
time.Now().UTC().Format(timeFormat),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func generateToken() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
|
||||
var t Token
|
||||
var createdAt string
|
||||
var expiresAt string
|
||||
var disabledAt sql.NullString
|
||||
var disabledInt int
|
||||
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
ct, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
et, err := time.Parse(timeFormat, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt = ct
|
||||
t.ExpiresAt = et
|
||||
t.Disabled = disabledInt == 1
|
||||
if disabledAt.Valid {
|
||||
dt, err := time.Parse(timeFormat, disabledAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.DisabledAt = &dt
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func isDuplicateColumnErr(err error) bool {
|
||||
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
|
||||
}
|
||||
Reference in New Issue
Block a user