f8051ea8c9
Co-authored-by: Copilot <copilot@github.com>
534 lines
12 KiB
Go
534 lines
12 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
const timeFormat = time.RFC3339Nano
|
|
|
|
var ErrNotFound = errors.New("not found")
|
|
|
|
var (
|
|
ErrTokenDisabled = errors.New("token disabled")
|
|
ErrTokenExpired = errors.New("token expired")
|
|
ErrTokenExhausted = errors.New("token exhausted")
|
|
)
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Username string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type Token struct {
|
|
ID int64
|
|
UserID int64
|
|
Token string
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
Disabled bool
|
|
DisabledAt *time.Time
|
|
MaxUses int
|
|
UsedCount int
|
|
}
|
|
|
|
type TokenWithUser struct {
|
|
Token
|
|
Username string
|
|
}
|
|
|
|
type UsageEntry struct {
|
|
RequestIP string
|
|
UserID sql.NullInt64
|
|
TokenID sql.NullInt64
|
|
OriginalURL string
|
|
HTTPStatus int
|
|
Success bool
|
|
ErrorReason string
|
|
OccurredAt time.Time
|
|
}
|
|
|
|
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
|
|
s := &Store{db: db}
|
|
if err := s.migrate(context.Background()); err != nil {
|
|
db.Close()
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
return s.db.Close()
|
|
}
|
|
|
|
func (s *Store) migrate(ctx context.Context) error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
created_at TEXT NOT NULL
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS tokens (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
token TEXT NOT NULL UNIQUE,
|
|
created_at TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
disabled INTEGER NOT NULL DEFAULT 0,
|
|
max_uses INTEGER NOT NULL DEFAULT 0,
|
|
used_count INTEGER NOT NULL DEFAULT 0,
|
|
disabled_at TEXT,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS site_stats (
|
|
id INTEGER PRIMARY KEY,
|
|
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
|
|
updated_at TEXT NOT NULL
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS token_usage (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
created_at TEXT NOT NULL,
|
|
request_ip TEXT NOT NULL,
|
|
user_id INTEGER,
|
|
token_id INTEGER,
|
|
original_url TEXT NOT NULL,
|
|
http_status INTEGER NOT NULL,
|
|
success INTEGER NOT NULL,
|
|
error_reason TEXT NOT NULL DEFAULT '',
|
|
FOREIGN KEY(user_id) REFERENCES users(id),
|
|
FOREIGN KEY(token_id) REFERENCES tokens(id)
|
|
);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
|
|
}
|
|
|
|
for _, stmt := range stmts {
|
|
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
|
return err
|
|
}
|
|
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
|
return err
|
|
}
|
|
|
|
_, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
|
|
VALUES (1, 0, ?)
|
|
ON CONFLICT(id) DO NOTHING;`,
|
|
time.Now().UTC().Format(timeFormat),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
|
|
now := time.Now().UTC()
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
"INSERT INTO users (username, created_at) VALUES (?, ?);",
|
|
username,
|
|
now.Format(timeFormat),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &User{ID: id, Username: username, CreatedAt: now}, nil
|
|
}
|
|
|
|
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
|
|
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
|
|
var u User
|
|
var createdAt string
|
|
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
parsed, err := time.Parse(timeFormat, createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
u.CreatedAt = parsed
|
|
return &u, nil
|
|
}
|
|
|
|
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
|
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []User
|
|
for rows.Next() {
|
|
var u User
|
|
var createdAt string
|
|
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
|
return nil, err
|
|
}
|
|
parsed, err := time.Parse(timeFormat, createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
u.CreatedAt = parsed
|
|
out = append(out, u)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
|
|
if maxUses < 0 {
|
|
return nil, fmt.Errorf("maxUses cannot be negative")
|
|
}
|
|
tokenValue, err := generateToken()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := time.Now().UTC()
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
|
|
VALUES (?, ?, ?, ?, 0, ?, 0);`,
|
|
userID,
|
|
tokenValue,
|
|
now.Format(timeFormat),
|
|
expiresAt.UTC().Format(timeFormat),
|
|
maxUses,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
|
|
}
|
|
|
|
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
|
|
row := s.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
|
FROM tokens
|
|
WHERE token = ?;`,
|
|
tokenValue,
|
|
)
|
|
return scanToken(row)
|
|
}
|
|
|
|
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
row := tx.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
|
FROM tokens
|
|
WHERE token = ?;`,
|
|
tokenValue,
|
|
)
|
|
token, err := scanToken(row)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
if token.Disabled {
|
|
return nil, ErrTokenDisabled
|
|
}
|
|
if now.After(token.ExpiresAt) {
|
|
return nil, ErrTokenExpired
|
|
}
|
|
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
|
|
return nil, ErrTokenExhausted
|
|
}
|
|
|
|
res, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE tokens
|
|
SET used_count = used_count + 1
|
|
WHERE id = ?
|
|
AND (max_uses = 0 OR used_count < max_uses);`,
|
|
token.ID,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if affected == 0 {
|
|
return nil, ErrTokenExhausted
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return nil, err
|
|
}
|
|
token.UsedCount++
|
|
return token, nil
|
|
}
|
|
|
|
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
|
|
now := time.Now().UTC().Format(timeFormat)
|
|
res, err := s.db.ExecContext(
|
|
ctx,
|
|
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
|
|
now,
|
|
tokenValue,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
|
|
rows, err := s.db.QueryContext(
|
|
ctx,
|
|
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
|
|
FROM tokens t
|
|
INNER JOIN users u ON u.id = t.user_id
|
|
ORDER BY t.id ASC;`,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []TokenWithUser
|
|
for rows.Next() {
|
|
var t TokenWithUser
|
|
var createdAt string
|
|
var expiresAt string
|
|
var disabledAt sql.NullString
|
|
var disabledInt int
|
|
if err := rows.Scan(&t.ID, &t.UserID, &t.Token.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
|
|
return nil, err
|
|
}
|
|
ct, err := time.Parse(timeFormat, createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
et, err := time.Parse(timeFormat, expiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.CreatedAt = ct
|
|
t.ExpiresAt = et
|
|
t.Disabled = disabledInt == 1
|
|
if disabledAt.Valid {
|
|
dt, err := time.Parse(timeFormat, disabledAt.String)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.DisabledAt = &dt
|
|
}
|
|
out = append(out, t)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ListUsageEntries(ctx context.Context, limit int) ([]UsageEntry, error) {
|
|
query := `SELECT created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason
|
|
FROM token_usage
|
|
ORDER BY id DESC`
|
|
args := []any{}
|
|
if limit > 0 {
|
|
query += ` LIMIT ?`
|
|
args = append(args, limit)
|
|
}
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []UsageEntry
|
|
for rows.Next() {
|
|
var e UsageEntry
|
|
var createdAt string
|
|
var successInt int
|
|
if err := rows.Scan(&createdAt, &e.RequestIP, &e.UserID, &e.TokenID, &e.OriginalURL, &e.HTTPStatus, &successInt, &e.ErrorReason); err != nil {
|
|
return nil, err
|
|
}
|
|
parsed, err := time.Parse(timeFormat, createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e.OccurredAt = parsed
|
|
e.Success = successInt == 1
|
|
out = append(out, e)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
|
|
if entry.OccurredAt.IsZero() {
|
|
entry.OccurredAt = time.Now().UTC()
|
|
}
|
|
if entry.RequestIP == "" {
|
|
entry.RequestIP = "unknown"
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
successInt := 0
|
|
if entry.Success {
|
|
successInt = 1
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
|
|
entry.OccurredAt.UTC().Format(timeFormat),
|
|
entry.RequestIP,
|
|
entry.UserID,
|
|
entry.TokenID,
|
|
entry.OriginalURL,
|
|
entry.HTTPStatus,
|
|
successInt,
|
|
entry.ErrorReason,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if increment {
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE site_stats
|
|
SET total_accelerated_count = total_accelerated_count + 1,
|
|
updated_at = ?
|
|
WHERE id = 1;`,
|
|
time.Now().UTC().Format(timeFormat),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func generateToken() (string, error) {
|
|
buf := make([]byte, 32)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
}
|
|
|
|
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
|
|
var t Token
|
|
var createdAt string
|
|
var expiresAt string
|
|
var disabledAt sql.NullString
|
|
var disabledInt int
|
|
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
ct, err := time.Parse(timeFormat, createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
et, err := time.Parse(timeFormat, expiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.CreatedAt = ct
|
|
t.ExpiresAt = et
|
|
t.Disabled = disabledInt == 1
|
|
if disabledAt.Valid {
|
|
dt, err := time.Parse(timeFormat, disabledAt.String)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.DisabledAt = &dt
|
|
}
|
|
return &t, nil
|
|
}
|
|
|
|
func isDuplicateColumnErr(err error) bool {
|
|
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
|
|
}
|