初步实现鉴权

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-23 20:21:35 +08:00
parent 847ce0c6a8
commit 62e076111f
10 changed files with 1212 additions and 87 deletions
+4
View File
@@ -25,3 +25,7 @@ go.work.sum
# env file # env file
.env .env
# database file
*.db
*.db-shm
*.db-wal
+95 -65
View File
@@ -1,110 +1,140 @@
# Hugs-Proxy # Hugs-Proxy
> 本项目目前支持Github,后面可能会增加HuggingFace...... 一个轻量的 GitHub 资源加速反向代理工具。当前版本在原有代理能力上新增了 SQLite 持久化鉴权和访问审计。
一个**轻量的 GitHub 资源加速/反代下载**小工具:把 GitHub 的下载链接(Release、Archive、Raw、Gist、部分 git/info 相关请求)通过本地 HTTP 服务转发,从而在某些网络环境下提升可用性。 ## 功能概览
本项目默认只监听 `127.0.0.1`,更改为对外监听前请务必阅读“安全提示”。 - 支持 GitHub 资源代理(不匹配返回 403):
- github.com/<owner>/<repo>/releases/...
## 特性 - github.com/<owner>/<repo>/archive/...
- github.com/<owner>/<repo>/blob/...(自动改写为 /raw/
- 支持代理的链接类型(不匹配会返回 403): - github.com/<owner>/<repo>/raw/...
- `github.com/<owner>/<repo>/releases/...` - github.com/<owner>/<repo>/info/... 与 github.com/<owner>/<repo>/git-...
- `github.com/<owner>/<repo>/archive/...` - raw.githubusercontent.com/<owner>/<repo>/...
- `github.com/<owner>/<repo>/blob/...`(会自动改写为 `.../raw/...` 进行下载) - gist.github.com/... 与 gist.githubusercontent.com/...
- `github.com/<owner>/<repo>/raw/...` - Bearer Token 鉴权(Authorization 头)
- `github.com/<owner>/<repo>/info/...``github.com/<owner>/<repo>/git-...` - SQLite 持久化:用户、token、访问明细、全站累计计数
- `raw.githubusercontent.com/<owner>/<repo>/...` - 白名单/黑名单(先白名单再黑名单)
- `gist.github.com/...``gist.githubusercontent.com/...` - 大文件保护(>1GB 直接 302 回源)
- 白名单/黑名单(先白名单,后黑名单 - 上游重定向 Location 改写(可识别 GitHub URL 时继续经由代理
- 大文件保护:响应体 `Content-Length` 超过 1GB 时直接 302 重定向到源站,避免本机带宽/内存压力
- 处理上游重定向:对可识别的 GitHub 下载链接会改写 `Location`,让跳转继续走本代理
## 快速开始 ## 快速开始
### 1) 运行 ### 1. 启动服务
在仓库根目录执行:
```bash ```bash
go run . go run .
``` ```
默认监听:`127.0.0.1:2005` 默认监听地址127.0.0.1:2005
### 2) 使用方式 ### 2. 初始化用户与 token
访问路径就是你要代理的目标 URL(去掉前面的 `/`): ```bash
go run ./cmd/tokenctl create-user --name alice
```text go run ./cmd/tokenctl issue-token --user alice --expires 30d --uses 1
http://127.0.0.1:2005/<目标URL>
``` ```
目标 URL 可以写全(推荐),也可以省略 scheme(会自动补 `https://`)。示例 说明
- Release 文件: - --expires 支持 7d/30d 这类天数写法
- 也支持 RFC3339 绝对时间(例如 2026-12-31T23:59:59Z
- 不传 --expires 时使用默认有效期(环境变量可配置)
- --uses 指定可用次数(例如 1 为一次性 token,5 为可用 5 次)
```text ### 3. 携带 Bearer Token 请求
http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/app-darwin-amd64.zip
```bash
curl -L \
-H "Authorization: Bearer <TOKEN>" \
"http://127.0.0.1:2005/https://github.com/OWNER/REPO/releases/download/v1.0.0/file.zip"
``` ```
- 仓库归档(archive): 无 token:返回 401。
token 无效、已禁用、已过期或使用次数耗尽:返回 403。
## tokenctl 用法
```text ```text
http://127.0.0.1:2005/https://github.com/OWNER/REPO/archive/refs/heads/main.zip tokenctl create-user --name <username> [--db <path>]
tokenctl issue-token --user <username> [--expires <7d|30d|RFC3339>] [--uses <n>] [--db <path>]
tokenctl disable-token --token <token> [--db <path>]
tokenctl list-users [--db <path>]
tokenctl list-tokens [--db <path>]
``` ```
- Raw 文件(也可以给 `blob`,服务端会自动替换为 `raw` 示例
```text ```bash
http://127.0.0.1:2005/https://github.com/OWNER/REPO/blob/main/README.md go run ./cmd/tokenctl list-users
go run ./cmd/tokenctl list-tokens
go run ./cmd/tokenctl issue-token --user alice --uses 5
go run ./cmd/tokenctl disable-token --token <TOKEN>
``` ```
- raw.githubusercontent.com ## 环境变量配置
```text 服务与 CLI 共用同一套配置:
http://127.0.0.1:2005/https://raw.githubusercontent.com/OWNER/REPO/main/README.md
```
如果你传入的链接不属于上面支持的 GitHub 资源格式,会返回:`403 Invalid input.` - HUGS_PROXY_HOST:监听地址,默认 127.0.0.1
- HUGS_PROXY_PORT:监听端口,默认 2005
- HUGS_PROXY_DB_PATHSQLite 文件路径,默认 ./hugs_proxy.db
- HUGS_PROXY_DEFAULT_TOKEN_TTL:默认 token 有效期,默认 30d
- HUGS_PROXY_DB_BUSY_TIMEOUT_MSSQLite busy_timeout,默认 5000
## 配置 说明:
目前所有配置都在 [main.go](main.go) 顶部的“配置区域”,修改后重新运行即可: - 代理大文件阈值 sizeLimit 仍在 main.go 内部常量(默认 1GB)。
- 白名单与黑名单规则仍在 main.go 的 whiteListStr 与 blackListStr。
- `host`:监听地址(默认 `127.0.0.1` ## 数据库说明
- `port`:监听端口(默认 `2005`
- `sizeLimit`:大文件阈值(默认 `1GB`
- `whiteListStr`:白名单(多行字符串,每行一条规则)
- `blackListStr`:黑名单(多行字符串,每行一条规则)
### 白名单/黑名单规则 程序启动会自动执行幂等建表迁移(CREATE TABLE IF NOT EXISTS),并设置:
每行一条,支持三种写法: - PRAGMA journal_mode = WAL
- PRAGMA busy_timeout = 5000(或你设置的环境变量值)
- `user1`:匹配/封禁 `user1` 下的所有仓库 表结构语义:
- `user1/repo1`:匹配/封禁 `user1/repo1`
- `*/repo1`:匹配/封禁所有名为 `repo1` 的仓库 - users:用户主体
- tokens:token 主体(含过期、禁用、可用次数上限与已使用次数)
- token_usage:访问审计明细(时间/IP/用户/token/原始 URL/HTTP 状态/成功标志/失败原因)
- site_stats:总体累计计数(total_accelerated_count
## 白名单/黑名单规则
每行一条,支持:
- user1:匹配 user1 下所有仓库
- user1/repo1:匹配 user1/repo1
- \*/repo1:匹配所有 repo1
判定顺序: 判定顺序:
1. **白名单优先生效**:如果白名单非空,必须至少命中一条白名单规则,否则直接拒绝(403)。 1. 白名单优先(若白名单非空,必须命中)
2. **再匹配黑名单**:命中任意黑名单规则则拒绝(403)。 2. 再匹配黑名单(命中即拒绝)
## 验证步骤
1. 构建:go build ./...
2. 无 token 请求:应返回 401
3. 签发 token 后请求合法 URL:应成功且代理行为与旧版本一致
4. 禁用或过期 token 请求:应返回 403,并记录失败 usage
5. 多次请求后检查 site_stats.total_accelerated_count 累计值
6. 检查 token_usage 记录字段完整性
## 常见问题 ## 排错
### 为什么有时会直接跳转到 GitHub? - 启动报数据库错误:
- 检查 HUGS_PROXY_DB_PATH 是否可写
当上游响应 `Content-Length` 大于 `sizeLimit`(默认 1GB)时,程序会返回 302 重定向到目标 URL,而不是继续转发大文件内容。 - 检查目录权限
- 总是 401
### 支持 Git Clone / Git LFS 吗? - 确认请求头格式为 Authorization: Bearer <TOKEN>
- 总是 403
本项目主要面向“下载/获取资源”。它只放行并代理部分与 GitHub 资源下载相关的 URL 形态;不保证覆盖完整的 Git 协议或 Git LFS 场景。 - 检查 token 是否过期或禁用
- 检查 URL 是否命中支持规则以及白黑名单规则
## License ## License
详见 License 文件 详见仓库中的 LICENSE 文件
+198
View File
@@ -0,0 +1,198 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
func main() {
cfg := config.Load()
if len(os.Args) < 2 {
printUsage()
os.Exit(1)
}
sub := os.Args[1]
switch sub {
case "create-user":
handleCreateUser(cfg, os.Args[2:])
case "issue-token":
handleIssueToken(cfg, os.Args[2:])
case "disable-token":
handleDisableToken(cfg, os.Args[2:])
case "list-users":
handleListUsers(cfg, os.Args[2:])
case "list-tokens":
handleListTokens(cfg, os.Args[2:])
default:
printUsage()
os.Exit(1)
}
}
func handleCreateUser(cfg config.Config, args []string) {
fs := flag.NewFlagSet("create-user", flag.ExitOnError)
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
name := fs.String("name", "", "username")
_ = fs.Parse(args)
if strings.TrimSpace(*name) == "" {
log.Fatal("--name is required")
}
store := openStore(*dbPath, cfg.BusyTimeoutMS)
defer store.Close()
u, err := store.CreateUser(context.Background(), strings.TrimSpace(*name))
if err != nil {
log.Fatal(err)
}
fmt.Printf("created user: id=%d username=%s created_at=%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339))
}
func handleIssueToken(cfg config.Config, args []string) {
fs := flag.NewFlagSet("issue-token", flag.ExitOnError)
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
username := fs.String("user", "", "username")
expires := fs.String("expires", "", "expires duration like 7d/30d or RFC3339 time")
uses := fs.Int("uses", 1, "max usable times for this token, must be > 0")
_ = fs.Parse(args)
if strings.TrimSpace(*username) == "" {
log.Fatal("--user is required")
}
if *uses <= 0 {
log.Fatal("--uses must be > 0")
}
store := openStore(*dbPath, cfg.BusyTimeoutMS)
defer store.Close()
ctx := context.Background()
u, err := store.GetUserByName(ctx, strings.TrimSpace(*username))
if err != nil {
if errors.Is(err, db.ErrNotFound) {
log.Fatalf("user %q not found", *username)
}
log.Fatal(err)
}
expiresAt, err := parseExpiresAt(*expires, cfg.DefaultTokenTTL)
if err != nil {
log.Fatal(err)
}
tok, err := store.IssueToken(ctx, u.ID, expiresAt, *uses)
if err != nil {
log.Fatal(err)
}
fmt.Printf("issued token: id=%d user=%s token=%s expires_at=%s max_uses=%d\n", tok.ID, u.Username, tok.Token, tok.ExpiresAt.Format(time.RFC3339), tok.MaxUses)
}
func handleDisableToken(cfg config.Config, args []string) {
fs := flag.NewFlagSet("disable-token", flag.ExitOnError)
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
token := fs.String("token", "", "token value")
_ = fs.Parse(args)
if strings.TrimSpace(*token) == "" {
log.Fatal("--token is required")
}
store := openStore(*dbPath, cfg.BusyTimeoutMS)
defer store.Close()
err := store.DisableToken(context.Background(), strings.TrimSpace(*token))
if err != nil {
if errors.Is(err, db.ErrNotFound) {
log.Fatalf("token not found")
}
log.Fatal(err)
}
fmt.Println("token disabled")
}
func handleListUsers(cfg config.Config, args []string) {
fs := flag.NewFlagSet("list-users", flag.ExitOnError)
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
_ = fs.Parse(args)
store := openStore(*dbPath, cfg.BusyTimeoutMS)
defer store.Close()
users, err := store.ListUsers(context.Background())
if err != nil {
log.Fatal(err)
}
fmt.Println("id\tusername\tcreated_at")
for _, u := range users {
fmt.Printf("%d\t%s\t%s\n", u.ID, u.Username, u.CreatedAt.Format(time.RFC3339))
}
}
func handleListTokens(cfg config.Config, args []string) {
fs := flag.NewFlagSet("list-tokens", flag.ExitOnError)
dbPath := fs.String("db", cfg.DBPath, "SQLite DB path")
_ = fs.Parse(args)
store := openStore(*dbPath, cfg.BusyTimeoutMS)
defer store.Close()
tokens, err := store.ListTokens(context.Background())
if err != nil {
log.Fatal(err)
}
fmt.Println("id\tuser\tdisabled\tused/max\texpires_at\tcreated_at\ttoken")
for _, t := range tokens {
maxUsesDisplay := fmt.Sprintf("%d", t.MaxUses)
if t.MaxUses == 0 {
maxUsesDisplay = "unlimited"
}
fmt.Printf("%d\t%s\t%t\t%d/%s\t%s\t%s\t%s\n", t.ID, t.Username, t.Disabled, t.UsedCount, maxUsesDisplay, t.ExpiresAt.Format(time.RFC3339), t.CreatedAt.Format(time.RFC3339), t.Token)
}
}
func parseExpiresAt(raw string, defaultTTL time.Duration) (time.Time, error) {
now := time.Now().UTC()
clean := strings.TrimSpace(raw)
if clean == "" {
return now.Add(defaultTTL), nil
}
if t, err := time.Parse(time.RFC3339, clean); err == nil {
return t.UTC(), nil
}
d, err := config.ParseExpiryDuration(clean)
if err != nil {
return time.Time{}, fmt.Errorf("invalid --expires value: %w", err)
}
return now.Add(d), nil
}
func openStore(dbPath string, busyTimeoutMS int) *db.Store {
store, err := db.NewStore(strings.TrimSpace(dbPath), busyTimeoutMS)
if err != nil {
log.Fatal(err)
}
return store
}
func printUsage() {
fmt.Println(`tokenctl usage:
tokenctl create-user --name <username> [--db <path>]
tokenctl issue-token --user <username> [--expires <7d|30d|RFC3339>] [--uses <n>] [--db <path>]
tokenctl disable-token --token <token> [--db <path>]
tokenctl list-users [--db <path>]
tokenctl list-tokens [--db <path>]`)
}
+14
View File
@@ -1,3 +1,17 @@
module gitea.gangary.cn/gary/Hugs-Proxy module gitea.gangary.cn/gary/Hugs-Proxy
go 1.26.2 go 1.26.2
require modernc.org/sqlite v1.49.1
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/sys v0.42.0 // indirect
modernc.org/libc v1.72.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)
+51
View File
@@ -0,0 +1,51 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U=
modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+51
View File
@@ -0,0 +1,51 @@
package audit
import (
"context"
"database/sql"
"strings"
"time"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
type Entry struct {
RequestIP string
UserID int64
HasUser bool
TokenID int64
HasToken bool
OriginalURL string
HTTPStatus int
Success bool
ErrorReason string
CountAsSuccess bool
OccurredAt time.Time
}
type Logger struct {
store *db.Store
}
func NewLogger(store *db.Store) *Logger {
return &Logger{store: store}
}
func (l *Logger) Log(ctx context.Context, entry Entry) error {
usage := db.UsageEntry{
RequestIP: strings.TrimSpace(entry.RequestIP),
OriginalURL: strings.TrimSpace(entry.OriginalURL),
HTTPStatus: entry.HTTPStatus,
Success: entry.Success,
ErrorReason: strings.TrimSpace(entry.ErrorReason),
OccurredAt: entry.OccurredAt,
}
if entry.HasUser {
usage.UserID = sql.NullInt64{Int64: entry.UserID, Valid: true}
}
if entry.HasToken {
usage.TokenID = sql.NullInt64{Int64: entry.TokenID, Valid: true}
}
return l.store.RecordUsageAndMaybeIncrement(ctx, usage, entry.CountAsSuccess)
}
+90
View File
@@ -0,0 +1,90 @@
package auth
import (
"context"
"errors"
"net/http"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
type contextKey string
const authInfoKey contextKey = "authInfo"
type AuthInfo struct {
UserID int64
TokenID int64
Token string
}
type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64)
func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenValue, ok := bearerToken(r.Header.Get("Authorization"))
if !ok {
http.Error(w, "Missing bearer token.", http.StatusUnauthorized)
if onFailure != nil {
onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0)
}
return
}
token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
if err != nil {
status := http.StatusForbidden
reason := "invalid_token"
switch {
case errors.Is(err, db.ErrNotFound):
reason = "invalid_token"
case errors.Is(err, db.ErrTokenDisabled):
reason = "token_disabled"
case errors.Is(err, db.ErrTokenExpired):
reason = "token_expired"
case errors.Is(err, db.ErrTokenExhausted):
reason = "token_exhausted"
default:
status = http.StatusInternalServerError
reason = "token_lookup_error"
}
http.Error(w, http.StatusText(status), status)
if onFailure != nil {
if token != nil {
onFailure(r, tokenValue, status, reason, token.UserID, token.ID)
} else {
onFailure(r, tokenValue, status, reason, 0, 0)
}
}
return
}
authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue}
ctx := context.WithValue(r.Context(), authInfoKey, authInfo)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func FromContext(ctx context.Context) (AuthInfo, bool) {
v := ctx.Value(authInfoKey)
if v == nil {
return AuthInfo{}, false
}
info, ok := v.(AuthInfo)
return info, ok
}
func bearerToken(authorizationHeader string) (string, bool) {
parts := strings.Fields(strings.TrimSpace(authorizationHeader))
if len(parts) != 2 {
return "", false
}
if !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
if parts[1] == "" {
return "", false
}
return parts[1], true
}
+78
View File
@@ -0,0 +1,78 @@
package config
import (
"os"
"strconv"
"strings"
"time"
)
const (
defaultHost = "127.0.0.1"
defaultPort = "2005"
defaultDBPath = "./hugs_proxy.db"
defaultTokenTTLString = "30d"
defaultBusyTimeoutMilli = 5000
)
type Config struct {
Host string
Port string
DBPath string
DefaultTokenTTL time.Duration
BusyTimeoutMS int
}
func Load() Config {
host := getenvOrDefault("HUGS_PROXY_HOST", defaultHost)
port := getenvOrDefault("HUGS_PROXY_PORT", defaultPort)
dbPath := getenvOrDefault("HUGS_PROXY_DB_PATH", defaultDBPath)
ttlRaw := getenvOrDefault("HUGS_PROXY_DEFAULT_TOKEN_TTL", defaultTokenTTLString)
busyTimeout := getenvIntOrDefault("HUGS_PROXY_DB_BUSY_TIMEOUT_MS", defaultBusyTimeoutMilli)
ttl, err := ParseExpiryDuration(ttlRaw)
if err != nil {
ttl, _ = ParseExpiryDuration(defaultTokenTTLString)
}
return Config{
Host: host,
Port: port,
DBPath: dbPath,
DefaultTokenTTL: ttl,
BusyTimeoutMS: busyTimeout,
}
}
func ParseExpiryDuration(s string) (time.Duration, error) {
s = strings.TrimSpace(strings.ToLower(s))
if strings.HasSuffix(s, "d") {
daysPart := strings.TrimSuffix(s, "d")
days, err := strconv.Atoi(daysPart)
if err != nil {
return 0, err
}
return time.Duration(days) * 24 * time.Hour, nil
}
return time.ParseDuration(s)
}
func getenvOrDefault(key string, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}
func getenvIntOrDefault(key string, fallback int) int {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}
+495
View File
@@ -0,0 +1,495 @@
package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
_ "modernc.org/sqlite"
)
const timeFormat = time.RFC3339Nano
var ErrNotFound = errors.New("not found")
var (
ErrTokenDisabled = errors.New("token disabled")
ErrTokenExpired = errors.New("token expired")
ErrTokenExhausted = errors.New("token exhausted")
)
type Store struct {
db *sql.DB
}
type User struct {
ID int64
Username string
CreatedAt time.Time
}
type Token struct {
ID int64
UserID int64
Token string
CreatedAt time.Time
ExpiresAt time.Time
Disabled bool
DisabledAt *time.Time
MaxUses int
UsedCount int
}
type TokenWithUser struct {
Token
Username string
}
type UsageEntry struct {
RequestIP string
UserID sql.NullInt64
TokenID sql.NullInt64
OriginalURL string
HTTPStatus int
Success bool
ErrorReason string
OccurredAt time.Time
}
func NewStore(dbPath string, busyTimeoutMS int) (*Store, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d;", busyTimeoutMS)); err != nil {
db.Close()
return nil, err
}
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
db.Close()
return nil, err
}
s := &Store{db: db}
if err := s.migrate(context.Background()); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) migrate(ctx context.Context) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
disabled INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 0,
used_count INTEGER NOT NULL DEFAULT 0,
disabled_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE TABLE IF NOT EXISTS site_stats (
id INTEGER PRIMARY KEY,
total_accelerated_count INTEGER NOT NULL DEFAULT 0,
updated_at TEXT NOT NULL
);`,
`CREATE TABLE IF NOT EXISTS token_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
request_ip TEXT NOT NULL,
user_id INTEGER,
token_id INTEGER,
original_url TEXT NOT NULL,
http_status INTEGER NOT NULL,
success INTEGER NOT NULL,
error_reason TEXT NOT NULL DEFAULT '',
FOREIGN KEY(user_id) REFERENCES users(id),
FOREIGN KEY(token_id) REFERENCES tokens(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_token ON tokens(token);`,
`CREATE INDEX IF NOT EXISTS idx_tokens_user_id ON tokens(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_created_at ON token_usage(created_at);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_user_id ON token_usage(user_id);`,
`CREATE INDEX IF NOT EXISTS idx_token_usage_token_id ON token_usage(token_id);`,
}
for _, stmt := range stmts {
if _, err := s.db.ExecContext(ctx, stmt); err != nil {
return err
}
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN max_uses INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
if _, err := s.db.ExecContext(ctx, `ALTER TABLE tokens ADD COLUMN used_count INTEGER NOT NULL DEFAULT 0;`); err != nil && !isDuplicateColumnErr(err) {
return err
}
_, err := s.db.ExecContext(
ctx,
`INSERT INTO site_stats (id, total_accelerated_count, updated_at)
VALUES (1, 0, ?)
ON CONFLICT(id) DO NOTHING;`,
time.Now().UTC().Format(timeFormat),
)
return err
}
func (s *Store) CreateUser(ctx context.Context, username string) (*User, error) {
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
"INSERT INTO users (username, created_at) VALUES (?, ?);",
username,
now.Format(timeFormat),
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &User{ID: id, Username: username, CreatedAt: now}, nil
}
func (s *Store) GetUserByName(ctx context.Context, username string) (*User, error) {
row := s.db.QueryRowContext(ctx, "SELECT id, username, created_at FROM users WHERE username = ?;", username)
var u User
var createdAt string
if err := row.Scan(&u.ID, &u.Username, &createdAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
return &u, nil
}
func (s *Store) ListUsers(ctx context.Context) ([]User, error) {
rows, err := s.db.QueryContext(ctx, "SELECT id, username, created_at FROM users ORDER BY id ASC;")
if err != nil {
return nil, err
}
defer rows.Close()
var out []User
for rows.Next() {
var u User
var createdAt string
if err := rows.Scan(&u.ID, &u.Username, &createdAt); err != nil {
return nil, err
}
parsed, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
u.CreatedAt = parsed
out = append(out, u)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) IssueToken(ctx context.Context, userID int64, expiresAt time.Time, maxUses int) (*Token, error) {
if maxUses < 0 {
return nil, fmt.Errorf("maxUses cannot be negative")
}
tokenValue, err := generateToken()
if err != nil {
return nil, err
}
now := time.Now().UTC()
res, err := s.db.ExecContext(
ctx,
`INSERT INTO tokens (user_id, token, created_at, expires_at, disabled, max_uses, used_count)
VALUES (?, ?, ?, ?, 0, ?, 0);`,
userID,
tokenValue,
now.Format(timeFormat),
expiresAt.UTC().Format(timeFormat),
maxUses,
)
if err != nil {
return nil, err
}
id, err := res.LastInsertId()
if err != nil {
return nil, err
}
return &Token{ID: id, UserID: userID, Token: tokenValue, CreatedAt: now, ExpiresAt: expiresAt.UTC(), MaxUses: maxUses, UsedCount: 0}, nil
}
func (s *Store) GetToken(ctx context.Context, tokenValue string) (*Token, error) {
row := s.db.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
return scanToken(row)
}
func (s *Store) ValidateAndConsumeToken(ctx context.Context, tokenValue string) (*Token, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
row := tx.QueryRowContext(
ctx,
`SELECT id, user_id, token, created_at, expires_at, disabled, disabled_at, max_uses, used_count
FROM tokens
WHERE token = ?;`,
tokenValue,
)
token, err := scanToken(row)
if err != nil {
return nil, err
}
now := time.Now().UTC()
if token.Disabled {
return nil, ErrTokenDisabled
}
if now.After(token.ExpiresAt) {
return nil, ErrTokenExpired
}
if token.MaxUses > 0 && token.UsedCount >= token.MaxUses {
return nil, ErrTokenExhausted
}
res, err := tx.ExecContext(
ctx,
`UPDATE tokens
SET used_count = used_count + 1
WHERE id = ?
AND (max_uses = 0 OR used_count < max_uses);`,
token.ID,
)
if err != nil {
return nil, err
}
affected, err := res.RowsAffected()
if err != nil {
return nil, err
}
if affected == 0 {
return nil, ErrTokenExhausted
}
if err = tx.Commit(); err != nil {
return nil, err
}
token.UsedCount++
return token, nil
}
func (s *Store) DisableToken(ctx context.Context, tokenValue string) error {
now := time.Now().UTC().Format(timeFormat)
res, err := s.db.ExecContext(
ctx,
"UPDATE tokens SET disabled = 1, disabled_at = ? WHERE token = ?;",
now,
tokenValue,
)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) {
rows, err := s.db.QueryContext(
ctx,
`SELECT t.id, t.user_id, t.token, t.created_at, t.expires_at, t.disabled, t.disabled_at, t.max_uses, t.used_count, u.username
FROM tokens t
INNER JOIN users u ON u.id = t.user_id
ORDER BY t.id ASC;`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []TokenWithUser
for rows.Next() {
var t TokenWithUser
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil {
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
out = append(out, t)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error {
if entry.OccurredAt.IsZero() {
entry.OccurredAt = time.Now().UTC()
}
if entry.RequestIP == "" {
entry.RequestIP = "unknown"
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
successInt := 0
if entry.Success {
successInt = 1
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO token_usage (created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`,
entry.OccurredAt.UTC().Format(timeFormat),
entry.RequestIP,
entry.UserID,
entry.TokenID,
entry.OriginalURL,
entry.HTTPStatus,
successInt,
entry.ErrorReason,
)
if err != nil {
return err
}
if increment {
_, err = tx.ExecContext(
ctx,
`UPDATE site_stats
SET total_accelerated_count = total_accelerated_count + 1,
updated_at = ?
WHERE id = 1;`,
time.Now().UTC().Format(timeFormat),
)
if err != nil {
return err
}
}
return tx.Commit()
}
func generateToken() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func scanToken(scanner interface{ Scan(dest ...any) error }) (*Token, error) {
var t Token
var createdAt string
var expiresAt string
var disabledAt sql.NullString
var disabledInt int
if err := scanner.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
ct, err := time.Parse(timeFormat, createdAt)
if err != nil {
return nil, err
}
et, err := time.Parse(timeFormat, expiresAt)
if err != nil {
return nil, err
}
t.CreatedAt = ct
t.ExpiresAt = et
t.Disabled = disabledInt == 1
if disabledAt.Valid {
dt, err := time.Parse(timeFormat, disabledAt.String)
if err != nil {
return nil, err
}
t.DisabledAt = &dt
}
return &t, nil
}
func isDuplicateColumnErr(err error) bool {
return strings.Contains(strings.ToLower(err.Error()), "duplicate column name")
}
+132 -18
View File
@@ -1,22 +1,27 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
) )
// ======================== // ========================
// 配置区域 // 配置区域
// ======================== // ========================
const ( const (
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
host = "127.0.0.1"
port = "2005"
) )
// 先生效白名单再匹配黑名单 // 先生效白名单再匹配黑名单
@@ -29,7 +34,6 @@ var (
blackListStr = `` blackListStr = ``
) )
// ======================== // ========================
// 全局变量与预编译正则 // 全局变量与预编译正则
// ======================== // ========================
@@ -37,7 +41,6 @@ var (
whiteList [][]string whiteList [][]string
blackList [][]string blackList [][]string
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`) exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`) exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`) exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
@@ -45,6 +48,8 @@ var (
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`) exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
httpClient *http.Client httpClient *http.Client
store *db.Store
auditor *audit.Logger
) )
func init() { func init() {
@@ -61,8 +66,22 @@ func init() {
} }
func main() { func main() {
http.HandleFunc("/", routeHandler) cfg := config.Load()
addr := fmt.Sprintf("%s:%s", host, port)
var err error
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
if err != nil {
log.Fatal(err)
}
defer store.Close()
auditor = audit.NewLogger(store)
base := http.HandlerFunc(routeHandler)
protected := auth.Middleware(store, auditAuthFailure, base)
http.Handle("/", protected)
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
log.Printf("服务器启动成功,正在监听 %s", addr) log.Printf("服务器启动成功,正在监听 %s", addr)
if err := http.ListenAndServe(addr, nil); err != nil { if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatal(err) log.Fatal(err)
@@ -70,8 +89,40 @@ func main() {
} }
func routeHandler(w http.ResponseWriter, r *http.Request) { func routeHandler(w http.ResponseWriter, r *http.Request) {
authInfo, ok := auth.FromContext(r.Context())
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
u := strings.TrimPrefix(r.URL.RequestURI(), "/") recorder := newStatusRecorder(w)
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
normalizedURL := originalInput
success := false
errorReason := ""
defer func() {
statusCode := recorder.StatusCode()
if statusCode >= 200 && statusCode < 400 {
success = true
}
if err := auditor.Log(r.Context(), audit.Entry{
RequestIP: clientIP(r),
UserID: authInfo.UserID,
HasUser: true,
TokenID: authInfo.TokenID,
HasToken: true,
OriginalURL: normalizedURL,
HTTPStatus: statusCode,
Success: success,
ErrorReason: errorReason,
CountAsSuccess: success,
}); err != nil {
log.Printf("audit log failed: %v", err)
}
}()
u := originalInput
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
u = "https://" + strings.TrimPrefix(u, "https:/") u = "https://" + strings.TrimPrefix(u, "https:/")
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
@@ -84,7 +135,8 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
m := checkURL(u) m := checkURL(u)
if m == nil { if m == nil {
http.Error(w, "Invalid input.", http.StatusForbidden) errorReason = "invalid_input"
http.Error(recorder, "Invalid input.", http.StatusForbidden)
return return
} }
@@ -97,14 +149,16 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
if !allowed { if !allowed {
http.Error(w, "Forbidden by white list.", http.StatusForbidden) errorReason = "forbidden_by_white_list"
http.Error(recorder, "Forbidden by white list.", http.StatusForbidden)
return return
} }
} }
for _, i := range blackList { for _, i := range blackList {
if matchRule(m, i) { if matchRule(m, i) {
http.Error(w, "Forbidden by black list.", http.StatusForbidden) errorReason = "forbidden_by_black_list"
http.Error(recorder, "Forbidden by black list.", http.StatusForbidden)
return return
} }
} }
@@ -119,18 +173,19 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
if err == nil { if err == nil {
u = parsedURL.String() u = parsedURL.String()
} }
proxy(w, r, u) normalizedURL = u
proxy(recorder, r, u)
} }
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
// 修正由于多重代理可能造成的 URL 格式错误 // 修正由于多重代理可能造成的 URL 格式错误
if strings.HasPrefix(targetURL, "https:/") && ! strings.HasPrefix(targetURL, "https://") { if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") {
targetURL = "https://" + targetURL[7:] targetURL = "https://" + targetURL[7:]
} }
req, err := http.NewRequest(r.Method, targetURL, r.Body) req, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil { if err != nil {
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError) http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return return
} }
@@ -146,7 +201,7 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError) http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -180,6 +235,66 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
io.Copy(w, resp.Body) io.Copy(w, resp.Body)
} }
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
entry := audit.Entry{
RequestIP: clientIP(r),
OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"),
HTTPStatus: statusCode,
Success: false,
ErrorReason: reason,
CountAsSuccess: false,
}
if userID > 0 {
entry.UserID = userID
entry.HasUser = true
}
if tokenID > 0 {
entry.TokenID = tokenID
entry.HasToken = true
}
if err := auditor.Log(context.Background(), entry); err != nil {
log.Printf("auth failure audit log failed: %v", err)
}
}
func clientIP(r *http.Request) string {
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
return xrip
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err == nil && host != "" {
return host
}
if strings.TrimSpace(r.RemoteAddr) != "" {
return strings.TrimSpace(r.RemoteAddr)
}
return "unknown"
}
type statusRecorder struct {
http.ResponseWriter
statusCode int
}
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
}
func (r *statusRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *statusRecorder) StatusCode() int {
return r.statusCode
}
func matchRule(m []string, i []string) bool { func matchRule(m []string, i []string) bool {
// m 通常为 [author, repo] 或 [author] // m 通常为 [author, repo] 或 [author]
if len(i) == 1 { if len(i) == 1 {
@@ -194,7 +309,7 @@ func matchRule(m []string, i []string) bool {
} }
func checkURL(u string) []string { func checkURL(u string) []string {
exps := []*regexp.Regexp {exp1, exp2, exp3, exp4, exp5} exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5}
for _, exp := range exps { for _, exp := range exps {
matches := exp.FindStringSubmatch(u) matches := exp.FindStringSubmatch(u)
if matches != nil { if matches != nil {
@@ -210,12 +325,11 @@ func checkURL(u string) []string {
return nil return nil
} }
func parseList(s string) [][]string { func parseList(s string) [][]string {
var res [][]string var res [][]string
lines := strings.Split(s, "\n") lines := strings.Split(s, "\n")
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line != "" { if line != "" {
parts := strings.Split(line, "/") parts := strings.Split(line, "/")
var cleaned []string var cleaned []string