package main import ( "context" "errors" "fmt" "io" "log" "net" "net/http" "net/url" "regexp" "strings" "gitea.gangary.cn/gary/Hugs-Proxy/internal/audit" "gitea.gangary.cn/gary/Hugs-Proxy/internal/auth" "gitea.gangary.cn/gary/Hugs-Proxy/internal/config" "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" ) // ======================== // 配置区域 // ======================== const ( sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB ) // 先生效白名单再匹配黑名单 // 每行一个规则,示例: // user1 # 封禁user1的所有仓库 // user1/repo1 # 封禁user1的repo1 // */repo1 # 封禁所有叫做repo1的仓库 var ( whiteListStr = `` blackListStr = `` ) // ======================== // 全局变量与预编译正则 // ======================== var ( whiteList [][]string blackList [][]string exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:releases|archive)/.*$`) exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:blob|raw)/.*$`) exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:info|git-).*$`) exp4 = regexp.MustCompile(`^(?:https?://)?raw\.(?:githubusercontent|github)\.com/(?P[^/]+)/(?P[^/]+)/.+?/.+$`) exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P[^/]+)/.+?/.+$`) httpClient *http.Client store *db.Store auditor *audit.Logger ) const landingPageHTML = ` Hugs-Proxy

Hugs-Proxy

轻量的学术资源加速反向代理:支持常见 GitHub/Raw/Gist 下载格式,并带 SQLite 鉴权与访问审计。

  • 把目标 URL 放在路径里即可加速:https://github.com/...
  • Bearer Token 鉴权与使用次数控制
  • 重定向 Location 改写,尽量全程走代理
  • 大文件保护(>1GB 自动 302 回源)

加速下载

提示:token 通过表单 POST 发送,不会出现在地址栏。也可使用 curl:` + "\n" + `curl -L -H \"Authorization: Bearer <TOKEN>\" \"https://hugs.you/https://github.com/...\"

