Files
Hugs-Proxy/internal/auth/middleware.go
T
gary 62e076111f 初步实现鉴权
Co-authored-by: Copilot <copilot@github.com>
2026-04-23 20:21:35 +08:00

91 lines
2.2 KiB
Go

package auth
import (
"context"
"errors"
"net/http"
"strings"
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
)
type contextKey string
const authInfoKey contextKey = "authInfo"
type AuthInfo struct {
UserID int64
TokenID int64
Token string
}
type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64)
func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenValue, ok := bearerToken(r.Header.Get("Authorization"))
if !ok {
http.Error(w, "Missing bearer token.", http.StatusUnauthorized)
if onFailure != nil {
onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0)
}
return
}
token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
if err != nil {
status := http.StatusForbidden
reason := "invalid_token"
switch {
case errors.Is(err, db.ErrNotFound):
reason = "invalid_token"
case errors.Is(err, db.ErrTokenDisabled):
reason = "token_disabled"
case errors.Is(err, db.ErrTokenExpired):
reason = "token_expired"
case errors.Is(err, db.ErrTokenExhausted):
reason = "token_exhausted"
default:
status = http.StatusInternalServerError
reason = "token_lookup_error"
}
http.Error(w, http.StatusText(status), status)
if onFailure != nil {
if token != nil {
onFailure(r, tokenValue, status, reason, token.UserID, token.ID)
} else {
onFailure(r, tokenValue, status, reason, 0, 0)
}
}
return
}
authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue}
ctx := context.WithValue(r.Context(), authInfoKey, authInfo)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func FromContext(ctx context.Context) (AuthInfo, bool) {
v := ctx.Value(authInfoKey)
if v == nil {
return AuthInfo{}, false
}
info, ok := v.(AuthInfo)
return info, ok
}
func bearerToken(authorizationHeader string) (string, bool) {
parts := strings.Fields(strings.TrimSpace(authorizationHeader))
if len(parts) != 2 {
return "", false
}
if !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
if parts[1] == "" {
return "", false
}
return parts[1], true
}