@@ -25,3 +25,7 @@ go.work.sum
|
||||
# env file
|
||||
.env
|
||||
|
||||
# database file
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
@@ -1,110 +1,140 @@
|
||||
|
||||
# Hugs-Proxy
|
||||
|
||||
> 本项目目前支持Github,后面可能会增加HuggingFace......
|
||||
一个轻量的 GitHub 资源加速反向代理工具。当前版本在原有代理能力上新增了 SQLite 持久化鉴权和访问审计。
|
||||
|
||||
一个**轻量的 GitHub 资源加速/反代下载**小工具:把 GitHub 的下载链接(Release、Archive、Raw、Gist、部分 git/info 相关请求)通过本地 HTTP 服务转发,从而在某些网络环境下提升可用性。
|
||||
## 功能概览
|
||||
|
||||
本项目默认只监听 `127.0.0.1`,更改为对外监听前请务必阅读“安全提示”。
|
||||
|
||||
## 特性
|
||||
|
||||
- 支持代理的链接类型(不匹配会返回 403):
|
||||
- `github.com/<owner>/<repo>/releases/...`
|
||||
- `github.com/<owner>/<repo>/archive/...`
|
||||
- `github.com/<owner>/<repo>/blob/...`(会自动改写为 `.../raw/...` 进行下载)
|
||||
- `github.com/<owner>/<repo>/raw/...`
|
||||
- `github.com/<owner>/<repo>/info/...`、`github.com/<owner>/<repo>/git-...`
|
||||
- `raw.githubusercontent.com/<owner>/<repo>/...`
|
||||
- `gist.github.com/...`、`gist.githubusercontent.com/...`
|
||||
- 白名单/黑名单(先白名单,后黑名单)
|
||||
- 大文件保护:响应体 `Content-Length` 超过 1GB 时直接 302 重定向到源站,避免本机带宽/内存压力
|
||||
- 处理上游重定向:对可识别的 GitHub 下载链接会改写 `Location`,让跳转继续走本代理
|
||||
- 支持 GitHub 资源代理(不匹配返回 403):
|
||||
- github.com/<owner>/<repo>/releases/...
|
||||
- github.com/<owner>/<repo>/archive/...
|
||||
- github.com/<owner>/<repo>/blob/...(自动改写为 /raw/)
|
||||
- github.com/<owner>/<repo>/raw/...
|
||||
- github.com/<owner>/<repo>/info/... 与 github.com/<owner>/<repo>/git-...
|
||||
- raw.githubusercontent.com/<owner>/<repo>/...
|
||||
- gist.github.com/... 与 gist.githubusercontent.com/...
|
||||
- Bearer Token 鉴权(Authorization 头)
|
||||
- SQLite 持久化:用户、token、访问明细、全站累计计数
|
||||
- 白名单/黑名单(先白名单再黑名单)
|
||||
- 大文件保护(>1GB 直接 302 回源)
|
||||
- 上游重定向 Location 改写(可识别 GitHub URL 时继续经由代理)
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1) 运行
|
||||
|
||||
在仓库根目录执行:
|
||||
### 1. 启动服务
|
||||
|
||||
```bash
|
||||
go run .
|
||||
```
|
||||
|
||||
默认监听:`127.0.0.1:2005`
|
||||
默认监听地址:127.0.0.1:2005
|
||||
|
||||
### 2) 使用方式
|
||||
### 2. 初始化用户与 token
|
||||
|
||||
访问路径就是你要代理的目标 URL(去掉前面的 `/`):
|
||||
|
||||
```text
|
||||
http://127.0.0.1:2005/<目标URL>
|
||||
```bash
|
||||
go run ./cmd/tokenctl create-user --name alice
|
||||
go run ./cmd/tokenctl issue-token --user alice --expires 30d --uses 1
|
||||
```
|
||||
|
||||
目标 URL 可以写全(推荐),也可以省略 scheme(会自动补 `https://`)。示例:
|
||||
说明:
|
||||
|
||||
- Release 文件:
|
||||
- --expires 支持 7d/30d 这类天数写法
|
||||
- 也支持 RFC3339 绝对时间(例如 2026-12-31T23:59:59Z)
|
||||
- 不传 --expires 时使用默认有效期(环境变量可配置)
|
||||
- --uses 指定可用次数(例如 1 为一次性 token,5 为可用 5 次)
|
||||
|
||||
```text
|
||||
http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/app-darwin-amd64.zip
|
||||
### 3. 携带 Bearer Token 请求
|
||||
|
||||
```bash
|
||||
curl -L \
|
||||
-H "Authorization: Bearer <TOKEN>" \
|
||||
"http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/file.zip"
|
||||
```
|
||||
|
||||
- 仓库归档(archive):
|
||||
无 token:返回 401。
|
||||
|
||||
token 无效、已禁用、已过期或使用次数耗尽:返回 403。
|
||||
|
||||
## tokenctl 用法
|
||||
|
||||
```text
|
||||
http://127.0.0.1:2005/https://github.com/OWNER/REPO/archive/refs/heads/main.zip
|
||||
tokenctl create-user --name <username> [--db <path>]
|
||||
tokenctl issue-token --user <username> [--expires <7d|30d|RFC3339>] [--uses <n>] [--db <path>]
|
||||
tokenctl disable-token --token <token> [--db <path>]
|
||||
tokenctl list-users [--db <path>]
|
||||
tokenctl list-tokens [--db <path>]
|
||||
```
|
||||
|
||||
- Raw 文件(也可以给 `blob`,服务端会自动替换为 `raw`):
|
||||
示例:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:2005/https://github.com/OWNER/REPO/blob/main/README.md
|
||||
```bash
|
||||
go run ./cmd/tokenctl list-users
|
||||
go run ./cmd/tokenctl list-tokens
|
||||
go run ./cmd/tokenctl issue-token --user alice --uses 5
|
||||
go run ./cmd/tokenctl disable-token --token <TOKEN>
|
||||
```
|
||||
|
||||
- raw.githubusercontent.com:
|
||||
## 环境变量配置
|
||||
|
||||
```text
|
||||
http://127.0.0.1:2005/https://raw.githubusercontent.com/OWNER/REPO/main/README.md
|
||||
```
|
||||
服务与 CLI 共用同一套配置:
|
||||
|
||||
如果你传入的链接不属于上面支持的 GitHub 资源格式,会返回:`403 Invalid input.`
|
||||
- HUGS_PROXY_HOST:监听地址,默认 127.0.0.1
|
||||
- HUGS_PROXY_PORT:监听端口,默认 2005
|
||||
- HUGS_PROXY_DB_PATH:SQLite 文件路径,默认 ./hugs_proxy.db
|
||||
- HUGS_PROXY_DEFAULT_TOKEN_TTL:默认 token 有效期,默认 30d
|
||||
- HUGS_PROXY_DB_BUSY_TIMEOUT_MS:SQLite busy_timeout,默认 5000
|
||||
|
||||
## 配置
|
||||
说明:
|
||||
|
||||
目前所有配置都在 [main.go](main.go) 顶部的“配置区域”,修改后重新运行即可:
|
||||
- 代理大文件阈值 sizeLimit 仍在 main.go 内部常量(默认 1GB)。
|
||||
- 白名单与黑名单规则仍在 main.go 的 whiteListStr 与 blackListStr。
|
||||
|
||||
- `host`:监听地址(默认 `127.0.0.1`)
|
||||
- `port`:监听端口(默认 `2005`)
|
||||
- `sizeLimit`:大文件阈值(默认 `1GB`)
|
||||
- `whiteListStr`:白名单(多行字符串,每行一条规则)
|
||||
- `blackListStr`:黑名单(多行字符串,每行一条规则)
|
||||
## 数据库说明
|
||||
|
||||
### 白名单/黑名单规则
|
||||
程序启动会自动执行幂等建表迁移(CREATE TABLE IF NOT EXISTS),并设置:
|
||||
|
||||
每行一条,支持三种写法:
|
||||
- PRAGMA journal_mode = WAL
|
||||
- PRAGMA busy_timeout = 5000(或你设置的环境变量值)
|
||||
|
||||
- `user1`:匹配/封禁 `user1` 下的所有仓库
|
||||
- `user1/repo1`:匹配/封禁 `user1/repo1`
|
||||
- `*/repo1`:匹配/封禁所有名为 `repo1` 的仓库
|
||||
表结构语义:
|
||||
|
||||
- users:用户主体
|
||||
- tokens:token 主体(含过期、禁用、可用次数上限与已使用次数)
|
||||
- token_usage:访问审计明细(时间/IP/用户/token/原始 URL/HTTP 状态/成功标志/失败原因)
|
||||
- site_stats:总体累计计数(total_accelerated_count)
|
||||
|
||||
## 白名单/黑名单规则
|
||||
|
||||
每行一条,支持:
|
||||
|
||||
- user1:匹配 user1 下所有仓库
|
||||
- user1/repo1:匹配 user1/repo1
|
||||
- \*/repo1:匹配所有 repo1
|
||||
|
||||
判定顺序:
|
||||
|
||||
1. **白名单优先生效**:如果白名单非空,则必须至少命中一条白名单规则,否则直接拒绝(403)。
|
||||
2. **再匹配黑名单**:命中任意黑名单规则则拒绝(403)。
|
||||
1. 白名单优先(若白名单非空,必须命中)
|
||||
2. 再匹配黑名单(命中即拒绝)
|
||||
|
||||
## 验证步骤
|
||||
|
||||
1. 构建:go build ./...
|
||||
2. 无 token 请求:应返回 401
|
||||
3. 签发 token 后请求合法 URL:应成功且代理行为与旧版本一致
|
||||
4. 禁用或过期 token 请求:应返回 403,并记录失败 usage
|
||||
5. 多次请求后检查 site_stats.total_accelerated_count 累计值
|
||||
6. 检查 token_usage 记录字段完整性
|
||||
|
||||
## 常见问题
|
||||
## 排错
|
||||
|
||||
### 为什么有时会直接跳转到 GitHub?
|
||||
|
||||
当上游响应 `Content-Length` 大于 `sizeLimit`(默认 1GB)时,程序会返回 302 重定向到目标 URL,而不是继续转发大文件内容。
|
||||
|
||||
### 支持 Git Clone / Git LFS 吗?
|
||||
|
||||
本项目主要面向“下载/获取资源”。它只放行并代理部分与 GitHub 资源下载相关的 URL 形态;不保证覆盖完整的 Git 协议或 Git LFS 场景。
|
||||
- 启动报数据库错误:
|
||||
- 检查 HUGS_PROXY_DB_PATH 是否可写
|
||||
- 检查目录权限
|
||||
- 总是 401:
|
||||
- 确认请求头格式为 Authorization: Bearer <TOKEN>
|
||||
- 总是 403:
|
||||
- 检查 token 是否过期或禁用
|
||||
- 检查 URL 是否命中支持规则以及白黑名单规则
|
||||
|
||||
## License
|
||||
|
||||
详见 License 文件
|
||||
|
||||
详见仓库中的 LICENSE 文件。
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := config.Load()
|
||||
if len(os.Args) < 2 {
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sub := os.Args[1]
|
||||
switch sub {
|
||||
case "create-user":
|
||||
handleCreateUser(cfg, os.Args[2:])
|
||||
case "issue-token":
|
||||
handleIssueToken(cfg, os.Args[2:])
|
||||
case "disable-token":
|
||||
handleDisableToken(cfg, os.Args[2:])
|
||||
case "list-users":
|
||||
handleListUsers(cfg, os.Args[2:])
|
||||
case "list-tokens":
|
||||
handleListTokens(cfg, os.Args[2:])
|
||||
default:
|
||||
printUsage()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateUser(cfg config.Config, args []string) {
|
||||
fs := flag.NewFlagSet("create-user", flag.ExitOnError)
|
||||
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
|
||||
name := fs.String("name", "", "username")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if strings.TrimSpace(*name) == "" {
|
||||
log.Fatal("--name is required")
|
||||
}
|
||||
|
||||
store := openStore(*dbPath, cfg.BusyTimeoutMS)
|
||||
defer store.Close()
|
||||
|
||||
u, err := store.CreateUser(context.Background(), strings.TrimSpace(*name))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("created user: id=%d username=%s created_at=%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func handleIssueToken(cfg config.Config, args []string) {
|
||||
fs := flag.NewFlagSet("issue-token", flag.ExitOnError)
|
||||
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
|
||||
username := fs.String("user", "", "username")
|
||||
expires := fs.String("expires", "", "expires duration like 7d/30d or RFC3339 time")
|
||||
uses := fs.Int("uses", 1, "max usable times for this token, must be > 0")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if strings.TrimSpace(*username) == "" {
|
||||
log.Fatal("--user is required")
|
||||
}
|
||||
if *uses <= 0 {
|
||||
log.Fatal("--uses must be > 0")
|
||||
}
|
||||
|
||||
store := openStore(*dbPath, cfg.BusyTimeoutMS)
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
u, err := store.GetUserByName(ctx, strings.TrimSpace(*username))
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
log.Fatalf("user %q not found", *username)
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
expiresAt, err := parseExpiresAt(*expires, cfg.DefaultTokenTTL)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tok, err := store.IssueToken(ctx, u.ID, expiresAt, *uses)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("issued token: id=%d user=%s token=%s expires_at=%s max_uses=%d\n", tok.ID, u.Username, tok.Token, tok.ExpiresAt.Format(time.RFC3339), tok.MaxUses)
|
||||
}
|
||||
|
||||
func handleDisableToken(cfg config.Config, args []string) {
|
||||
fs := flag.NewFlagSet("disable-token", flag.ExitOnError)
|
||||
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
|
||||
token := fs.String("token", "", "token value")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if strings.TrimSpace(*token) == "" {
|
||||
log.Fatal("--token is required")
|
||||
}
|
||||
|
||||
store := openStore(*dbPath, cfg.BusyTimeoutMS)
|
||||
defer store.Close()
|
||||
|
||||
err := store.DisableToken(context.Background(), strings.TrimSpace(*token))
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
log.Fatalf("token not found")
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("token disabled")
|
||||
}
|
||||
|
||||
func handleListUsers(cfg config.Config, args []string) {
|
||||
fs := flag.NewFlagSet("list-users", flag.ExitOnError)
|
||||
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
store := openStore(*dbPath, cfg.BusyTimeoutMS)
|
||||
defer store.Close()
|
||||
|
||||
users, err := store.ListUsers(context.Background())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("id\tusername\tcreated_at")
|
||||
for _, u := range users {
|
||||
fmt.Printf("%d\t%s\t%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func handleListTokens(cfg config.Config, args []string) {
|
||||
fs := flag.NewFlagSet("list-tokens", flag.ExitOnError)
|
||||
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
store := openStore(*dbPath, cfg.BusyTimeoutMS)
|
||||
defer store.Close()
|
||||
|
||||
tokens, err := store.ListTokens(context.Background())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("id\tuser\tdisabled\tused/max\texpires_at\tcreated_at\ttoken")
|
||||
for _, t := range tokens {
|
||||
maxUsesDisplay := fmt.Sprintf("%d", t.MaxUses)
|
||||
if t.MaxUses == 0 {
|
||||
maxUsesDisplay = "unlimited"
|
||||
}
|
||||
fmt.Printf("%d\t%s\t%t\t%d/%s\t%s\t%s\t%s\n", t.ID, t.Username, t.Disabled, t.UsedCount, maxUsesDisplay, t.ExpiresAt.Format(time.RFC3339), t.CreatedAt.Format(time.RFC3339), t.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func parseExpiresAt(raw string, defaultTTL time.Duration) (time.Time, error) {
|
||||
now := time.Now().UTC()
|
||||
clean := strings.TrimSpace(raw)
|
||||
if clean == "" {
|
||||
return now.Add(defaultTTL), nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, clean); err == nil {
|
||||
return t.UTC(), nil
|
||||
}
|
||||
d, err := config.ParseExpiryDuration(clean)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("invalid --expires value: %w", err)
|
||||
}
|
||||
return now.Add(d), nil
|
||||
}
|
||||
|
||||
func openStore(dbPath string, busyTimeoutMS int) *db.Store {
|
||||
store, err := db.NewStore(strings.TrimSpace(dbPath), busyTimeoutMS)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Println(`tokenctl usage:
|
||||
tokenctl create-user --name <username> [--db <path>]
|
||||
tokenctl issue-token --user <username> [--expires <7d|30d|RFC3339>] [--uses <n>] [--db <path>]
|
||||
tokenctl disable-token --token <token> [--db <path>]
|
||||
tokenctl list-users [--db <path>]
|
||||
tokenctl list-tokens [--db <path>]`)
|
||||
}
|
||||
@@ -1,3 +1,17 @@
|
||||
module gitea.gangary.cn/gary/Hugs-Proxy
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require modernc.org/sqlite v1.49.1
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
modernc.org/libc v1.72.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
||||
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
||||
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
||||
modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
|
||||
modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U=
|
||||
modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
@@ -0,0 +1,51 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
RequestIP string
|
||||
UserID int64
|
||||
HasUser bool
|
||||
TokenID int64
|
||||
HasToken bool
|
||||
OriginalURL string
|
||||
HTTPStatus int
|
||||
Success bool
|
||||
ErrorReason string
|
||||
CountAsSuccess bool
|
||||
OccurredAt time.Time
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
store *db.Store
|
||||
}
|
||||
|
||||
func NewLogger(store *db.Store) *Logger {
|
||||
return &Logger{store: store}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(ctx context.Context, entry Entry) error {
|
||||
usage := db.UsageEntry{
|
||||
RequestIP: strings.TrimSpace(entry.RequestIP),
|
||||
OriginalURL: strings.TrimSpace(entry.OriginalURL),
|
||||
HTTPStatus: entry.HTTPStatus,
|
||||
Success: entry.Success,
|
||||
ErrorReason: strings.TrimSpace(entry.ErrorReason),
|
||||
OccurredAt: entry.OccurredAt,
|
||||
}
|
||||
if entry.HasUser {
|
||||
usage.UserID = sql.NullInt64{Int64: entry.UserID, Valid: true}
|
||||
}
|
||||
if entry.HasToken {
|
||||
usage.TokenID = sql.NullInt64{Int64: entry.TokenID, Valid: true}
|
||||
}
|
||||
|
||||
return l.store.RecordUsageAndMaybeIncrement(ctx, usage, entry.CountAsSuccess)
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const authInfoKey contextKey = "authInfo"
|
||||
|
||||
type AuthInfo struct {
|
||||
UserID int64
|
||||
TokenID int64
|
||||
Token string
|
||||
}
|
||||
|
||||
type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64)
|
||||
|
||||
func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenValue, ok := bearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
http.Error(w, "Missing bearer token.", http.StatusUnauthorized)
|
||||
if onFailure != nil {
|
||||
onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
|
||||
if err != nil {
|
||||
status := http.StatusForbidden
|
||||
reason := "invalid_token"
|
||||
switch {
|
||||
case errors.Is(err, db.ErrNotFound):
|
||||
reason = "invalid_token"
|
||||
case errors.Is(err, db.ErrTokenDisabled):
|
||||
reason = "token_disabled"
|
||||
case errors.Is(err, db.ErrTokenExpired):
|
||||
reason = "token_expired"
|
||||
case errors.Is(err, db.ErrTokenExhausted):
|
||||
reason = "token_exhausted"
|
||||
default:
|
||||
status = http.StatusInternalServerError
|
||||
reason = "token_lookup_error"
|
||||
}
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
if onFailure != nil {
|
||||
if token != nil {
|
||||
onFailure(r, tokenValue, status, reason, token.UserID, token.ID)
|
||||
} else {
|
||||
onFailure(r, tokenValue, status, reason, 0, 0)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue}
|
||||
ctx := context.WithValue(r.Context(), authInfoKey, authInfo)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (AuthInfo, bool) {
|
||||
v := ctx.Value(authInfoKey)
|
||||
if v == nil {
|
||||
return AuthInfo{}, false
|
||||
}
|
||||
info, ok := v.(AuthInfo)
|
||||
return info, ok
|
||||
}
|
||||
|
||||
func bearerToken(authorizationHeader string) (string, bool) {
|
||||
parts := strings.Fields(strings.TrimSpace(authorizationHeader))
|
||||
if len(parts) != 2 {
|
||||
return "", false
|
||||
}
|
||||
if !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", false
|
||||
}
|
||||
if parts[1] == "" {
|
||||
return "", false
|
||||
}
|
||||
return parts[1], true
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHost = "127.0.0.1"
|
||||
defaultPort = "2005"
|
||||
defaultDBPath = "./hugs_proxy.db"
|
||||
defaultTokenTTLString = "30d"
|
||||
defaultBusyTimeoutMilli = 5000
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
DBPath string
|
||||
DefaultTokenTTL time.Duration
|
||||
BusyTimeoutMS int
|
||||
}
|
||||
|
||||
func Load() Config {
|
||||
host := getenvOrDefault("HUGS_PROXY_HOST", defaultHost)
|
||||
port := getenvOrDefault("HUGS_PROXY_PORT", defaultPort)
|
||||
dbPath := getenvOrDefault("HUGS_PROXY_DB_PATH", defaultDBPath)
|
||||
ttlRaw := getenvOrDefault("HUGS_PROXY_DEFAULT_TOKEN_TTL", defaultTokenTTLString)
|
||||
busyTimeout := getenvIntOrDefault("HUGS_PROXY_DB_BUSY_TIMEOUT_MS", defaultBusyTimeoutMilli)
|
||||
|
||||
ttl, err := ParseExpiryDuration(ttlRaw)
|
||||
if err != nil {
|
||||
ttl, _ = ParseExpiryDuration(defaultTokenTTLString)
|
||||
}
|
||||
|
||||
return Config{
|
||||
Host: host,
|
||||
Port: port,
|
||||
DBPath: dbPath,
|
||||
DefaultTokenTTL: ttl,
|
||||
BusyTimeoutMS: busyTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseExpiryDuration(s string) (time.Duration, error) {
|
||||
s = strings.TrimSpace(strings.ToLower(s))
|
||||
if strings.HasSuffix(s, "d") {
|
||||
daysPart := strings.TrimSuffix(s, "d")
|
||||
days, err := strconv.Atoi(daysPart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return time.Duration(days) * 24 * time.Hour, nil
|
||||
}
|
||||
return time.ParseDuration(s)
|
||||
}
|
||||
|
||||
func getenvOrDefault(key string, fallback string) string {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getenvIntOrDefault(key string, fallback int) int {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,495 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const timeFormat = time.RFC3339Nano
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
var (
|
||||
ErrTokenDisabled = errors.New("token disabled")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrTokenExhausted = errors.New("token exhausted")
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Username string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Disabled bool
|
||||
DisabledAt *time.Time
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
}
|
||||
|
||||
type TokenWithUser struct {
|
||||
Token
|
||||
Username string
|
||||
}
|
||||
|
||||
type UsageEntry struct {
|
||||
RequestIP string
|
||||
UserID sql.NullInt64
|
||||
TokenID sql.NullInt64
|
||||
OriginalURL string
|
||||
HTTPStatus int
|
||||
Success bool
|
||||
ErrorReason string
|
||||
OccurredAt time.Time
|
||||
}
|
||||
|
||||
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
if err := s.migrate(context.Background()); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) migrate(ctx context.Context) error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
disabled INTEGER NOT NULL DEFAULT 0,
|
||||
max_uses INTEGER NOT NULL DEFAULT 0,
|
||||
used_count INTEGER NOT NULL DEFAULT 0,
|
||||
disabled_at TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS site_stats (
|
||||
id INTEGER PRIMARY KEY,
|
||||
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at TEXT NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS token_usage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TEXT NOT NULL,
|
||||
request_ip TEXT NOT NULL,
|
||||
user_id INTEGER,
|
||||
token_id INTEGER,
|
||||
original_url TEXT NOT NULL,
|
||||
http_status INTEGER NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
error_reason TEXT NOT NULL DEFAULT '',
|
||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
||||
FOREIGN KEY(token_id) REFERENCES tokens(id)
|
||||
);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
||||
return err
|
||||
}
|
||||
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
|
||||
VALUES (1, 0, ?)
|
||||
ON CONFLICT(id) DO NOTHING;`,
|
||||
time.Now().UTC().Format(timeFormat),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
|
||||
now := time.Now().UTC()
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"INSERT INTO users (username, created_at) VALUES (?, ?);",
|
||||
username,
|
||||
now.Format(timeFormat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &User{ID: id, Username: username, CreatedAt: now}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
|
||||
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
|
||||
var u User
|
||||
var createdAt string
|
||||
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.CreatedAt = parsed
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []User
|
||||
for rows.Next() {
|
||||
var u User
|
||||
var createdAt string
|
||||
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.CreatedAt = parsed
|
||||
out = append(out, u)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
|
||||
if maxUses < 0 {
|
||||
return nil, fmt.Errorf("maxUses cannot be negative")
|
||||
}
|
||||
tokenValue, err := generateToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
|
||||
VALUES (?, ?, ?, ?, 0, ?, 0);`,
|
||||
userID,
|
||||
tokenValue,
|
||||
now.Format(timeFormat),
|
||||
expiresAt.UTC().Format(timeFormat),
|
||||
maxUses,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
|
||||
row := s.db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
||||
FROM tokens
|
||||
WHERE token = ?;`,
|
||||
tokenValue,
|
||||
)
|
||||
return scanToken(row)
|
||||
}
|
||||
|
||||
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
row := tx.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
|
||||
FROM tokens
|
||||
WHERE token = ?;`,
|
||||
tokenValue,
|
||||
)
|
||||
token, err := scanToken(row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if token.Disabled {
|
||||
return nil, ErrTokenDisabled
|
||||
}
|
||||
if now.After(token.ExpiresAt) {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
|
||||
return nil, ErrTokenExhausted
|
||||
}
|
||||
|
||||
res, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE tokens
|
||||
SET used_count = used_count + 1
|
||||
WHERE id = ?
|
||||
AND (max_uses = 0 OR used_count < max_uses);`,
|
||||
token.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if affected == 0 {
|
||||
return nil, ErrTokenExhausted
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.UsedCount++
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
|
||||
now := time.Now().UTC().Format(timeFormat)
|
||||
res, err := s.db.ExecContext(
|
||||
ctx,
|
||||
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
|
||||
now,
|
||||
tokenValue,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
|
||||
FROM tokens t
|
||||
INNER JOIN users u ON u.id = t.user_id
|
||||
ORDER BY t.id ASC;`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []TokenWithUser
|
||||
for rows.Next() {
|
||||
var t TokenWithUser
|
||||
var createdAt string
|
||||
var expiresAt string
|
||||
var disabledAt sql.NullString
|
||||
var disabledInt int
|
||||
if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ct, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
et, err := time.Parse(timeFormat, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt = ct
|
||||
t.ExpiresAt = et
|
||||
t.Disabled = disabledInt == 1
|
||||
if disabledAt.Valid {
|
||||
dt, err := time.Parse(timeFormat, disabledAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.DisabledAt = &dt
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
|
||||
if entry.OccurredAt.IsZero() {
|
||||
entry.OccurredAt = time.Now().UTC()
|
||||
}
|
||||
if entry.RequestIP == "" {
|
||||
entry.RequestIP = "unknown"
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
successInt := 0
|
||||
if entry.Success {
|
||||
successInt = 1
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
|
||||
entry.OccurredAt.UTC().Format(timeFormat),
|
||||
entry.RequestIP,
|
||||
entry.UserID,
|
||||
entry.TokenID,
|
||||
entry.OriginalURL,
|
||||
entry.HTTPStatus,
|
||||
successInt,
|
||||
entry.ErrorReason,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if increment {
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE site_stats
|
||||
SET total_accelerated_count = total_accelerated_count + 1,
|
||||
updated_at = ?
|
||||
WHERE id = 1;`,
|
||||
time.Now().UTC().Format(timeFormat),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func generateToken() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
|
||||
var t Token
|
||||
var createdAt string
|
||||
var expiresAt string
|
||||
var disabledAt sql.NullString
|
||||
var disabledInt int
|
||||
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
ct, err := time.Parse(timeFormat, createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
et, err := time.Parse(timeFormat, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt = ct
|
||||
t.ExpiresAt = et
|
||||
t.Disabled = disabledInt == 1
|
||||
if disabledAt.Valid {
|
||||
dt, err := time.Parse(timeFormat, disabledAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.DisabledAt = &dt
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func isDuplicateColumnErr(err error) bool {
|
||||
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
|
||||
}
|
||||
@@ -1,13 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
||||
)
|
||||
|
||||
// ========================
|
||||
@@ -15,8 +22,6 @@ import (
|
||||
// ========================
|
||||
const (
|
||||
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
|
||||
host = "127.0.0.1"
|
||||
port = "2005"
|
||||
)
|
||||
|
||||
// 先生效白名单再匹配黑名单
|
||||
@@ -29,7 +34,6 @@ var (
|
||||
blackListStr = ``
|
||||
)
|
||||
|
||||
|
||||
// ========================
|
||||
// 全局变量与预编译正则
|
||||
// ========================
|
||||
@@ -37,7 +41,6 @@ var (
|
||||
whiteList [][]string
|
||||
blackList [][]string
|
||||
|
||||
|
||||
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
|
||||
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
|
||||
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
|
||||
@@ -45,6 +48,8 @@ var (
|
||||
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
|
||||
|
||||
httpClient *http.Client
|
||||
store *db.Store
|
||||
auditor *audit.Logger
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -61,8 +66,22 @@ func init() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
http.HandleFunc("/", routeHandler)
|
||||
addr := fmt.Sprintf("%s:%s", host, port)
|
||||
cfg := config.Load()
|
||||
|
||||
var err error
|
||||
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
auditor = audit.NewLogger(store)
|
||||
|
||||
base := http.HandlerFunc(routeHandler)
|
||||
protected := auth.Middleware(store, auditAuthFailure, base)
|
||||
http.Handle("/", protected)
|
||||
|
||||
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
|
||||
log.Printf("服务器启动成功,正在监听 %s", addr)
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -70,8 +89,40 @@ func main() {
|
||||
}
|
||||
|
||||
func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
authInfo, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
u := strings.TrimPrefix(r.URL.RequestURI(), "/")
|
||||
recorder := newStatusRecorder(w)
|
||||
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
|
||||
normalizedURL := originalInput
|
||||
success := false
|
||||
errorReason := ""
|
||||
|
||||
defer func() {
|
||||
statusCode := recorder.StatusCode()
|
||||
if statusCode >= 200 && statusCode < 400 {
|
||||
success = true
|
||||
}
|
||||
if err := auditor.Log(r.Context(), audit.Entry{
|
||||
RequestIP: clientIP(r),
|
||||
UserID: authInfo.UserID,
|
||||
HasUser: true,
|
||||
TokenID: authInfo.TokenID,
|
||||
HasToken: true,
|
||||
OriginalURL: normalizedURL,
|
||||
HTTPStatus: statusCode,
|
||||
Success: success,
|
||||
ErrorReason: errorReason,
|
||||
CountAsSuccess: success,
|
||||
}); err != nil {
|
||||
log.Printf("audit log failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
u := originalInput
|
||||
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
|
||||
u = "https://" + strings.TrimPrefix(u, "https:/")
|
||||
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
|
||||
@@ -84,7 +135,8 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
m := checkURL(u)
|
||||
if m == nil {
|
||||
http.Error(w, "Invalid input.", http.StatusForbidden)
|
||||
errorReason = "invalid_input"
|
||||
http.Error(recorder, "Invalid input.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -97,14 +149,16 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
http.Error(w, "Forbidden by white list.", http.StatusForbidden)
|
||||
errorReason = "forbidden_by_white_list"
|
||||
http.Error(recorder, "Forbidden by white list.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, i := range blackList {
|
||||
if matchRule(m, i) {
|
||||
http.Error(w, "Forbidden by black list.", http.StatusForbidden)
|
||||
errorReason = "forbidden_by_black_list"
|
||||
http.Error(recorder, "Forbidden by black list.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -119,7 +173,8 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
u = parsedURL.String()
|
||||
}
|
||||
proxy(w, r, u)
|
||||
normalizedURL = u
|
||||
proxy(recorder, r, u)
|
||||
}
|
||||
|
||||
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
|
||||
@@ -180,6 +235,66 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
|
||||
entry := audit.Entry{
|
||||
RequestIP: clientIP(r),
|
||||
OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"),
|
||||
HTTPStatus: statusCode,
|
||||
Success: false,
|
||||
ErrorReason: reason,
|
||||
CountAsSuccess: false,
|
||||
}
|
||||
if userID > 0 {
|
||||
entry.UserID = userID
|
||||
entry.HasUser = true
|
||||
}
|
||||
if tokenID > 0 {
|
||||
entry.TokenID = tokenID
|
||||
entry.HasToken = true
|
||||
}
|
||||
if err := auditor.Log(context.Background(), entry); err != nil {
|
||||
log.Printf("auth failure audit log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
|
||||
return xrip
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
if strings.TrimSpace(r.RemoteAddr) != "" {
|
||||
return strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
|
||||
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
}
|
||||
|
||||
func (r *statusRecorder) WriteHeader(statusCode int) {
|
||||
r.statusCode = statusCode
|
||||
r.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (r *statusRecorder) StatusCode() int {
|
||||
return r.statusCode
|
||||
}
|
||||
|
||||
func matchRule(m []string, i []string) bool {
|
||||
// m 通常为 [author, repo] 或 [author]
|
||||
if len(i) == 1 {
|
||||
@@ -210,7 +325,6 @@ func checkURL(u string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func parseList(s string) [][]string {
|
||||
var res [][]string
|
||||
lines := strings.Split(s, "\n")
|
||||
|
||||
Reference in New Issue
Block a user