初步实现鉴权

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-23 20:21:35 +08:00
parent 847ce0c6a8
commit 62e076111f
10 changed files with 1212 additions and 87 deletions
+51
View File
@@ -0,0 +1,51 @@
package audit
import (
"context"
"database/sql"
"strings"
"time"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
type Entry struct {
RequestIP string
UserID int64
HasUser bool
TokenID int64
HasToken bool
OriginalURL string
HTTPStatus int
Success bool
ErrorReason string
CountAsSuccess bool
OccurredAt time.Time
}
type Logger struct {
store *db.Store
}
func NewLogger(store *db.Store) *Logger {
return &Logger{store: store}
}
func (l *Logger) Log(ctx context.Context, entry Entry) error {
usage := db.UsageEntry{
RequestIP: strings.TrimSpace(entry.RequestIP),
OriginalURL: strings.TrimSpace(entry.OriginalURL),
HTTPStatus: entry.HTTPStatus,
Success: entry.Success,
ErrorReason: strings.TrimSpace(entry.ErrorReason),
OccurredAt: entry.OccurredAt,
}
if entry.HasUser {
usage.UserID = sql.NullInt64{Int64: entry.UserID, Valid: true}
}
if entry.HasToken {
usage.TokenID = sql.NullInt64{Int64: entry.TokenID, Valid: true}
}
return l.store.RecordUsageAndMaybeIncrement(ctx, usage, entry.CountAsSuccess)
}
+90
View File
@@ -0,0 +1,90 @@
package auth
import (
"context"
"errors"
"net/http"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
type contextKey string
const authInfoKey contextKey = "authInfo"
type AuthInfo struct {
UserID int64
TokenID int64
Token string
}
type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64)
func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenValue, ok := bearerToken(r.Header.Get("Authorization"))
if !ok {
http.Error(w, "Missing bearer token.", http.StatusUnauthorized)
if onFailure != nil {
onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0)
}
return
}
token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
if err != nil {
status := http.StatusForbidden
reason := "invalid_token"
switch {
case errors.Is(err, db.ErrNotFound):
reason = "invalid_token"
case errors.Is(err, db.ErrTokenDisabled):
reason = "token_disabled"
case errors.Is(err, db.ErrTokenExpired):
reason = "token_expired"
case errors.Is(err, db.ErrTokenExhausted):
reason = "token_exhausted"
default:
status = http.StatusInternalServerError
reason = "token_lookup_error"
}
http.Error(w, http.StatusText(status), status)
if onFailure != nil {
if token != nil {
onFailure(r, tokenValue, status, reason, token.UserID, token.ID)
} else {
onFailure(r, tokenValue, status, reason, 0, 0)
}
}
return
}
authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue}
ctx := context.WithValue(r.Context(), authInfoKey, authInfo)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func FromContext(ctx context.Context) (AuthInfo, bool) {
v := ctx.Value(authInfoKey)
if v == nil {
return AuthInfo{}, false
}
info, ok := v.(AuthInfo)
return info, ok
}
func bearerToken(authorizationHeader string) (string, bool) {
parts := strings.Fields(strings.TrimSpace(authorizationHeader))
if len(parts) != 2 {
return "", false
}
if !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
if parts[1] == "" {
return "", false
}
return parts[1], true
}
+78
View File
@@ -0,0 +1,78 @@
package config
import (
"os"
"strconv"
"strings"
"time"
)
const (
defaultHost = "127.0.0.1"
defaultPort = "2005"
defaultDBPath = "./hugs_proxy.db"
defaultTokenTTLString = "30d"
defaultBusyTimeoutMilli = 5000
)
type Config struct {
Host string
Port string
DBPath string
DefaultTokenTTL time.Duration
BusyTimeoutMS int
}
func Load() Config {
host := getenvOrDefault("HUGS_PROXY_HOST", defaultHost)
port := getenvOrDefault("HUGS_PROXY_PORT", defaultPort)
dbPath := getenvOrDefault("HUGS_PROXY_DB_PATH", defaultDBPath)
ttlRaw := getenvOrDefault("HUGS_PROXY_DEFAULT_TOKEN_TTL", defaultTokenTTLString)
busyTimeout := getenvIntOrDefault("HUGS_PROXY_DB_BUSY_TIMEOUT_MS", defaultBusyTimeoutMilli)
ttl, err := ParseExpiryDuration(ttlRaw)
if err != nil {
ttl, _ = ParseExpiryDuration(defaultTokenTTLString)
}
return Config{
Host: host,
Port: port,
DBPath: dbPath,
DefaultTokenTTL: ttl,
BusyTimeoutMS: busyTimeout,
}
}
func ParseExpiryDuration(s string) (time.Duration, error) {
s = strings.TrimSpace(strings.ToLower(s))
if strings.HasSuffix(s, "d") {
daysPart := strings.TrimSuffix(s, "d")
days, err := strconv.Atoi(daysPart)
if err != nil {
return 0, err
}
return time.Duration(days) * 24 * time.Hour, nil
}
return time.ParseDuration(s)
}
func getenvOrDefault(key string, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}
func getenvIntOrDefault(key string, fallback int) int {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}
+495
View File
@@ -0,0 +1,495 @@
package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
_ "modernc.org/sqlite"
)
const timeFormat = time.RFC3339Nano
var ErrNotFound = errors.New("not found")
var (
ErrTokenDisabled = errors.New("token disabled")
ErrTokenExpired = errors.New("token expired")
ErrTokenExhausted = errors.New("token exhausted")
)
type Store struct {
db *sql.DB
}
type User struct {
ID int64
Username string
CreatedAt time.Time
}
type Token struct {
ID int64
UserID int64
Token string
CreatedAt time.Time
ExpiresAt time.Time
Disabled bool
DisabledAt *time.Time
MaxUses int
UsedCount int
}
type TokenWithUser struct {
Token
Username string
}
type UsageEntry struct {
RequestIP string
UserID sql.NullInt64
TokenID sql.NullInt64
OriginalURL string
HTTPStatus int
Success bool
ErrorReason string
OccurredAt time.Time
}
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
db.Close()
return nil, err
}
s := &Store{db: db}
if err := s.migrate(context.Background()); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) migrate(ctx context.Context) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
disabled INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 0,
used_count INTEGER NOT NULL DEFAULT 0,
disabled_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE TABLE IF NOT EXISTS site_stats (
id INTEGER PRIMARY KEY,
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
updated_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS token_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
request_ip TEXT NOT NULL,
user_id INTEGER,
token_id INTEGER,
original_url TEXT NOT NULL,
http_status INTEGER NOT NULL,
success INTEGER NOT NULL,
error_reason TEXT NOT NULL DEFAULT '',
FOREIGN KEY(user_id) REFERENCES users(id),
FOREIGN KEY(token_id) REFERENCES tokens(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
}
for _, stmt := range stmts {
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err
}
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
_, err := s.db.ExecContext(
ctx,
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
VALUES (1, 0, ?)
ON CONFLICT(id) DO NOTHING;`,
time.Now().UTC().Format(timeFormat),
)
return err
}
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
"INSERT INTO users (username, created_at) VALUES (?, ?);",
username,
now.Format(timeFormat),
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &User{ID: id, Username: username, CreatedAt: now}, nil
}
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
var u User
var createdAt string
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
return &u, nil
}
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
if err != nil {
return nil, err
}
defer rows.Close()
var out []User
for rows.Next() {
var u User
var createdAt string
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
out = append(out, u)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
if maxUses < 0 {
return nil, fmt.Errorf("maxUses cannot be negative")
}
tokenValue, err := generateToken()
if err != nil {
return nil, err
}
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
VALUES (?, ?, ?, ?, 0, ?, 0);`,
userID,
tokenValue,
now.Format(timeFormat),
expiresAt.UTC().Format(timeFormat),
maxUses,
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
}
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
row := s.db.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
return scanToken(row)
}
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
row := tx.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
token, err := scanToken(row)
if err != nil {
return nil, err
}
now := time.Now().UTC()
if token.Disabled {
return nil, ErrTokenDisabled
}
if now.After(token.ExpiresAt) {
return nil, ErrTokenExpired
}
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
return nil, ErrTokenExhausted
}
res, err := tx.ExecContext(
ctx,
`UPDATE tokens
SET used_count = used_count + 1
WHERE id = ?
AND (max_uses = 0 OR used_count < max_uses);`,
token.ID,
)
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
if affected == 0 {
return nil, ErrTokenExhausted
}
if err = tx.Commit(); err != nil {
return nil, err
}
token.UsedCount++
return token, nil
}
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
now := time.Now().UTC().Format(timeFormat)
res, err := s.db.ExecContext(
ctx,
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
now,
tokenValue,
)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
rows, err := s.db.QueryContext(
ctx,
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
FROM tokens t
INNER JOIN users u ON u.id = t.user_id
ORDER BY t.id ASC;`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []TokenWithUser
for rows.Next() {
var t TokenWithUser
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
out = append(out, t)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
if entry.OccurredAt.IsZero() {
entry.OccurredAt = time.Now().UTC()
}
if entry.RequestIP == "" {
entry.RequestIP = "unknown"
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
successInt := 0
if entry.Success {
successInt = 1
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
entry.OccurredAt.UTC().Format(timeFormat),
entry.RequestIP,
entry.UserID,
entry.TokenID,
entry.OriginalURL,
entry.HTTPStatus,
successInt,
entry.ErrorReason,
)
if err != nil {
return err
}
if increment {
_, err = tx.ExecContext(
ctx,
`UPDATE site_stats
SET total_accelerated_count = total_accelerated_count + 1,
updated_at = ?
WHERE id = 1;`,
time.Now().UTC().Format(timeFormat),
)
if err != nil {
return err
}
}
return tx.Commit()
}
func generateToken() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
var t Token
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
return &t, nil
}
func isDuplicateColumnErr(err error) bool {
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
}