初步实现鉴权

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-04-23 20:21:35 +08:00
parent 847ce0c6a8
commit 62e076111f
10 changed files with 1212 additions and 87 deletions
+136 -22
View File
@@ -1,22 +1,27 @@
package main
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
// ========================
// 配置区域
// ========================
const (
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
host = "127.0.0.1"
port = "2005"
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
)
// 先生效白名单再匹配黑名单
@@ -29,7 +34,6 @@ var (
blackListStr = ``
)
// ========================
// 全局变量与预编译正则
// ========================
@@ -37,7 +41,6 @@ var (
whiteList [][]string
blackList [][]string
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
@@ -45,6 +48,8 @@ var (
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
httpClient *http.Client
store *db.Store
auditor *audit.Logger
)
func init() {
@@ -61,8 +66,22 @@ func init() {
}
func main() {
http.HandleFunc("/", routeHandler)
addr := fmt.Sprintf("%s:%s", host, port)
cfg := config.Load()
var err error
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
if err != nil {
log.Fatal(err)
}
defer store.Close()
auditor = audit.NewLogger(store)
base := http.HandlerFunc(routeHandler)
protected := auth.Middleware(store, auditAuthFailure, base)
http.Handle("/", protected)
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
log.Printf("服务器启动成功,正在监听 %s", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatal(err)
@@ -70,8 +89,40 @@ func main() {
}
func routeHandler(w http.ResponseWriter, r *http.Request) {
authInfo, ok := auth.FromContext(r.Context())
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
u := strings.TrimPrefix(r.URL.RequestURI(), "/")
recorder := newStatusRecorder(w)
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
normalizedURL := originalInput
success := false
errorReason := ""
defer func() {
statusCode := recorder.StatusCode()
if statusCode >= 200 && statusCode < 400 {
success = true
}
if err := auditor.Log(r.Context(), audit.Entry{
RequestIP: clientIP(r),
UserID: authInfo.UserID,
HasUser: true,
TokenID: authInfo.TokenID,
HasToken: true,
OriginalURL: normalizedURL,
HTTPStatus: statusCode,
Success: success,
ErrorReason: errorReason,
CountAsSuccess: success,
}); err != nil {
log.Printf("audit log failed: %v", err)
}
}()
u := originalInput
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
u = "https://" + strings.TrimPrefix(u, "https:/")
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
@@ -84,8 +135,9 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
m := checkURL(u)
if m == nil {
http.Error(w, "Invalid input.", http.StatusForbidden)
return
errorReason = "invalid_input"
http.Error(recorder, "Invalid input.", http.StatusForbidden)
return
}
if len(whiteList) > 0 {
@@ -97,14 +149,16 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
}
}
if !allowed {
http.Error(w, "Forbidden by white list.", http.StatusForbidden)
return
errorReason = "forbidden_by_white_list"
http.Error(recorder, "Forbidden by white list.", http.StatusForbidden)
return
}
}
for _, i := range blackList {
if matchRule(m, i) {
http.Error(w, "Forbidden by black list.", http.StatusForbidden)
errorReason = "forbidden_by_black_list"
http.Error(recorder, "Forbidden by black list.", http.StatusForbidden)
return
}
}
@@ -119,18 +173,19 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
if err == nil {
u = parsedURL.String()
}
proxy(w, r, u)
normalizedURL = u
proxy(recorder, r, u)
}
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
// 修正由于多重代理可能造成的 URL 格式错误
if strings.HasPrefix(targetURL, "https:/") && ! strings.HasPrefix(targetURL, "https://") {
if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") {
targetURL = "https://" + targetURL[7:]
}
req, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError)
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
@@ -146,7 +201,7 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
resp, err := httpClient.Do(req)
if err != nil {
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError)
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
@@ -180,6 +235,66 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
io.Copy(w, resp.Body)
}
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
entry := audit.Entry{
RequestIP: clientIP(r),
OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"),
HTTPStatus: statusCode,
Success: false,
ErrorReason: reason,
CountAsSuccess: false,
}
if userID > 0 {
entry.UserID = userID
entry.HasUser = true
}
if tokenID > 0 {
entry.TokenID = tokenID
entry.HasToken = true
}
if err := auditor.Log(context.Background(), entry); err != nil {
log.Printf("auth failure audit log failed: %v", err)
}
}
func clientIP(r *http.Request) string {
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
return xrip
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err == nil && host != "" {
return host
}
if strings.TrimSpace(r.RemoteAddr) != "" {
return strings.TrimSpace(r.RemoteAddr)
}
return "unknown"
}
type statusRecorder struct {
http.ResponseWriter
statusCode int
}
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
}
func (r *statusRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *statusRecorder) StatusCode() int {
return r.statusCode
}
func matchRule(m []string, i []string) bool {
// m 通常为 [author, repo] 或 [author]
if len(i) == 1 {
@@ -194,7 +309,7 @@ func matchRule(m []string, i []string) bool {
}
func checkURL(u string) []string {
exps := []*regexp.Regexp {exp1, exp2, exp3, exp4, exp5}
exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5}
for _, exp := range exps {
matches := exp.FindStringSubmatch(u)
if matches != nil {
@@ -206,16 +321,15 @@ func checkURL(u string) []string {
}
return result
}
}
}
return nil
}
func parseList(s string) [][]string {
var res [][]string
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.TrimSpace(line)
if line != "" {
parts := strings.Split(line, "/")
var cleaned []string
@@ -226,4 +340,4 @@ func parseList(s string) [][]string {
}
}
return res
}
}