Files
Hugs-Proxy/main.go
T
gary f8051ea8c9 v0.1 公开测试版
Co-authored-by: Copilot <copilot@github.com>
2026-04-23 21:48:04 +08:00

747 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/audit"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/auth"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/config"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
// ========================
// 配置区域
// ========================
const (
sizeLimit = int64(1024 * 1024 * 1024 * 1) // 允许的文件大小, 1GB
)
// 先生效白名单再匹配黑名单
// 每行一个规则,示例:
// user1 # 封禁user1的所有仓库
// user1/repo1 # 封禁user1的repo1
// */repo1 # 封禁所有叫做repo1的仓库
var (
whiteListStr = ``
blackListStr = ``
)
// ========================
// 全局变量与预编译正则
// ========================
var (
whiteList [][]string
blackList [][]string
exp1 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:releases|archive)/.*$`)
exp2 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:blob|raw)/.*$`)
exp3 = regexp.MustCompile(`^(?:https?://)?github\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/(?:info|git-).*$`)
exp4 = regexp.MustCompile(`^(?:https?://)?raw\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/(?P<repo>[^/]+)/.+?/.+$`)
exp5 = regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(?P<author>[^/]+)/.+?/.+$`)
httpClient *http.Client
store *db.Store
auditor *audit.Logger
)
const landingPageHTML = `<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Hugs-Proxy</title>
<style>
:root {
--bg1: #070a12;
--bg2: #0b1024;
--fg: rgba(255,255,255,.92);
--muted: rgba(255,255,255,.65);
--card: rgba(255,255,255,.06);
--card2: rgba(255,255,255,.10);
--stroke: rgba(255,255,255,.16);
--accent: #6ee7ff;
--accent2: #a78bfa;
}
* { box-sizing: border-box; }
html, body { height: 100%; }
body {
margin: 0;
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji";
color: var(--fg);
background: radial-gradient(1200px 700px at 10% 10%, rgba(167,139,250,.18), transparent 60%),
radial-gradient(900px 600px at 90% 20%, rgba(110,231,255,.18), transparent 55%),
linear-gradient(180deg, var(--bg1), var(--bg2));
overflow-x: hidden;
}
.grid {
position: fixed;
inset: 0;
background-image:
linear-gradient(to right, rgba(255,255,255,.05) 1px, transparent 1px),
linear-gradient(to bottom, rgba(255,255,255,.05) 1px, transparent 1px);
background-size: 80px 80px;
mask-image: radial-gradient(60% 60% at 50% 30%, black 30%, transparent 75%);
opacity: .25;
pointer-events: none;
transform: translateZ(0);
animation: drift 14s linear infinite;
}
@keyframes drift {
0% { transform: translate3d(0,0,0); }
100% { transform: translate3d(-80px,-80px,0); }
}
.wrap {
max-width: 980px;
margin: 0 auto;
padding: 56px 20px 24px;
}
.hero {
display: grid;
gap: 18px;
align-items: start;
grid-template-columns: 1.15fr .85fr;
}
@media (max-width: 860px) {
.hero { grid-template-columns: 1fr; }
}
.brand {
display: flex;
gap: 12px;
align-items: center;
}
.badge {
width: 44px;
height: 44px;
border-radius: 12px;
background: linear-gradient(135deg, rgba(110,231,255,.25), rgba(167,139,250,.25));
border: 1px solid var(--stroke);
box-shadow: 0 18px 50px rgba(0,0,0,.4);
display: grid;
place-items: center;
}
.badge svg { opacity: .92; }
h1 {
margin: 0;
font-size: 42px;
letter-spacing: -0.02em;
line-height: 1.05;
}
.sub {
margin: 8px 0 0;
color: var(--muted);
font-size: 15px;
line-height: 1.6;
}
.card {
background: linear-gradient(180deg, var(--card), rgba(255,255,255,.03));
border: 1px solid var(--stroke);
border-radius: 18px;
padding: 18px;
box-shadow: 0 18px 60px rgba(0,0,0,.45);
backdrop-filter: blur(10px);
}
.card h2 {
margin: 0 0 10px;
font-size: 16px;
letter-spacing: .02em;
color: rgba(255,255,255,.9);
}
.list {
margin: 0;
padding-left: 18px;
color: var(--muted);
line-height: 1.7;
font-size: 14px;
}
label {
display: block;
font-size: 12px;
color: rgba(255,255,255,.72);
margin: 12px 0 6px;
}
input {
width: 100%;
padding: 12px 12px;
border-radius: 12px;
border: 1px solid rgba(255,255,255,.18);
background: rgba(0,0,0,.22);
color: var(--fg);
outline: none;
}
input:focus {
border-color: rgba(110,231,255,.45);
box-shadow: 0 0 0 4px rgba(110,231,255,.12);
}
.row {
display: grid;
grid-template-columns: 1fr;
gap: 12px;
}
.btn {
margin-top: 14px;
width: 100%;
padding: 12px 14px;
border-radius: 12px;
border: 1px solid rgba(255,255,255,.18);
background: linear-gradient(135deg, rgba(110,231,255,.20), rgba(167,139,250,.20));
color: rgba(255,255,255,.92);
font-weight: 600;
letter-spacing: .02em;
cursor: pointer;
transition: transform .08s ease, border-color .2s ease, background .2s ease;
}
.btn:hover { transform: translateY(-1px); border-color: rgba(255,255,255,.26); }
.btn:active { transform: translateY(0px); }
.hint {
margin-top: 10px;
color: rgba(255,255,255,.58);
font-size: 12px;
line-height: 1.6;
}
.footer {
margin-top: 26px;
display: flex;
justify-content: center;
}
.gitea {
display: inline-flex;
align-items: center;
gap: 10px;
padding: 10px 12px;
border-radius: 999px;
border: 1px solid rgba(255,255,255,.14);
background: rgba(255,255,255,.04);
text-decoration: none;
color: rgba(255,255,255,.8);
transition: background .2s ease, border-color .2s ease;
}
.gitea:hover { background: rgba(255,255,255,.06); border-color: rgba(255,255,255,.24); }
.gitea svg { opacity: .9; }
.sr { position: absolute; width: 1px; height: 1px; padding: 0; margin: -1px; overflow: hidden; clip: rect(0,0,0,0); white-space: nowrap; border: 0; }
</style>
</head>
<body>
<div class="grid" aria-hidden="true"></div>
<div class="wrap">
<div class="hero">
<div class="card">
<div class="brand">
<div class="badge" aria-hidden="true">
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2Z" stroke="rgba(255,255,255,.9)" stroke-width="1.6"/>
<path d="M7.5 12h9" stroke="rgba(255,255,255,.9)" stroke-width="1.6" stroke-linecap="round"/>
<path d="M12 7.5v9" stroke="rgba(255,255,255,.9)" stroke-width="1.6" stroke-linecap="round"/>
</svg>
</div>
<div>
<h1>Hugs-Proxy</h1>
<p class="sub">轻量的学术资源加速反向代理:支持常见 GitHub/Raw/Gist 下载格式,并带 SQLite 鉴权与访问审计。</p>
</div>
</div>
<ul class="list">
<li>把目标 URL 放在路径里即可加速:<span style="color:rgba(255,255,255,.78)"><b>https://github.com/...</b></span></li>
<li>Bearer Token 鉴权与使用次数控制</li>
<li>重定向 Location 改写,尽量全程走代理</li>
<li>大文件保护(>1GB 自动 302 回源)</li>
</ul>
</div>
<div class="card">
<h2>加速下载</h2>
<form class="row" action="/download" method="post">
<div>
<label for="url">GitHub 链接</label>
<input id="url" name="url" inputmode="url" autocomplete="url" placeholder="https://github.com/OWNER/REPO/releases/download/..." required />
</div>
<div>
<label for="token">TokenHugs-Proxy 签发)</label>
<input id="token" name="token" type="password" autocomplete="off" placeholder="粘贴 token" required />
</div>
<button class="btn" type="submit">下载(走加速代理)</button>
<p class="hint">提示:token 通过表单 POST 发送,不会出现在地址栏。也可使用 curl:` + "\n" + `curl -L -H \"Authorization: Bearer &lt;TOKEN&gt;\" \"https://hugs.you/https://github.com/...\"</p>
</form>
</div>
</div>
<div class="footer">
<a class="gitea" href="{{GITEA_REPO_URL}}" target="_blank" rel="noreferrer">
<span class="sr">Gitea Repo</span>
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
<path d="M8 19a4 4 0 0 1-4-4V7.8a3.8 3.8 0 0 1 3.8-3.8H16" stroke="rgba(255,255,255,.85)" stroke-width="1.6" stroke-linecap="round"/>
<path d="M10 5h7a3 3 0 0 1 3 3v11" stroke="rgba(255,255,255,.85)" stroke-width="1.6" stroke-linecap="round"/>
<path d="M10 9h10" stroke="rgba(255,255,255,.65)" stroke-width="1.6" stroke-linecap="round"/>
<path d="M10 13h10" stroke="rgba(255,255,255,.65)" stroke-width="1.6" stroke-linecap="round"/>
</svg>
<span>Gitea 项目仓库</span>
</a>
</div>
</div>
</body>
</html>
`
func init() {
// 1. 初始化列表
whiteList = parseList(whiteListStr)
blackList = parseList(blackListStr)
httpClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 0,
}
}
func main() {
cfg := config.Load()
var err error
store, err = db.NewStore(cfg.DBPath, cfg.BusyTimeoutMS)
if err != nil {
log.Fatal(err)
}
defer store.Close()
auditor = audit.NewLogger(store)
base := http.HandlerFunc(routeHandler)
protected := auth.Middleware(store, auditAuthFailure, base)
root := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/" && (r.Method == http.MethodGet || r.Method == http.MethodHead):
serveLandingPage(w, r, cfg)
return
case r.URL.Path == "/favicon.ico" || r.URL.Path == "/webicon.png":
serveWebIcon(w, r)
return
case r.URL.Path == "/download":
downloadHandler(w, r)
return
default:
protected.ServeHTTP(w, r)
return
}
})
http.Handle("/", root)
addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
log.Printf("服务器启动成功,正在监听 %s", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatal(err)
}
}
func serveWebIcon(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "./webicon.png")
}
func serveLandingPage(w http.ResponseWriter, r *http.Request, cfg config.Config) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
page := strings.ReplaceAll(landingPageHTML, "{{GITEA_REPO_URL}}", htmlAttrEscape(cfg.GiteaRepoURL))
if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK)
return
}
_, _ = io.WriteString(w, page)
}
func downloadHandler(w http.ResponseWriter, r *http.Request) {
// 为了避免 token 出现在 URL,下载入口仅支持表单 POST。
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// ParseForm 会读取 bodyapplication/x-www-form-urlencoded / multipart/form-data)。
if err := r.ParseForm(); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
urlInput := strings.TrimSpace(r.FormValue("url"))
tokenValue := strings.TrimSpace(r.FormValue("token"))
if tokenValue == "" {
auditRequestFailure(r, http.StatusUnauthorized, "missing_token", 0, 0, urlInput)
http.Error(w, "Missing token.", http.StatusUnauthorized)
return
}
// 先校验并消耗一次 token。
tok, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
if err != nil {
status, reason := classifyTokenError(err)
var userID int64
var tokenID int64
if !errors.Is(err, db.ErrNotFound) {
if t, getErr := store.GetToken(r.Context(), tokenValue); getErr == nil {
userID = t.UserID
tokenID = t.ID
}
}
auditRequestFailure(r, status, reason, userID, tokenID, urlInput)
http.Error(w, http.StatusText(status), status)
return
}
recorder := newStatusRecorder(w)
normalizedURL := urlInput
success := false
errorReason := ""
defer func() {
statusCode := recorder.StatusCode()
if statusCode >= 200 && statusCode < 400 {
success = true
}
if err := auditor.Log(r.Context(), audit.Entry{
RequestIP: clientIP(r),
UserID: tok.UserID,
HasUser: true,
TokenID: tok.ID,
HasToken: true,
OriginalURL: normalizedURL,
HTTPStatus: statusCode,
Success: success,
ErrorReason: errorReason,
CountAsSuccess: success,
}); err != nil {
log.Printf("audit log failed: %v", err)
}
}()
if urlInput == "" {
errorReason = "missing_url"
http.Error(recorder, "Missing url.", http.StatusBadRequest)
return
}
u, prepErr := prepareTargetURL(urlInput)
if prepErr != nil {
errorReason = prepErr.Reason
http.Error(recorder, prepErr.Message, prepErr.StatusCode)
return
}
normalizedURL = u
// 表单是 POST,但下载应当总是 GET。
r2 := r.Clone(r.Context())
r2.Method = http.MethodGet
r2.Body = nil
proxy(recorder, r2, u)
}
func classifyTokenError(err error) (int, string) {
switch {
case errors.Is(err, db.ErrNotFound):
return http.StatusForbidden, "invalid_token"
case errors.Is(err, db.ErrTokenDisabled):
return http.StatusForbidden, "token_disabled"
case errors.Is(err, db.ErrTokenExpired):
return http.StatusForbidden, "token_expired"
case errors.Is(err, db.ErrTokenExhausted):
return http.StatusForbidden, "token_exhausted"
default:
return http.StatusInternalServerError, "token_lookup_error"
}
}
func auditRequestFailure(r *http.Request, statusCode int, reason string, userID int64, tokenID int64, originalURL string) {
entry := audit.Entry{
RequestIP: clientIP(r),
OriginalURL: strings.TrimSpace(originalURL),
HTTPStatus: statusCode,
Success: false,
ErrorReason: reason,
CountAsSuccess: false,
}
if entry.OriginalURL == "" {
entry.OriginalURL = strings.TrimPrefix(r.URL.RequestURI(), "/")
}
if userID > 0 {
entry.UserID = userID
entry.HasUser = true
}
if tokenID > 0 {
entry.TokenID = tokenID
entry.HasToken = true
}
if err := auditor.Log(context.Background(), entry); err != nil {
log.Printf("audit failure log failed: %v", err)
}
}
func htmlAttrEscape(s string) string {
// 只用于 attribute:非常轻量的转义,避免引入额外依赖。
// 允许 http(s) URL;其他字符做最小替换。
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "\"", "&quot;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
return s
}
type targetPrepareError struct {
StatusCode int
Reason string
Message string
}
func prepareTargetURL(raw string) (string, *targetPrepareError) {
u := strings.TrimSpace(raw)
if u == "" {
return "", &targetPrepareError{StatusCode: http.StatusBadRequest, Reason: "missing_url", Message: "Missing url."}
}
if strings.HasPrefix(u, "https:/") && !strings.HasPrefix(u, "https://") {
u = "https://" + strings.TrimPrefix(u, "https:/")
} else if strings.HasPrefix(u, "http:/") && !strings.HasPrefix(u, "http://") {
u = "http://" + strings.TrimPrefix(u, "http:/")
}
if !strings.HasPrefix(u, "http") {
u = "https://" + u
}
m := checkURL(u)
if m == nil {
return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "invalid_input", Message: "Invalid input."}
}
if len(whiteList) > 0 {
allowed := false
for _, i := range whiteList {
if matchRule(m, i) {
allowed = true
break
}
}
if !allowed {
return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_white_list", Message: "Forbidden by white list."}
}
}
for _, i := range blackList {
if matchRule(m, i) {
return "", &targetPrepareError{StatusCode: http.StatusForbidden, Reason: "forbidden_by_black_list", Message: "Forbidden by black list."}
}
}
// 将网页浏览的blob链接统一替换为raw下载链接
if exp2.MatchString(u) {
u = strings.Replace(u, "/blob/", "/raw/", 1)
}
parsedURL, err := url.Parse(u)
if err == nil {
u = parsedURL.String()
}
return u, nil
}
func routeHandler(w http.ResponseWriter, r *http.Request) {
authInfo, ok := auth.FromContext(r.Context())
if !ok {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
recorder := newStatusRecorder(w)
originalInput := strings.TrimPrefix(r.URL.RequestURI(), "/")
normalizedURL := originalInput
success := false
errorReason := ""
defer func() {
statusCode := recorder.StatusCode()
if statusCode >= 200 && statusCode < 400 {
success = true
}
if err := auditor.Log(r.Context(), audit.Entry{
RequestIP: clientIP(r),
UserID: authInfo.UserID,
HasUser: true,
TokenID: authInfo.TokenID,
HasToken: true,
OriginalURL: normalizedURL,
HTTPStatus: statusCode,
Success: success,
ErrorReason: errorReason,
CountAsSuccess: success,
}); err != nil {
log.Printf("audit log failed: %v", err)
}
}()
u, prepErr := prepareTargetURL(originalInput)
if prepErr != nil {
errorReason = prepErr.Reason
http.Error(recorder, prepErr.Message, prepErr.StatusCode)
return
}
// 发起代理请求
normalizedURL = u
proxy(recorder, r, u)
}
func proxy(w http.ResponseWriter, r *http.Request, targetURL string) {
// 修正由于多重代理可能造成的 URL 格式错误
if strings.HasPrefix(targetURL, "https:/") && !strings.HasPrefix(targetURL, "https://") {
targetURL = "https://" + targetURL[7:]
}
req, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
// 拷贝 Header 但不包含 Host
for k, vv := range r.Header {
if strings.EqualFold(k, "Host") {
continue
}
for _, v := range vv {
req.Header.Add(k, v)
}
}
resp, err := httpClient.Do(req)
if err != nil {
http.Error(w, "server error "+err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 校验大文件限制
if resp.ContentLength > sizeLimit {
http.Redirect(w, r, targetURL, http.StatusFound)
return
}
// 处理重定向 Location Header
if loc := resp.Header.Get("Location"); loc != "" {
if checkURL(loc) != nil {
resp.Header.Set("Location", "/"+loc)
} else {
// 如果不符合 Github 下载格式,则递归代理目标地址
proxy(w, r, loc)
return
}
}
// 拷贝响应的 Header 并写入状态码
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// Go 会原生以 Stream 方式拷贝数据给 Client 端
io.Copy(w, resp.Body)
}
func auditAuthFailure(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) {
auditRequestFailure(r, statusCode, reason, userID, tokenID, "")
}
func clientIP(r *http.Request) string {
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
return xrip
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err == nil && host != "" {
return host
}
if strings.TrimSpace(r.RemoteAddr) != "" {
return strings.TrimSpace(r.RemoteAddr)
}
return "unknown"
}
type statusRecorder struct {
http.ResponseWriter
statusCode int
}
func newStatusRecorder(w http.ResponseWriter) *statusRecorder {
return &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
}
func (r *statusRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
func (r *statusRecorder) StatusCode() int {
return r.statusCode
}
func matchRule(m []string, i []string) bool {
// m 通常为 [author, repo] 或 [author]
if len(i) == 1 {
return len(m) >= 1 && m[0] == i[0]
} else if len(i) == 2 {
if i[0] == "*" && len(m) >= 2 && m[1] == i[1] {
return true
}
return len(m) >= 2 && m[0] == i[0] && m[1] == i[1]
}
return false
}
func checkURL(u string) []string {
exps := []*regexp.Regexp{exp1, exp2, exp3, exp4, exp5}
for _, exp := range exps {
matches := exp.FindStringSubmatch(u)
if matches != nil {
var result []string
for i, name := range exp.SubexpNames() {
if i > 0 && name != "" {
result = append(result, matches[i])
}
}
return result
}
}
return nil
}
func parseList(s string) [][]string {
var res [][]string
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
parts := strings.Split(line, "/")
var cleaned []string
for _, p := range parts {
cleaned = append(cleaned, strings.ReplaceAll(p, " ", ""))
}
res = append(res, cleaned)
}
}
return res
}