` func init() { // 1. 初始化列表 whiteList = parseList(whiteListStr) blackList = parseList(blackListStr) httpClient = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, Timeout: 0, } } func main() { cfg := config.Load() var err error store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS) if err != nil { log.Fatal(err) } defer store.Close() auditor = audit.NewLogger(store) base := http.HandlerFunc(routeHandler) protected := auth.Middleware(store, auditAuthFailure, base) root := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.Path == "/" && (r.Method == http.MethodGet || r.Method == http.MethodHead): serveLandingPage(w, r, cfg) return case r.URL.Path == "/favicon.ico" || r.URL.Path == "/webicon.png": serveWebIcon(w, r) return case r.URL.Path == "/download": downloadHandler(w, r) return default: protected.ServeHTTP(w, r) return } }) http.Handle("/", root) addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) log.Printf("服务器启动成功,正在监听 %s", addr) if err := http.ListenAndServe(addr, nil); err != nil { log.Fatal(err) } } func serveWebIcon(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "./webicon.png") } func serveLandingPage(w http.ResponseWriter, r *http.Request, cfg config.Config) { w.Header().Set("Content-Type", "text/html; charset=utf-8") page := strings.ReplaceAll(landingPageHTML, "{{GITEA_REPO_URL}}", htmlAttrEscape(cfg.GiteaRepoURL)) if r.Method == http.MethodHead { w.WriteHeader(http.StatusOK) return } _, _ = io.WriteString(w, page) } func downloadHandler(w http.ResponseWriter, r *http.Request) { // 为了避免 token 出现在 URL,下载入口仅支持表单 POST。 if r.Method != http.MethodPost { w.Header().Set("Allow", http.MethodPost) http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) return } // ParseForm 会读取 body(application/x-www-form-urlencoded / multipart/form-data)。 if err := r.ParseForm(); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } urlInput := strings.TrimSpace(r.FormValue("url")) tokenValue := strings.TrimSpace(r.FormValue("token")) if tokenValue == "" { auditRequestFailure(r, http.StatusUnauthorized, "missing_token", 0, 0, urlInput) http.Error(w, "Missing token.", http.StatusUnauthorized) return } // 先校验并消耗一次 token。 tok, err := store.ValidateAndConsumeToken(r.Context(), tokenValue) if err != nil { status, reason := classifyTokenError(err) var userID int64 var tokenID int64 if !errors.Is(err, db.ErrNotFound) { if t, getErr := store.GetToken(r.Context(), tokenValue); getErr == nil { userID = t.UserID tokenID = t.ID } } auditRequestFailure(r, status, reason, userID, tokenID, urlInput) http.Error(w, http.StatusText(status), status) return } recorder := newStatusRecorder(w) normalizedURL := urlInput success := false errorReason := "" defer func() { statusCode := recorder.StatusCode() if statusCode >= 200 && statusCode < 400 { success = true } if err := auditor.Log(r.Context(), audit.Entry{ RequestIP: clientIP(r), UserID: tok.UserID, HasUser: true, TokenID: tok.ID, HasToken: true, OriginalURL: normalizedURL, HTTPStatus: statusCode, Success: success, ErrorReason: errorReason, CountAsSuccess: success, }); err != nil { log.Printf("audit log failed: %v", err) } }() if urlInput == "" { errorReason = "missing_url" http.Error(recorder, "Missing url.", http.StatusBadRequest) return } u, prepErr := prepareTargetURL(urlInput) if prepErr != nil { errorReason = prepErr.Reason http.Error(recorder, prepErr.Message, prepErr.StatusCode) return } normalizedURL = u // 表单是 POST,但下载应当总是 GET。 r2 := r.Clone(r.Context()) r2.Method = http.MethodGet r2.Body = nil proxy(recorder, r2, u) } func classifyTokenError(err error) (int, string) { switch { case errors.Is(err, db.ErrNotFound): return http.StatusForbidden, "invalid_token" case errors.Is(err, db.ErrTokenDisabled): return http.StatusForbidden, "token_disabled" case errors.Is(err, db.ErrTokenExpired): return http.StatusForbidden, "token_expired" case errors.Is(err, db.ErrTokenExhausted): return http.StatusForbidden, "token_exhausted" default: return http.StatusInternalServerError, "token_lookup_error" } } func auditRequestFailure(r *http.Request, statusCode int, reason string, userID int64, tokenID int64, originalURL string) { entry := audit.Entry{ RequestIP: clientIP(r), OriginalURL: strings.TrimSpace(originalURL), HTTPStatus: statusCode, Success: false, ErrorReason: reason, CountAsSuccess: false, } if entry.OriginalURL == "" { entry.OriginalURL = strings.TrimPrefix(r.URL.RequestURI(), "/") } if userID > 0 { entry.UserID = userID entry.HasUser = true } if tokenID > 0 { entry.TokenID = tokenID entry.HasToken = true } if err := auditor.Log(context.Background(), entry); err != nil { log.Printf("audit failure log failed: %v", err) } } func htmlAttrEscape(s string) string { // 只用于 attribute:非常轻量的转义,避免引入额外依赖。 // 允许 http(s) URL;其他字符做最小替换。 s = strings.ReplaceAll(s, "&", "&") s = strings.ReplaceAll(s, "\"", """) s = strings.ReplaceAll(s, "<", "<") s = strings.ReplaceAll(s, ">", ">") return s } type targetPrepareError struct { StatusCode int Reason string Message string } func prepareTargetURL(raw string) (string, *targetPrepareError) { u := strings.TrimSpace(raw) if u == "" { return "", &targetPrepareError{StatusCode: http.StatusBadRequest, Reason: "missing_url", Message: "Missing url."} } if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { u = "https://" + strings.TrimPrefix(u, "https:/") } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { u = "http://" + strings.TrimPrefix(u, "http:/") } if !strings.HasPrefix(u, "http") { u = "https://" + u } m := checkURL(u) if m == nil { return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "invalid_input", Message: "Invalid input."} } if len(whiteList) > 0 { allowed := false for _, i := range whiteList { if matchRule(m, i) { allowed = true break } } if !allowed { return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_white_list", Message: "Forbidden by white list."} } } for _, i := range blackList { if matchRule(m, i) { return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_black_list", Message: "Forbidden by black list."} } } // 将网页浏览的blob链接统一替换为raw下载链接 if exp2.MatchString(u) { u = strings.Replace(u, "/blob/", "/raw/", 1) } parsedURL, err := url.Parse(u) if err == nil { u = parsedURL.String() } return u, nil } func routeHandler(w http.ResponseWriter, r *http.Request) { authInfo, ok := auth.FromContext(r.Context()) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } recorder := newStatusRecorder(w) originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/") normalizedURL := originalInput success := false errorReason := "" defer func() { statusCode := recorder.StatusCode() if statusCode >= 200 && statusCode < 400 { success = true } if err := auditor.Log(r.Context(), audit.Entry{ RequestIP: clientIP(r), UserID: authInfo.UserID, HasUser: true, TokenID: authInfo.TokenID, HasToken: true, OriginalURL: normalizedURL, HTTPStatus: statusCode, Success: success, ErrorReason: errorReason, CountAsSuccess: success, }); err != nil { log.Printf("audit log failed: %v", err) } }() u, prepErr := prepareTargetURL(originalInput) if prepErr != nil { errorReason = prepErr.Reason http.Error(recorder, prepErr.Message, prepErr.StatusCode) return } // 发起代理请求 normalizedURL = u proxy(recorder, r, u) } func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { // 修正由于多重代理可能造成的 URL 格式错误 if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") { targetURL = "https://" + targetURL[7:] } req, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } // 拷贝 Header, 但不包含 Host for k, vv := range r.Header { if strings.EqualFold(k, "Host") { continue } for _, v := range vv { req.Header.Add(k, v) } } resp, err := httpClient.Do(req) if err != nil { http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() // 校验大文件限制 if resp.ContentLength > sizeLimit { http.Redirect(w, r, targetURL, http.StatusFound) return } // 处理重定向 Location Header if loc := resp.Header.Get("Location"); loc != "" { if checkURL(loc) != nil { resp.Header.Set("Location", "/"+loc) } else { // 如果不符合 Github 下载格式,则递归代理目标地址 proxy(w, r, loc) return } } // 拷贝响应的 Header 并写入状态码 for k, vv := range resp.Header { for _, v := range vv { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) // Go 会原生以 Stream 方式拷贝数据给 Client 端 io.Copy(w, resp.Body) } func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) { auditRequestFailure(r, statusCode, reason, userID, tokenID, "") } func clientIP(r *http.Request) string { if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { parts := strings.Split(xff, ",") if len(parts) > 0 { return strings.TrimSpace(parts[0]) } } if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" { return xrip } host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) if err == nil && host != "" { return host } if strings.TrimSpace(r.RemoteAddr) != "" { return strings.TrimSpace(r.RemoteAddr) } return "unknown" } type statusRecorder struct { http.ResponseWriter statusCode int } func newStatusRecorder(w http.ResponseWriter) *statusRecorder { return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK} } func (r *statusRecorder) WriteHeader(statusCode int) { r.statusCode = statusCode r.ResponseWriter.WriteHeader(statusCode) } func (r *statusRecorder) StatusCode() int { return r.statusCode } func matchRule(m []string, i []string) bool { // m 通常为 [author, repo] 或 [author] if len(i) == 1 { return len(m) >= 1 && m[0] == i[0] } else if len(i) == 2 { if i[0] == "*" && len(m) >= 2 && m[1] == i[1] { return true } return len(m) >= 2 && m[0] == i[0] && m[1] == i[1] } return false } func checkURL(u string) []string { exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5} for _, exp := range exps { matches := exp.FindStringSubmatch(u) if matches != nil { var result []string for i, name := range exp.SubexpNames() { if i > 0 && name != "" { result = append(result, matches[i]) } } return result } } return nil } func parseList(s string) [][]string { var res [][]string lines := strings.Split(s, "\n") for _, line := range lines { line = strings.TrimSpace(line) if line != "" { parts := strings.Split(line, "/") var cleaned []string for _, p := range parts { cleaned = append(cleaned, strings.ReplaceAll(p, " ", "")) } res = append(res, cleaned) } } return res }