diff --git a/README.md b/README.md index bc238f3..bfcac8d 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,10 @@ curl -L \ token 无效、已禁用、已过期或使用次数耗尽:返回 403。 +## Web 首页 + +访问根路径 `/` 会返回一个简单的项目介绍页,包含:GitHub 链接输入框、Hugs-Proxy Token 输入框、以及“下载”按钮(表单 POST 到 `/download`)。 + ## tokenctl 用法 ```text @@ -62,6 +66,7 @@ tokenctl issue-token --user [--expires <7d|30d|RFC3339>] [--uses ] tokenctl disable-token --token [--db ] tokenctl list-users [--db ] tokenctl list-tokens [--db ] +tokenctl list-usage [--limit ] [--db ] ``` 示例: @@ -69,6 +74,7 @@ tokenctl list-tokens [--db ] ```bash go run ./cmd/tokenctl list-users go run ./cmd/tokenctl list-tokens +go run ./cmd/tokenctl list-usage --limit 100 go run ./cmd/tokenctl issue-token --user alice --uses 5 go run ./cmd/tokenctl disable-token --token ``` @@ -82,6 +88,7 @@ go run ./cmd/tokenctl disable-token --token - HUGS_PROXY_DB_PATH:SQLite 文件路径,默认 ./hugs_proxy.db - HUGS_PROXY_DEFAULT_TOKEN_TTL:默认 token 有效期,默认 30d - HUGS_PROXY_DB_BUSY_TIMEOUT_MS:SQLite busy_timeout,默认 5000 +- HUGS_PROXY_GITEA_REPO_URL:Web 首页底部“我的 Gitea 仓库”跳转地址 说明: diff --git a/cmd/tokenctl/main.go b/cmd/tokenctl/main.go index e0055db..378aba4 100644 --- a/cmd/tokenctl/main.go +++ b/cmd/tokenctl/main.go @@ -33,6 +33,8 @@ func main() { handleListUsers(cfg, os.Args[2:]) case "list-tokens": handleListTokens(cfg, os.Args[2:]) + case "list-usage": + handleListUsage(cfg, os.Args[2:]) default: printUsage() os.Exit(1) @@ -160,7 +162,44 @@ func handleListTokens(cfg config.Config, args []string) { if t.MaxUses == 0 { maxUsesDisplay = "unlimited" } - fmt.Printf("%d\t%s\t%t\t%d/%s\t%s\t%s\t%s\n", t.ID, t.Username, t.Disabled, t.UsedCount, maxUsesDisplay, t.ExpiresAt.Format(time.RFC3339), t.CreatedAt.Format(time.RFC3339), t.Token) + fmt.Printf("%d\t%s\t%t\t%d/%s\t%s\t%s\t%s\n", t.ID, t.Username, t.Disabled, t.UsedCount, maxUsesDisplay, t.ExpiresAt.Format(time.RFC3339), t.CreatedAt.Format(time.RFC3339), t.Token.Token) + } +} + +func handleListUsage(cfg config.Config, args []string) { + fs := flag.NewFlagSet("list-usage", flag.ExitOnError) + dbPath := fs.String("db", cfg.DBPath, "SQLite DB path") + limit := fs.Int("limit", 50, "max rows to show, <=0 means all") + _ = fs.Parse(args) + + store := openStore(*dbPath, cfg.BusyTimeoutMS) + defer store.Close() + + entries, err := store.ListUsageEntries(context.Background(), *limit) + if err != nil { + log.Fatal(err) + } + + fmt.Println("time\trequest_ip\tuser_id\ttoken_id\thttp_status\tsuccess\terror_reason\toriginal_url") + for _, e := range entries { + userID := "-" + tokenID := "-" + if e.UserID.Valid { + userID = fmt.Sprintf("%d", e.UserID.Int64) + } + if e.TokenID.Valid { + tokenID = fmt.Sprintf("%d", e.TokenID.Int64) + } + fmt.Printf("%s\t%s\t%s\t%s\t%d\t%t\t%s\t%s\n", + e.OccurredAt.Format(time.RFC3339), + e.RequestIP, + userID, + tokenID, + e.HTTPStatus, + e.Success, + e.ErrorReason, + e.OriginalURL, + ) } } @@ -194,5 +233,6 @@ func printUsage() { tokenctl issue-token --user [--expires <7d|30d|RFC3339>] [--uses ] [--db ] tokenctl disable-token --token [--db ] tokenctl list-users [--db ] - tokenctl list-tokens [--db ]`) + tokenctl list-tokens [--db ] + tokenctl list-usage [--limit ] [--db ]`) } diff --git a/internal/config/config.go b/internal/config/config.go index d9a9045..d920d54 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { Host string Port string DBPath string + GiteaRepoURL string DefaultTokenTTL time.Duration BusyTimeoutMS int } @@ -27,6 +28,7 @@ func Load() Config { host := getenvOrDefault("HUGS_PROXY_HOST", defaultHost) port := getenvOrDefault("HUGS_PROXY_PORT", defaultPort) dbPath := getenvOrDefault("HUGS_PROXY_DB_PATH", defaultDBPath) + giteaRepoURL := getenvOrDefault("HUGS_PROXY_GITEA_REPO_URL", "https://gitea.gangary.cn/gary/Hugs-Proxy") ttlRaw := getenvOrDefault("HUGS_PROXY_DEFAULT_TOKEN_TTL", defaultTokenTTLString) busyTimeout := getenvIntOrDefault("HUGS_PROXY_DB_BUSY_TIMEOUT_MS", defaultBusyTimeoutMilli) @@ -39,6 +41,7 @@ func Load() Config { Host: host, Port: port, DBPath: dbPath, + GiteaRepoURL: giteaRepoURL, DefaultTokenTTL: ttl, BusyTimeoutMS: busyTimeout, } diff --git a/internal/db/store.go b/internal/db/store.go index 2582513..919a065 100644 --- a/internal/db/store.go +++ b/internal/db/store.go @@ -363,7 +363,7 @@ func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) { var expiresAt string var disabledAt sql.NullString var disabledInt int - if err := rows.Scan(&t.ID, &t.UserID, &t.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil { + if err := rows.Scan(&t.ID, &t.UserID, &t.Token.Token, &createdAt, &expiresAt, &disabledInt, &disabledAt, &t.MaxUses, &t.UsedCount, &t.Username); err != nil { return nil, err } ct, err := time.Parse(timeFormat, createdAt) @@ -392,6 +392,44 @@ func (s *Store) ListTokens(ctx context.Context) ([]TokenWithUser, error) { return out, nil } +func (s *Store) ListUsageEntries(ctx context.Context, limit int) ([]UsageEntry, error) { + query := `SELECT created_at, request_ip, user_id, token_id, original_url, http_status, success, error_reason + FROM token_usage + ORDER BY id DESC` + args := []any{} + if limit > 0 { + query += ` LIMIT ?` + args = append(args, limit) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []UsageEntry + for rows.Next() { + var e UsageEntry + var createdAt string + var successInt int + if err := rows.Scan(&createdAt, &e.RequestIP, &e.UserID, &e.TokenID, &e.OriginalURL, &e.HTTPStatus, &successInt, &e.ErrorReason); err != nil { + return nil, err + } + parsed, err := time.Parse(timeFormat, createdAt) + if err != nil { + return nil, err + } + e.OccurredAt = parsed + e.Success = successInt == 1 + out = append(out, e) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + func (s *Store) RecordUsageAndMaybeIncrement(ctx context.Context, entry UsageEntry, increment bool) error { if entry.OccurredAt.IsZero() { entry.OccurredAt = time.Now().UTC() diff --git a/main.go b/main.go index c0f8458..124f587 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "io" "log" @@ -52,6 +53,242 @@ var ( auditor *audit.Logger ) +const landingPageHTML = ` + + + + + Hugs-Proxy + + + + +
+
+
+
+ +
+

