Files
Hugs-Proxy/internal/db/store.go
T
gary 62e076111f 初步实现鉴权
Co-authored-by: Copilot <copilot@github.com>
2026-04-23 20:21:35 +08:00

496 lines
11 KiB
Go

package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
_ "modernc.org/sqlite"
)
const timeFormat = time.RFC3339Nano
var ErrNotFound = errors.New("not found")
var (
ErrTokenDisabled = errors.New("token disabled")
ErrTokenExpired = errors.New("token expired")
ErrTokenExhausted = errors.New("token exhausted")
)
type Store struct {
db *sql.DB
}
type User struct {
ID int64
Username string
CreatedAt time.Time
}
type Token struct {
ID int64
UserID int64
Token string
CreatedAt time.Time
ExpiresAt time.Time
Disabled bool
DisabledAt *time.Time
MaxUses int
UsedCount int
}
type TokenWithUser struct {
Token
Username string
}
type UsageEntry struct {
RequestIP string
UserID sql.NullInt64
TokenID sql.NullInt64
OriginalURL string
HTTPStatus int
Success bool
ErrorReason string
OccurredAt time.Time
}
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
db.Close()
return nil, err
}
s := &Store{db: db}
if err := s.migrate(context.Background()); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) migrate(ctx context.Context) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
disabled INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 0,
used_count INTEGER NOT NULL DEFAULT 0,
disabled_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE TABLE IF NOT EXISTS site_stats (
id INTEGER PRIMARY KEY,
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
updated_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS token_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
request_ip TEXT NOT NULL,
user_id INTEGER,
token_id INTEGER,
original_url TEXT NOT NULL,
http_status INTEGER NOT NULL,
success INTEGER NOT NULL,
error_reason TEXT NOT NULL DEFAULT '',
FOREIGN KEY(user_id) REFERENCES users(id),
FOREIGN KEY(token_id) REFERENCES tokens(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
}
for _, stmt := range stmts {
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err
}
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
_, err := s.db.ExecContext(
ctx,
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
VALUES (1, 0, ?)
ON CONFLICT(id) DO NOTHING;`,
time.Now().UTC().Format(timeFormat),
)
return err
}
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
"INSERT INTO users (username, created_at) VALUES (?, ?);",
username,
now.Format(timeFormat),
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &User{ID: id, Username: username, CreatedAt: now}, nil
}
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
var u User
var createdAt string
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
return &u, nil
}
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
if err != nil {
return nil, err
}
defer rows.Close()
var out []User
for rows.Next() {
var u User
var createdAt string
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
out = append(out, u)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
if maxUses < 0 {
return nil, fmt.Errorf("maxUses cannot be negative")
}
tokenValue, err := generateToken()
if err != nil {
return nil, err
}
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
VALUES (?, ?, ?, ?, 0, ?, 0);`,
userID,
tokenValue,
now.Format(timeFormat),
expiresAt.UTC().Format(timeFormat),
maxUses,
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
}
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
row := s.db.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
return scanToken(row)
}
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
row := tx.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
token, err := scanToken(row)
if err != nil {
return nil, err
}
now := time.Now().UTC()
if token.Disabled {
return nil, ErrTokenDisabled
}
if now.After(token.ExpiresAt) {
return nil, ErrTokenExpired
}
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
return nil, ErrTokenExhausted
}
res, err := tx.ExecContext(
ctx,
`UPDATE tokens
SET used_count = used_count + 1
WHERE id = ?
AND (max_uses = 0 OR used_count < max_uses);`,
token.ID,
)
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
if affected == 0 {
return nil, ErrTokenExhausted
}
if err = tx.Commit(); err != nil {
return nil, err
}
token.UsedCount++
return token, nil
}
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
now := time.Now().UTC().Format(timeFormat)
res, err := s.db.ExecContext(
ctx,
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
now,
tokenValue,
)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
rows, err := s.db.QueryContext(
ctx,
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
FROM tokens t
INNER JOIN users u ON u.id = t.user_id
ORDER BY t.id ASC;`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []TokenWithUser
for rows.Next() {
var t TokenWithUser
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
out = append(out, t)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
if entry.OccurredAt.IsZero() {
entry.OccurredAt = time.Now().UTC()
}
if entry.RequestIP == "" {
entry.RequestIP = "unknown"
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
successInt := 0
if entry.Success {
successInt = 1
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
entry.OccurredAt.UTC().Format(timeFormat),
entry.RequestIP,
entry.UserID,
entry.TokenID,
entry.OriginalURL,
entry.HTTPStatus,
successInt,
entry.ErrorReason,
)
if err != nil {
return err
}
if increment {
_, err = tx.ExecContext(
ctx,
`UPDATE site_stats
SET total_accelerated_count = total_accelerated_count + 1,
updated_at = ?
WHERE id = 1;`,
time.Now().UTC().Format(timeFormat),
)
if err != nil {
return err
}
}
return tx.Commit()
}
func generateToken() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
var t Token
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
return &t, nil
}
func isDuplicateColumnErr(err error) bool {
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
}