package main import ( "context" "fmt" "io" "log" "net" "net/http" "net/url" "regexp" "strings" "gitea.gangary.cn/gary/Hugs-Proxy/internal/audit" "gitea.gangary.cn/gary/Hugs-Proxy/internal/auth" "gitea.gangary.cn/gary/Hugs-Proxy/internal/config" "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" ) // ======================== // 配置区域 // ======================== const ( sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB ) // 先生效白名单再匹配黑名单 // 每行一个规则,示例: // user1 # 封禁user1的所有仓库 // user1/repo1 # 封禁user1的repo1 // */repo1 # 封禁所有叫做repo1的仓库 var ( whiteListStr = `` blackListStr = `` ) // ======================== // 全局变量与预编译正则 // ======================== var ( whiteList [][]string blackList [][]string exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:releases|archive)/.*$`) exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:blob|raw)/.*$`) exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P[^/]+)/(?P[^/]+)/(?:info|git-).*$`) exp4 = regexp.MustCompile(`^(?:https?://)?raw\.(?:githubusercontent|github)\.com/(?P[^/]+)/(?P[^/]+)/.+?/.+$`) exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P[^/]+)/.+?/.+$`) httpClient *http.Client store *db.Store auditor *audit.Logger ) func init() { // 1. 初始化列表 whiteList = parseList(whiteListStr) blackList = parseList(blackListStr) httpClient = &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, Timeout: 0, } } func main() { cfg := config.Load() var err error store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS) if err != nil { log.Fatal(err) } defer store.Close() auditor = audit.NewLogger(store) base := http.HandlerFunc(routeHandler) protected := auth.Middleware(store, auditAuthFailure, base) http.Handle("/", protected) addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) log.Printf("服务器启动成功,正在监听 %s", addr) if err := http.ListenAndServe(addr, nil); err != nil { log.Fatal(err) } } func routeHandler(w http.ResponseWriter, r *http.Request) { authInfo, ok := auth.FromContext(r.Context()) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } recorder := newStatusRecorder(w) originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/") normalizedURL := originalInput success := false errorReason := "" defer func() { statusCode := recorder.StatusCode() if statusCode >= 200 && statusCode < 400 { success = true } if err := auditor.Log(r.Context(), audit.Entry{ RequestIP: clientIP(r), UserID: authInfo.UserID, HasUser: true, TokenID: authInfo.TokenID, HasToken: true, OriginalURL: normalizedURL, HTTPStatus: statusCode, Success: success, ErrorReason: errorReason, CountAsSuccess: success, }); err != nil { log.Printf("audit log failed: %v", err) } }() u := originalInput if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") { u = "https://" + strings.TrimPrefix(u, "https:/") } else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") { u = "http://" + strings.TrimPrefix(u, "http:/") } if !strings.HasPrefix(u, "http") { u = "https://" + u } m := checkURL(u) if m == nil { errorReason = "invalid_input" http.Error(recorder, "Invalid input.", http.StatusForbidden) return } if len(whiteList) > 0 { allowed := false for _, i := range whiteList { if matchRule(m, i) { allowed = true break } } if !allowed { errorReason = "forbidden_by_white_list" http.Error(recorder, "Forbidden by white list.", http.StatusForbidden) return } } for _, i := range blackList { if matchRule(m, i) { errorReason = "forbidden_by_black_list" http.Error(recorder, "Forbidden by black list.", http.StatusForbidden) return } } // 将网页浏览的blob链接统一替换为raw下载链接 if exp2.MatchString(u) { u = strings.Replace(u, "/blob/", "/raw/", 1) } // 发起代理请求 parsedURL, err := url.Parse(u) if err == nil { u = parsedURL.String() } normalizedURL = u proxy(recorder, r, u) } func proxy(w http.ResponseWriter, r *http.Request, targetURL string) { // 修正由于多重代理可能造成的 URL 格式错误 if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") { targetURL = "https://" + targetURL[7:] } req, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } // 拷贝 Header, 但不包含 Host for k, vv := range r.Header { if strings.EqualFold(k, "Host") { continue } for _, v := range vv { req.Header.Add(k, v) } } resp, err := httpClient.Do(req) if err != nil { http.Error(w, "server error "+err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() // 校验大文件限制 if resp.ContentLength > sizeLimit { http.Redirect(w, r, targetURL, http.StatusFound) return } // 处理重定向 Location Header if loc := resp.Header.Get("Location"); loc != "" { if checkURL(loc) != nil { resp.Header.Set("Location", "/"+loc) } else { // 如果不符合 Github 下载格式,则递归代理目标地址 proxy(w, r, loc) return } } // 拷贝响应的 Header 并写入状态码 for k, vv := range resp.Header { for _, v := range vv { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) // Go 会原生以 Stream 方式拷贝数据给 Client 端 io.Copy(w, resp.Body) } func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) { entry := audit.Entry{ RequestIP: clientIP(r), OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"), HTTPStatus: statusCode, Success: false, ErrorReason: reason, CountAsSuccess: false, } if userID > 0 { entry.UserID = userID entry.HasUser = true } if tokenID > 0 { entry.TokenID = tokenID entry.HasToken = true } if err := auditor.Log(context.Background(), entry); err != nil { log.Printf("auth failure audit log failed: %v", err) } } func clientIP(r *http.Request) string { if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { parts := strings.Split(xff, ",") if len(parts) > 0 { return strings.TrimSpace(parts[0]) } } if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" { return xrip } host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) if err == nil && host != "" { return host } if strings.TrimSpace(r.RemoteAddr) != "" { return strings.TrimSpace(r.RemoteAddr) } return "unknown" } type statusRecorder struct { http.ResponseWriter statusCode int } func newStatusRecorder(w http.ResponseWriter) *statusRecorder { return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK} } func (r *statusRecorder) WriteHeader(statusCode int) { r.statusCode = statusCode r.ResponseWriter.WriteHeader(statusCode) } func (r *statusRecorder) StatusCode() int { return r.statusCode } func matchRule(m []string, i []string) bool { // m 通常为 [author, repo] 或 [author] if len(i) == 1 { return len(m) >= 1 && m[0] == i[0] } else if len(i) == 2 { if i[0] == "*" && len(m) >= 2 && m[1] == i[1] { return true } return len(m) >= 2 && m[0] == i[0] && m[1] == i[1] } return false } func checkURL(u string) []string { exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5} for _, exp := range exps { matches := exp.FindStringSubmatch(u) if matches != nil { var result []string for i, name := range exp.SubexpNames() { if i > 0 && name != "" { result = append(result, matches[i]) } } return result } } return nil } func parseList(s string) [][]string { var res [][]string lines := strings.Split(s, "\n") for _, line := range lines { line = strings.TrimSpace(line) if line != "" { parts := strings.Split(line, "/") var cleaned []string for _, p := range parts { cleaned = append(cleaned, strings.ReplaceAll(p, " ", "")) } res = append(res, cleaned) } } return res }