@@ -1,22 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
|
||||
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
||||
)
|
||||
|
||||
// ========================
|
||||
// 配置区域
|
||||
// ========================
|
||||
const (
|
||||
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
|
||||
host = "127.0.0.1"
|
||||
port = "2005"
|
||||
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
|
||||
)
|
||||
|
||||
// 先生效白名单再匹配黑名单
|
||||
@@ -29,7 +34,6 @@ var (
|
||||
blackListStr = ``
|
||||
)
|
||||
|
||||
|
||||
// ========================
|
||||
// 全局变量与预编译正则
|
||||
// ========================
|
||||
@@ -37,7 +41,6 @@ var (
|
||||
whiteList [][]string
|
||||
blackList [][]string
|
||||
|
||||
|
||||
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
|
||||
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
|
||||
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
|
||||
@@ -45,6 +48,8 @@ var (
|
||||
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
|
||||
|
||||
httpClient *http.Client
|
||||
store *db.Store
|
||||
auditor *audit.Logger
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -61,8 +66,22 @@ func init() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
http.HandleFunc("/", routeHandler)
|
||||
addr := fmt.Sprintf("%s:%s", host, port)
|
||||
cfg := config.Load()
|
||||
|
||||
var err error
|
||||
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
auditor = audit.NewLogger(store)
|
||||
|
||||
base := http.HandlerFunc(routeHandler)
|
||||
protected := auth.Middleware(store, auditAuthFailure, base)
|
||||
http.Handle("/", protected)
|
||||
|
||||
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
|
||||
log.Printf("服务器启动成功,正在监听 %s", addr)
|
||||
if err := http.ListenAndServe(addr, nil); err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -70,8 +89,40 @@ func main() {
|
||||
}
|
||||
|
||||
func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
authInfo, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
u := strings.TrimPrefix(r.URL.RequestURI(), "/")
|
||||
recorder := newStatusRecorder(w)
|
||||
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
|
||||
normalizedURL := originalInput
|
||||
success := false
|
||||
errorReason := ""
|
||||
|
||||
defer func() {
|
||||
statusCode := recorder.StatusCode()
|
||||
if statusCode >= 200 && statusCode < 400 {
|
||||
success = true
|
||||
}
|
||||
if err := auditor.Log(r.Context(), audit.Entry{
|
||||
RequestIP: clientIP(r),
|
||||
UserID: authInfo.UserID,
|
||||
HasUser: true,
|
||||
TokenID: authInfo.TokenID,
|
||||
HasToken: true,
|
||||
OriginalURL: normalizedURL,
|
||||
HTTPStatus: statusCode,
|
||||
Success: success,
|
||||
ErrorReason: errorReason,
|
||||
CountAsSuccess: success,
|
||||
}); err != nil {
|
||||
log.Printf("audit log failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
u := originalInput
|
||||
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
|
||||
u = "https://" + strings.TrimPrefix(u, "https:/")
|
||||
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
|
||||
@@ -84,8 +135,9 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
m := checkURL(u)
|
||||
if m == nil {
|
||||
http.Error(w, "Invalid input.", http.StatusForbidden)
|
||||
return
|
||||
errorReason = "invalid_input"
|
||||
http.Error(recorder, "Invalid input.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if len(whiteList) > 0 {
|
||||
@@ -97,14 +149,16 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
http.Error(w, "Forbidden by white list.", http.StatusForbidden)
|
||||
return
|
||||
errorReason = "forbidden_by_white_list"
|
||||
http.Error(recorder, "Forbidden by white list.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, i := range blackList {
|
||||
if matchRule(m, i) {
|
||||
http.Error(w, "Forbidden by black list.", http.StatusForbidden)
|
||||
errorReason = "forbidden_by_black_list"
|
||||
http.Error(recorder, "Forbidden by black list.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -119,18 +173,19 @@ func routeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil {
|
||||
u = parsedURL.String()
|
||||
}
|
||||
proxy(w, r, u)
|
||||
normalizedURL = u
|
||||
proxy(recorder, r, u)
|
||||
}
|
||||
|
||||
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
|
||||
// 修正由于多重代理可能造成的 URL 格式错误
|
||||
if strings.HasPrefix(targetURL, "https:/") && ! strings.HasPrefix(targetURL, "https://") {
|
||||
if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") {
|
||||
targetURL = "https://" + targetURL[7:]
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(r.Method, targetURL, r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,7 +201,7 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, "server error " + err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -180,6 +235,66 @@ func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
|
||||
entry := audit.Entry{
|
||||
RequestIP: clientIP(r),
|
||||
OriginalURL: strings.TrimPrefix(r.URL.RequestURI(), "/"),
|
||||
HTTPStatus: statusCode,
|
||||
Success: false,
|
||||
ErrorReason: reason,
|
||||
CountAsSuccess: false,
|
||||
}
|
||||
if userID > 0 {
|
||||
entry.UserID = userID
|
||||
entry.HasUser = true
|
||||
}
|
||||
if tokenID > 0 {
|
||||
entry.TokenID = tokenID
|
||||
entry.HasToken = true
|
||||
}
|
||||
if err := auditor.Log(context.Background(), entry); err != nil {
|
||||
log.Printf("auth failure audit log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
|
||||
return xrip
|
||||
}
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
if strings.TrimSpace(r.RemoteAddr) != "" {
|
||||
return strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
|
||||
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
}
|
||||
|
||||
func (r *statusRecorder) WriteHeader(statusCode int) {
|
||||
r.statusCode = statusCode
|
||||
r.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (r *statusRecorder) StatusCode() int {
|
||||
return r.statusCode
|
||||
}
|
||||
|
||||
func matchRule(m []string, i []string) bool {
|
||||
// m 通常为 [author, repo] 或 [author]
|
||||
if len(i) == 1 {
|
||||
@@ -194,7 +309,7 @@ func matchRule(m []string, i []string) bool {
|
||||
}
|
||||
|
||||
func checkURL(u string) []string {
|
||||
exps := []*regexp.Regexp {exp1, exp2, exp3, exp4, exp5}
|
||||
exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5}
|
||||
for _, exp := range exps {
|
||||
matches := exp.FindStringSubmatch(u)
|
||||
if matches != nil {
|
||||
@@ -206,16 +321,15 @@ func checkURL(u string) []string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func parseList(s string) [][]string {
|
||||
var res [][]string
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
parts := strings.Split(line, "/")
|
||||
var cleaned []string
|
||||
@@ -226,4 +340,4 @@ func parseList(s string) [][]string {
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user