package auth import ( "context" "errors" "net/http" "strings" "gitea.gangary.cn/gary/Hugs-Proxy/internal/db" ) type contextKey string const authInfoKey contextKey = "authInfo" type AuthInfo struct { UserID int64 TokenID int64 Token string } type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64) func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenValue, ok := bearerToken(r.Header.Get("Authorization")) if !ok { http.Error(w, "Missing bearer token.", http.StatusUnauthorized) if onFailure != nil { onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0) } return } token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue) if err != nil { status := http.StatusForbidden reason := "invalid_token" switch { case errors.Is(err, db.ErrNotFound): reason = "invalid_token" case errors.Is(err, db.ErrTokenDisabled): reason = "token_disabled" case errors.Is(err, db.ErrTokenExpired): reason = "token_expired" case errors.Is(err, db.ErrTokenExhausted): reason = "token_exhausted" default: status = http.StatusInternalServerError reason = "token_lookup_error" } http.Error(w, http.StatusText(status), status) if onFailure != nil { if token != nil { onFailure(r, tokenValue, status, reason, token.UserID, token.ID) } else { onFailure(r, tokenValue, status, reason, 0, 0) } } return } authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue} ctx := context.WithValue(r.Context(), authInfoKey, authInfo) next.ServeHTTP(w, r.WithContext(ctx)) }) } func FromContext(ctx context.Context) (AuthInfo, bool) { v := ctx.Value(authInfoKey) if v == nil { return AuthInfo{}, false } info, ok := v.(AuthInfo) return info, ok } func bearerToken(authorizationHeader string) (string, bool) { parts := strings.Fields(strings.TrimSpace(authorizationHeader)) if len(parts) != 2 { return "", false } if !strings.EqualFold(parts[0], "Bearer") { return "", false } if parts[1] == "" { return "", false } return parts[1], true }