Hugs-Proxy

+

轻量的学术资源加速反向代理:支持常见 GitHub/Raw/Gist 下载格式,并带 SQLite 鉴权与访问审计。

+
+
+
    +
  • 把目标 URL 放在路径里即可加速:https://github.com/...
  • +
  • Bearer Token 鉴权与使用次数控制
  • +
  • 重定向 Location 改写,尽量全程走代理
  • +
  • 大文件保护(>1GB 自动 302 回源)
  • +
+
+ +
+

加速下载

+
+
+ + +
+
+ + +
+ +

提示:token 通过表单 POST 发送,不会出现在地址栏。也可使用 curl:` + "\n" + `curl -L -H \"Authorization: Bearer <TOKEN>\" \"https://hugs.you/https://github.com/...\"

+
+
+
+ + +
+ + +` + func init() { // 1. 初始化列表 whiteList = parseList(whiteListStr) @@ -79,7 +316,23 @@ func main() { base := http.HandlerFunc(routeHandler) protected := auth.Middleware(store, auditAuthFailure, base) - http.Handle("/", protected) + root := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/" && (r.Method == http.MethodGet || r.Method == http.MethodHead): + serveLandingPage(w, r, cfg) + return + case r.URL.Path == "/favicon.ico" || r.URL.Path == "/webicon.png": + serveWebIcon(w, r) + return + case r.URL.Path == "/download": + downloadHandler(w, r) + return + default: + protected.ServeHTTP(w, r) + return + } + }) + http.Handle("/", root) addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) log.Printf("服务器启动成功,正在监听 %s", addr) @@ -88,6 +341,217 @@ func main() { } } +func serveWebIcon(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, "./webicon.png") +} + +func serveLandingPage(w http.ResponseWriter, r *http.Request, cfg config.Config) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + page := strings.ReplaceAll(landingPageHTML, "{{GITEA_REPO_URL}}", htmlAttrEscape(cfg.GiteaRepoURL)) + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + return + } + _, _ = io.WriteString(w, page) +} + +func downloadHandler(w http.ResponseWriter, r *http.Request) { + // 为了避免 token 出现在 URL,下载入口仅支持表单 POST。 + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + // ParseForm 会读取 body(application/x-www-form-urlencoded / multipart/form-data)。 + if err := r.ParseForm(); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + urlInput := strings.TrimSpace(r.FormValue("url")) + tokenValue := strings.TrimSpace(r.FormValue("token")) + if tokenValue == "" { + auditRequestFailure(r, http.StatusUnauthorized, "missing_token", 0, 0, urlInput) + http.Error(w, "Missing token.", http.StatusUnauthorized) + return + } + + // 先校验并消耗一次 token。 + tok, err := store.ValidateAndConsumeToken(r.Context(), tokenValue) + if err != nil { + status, reason := classifyTokenError(err) + + var userID int64 + var tokenID int64 + if !errors.Is(err, db.ErrNotFound) { + if t, getErr := store.GetToken(r.Context(), tokenValue); getErr == nil { + userID = t.UserID + tokenID = t.ID + } + } + + auditRequestFailure(r, status, reason, userID, tokenID, urlInput) + http.Error(w, http.StatusText(status), status) + return + } + + recorder := newStatusRecorder(w) + normalizedURL := urlInput + success := false + errorReason := "" + + defer func() { + statusCode := recorder.StatusCode() + if statusCode >= 200 && statusCode < 400 { + success = true + } + if err := auditor.Log(r.Context(), audit.Entry{ + RequestIP: clientIP(r), + UserID: tok.UserID, + HasUser: true, + TokenID: tok.ID, + HasToken: true, + OriginalURL: normalizedURL, + HTTPStatus: statusCode, + Success: success, + ErrorReason: errorReason, + CountAsSuccess: success, + }); err != nil { + log.Printf("audit log failed: %v", err) + } + }() + + if urlInput == "" { + errorReason = "missing_url" + http.Error(recorder, "Missing url.", http.StatusBadRequest) + return + } + + u, prepErr := prepareTargetURL(urlInput) + if prepErr != nil { + errorReason = prepErr.Reason + http.Error(recorder, prepErr.Message, prepErr.StatusCode) + return + } + normalizedURL = u + + // 表单是 POST,但下载应当总是 GET。 + r2 := r.Clone(r.Context()) + r2.Method = http.MethodGet + r2.Body = nil + proxy(recorder, r2, u) +} + +func classifyTokenError(err error) (int, string) { + switch { + case errors.Is(err, db.ErrNotFound): + return http.StatusForbidden, "invalid_token" + case errors.Is(err, db.ErrTokenDisabled): + return http.StatusForbidden, "token_disabled" + case errors.Is(err, db.ErrTokenExpired): + return http.StatusForbidden, "token_expired" + case errors.Is(err, db.ErrTokenExhausted): + return http.StatusForbidden, "token_exhausted" + default: + return http.StatusInternalServerError, "token_lookup_error" + } +} + +func auditRequestFailure(r *http.Request, statusCode int, reason string, userID int64, tokenID int64, originalURL string) { + entry := audit.Entry{ + RequestIP: clientIP(r), + OriginalURL: strings.TrimSpace(originalURL), + HTTPStatus: statusCode, + Success: false, + ErrorReason: reason, + CountAsSuccess: false, + } + if entry.OriginalURL == "" { + entry.OriginalURL = strings.TrimPrefix(r.URL.RequestURI(), "/") + } + if userID > 0 { + entry.UserID = userID + entry.HasUser = true + } + if tokenID > 0 { + entry.TokenID = tokenID + entry.HasToken = true + } + if err := auditor.Log(context.Background(), entry); err != nil { + log.Printf("audit failure log failed: %v", err) + } +} + +func htmlAttrEscape(s string) string { + // 只用于 attribute:非常轻量的转义,避免引入额外依赖。 + // 允许 http(s) URL;其他字符做最小替换。 + s = strings.ReplaceAll(s, "&", "&") + s = strings.ReplaceAll(s, "\"", """) + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, ">", ">") + return s +} + +type targetPrepareError struct { + StatusCode int + Reason string + Message string +} + +func prepareTargetURL(raw string) (string, *targetPrepareError) { + u := strings.TrimSpace(raw) + if u == "" { + return "", &targetPrepareError{StatusCode: http.StatusBadRequest, Reason: "missing_url", Message: "Missing url."} + } + + if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { + u = "https://" + strings.TrimPrefix(u, "https:/") + } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { + u = "http://" + strings.TrimPrefix(u, "http:/") + } + + if !strings.HasPrefix(u, "http") { + u = "https://" + u + } + + m := checkURL(u) + if m == nil { + return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "invalid_input", Message: "Invalid input."} + } + + if len(whiteList) > 0 { + allowed := false + for _, i := range whiteList { + if matchRule(m, i) { + allowed = true + break + } + } + if !allowed { + return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_white_list", Message: "Forbidden by white list."} + } + } + + for _, i := range blackList { + if matchRule(m, i) { + return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_black_list", Message: "Forbidden by black list."} + } + } + + // 将网页浏览的blob链接统一替换为raw下载链接 + if exp2.MatchString(u) { + u = strings.Replace(u, "/blob/", "/raw/", 1) + } + + parsedURL, err := url.Parse(u) + if err == nil { + u = parsedURL.String() + } + + return u, nil +} + func routeHandler(w http.ResponseWriter, r *http.Request) { authInfo, ok := auth.FromContext(r.Context()) if !ok { @@ -122,57 +586,14 @@ func routeHandler(w http.ResponseWriter, r *http.Request) { } }() - u := originalInput - if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { - u = "https://" + strings.TrimPrefix(u, "https:/") - } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { - u = "http://" + strings.TrimPrefix(u, "http:/") - } - - if !strings.HasPrefix(u, "http") { - u = "https://" + u - } - - m := checkURL(u) - if m == nil { - errorReason = "invalid_input" - http.Error(recorder, "Invalid input.", http.StatusForbidden) + u, prepErr := prepareTargetURL(originalInput) + if prepErr != nil { + errorReason = prepErr.Reason + http.Error(recorder, prepErr.Message, prepErr.StatusCode) return } - if len(whiteList) > 0 { - allowed := false - for _, i := range whiteList { - if matchRule(m, i) { - allowed = true - break - } - } - if !allowed { - errorReason = "forbidden_by_white_list" - http.Error(recorder, "Forbidden by white list.", http.StatusForbidden) - return - } - } - - for _, i := range blackList { - if matchRule(m, i) { - errorReason = "forbidden_by_black_list" - http.Error(recorder, "Forbidden by black list.", http.StatusForbidden) - return - } - } - - // 将网页浏览的blob链接统一替换为raw下载链接 - if exp2.MatchString(u) { - u = strings.Replace(u, "/blob/", "/raw/", 1) - } - // 发起代理请求 - parsedURL, err := url.Parse(u) - if err == nil { - u = parsedURL.String() - } normalizedURL = u proxy(recorder, r, u) } @@ -236,25 +657,7 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { } func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) { - entry := audit.Entry{ - RequestIP: clientIP(r), - OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"), - HTTPStatus: statusCode, - Success: false, - ErrorReason: reason, - CountAsSuccess: false, - } - if userID > 0 { - entry.UserID = userID - entry.HasUser = true - } - if tokenID > 0 { - entry.TokenID = tokenID - entry.HasToken = true - } - if err := auditor.Log(context.Background(), entry); err != nil { - log.Printf("auth failure audit log failed: %v", err) - } + auditRequestFailure(r, statusCode, reason, userID, tokenID, "") } func clientIP(r *http.Request) string { diff --git a/webicon.png b/webicon.png new file mode 100644 index 0000000..d9f4b5a Binary files /dev/null and b/webicon.png differ