diff --git a/.gitignore b/.gitignore index 5b90e79..a687bb4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,7 @@ go.work.sum # env file .env +# database file +*.db +*.db-shm +*.db-wal \ No newline at end of file diff --git a/README.md b/README.md index 9c6b04c..bc238f3 100644 --- a/README.md +++ b/README.md @@ -1,110 +1,140 @@ - # Hugs-Proxy -> 本项目目前支持Github,后面可能会增加HuggingFace...... +一个轻量的 GitHub 资源加速反向代理工具。当前版本在原有代理能力上新增了 SQLite 持久化鉴权和访问审计。 -一个**轻量的 GitHub 资源加速/反代下载**小工具:把 GitHub 的下载链接(Release、Archive、Raw、Gist、部分 git/info 相关请求)通过本地 HTTP 服务转发,从而在某些网络环境下提升可用性。 +## 功能概览 -本项目默认只监听 `127.0.0.1`,更改为对外监听前请务必阅读“安全提示”。 - -## 特性 - -- 支持代理的链接类型(不匹配会返回 403): - - `github.com///releases/...` - - `github.com///archive/...` - - `github.com///blob/...`(会自动改写为 `.../raw/...` 进行下载) - - `github.com///raw/...` - - `github.com///info/...`、`github.com///git-...` - - `raw.githubusercontent.com///...` - - `gist.github.com/...`、`gist.githubusercontent.com/...` -- 白名单/黑名单(先白名单,后黑名单) -- 大文件保护:响应体 `Content-Length` 超过 1GB 时直接 302 重定向到源站,避免本机带宽/内存压力 -- 处理上游重定向:对可识别的 GitHub 下载链接会改写 `Location`,让跳转继续走本代理 +- 支持 GitHub 资源代理(不匹配返回 403): + - github.com///releases/... + - github.com///archive/... + - github.com///blob/...(自动改写为 /raw/) + - github.com///raw/... + - github.com///info/... 与 github.com///git-... + - raw.githubusercontent.com///... + - gist.github.com/... 与 gist.githubusercontent.com/... +- Bearer Token 鉴权(Authorization 头) +- SQLite 持久化:用户、token、访问明细、全站累计计数 +- 白名单/黑名单(先白名单再黑名单) +- 大文件保护(>1GB 直接 302 回源) +- 上游重定向 Location 改写(可识别 GitHub URL 时继续经由代理) ## 快速开始 -### 1) 运行 - -在仓库根目录执行: +### 1. 启动服务 ```bash go run . ``` -默认监听:`127.0.0.1:2005` +默认监听地址:127.0.0.1:2005 -### 2) 使用方式 +### 2. 初始化用户与 token -访问路径就是你要代理的目标 URL(去掉前面的 `/`): - -```text -http://127.0.0.1:2005/<目标URL> +```bash +go run ./cmd/tokenctl create-user --name alice +go run ./cmd/tokenctl issue-token --user alice --expires 30d --uses 1 ``` -目标 URL 可以写全(推荐),也可以省略 scheme(会自动补 `https://`)。示例: +说明: -- Release 文件: +- --expires 支持 7d/30d 这类天数写法 +- 也支持 RFC3339 绝对时间(例如 2026-12-31T23:59:59Z) +- 不传 --expires 时使用默认有效期(环境变量可配置) +- --uses 指定可用次数(例如 1 为一次性 token,5 为可用 5 次) -```text -http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/app-darwin-amd64.zip +### 3. 携带 Bearer Token 请求 + +```bash +curl -L \ + -H "Authorization: Bearer " \ + "http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/file.zip" ``` -- 仓库归档(archive): +无 token:返回 401。 + +token 无效、已禁用、已过期或使用次数耗尽:返回 403。 + +## tokenctl 用法 ```text -http://127.0.0.1:2005/https://github.com/OWNER/REPO/archive/refs/heads/main.zip +tokenctl create-user --name [--db ] +tokenctl issue-token --user [--expires <7d|30d|RFC3339>] [--uses ] [--db ] +tokenctl disable-token --token [--db ] +tokenctl list-users [--db ] +tokenctl list-tokens [--db ] ``` -- Raw 文件(也可以给 `blob`,服务端会自动替换为 `raw`): +示例: -```text -http://127.0.0.1:2005/https://github.com/OWNER/REPO/blob/main/README.md +```bash +go run ./cmd/tokenctl list-users +go run ./cmd/tokenctl list-tokens +go run ./cmd/tokenctl issue-token --user alice --uses 5 +go run ./cmd/tokenctl disable-token --token ``` -- raw.githubusercontent.com: +## 环境变量配置 -```text -http://127.0.0.1:2005/https://raw.githubusercontent.com/OWNER/REPO/main/README.md -``` +服务与 CLI 共用同一套配置: -如果你传入的链接不属于上面支持的 GitHub 资源格式,会返回:`403 Invalid input.` +- HUGS_PROXY_HOST:监听地址,默认 127.0.0.1 +- HUGS_PROXY_PORT:监听端口,默认 2005 +- HUGS_PROXY_DB_PATH:SQLite 文件路径,默认 ./hugs_proxy.db +- HUGS_PROXY_DEFAULT_TOKEN_TTL:默认 token 有效期,默认 30d +- HUGS_PROXY_DB_BUSY_TIMEOUT_MS:SQLite busy_timeout,默认 5000 -## 配置 +说明: -目前所有配置都在 [main.go](main.go) 顶部的“配置区域”,修改后重新运行即可: +- 代理大文件阈值 sizeLimit 仍在 main.go 内部常量(默认 1GB)。 +- 白名单与黑名单规则仍在 main.go 的 whiteListStr 与 blackListStr。 -- `host`:监听地址(默认 `127.0.0.1`) -- `port`:监听端口(默认 `2005`) -- `sizeLimit`:大文件阈值(默认 `1GB`) -- `whiteListStr`:白名单(多行字符串,每行一条规则) -- `blackListStr`:黑名单(多行字符串,每行一条规则) +## 数据库说明 -### 白名单/黑名单规则 +程序启动会自动执行幂等建表迁移(CREATE TABLE IF NOT EXISTS),并设置: -每行一条,支持三种写法: +- PRAGMA journal_mode = WAL +- PRAGMA busy_timeout = 5000(或你设置的环境变量值) -- `user1`:匹配/封禁 `user1` 下的所有仓库 -- `user1/repo1`:匹配/封禁 `user1/repo1` -- `*/repo1`:匹配/封禁所有名为 `repo1` 的仓库 +表结构语义: + +- users:用户主体 +- tokens:token 主体(含过期、禁用、可用次数上限与已使用次数) +- token_usage:访问审计明细(时间/IP/用户/token/原始 URL/HTTP 状态/成功标志/失败原因) +- site_stats:总体累计计数(total_accelerated_count) + +## 白名单/黑名单规则 + +每行一条,支持: + +- user1:匹配 user1 下所有仓库 +- user1/repo1:匹配 user1/repo1 +- \*/repo1:匹配所有 repo1 判定顺序: -1. **白名单优先生效**:如果白名单非空,则必须至少命中一条白名单规则,否则直接拒绝(403)。 -2. **再匹配黑名单**:命中任意黑名单规则则拒绝(403)。 +1. 白名单优先(若白名单非空,必须命中) +2. 再匹配黑名单(命中即拒绝) +## 验证步骤 +1. 构建:go build ./... +2. 无 token 请求:应返回 401 +3. 签发 token 后请求合法 URL:应成功且代理行为与旧版本一致 +4. 禁用或过期 token 请求:应返回 403,并记录失败 usage +5. 多次请求后检查 site_stats.total_accelerated_count 累计值 +6. 检查 token_usage 记录字段完整性 -## 常见问题 +## 排错 -### 为什么有时会直接跳转到 GitHub? - -当上游响应 `Content-Length` 大于 `sizeLimit`(默认 1GB)时,程序会返回 302 重定向到目标 URL,而不是继续转发大文件内容。 - -### 支持 Git Clone / Git LFS 吗? - -本项目主要面向“下载/获取资源”。它只放行并代理部分与 GitHub 资源下载相关的 URL 形态;不保证覆盖完整的 Git 协议或 Git LFS 场景。 +- 启动报数据库错误: + - 检查 HUGS_PROXY_DB_PATH 是否可写 + - 检查目录权限 +- 总是 401: + - 确认请求头格式为 Authorization: Bearer +- 总是 403: + - 检查 token 是否过期或禁用 + - 检查 URL 是否命中支持规则以及白黑名单规则 ## License -详见 License 文件 - +详见仓库中的 LICENSE 文件。 diff --git a/cmd/tokenctl/main.go b/cmd/tokenctl/main.go new file mode 100644 index 0000000..e0055db --- /dev/null +++ b/cmd/tokenctl/main.go @@ -0,0 +1,198 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "os" + "strings" + "time" + + "gitea.gangary.cn/gary/Hugs-Proxy/internal/config" + "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" +) + +func main() { + cfg := config.Load() + if len(os.Args) < 2 { + printUsage() + os.Exit(1) + } + + sub := os.Args[1] + switch sub { + case "create-user": + handleCreateUser(cfg, os.Args[2:]) + case "issue-token": + handleIssueToken(cfg, os.Args[2:]) + case "disable-token": + handleDisableToken(cfg, os.Args[2:]) + case "list-users": + handleListUsers(cfg, os.Args[2:]) + case "list-tokens": + handleListTokens(cfg, os.Args[2:]) + default: + printUsage() + os.Exit(1) + } +} + +func handleCreateUser(cfg config.Config, args []string) { + fs := flag.NewFlagSet("create-user", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + name := fs.String("name", "", "username") + _ = fs.Parse(args) + + if strings.TrimSpace(*name) == "" { + log.Fatal("--name is required") + } + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + u, err := store.CreateUser(context.Background(), strings.TrimSpace(*name)) + if err != nil { + log.Fatal(err) + } + fmt.Printf("created user: id=%d username=%s created_at=%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339)) +} + +func handleIssueToken(cfg config.Config, args []string) { + fs := flag.NewFlagSet("issue-token", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + username := fs.String("user", "", "username") + expires := fs.String("expires", "", "expires duration like 7d/30d or RFC3339 time") + uses := fs.Int("uses", 1, "max usable times for this token, must be > 0") + _ = fs.Parse(args) + + if strings.TrimSpace(*username) == "" { + log.Fatal("--user is required") + } + if *uses <= 0 { + log.Fatal("--uses must be > 0") + } + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + ctx := context.Background() + u, err := store.GetUserByName(ctx, strings.TrimSpace(*username)) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + log.Fatalf("user %q not found", *username) + } + log.Fatal(err) + } + + expiresAt, err := parseExpiresAt(*expires, cfg.DefaultTokenTTL) + if err != nil { + log.Fatal(err) + } + + tok, err := store.IssueToken(ctx, u.ID, expiresAt, *uses) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("issued token: id=%d user=%s token=%s expires_at=%s max_uses=%d\n", tok.ID, u.Username, tok.Token, tok.ExpiresAt.Format(time.RFC3339), tok.MaxUses) +} + +func handleDisableToken(cfg config.Config, args []string) { + fs := flag.NewFlagSet("disable-token", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + token := fs.String("token", "", "token value") + _ = fs.Parse(args) + + if strings.TrimSpace(*token) == "" { + log.Fatal("--token is required") + } + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + err := store.DisableToken(context.Background(), strings.TrimSpace(*token)) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + log.Fatalf("token not found") + } + log.Fatal(err) + } + fmt.Println("token disabled") +} + +func handleListUsers(cfg config.Config, args []string) { + fs := flag.NewFlagSet("list-users", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + _ = fs.Parse(args) + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + users, err := store.ListUsers(context.Background()) + if err != nil { + log.Fatal(err) + } + + fmt.Println("id\tusername\tcreated_at") + for _, u := range users { + fmt.Printf("%d\t%s\t%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339)) + } +} + +func handleListTokens(cfg config.Config, args []string) { + fs := flag.NewFlagSet("list-tokens", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + _ = fs.Parse(args) + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + tokens, err := store.ListTokens(context.Background()) + if err != nil { + log.Fatal(err) + } + + fmt.Println("id\tuser\tdisabled\tused/max\texpires_at\tcreated_at\ttoken") + for _, t := range tokens { + maxUsesDisplay := fmt.Sprintf("%d", t.MaxUses) + if t.MaxUses == 0 { + maxUsesDisplay = "unlimited" + } + fmt.Printf("%d\t%s\t%t\t%d/%s\t%s\t%s\t%s\n", t.ID, t.Username, t.Disabled, t.UsedCount, maxUsesDisplay, t.ExpiresAt.Format(time.RFC3339), t.CreatedAt.Format(time.RFC3339), t.Token) + } +} + +func parseExpiresAt(raw string, defaultTTL time.Duration) (time.Time, error) { + now := time.Now().UTC() + clean := strings.TrimSpace(raw) + if clean == "" { + return now.Add(defaultTTL), nil + } + if t, err := time.Parse(time.RFC3339, clean); err == nil { + return t.UTC(), nil + } + d, err := config.ParseExpiryDuration(clean) + if err != nil { + return time.Time{}, fmt.Errorf("invalid --expires value: %w", err) + } + return now.Add(d), nil +} + +func openStore(dbPath string, busyTimeoutMS int) *db.Store { + store, err := db.NewStore(strings.TrimSpace(dbPath), busyTimeoutMS) + if err != nil { + log.Fatal(err) + } + return store +} + +func printUsage() { + fmt.Println(`tokenctl usage: + tokenctl create-user --name [--db ] + tokenctl issue-token --user [--expires <7d|30d|RFC3339>] [--uses ] [--db ] + tokenctl disable-token --token [--db ] + tokenctl list-users [--db ] + tokenctl list-tokens [--db ]`) +} diff --git a/go.mod b/go.mod index 5201e7a..70aecc3 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,17 @@ module gitea.gangary.cn/gary/Hugs-Proxy go 1.26.2 + +require modernc.org/sqlite v1.49.1 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/sys v0.42.0 // indirect + modernc.org/libc v1.72.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..11fc64a --- /dev/null +++ b/go.sum @@ -0,0 +1,51 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U= +modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8= +modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU= +modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c= +modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U= +modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/audit/logger.go b/internal/audit/logger.go new file mode 100644 index 0000000..e8ceea7 --- /dev/null +++ b/internal/audit/logger.go @@ -0,0 +1,51 @@ +package audit + +import ( + "context" + "database/sql" + "strings" + "time" + + "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" +) + +type Entry struct { + RequestIP string + UserID int64 + HasUser bool + TokenID int64 + HasToken bool + OriginalURL string + HTTPStatus int + Success bool + ErrorReason string + CountAsSuccess bool + OccurredAt time.Time +} + +type Logger struct { + store *db.Store +} + +func NewLogger(store *db.Store) *Logger { + return &Logger{store: store} +} + +func (l *Logger) Log(ctx context.Context, entry Entry) error { + usage := db.UsageEntry{ + RequestIP: strings.TrimSpace(entry.RequestIP), + OriginalURL: strings.TrimSpace(entry.OriginalURL), + HTTPStatus: entry.HTTPStatus, + Success: entry.Success, + ErrorReason: strings.TrimSpace(entry.ErrorReason), + OccurredAt: entry.OccurredAt, + } + if entry.HasUser { + usage.UserID = sql.NullInt64{Int64: entry.UserID, Valid: true} + } + if entry.HasToken { + usage.TokenID = sql.NullInt64{Int64: entry.TokenID, Valid: true} + } + + return l.store.RecordUsageAndMaybeIncrement(ctx, usage, entry.CountAsSuccess) +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 0000000..c47bb74 --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,90 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "strings" + + "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" +) + +type contextKey string + +const authInfoKey contextKey = "authInfo" + +type AuthInfo struct { + UserID int64 + TokenID int64 + Token string +} + +type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) + +func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenValue, ok := bearerToken(r.Header.Get("Authorization")) + if !ok { + http.Error(w, "Missing bearer token.", http.StatusUnauthorized) + if onFailure != nil { + onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0) + } + return + } + + token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue) + if err != nil { + status := http.StatusForbidden + reason := "invalid_token" + switch { + case errors.Is(err, db.ErrNotFound): + reason = "invalid_token" + case errors.Is(err, db.ErrTokenDisabled): + reason = "token_disabled" + case errors.Is(err, db.ErrTokenExpired): + reason = "token_expired" + case errors.Is(err, db.ErrTokenExhausted): + reason = "token_exhausted" + default: + status = http.StatusInternalServerError + reason = "token_lookup_error" + } + http.Error(w, http.StatusText(status), status) + if onFailure != nil { + if token != nil { + onFailure(r, tokenValue, status, reason, token.UserID, token.ID) + } else { + onFailure(r, tokenValue, status, reason, 0, 0) + } + } + return + } + + authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue} + ctx := context.WithValue(r.Context(), authInfoKey, authInfo) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func FromContext(ctx context.Context) (AuthInfo, bool) { + v := ctx.Value(authInfoKey) + if v == nil { + return AuthInfo{}, false + } + info, ok := v.(AuthInfo) + return info, ok +} + +func bearerToken(authorizationHeader string) (string, bool) { + parts := strings.Fields(strings.TrimSpace(authorizationHeader)) + if len(parts) != 2 { + return "", false + } + if !strings.EqualFold(parts[0], "Bearer") { + return "", false + } + if parts[1] == "" { + return "", false + } + return parts[1], true +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..d9a9045 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,78 @@ +package config + +import ( + "os" + "strconv" + "strings" + "time" +) + +const ( + defaultHost = "127.0.0.1" + defaultPort = "2005" + defaultDBPath = "./hugs_proxy.db" + defaultTokenTTLString = "30d" + defaultBusyTimeoutMilli = 5000 +) + +type Config struct { + Host string + Port string + DBPath string + DefaultTokenTTL time.Duration + BusyTimeoutMS int +} + +func Load() Config { + host := getenvOrDefault("HUGS_PROXY_HOST", defaultHost) + port := getenvOrDefault("HUGS_PROXY_PORT", defaultPort) + dbPath := getenvOrDefault("HUGS_PROXY_DB_PATH", defaultDBPath) + ttlRaw := getenvOrDefault("HUGS_PROXY_DEFAULT_TOKEN_TTL", defaultTokenTTLString) + busyTimeout := getenvIntOrDefault("HUGS_PROXY_DB_BUSY_TIMEOUT_MS", defaultBusyTimeoutMilli) + + ttl, err := ParseExpiryDuration(ttlRaw) + if err != nil { + ttl, _ = ParseExpiryDuration(defaultTokenTTLString) + } + + return Config{ + Host: host, + Port: port, + DBPath: dbPath, + DefaultTokenTTL: ttl, + BusyTimeoutMS: busyTimeout, + } +} + +func ParseExpiryDuration(s string) (time.Duration, error) { + s = strings.TrimSpace(strings.ToLower(s)) + if strings.HasSuffix(s, "d") { + daysPart := strings.TrimSuffix(s, "d") + days, err := strconv.Atoi(daysPart) + if err != nil { + return 0, err + } + return time.Duration(days) * 24 * time.Hour, nil + } + return time.ParseDuration(s) +} + +func getenvOrDefault(key string, fallback string) string { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return fallback + } + return v +} + +func getenvIntOrDefault(key string, fallback int) int { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return fallback + } + n, err := strconv.Atoi(v) + if err != nil { + return fallback + } + return n +} diff --git a/internal/db/store.go b/internal/db/store.go new file mode 100644 index 0000000..2582513 --- /dev/null +++ b/internal/db/store.go @@ -0,0 +1,495 @@ +package db + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + _ "modernc.org/sqlite" +) + +const timeFormat = time.RFC3339Nano + +var ErrNotFound = errors.New("not found") + +var ( + ErrTokenDisabled = errors.New("token disabled") + ErrTokenExpired = errors.New("token expired") + ErrTokenExhausted = errors.New("token exhausted") +) + +type Store struct { + db *sql.DB +} + +type User struct { + ID int64 + Username string + CreatedAt time.Time +} + +type Token struct { + ID int64 + UserID int64 + Token string + CreatedAt time.Time + ExpiresAt time.Time + Disabled bool + DisabledAt *time.Time + MaxUses int + UsedCount int +} + +type TokenWithUser struct { + Token + Username string +} + +type UsageEntry struct { + RequestIP string + UserID sql.NullInt64 + TokenID sql.NullInt64 + OriginalURL string + HTTPStatus int + Success bool + ErrorReason string + OccurredAt time.Time +} + +func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + + if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { + db.Close() + return nil, err + } + if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil { + db.Close() + return nil, err + } + if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil { + db.Close() + return nil, err + } + + s := &Store{db: db} + if err := s.migrate(context.Background()); err != nil { + db.Close() + return nil, err + } + return s, nil +} + +func (s *Store) Close() error { + return s.db.Close() +} + +func (s *Store) migrate(ctx context.Context) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL + );`, + `CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + disabled INTEGER NOT NULL DEFAULT 0, + max_uses INTEGER NOT NULL DEFAULT 0, + used_count INTEGER NOT NULL DEFAULT 0, + disabled_at TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) + );`, + `CREATE TABLE IF NOT EXISTS site_stats ( + id INTEGER PRIMARY KEY, + total_accelerated_count INTEGER NOT NULL DEFAULT 0, + updated_at TEXT NOT NULL + );`, + `CREATE TABLE IF NOT EXISTS token_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TEXT NOT NULL, + request_ip TEXT NOT NULL, + user_id INTEGER, + token_id INTEGER, + original_url TEXT NOT NULL, + http_status INTEGER NOT NULL, + success INTEGER NOT NULL, + error_reason TEXT NOT NULL DEFAULT '', + FOREIGN KEY(user_id) REFERENCES users(id), + FOREIGN KEY(token_id) REFERENCES tokens(id) + );`, + `CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`, + `CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`, + `CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`, + `CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`, + `CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`, + } + + for _, stmt := range stmts { + if _, err := s.db.ExecContext(ctx, stmt); err != nil { + return err + } + } + + if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) { + return err + } + if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) { + return err + } + + _, err := s.db.ExecContext( + ctx, + `INSERT INTO site_stats (id, total_accelerated_count, updated_at) + VALUES (1, 0, ?) + ON CONFLICT(id) DO NOTHING;`, + time.Now().UTC().Format(timeFormat), + ) + return err +} + +func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) { + now := time.Now().UTC() + res, err := s.db.ExecContext( + ctx, + "INSERT INTO users (username, created_at) VALUES (?, ?);", + username, + now.Format(timeFormat), + ) + if err != nil { + return nil, err + } + id, err := res.LastInsertId() + if err != nil { + return nil, err + } + return &User{ID: id, Username: username, CreatedAt: now}, nil +} + +func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) { + row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username) + var u User + var createdAt string + if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + parsed, err := time.Parse(timeFormat, createdAt) + if err != nil { + return nil, err + } + u.CreatedAt = parsed + return &u, nil +} + +func (s *Store) ListUsers(ctx context.Context) ([]User, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;") + if err != nil { + return nil, err + } + defer rows.Close() + + var out []User + for rows.Next() { + var u User + var createdAt string + if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil { + return nil, err + } + parsed, err := time.Parse(timeFormat, createdAt) + if err != nil { + return nil, err + } + u.CreatedAt = parsed + out = append(out, u) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) { + if maxUses < 0 { + return nil, fmt.Errorf("maxUses cannot be negative") + } + tokenValue, err := generateToken() + if err != nil { + return nil, err + } + now := time.Now().UTC() + res, err := s.db.ExecContext( + ctx, + `INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count) + VALUES (?, ?, ?, ?, 0, ?, 0);`, + userID, + tokenValue, + now.Format(timeFormat), + expiresAt.UTC().Format(timeFormat), + maxUses, + ) + if err != nil { + return nil, err + } + id, err := res.LastInsertId() + if err != nil { + return nil, err + } + return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil +} + +func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) { + row := s.db.QueryRowContext( + ctx, + `SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count + FROM tokens + WHERE token = ?;`, + tokenValue, + ) + return scanToken(row) +} + +func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + row := tx.QueryRowContext( + ctx, + `SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count + FROM tokens + WHERE token = ?;`, + tokenValue, + ) + token, err := scanToken(row) + if err != nil { + return nil, err + } + + now := time.Now().UTC() + if token.Disabled { + return nil, ErrTokenDisabled + } + if now.After(token.ExpiresAt) { + return nil, ErrTokenExpired + } + if token.MaxUses > 0 && token.UsedCount >= token.MaxUses { + return nil, ErrTokenExhausted + } + + res, err := tx.ExecContext( + ctx, + `UPDATE tokens + SET used_count = used_count + 1 + WHERE id = ? + AND (max_uses = 0 OR used_count < max_uses);`, + token.ID, + ) + if err != nil { + return nil, err + } + affected, err := res.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, ErrTokenExhausted + } + + if err = tx.Commit(); err != nil { + return nil, err + } + token.UsedCount++ + return token, nil +} + +func (s *Store) DisableToken(ctx context.Context, tokenValue string) error { + now := time.Now().UTC().Format(timeFormat) + res, err := s.db.ExecContext( + ctx, + "UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;", + now, + tokenValue, + ) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return ErrNotFound + } + return nil +} + +func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) { + rows, err := s.db.QueryContext( + ctx, + `SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username + FROM tokens t + INNER JOIN users u ON u.id = t.user_id + ORDER BY t.id ASC;`, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []TokenWithUser + for rows.Next() { + var t TokenWithUser + var createdAt string + var expiresAt string + var disabledAt sql.NullString + var disabledInt int + if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil { + return nil, err + } + ct, err := time.Parse(timeFormat, createdAt) + if err != nil { + return nil, err + } + et, err := time.Parse(timeFormat, expiresAt) + if err != nil { + return nil, err + } + t.CreatedAt = ct + t.ExpiresAt = et + t.Disabled = disabledInt == 1 + if disabledAt.Valid { + dt, err := time.Parse(timeFormat, disabledAt.String) + if err != nil { + return nil, err + } + t.DisabledAt = &dt + } + out = append(out, t) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error { + if entry.OccurredAt.IsZero() { + entry.OccurredAt = time.Now().UTC() + } + if entry.RequestIP == "" { + entry.RequestIP = "unknown" + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + successInt := 0 + if entry.Success { + successInt = 1 + } + + _, err = tx.ExecContext( + ctx, + `INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason) + VALUES (?, ?, ?, ?, ?, ?, ?, ?);`, + entry.OccurredAt.UTC().Format(timeFormat), + entry.RequestIP, + entry.UserID, + entry.TokenID, + entry.OriginalURL, + entry.HTTPStatus, + successInt, + entry.ErrorReason, + ) + if err != nil { + return err + } + + if increment { + _, err = tx.ExecContext( + ctx, + `UPDATE site_stats + SET total_accelerated_count = total_accelerated_count + 1, + updated_at = ? + WHERE id = 1;`, + time.Now().UTC().Format(timeFormat), + ) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func generateToken() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) { + var t Token + var createdAt string + var expiresAt string + var disabledAt sql.NullString + var disabledInt int + if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + ct, err := time.Parse(timeFormat, createdAt) + if err != nil { + return nil, err + } + et, err := time.Parse(timeFormat, expiresAt) + if err != nil { + return nil, err + } + t.CreatedAt = ct + t.ExpiresAt = et + t.Disabled = disabledInt == 1 + if disabledAt.Valid { + dt, err := time.Parse(timeFormat, disabledAt.String) + if err != nil { + return nil, err + } + t.DisabledAt = &dt + } + return &t, nil +} + +func isDuplicateColumnErr(err error) bool { + return strings.Contains(strings.ToLower(err.Error()), "duplicate column name") +} diff --git a/main.go b/main.go index 40637d9..c0f8458 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,27 @@ package main import ( + "context" "fmt" "io" "log" + "net" "net/http" "net/url" "regexp" "strings" + + "gitea.gangary.cn/gary/Hugs-Proxy/internal/audit" + "gitea.gangary.cn/gary/Hugs-Proxy/internal/auth" + "gitea.gangary.cn/gary/Hugs-Proxy/internal/config" + "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" ) // ======================== // 配置区域 // ======================== const ( - sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB - host = "127.0.0.1" - port = "2005" + sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB ) // 先生效白名单再匹配黑名单 @@ -29,7 +34,6 @@ var ( blackListStr = `` ) - // ======================== // 全局变量与预编译正则 // ======================== @@ -37,7 +41,6 @@ var ( whiteList [][]string blackList [][]string - exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:releases|archive)/.*$`) exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:blob|raw)/.*$`) exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:info|git-).*$`) @@ -45,6 +48,8 @@ var ( exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P[^/]+)/.+?/.+$`) httpClient *http.Client + store *db.Store + auditor *audit.Logger ) func init() { @@ -61,8 +66,22 @@ func init() { } func main() { - http.HandleFunc("/", routeHandler) - addr := fmt.Sprintf("%s:%s", host, port) + cfg := config.Load() + + var err error + store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS) + if err != nil { + log.Fatal(err) + } + defer store.Close() + + auditor = audit.NewLogger(store) + + base := http.HandlerFunc(routeHandler) + protected := auth.Middleware(store, auditAuthFailure, base) + http.Handle("/", protected) + + addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) log.Printf("服务器启动成功,正在监听 %s", addr) if err := http.ListenAndServe(addr, nil); err != nil { log.Fatal(err) @@ -70,8 +89,40 @@ func main() { } func routeHandler(w http.ResponseWriter, r *http.Request) { + authInfo, ok := auth.FromContext(r.Context()) + if !ok { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } - u := strings.TrimPrefix(r.URL.RequestURI(), "/") + recorder := newStatusRecorder(w) + originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/") + normalizedURL := originalInput + success := false + errorReason := "" + + defer func() { + statusCode := recorder.StatusCode() + if statusCode >= 200 && statusCode < 400 { + success = true + } + if err := auditor.Log(r.Context(), audit.Entry{ + RequestIP: clientIP(r), + UserID: authInfo.UserID, + HasUser: true, + TokenID: authInfo.TokenID, + HasToken: true, + OriginalURL: normalizedURL, + HTTPStatus: statusCode, + Success: success, + ErrorReason: errorReason, + CountAsSuccess: success, + }); err != nil { + log.Printf("audit log failed: %v", err) + } + }() + + u := originalInput if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { u = "https://" + strings.TrimPrefix(u, "https:/") } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { @@ -84,8 +135,9 @@ func routeHandler(w http.ResponseWriter, r *http.Request) { m := checkURL(u) if m == nil { - http.Error(w, "Invalid input.", http.StatusForbidden) - return + errorReason = "invalid_input" + http.Error(recorder, "Invalid input.", http.StatusForbidden) + return } if len(whiteList) > 0 { @@ -97,14 +149,16 @@ func routeHandler(w http.ResponseWriter, r *http.Request) { } } if !allowed { - http.Error(w, "Forbidden by white list.", http.StatusForbidden) - return + errorReason = "forbidden_by_white_list" + http.Error(recorder, "Forbidden by white list.", http.StatusForbidden) + return } } for _, i := range blackList { if matchRule(m, i) { - http.Error(w, "Forbidden by black list.", http.StatusForbidden) + errorReason = "forbidden_by_black_list" + http.Error(recorder, "Forbidden by black list.", http.StatusForbidden) return } } @@ -119,18 +173,19 @@ func routeHandler(w http.ResponseWriter, r *http.Request) { if err == nil { u = parsedURL.String() } - proxy(w, r, u) + normalizedURL = u + proxy(recorder, r, u) } func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { // 修正由于多重代理可能造成的 URL 格式错误 - if strings.HasPrefix(targetURL, "https:/") && ! strings.HasPrefix(targetURL, "https://") { + if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") { targetURL = "https://" + targetURL[7:] } req, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { - http.Error(w, "server error " + err.Error(), http.StatusInternalServerError) + http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } @@ -146,7 +201,7 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { resp, err := httpClient.Do(req) if err != nil { - http.Error(w, "server error " + err.Error(), http.StatusInternalServerError) + http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() @@ -180,6 +235,66 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { io.Copy(w, resp.Body) } +func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) { + entry := audit.Entry{ + RequestIP: clientIP(r), + OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"), + HTTPStatus: statusCode, + Success: false, + ErrorReason: reason, + CountAsSuccess: false, + } + if userID > 0 { + entry.UserID = userID + entry.HasUser = true + } + if tokenID > 0 { + entry.TokenID = tokenID + entry.HasToken = true + } + if err := auditor.Log(context.Background(), entry); err != nil { + log.Printf("auth failure audit log failed: %v", err) + } +} + +func clientIP(r *http.Request) string { + if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { + parts := strings.Split(xff, ",") + if len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } + } + if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" { + return xrip + } + host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) + if err == nil && host != "" { + return host + } + if strings.TrimSpace(r.RemoteAddr) != "" { + return strings.TrimSpace(r.RemoteAddr) + } + return "unknown" +} + +type statusRecorder struct { + http.ResponseWriter + statusCode int +} + +func newStatusRecorder(w http.ResponseWriter) *statusRecorder { + return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK} +} + +func (r *statusRecorder) WriteHeader(statusCode int) { + r.statusCode = statusCode + r.ResponseWriter.WriteHeader(statusCode) +} + +func (r *statusRecorder) StatusCode() int { + return r.statusCode +} + func matchRule(m []string, i []string) bool { // m 通常为 [author, repo] 或 [author] if len(i) == 1 { @@ -194,7 +309,7 @@ func matchRule(m []string, i []string) bool { } func checkURL(u string) []string { - exps := []*regexp.Regexp {exp1, exp2, exp3, exp4, exp5} + exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5} for _, exp := range exps { matches := exp.FindStringSubmatch(u) if matches != nil { @@ -206,16 +321,15 @@ func checkURL(u string) []string { } return result } - } + } return nil } - func parseList(s string) [][]string { var res [][]string lines := strings.Split(s, "\n") for _, line := range lines { - line = strings.TrimSpace(line) + line = strings.TrimSpace(line) if line != "" { parts := strings.Split(line, "/") var cleaned []string @@ -226,4 +340,4 @@ func parseList(s string) [][]string { } } return res -} \ No newline at end of file +}