Files
Hugs-Proxy/main.go
T
gary 62e076111f 初步实现鉴权
Co-authored-by: Copilot <copilot@github.com>
2026-04-23 20:21:35 +08:00

344 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
// ========================
// 配置区域
// ========================
const (
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
)
// 先生效白名单再匹配黑名单
// 每行一个规则,示例:
// user1 # 封禁user1的所有仓库
// user1/repo1 # 封禁user1的repo1
// */repo1 # 封禁所有叫做repo1的仓库
var (
whiteListStr = ``
blackListStr = ``
)
// ========================
// 全局变量与预编译正则
// ========================
var (
whiteList [][]string
blackList [][]string
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
exp4 = regexp.MustCompile(`^(?:https?://)?raw\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/.+?/.+$`)
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
httpClient *http.Client
store *db.Store
auditor *audit.Logger
)
func init() {
// 1. 初始化列表
whiteList = parseList(whiteListStr)
blackList = parseList(blackListStr)
httpClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 0,
}
}
func main() {
cfg := config.Load()
var err error
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
if err != nil {
log.Fatal(err)
}
defer store.Close()
auditor = audit.NewLogger(store)
base := http.HandlerFunc(routeHandler)
protected := auth.Middleware(store, auditAuthFailure, base)
http.Handle("/", protected)
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
log.Printf("服务器启动成功,正在监听 %s", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatal(err)
}
}
func routeHandler(w http.ResponseWriter, r *http.Request) {
authInfo, ok := auth.FromContext(r.Context())
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
recorder := newStatusRecorder(w)
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
normalizedURL := originalInput
success := false
errorReason := ""
defer func() {
statusCode := recorder.StatusCode()
if statusCode >= 200 && statusCode < 400 {
success = true
}
if err := auditor.Log(r.Context(), audit.Entry{
RequestIP: clientIP(r),
UserID: authInfo.UserID,
HasUser: true,
TokenID: authInfo.TokenID,
HasToken: true,
OriginalURL: normalizedURL,
HTTPStatus: statusCode,
Success: success,
ErrorReason: errorReason,
CountAsSuccess: success,
}); err != nil {
log.Printf("audit log failed: %v", err)
}
}()
u := originalInput
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
u = "https://" + strings.TrimPrefix(u, "https:/")
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
u = "http://" + strings.TrimPrefix(u, "http:/")
}
if !strings.HasPrefix(u, "http") {
u = "https://" + u
}
m := checkURL(u)
if m == nil {
errorReason = "invalid_input"
http.Error(recorder, "Invalid input.", http.StatusForbidden)
return
}
if len(whiteList) > 0 {
allowed := false
for _, i := range whiteList {
if matchRule(m, i) {
allowed = true
break
}
}
if !allowed {
errorReason = "forbidden_by_white_list"
http.Error(recorder, "Forbidden by white list.", http.StatusForbidden)
return
}
}
for _, i := range blackList {
if matchRule(m, i) {
errorReason = "forbidden_by_black_list"
http.Error(recorder, "Forbidden by black list.", http.StatusForbidden)
return
}
}
// 将网页浏览的blob链接统一替换为raw下载链接
if exp2.MatchString(u) {
u = strings.Replace(u, "/blob/", "/raw/", 1)
}
// 发起代理请求
parsedURL, err := url.Parse(u)
if err == nil {
u = parsedURL.String()
}
normalizedURL = u
proxy(recorder, r, u)
}
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
// 修正由于多重代理可能造成的 URL 格式错误
if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") {
targetURL = "https://" + targetURL[7:]
}
req, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
// 拷贝 Header 但不包含 Host
for k, vv := range r.Header {
if strings.EqualFold(k, "Host") {
continue
}
for _, v := range vv {
req.Header.Add(k, v)
}
}
resp, err := httpClient.Do(req)
if err != nil {
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 校验大文件限制
if resp.ContentLength > sizeLimit {
http.Redirect(w, r, targetURL, http.StatusFound)
return
}
// 处理重定向 Location Header
if loc := resp.Header.Get("Location"); loc != "" {
if checkURL(loc) != nil {
resp.Header.Set("Location", "/"+loc)
} else {
// 如果不符合 Github 下载格式,则递归代理目标地址
proxy(w, r, loc)
return
}
}
// 拷贝响应的 Header 并写入状态码
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// Go 会原生以 Stream 方式拷贝数据给 Client 端
io.Copy(w, resp.Body)
}
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
entry := audit.Entry{
RequestIP: clientIP(r),
OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"),
HTTPStatus: statusCode,
Success: false,
ErrorReason: reason,
CountAsSuccess: false,
}
if userID > 0 {
entry.UserID = userID
entry.HasUser = true
}
if tokenID > 0 {
entry.TokenID = tokenID
entry.HasToken = true
}
if err := auditor.Log(context.Background(), entry); err != nil {
log.Printf("auth failure audit log failed: %v", err)
}
}
func clientIP(r *http.Request) string {
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
return xrip
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err == nil && host != "" {
return host
}
if strings.TrimSpace(r.RemoteAddr) != "" {
return strings.TrimSpace(r.RemoteAddr)
}
return "unknown"
}
type statusRecorder struct {
http.ResponseWriter
statusCode int
}
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
}
func (r *statusRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *statusRecorder) StatusCode() int {
return r.statusCode
}
func matchRule(m []string, i []string) bool {
// m 通常为 [author, repo] 或 [author]
if len(i) == 1 {
return len(m) >= 1 && m[0] == i[0]
} else if len(i) == 2 {
if i[0] == "*" && len(m) >= 2 && m[1] == i[1] {
return true
}
return len(m) >= 2 && m[0] == i[0] && m[1] == i[1]
}
return false
}
func checkURL(u string) []string {
exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5}
for _, exp := range exps {
matches := exp.FindStringSubmatch(u)
if matches != nil {
var result []string
for i, name := range exp.SubexpNames() {
if i > 0 && name != "" {
result = append(result, matches[i])
}
}
return result
}
}
return nil
}
func parseList(s string) [][]string {
var res [][]string
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
parts := strings.Split(line, "/")
var cleaned []string
for _, p := range parts {
cleaned = append(cleaned, strings.ReplaceAll(p, " ", ""))
}
res = append(res, cleaned)
}
}
return res
}