62e076111f
Co-authored-by: Copilot <copilot@github.com>
91 lines
2.2 KiB
Go
91 lines
2.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"gitea.gangary.cn/gary/Hugs-Proxy/internal/db"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const authInfoKey contextKey = "authInfo"
|
|
|
|
type AuthInfo struct {
|
|
UserID int64
|
|
TokenID int64
|
|
Token string
|
|
}
|
|
|
|
type FailureRecorder func(r *http.Request, token string, statusCode int, reason string, userID int64, tokenID int64)
|
|
|
|
func Middleware(store *db.Store, onFailure FailureRecorder, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenValue, ok := bearerToken(r.Header.Get("Authorization"))
|
|
if !ok {
|
|
http.Error(w, "Missing bearer token.", http.StatusUnauthorized)
|
|
if onFailure != nil {
|
|
onFailure(r, "", http.StatusUnauthorized, "missing_token", 0, 0)
|
|
}
|
|
return
|
|
}
|
|
|
|
token, err := store.ValidateAndConsumeToken(r.Context(), tokenValue)
|
|
if err != nil {
|
|
status := http.StatusForbidden
|
|
reason := "invalid_token"
|
|
switch {
|
|
case errors.Is(err, db.ErrNotFound):
|
|
reason = "invalid_token"
|
|
case errors.Is(err, db.ErrTokenDisabled):
|
|
reason = "token_disabled"
|
|
case errors.Is(err, db.ErrTokenExpired):
|
|
reason = "token_expired"
|
|
case errors.Is(err, db.ErrTokenExhausted):
|
|
reason = "token_exhausted"
|
|
default:
|
|
status = http.StatusInternalServerError
|
|
reason = "token_lookup_error"
|
|
}
|
|
http.Error(w, http.StatusText(status), status)
|
|
if onFailure != nil {
|
|
if token != nil {
|
|
onFailure(r, tokenValue, status, reason, token.UserID, token.ID)
|
|
} else {
|
|
onFailure(r, tokenValue, status, reason, 0, 0)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
authInfo := AuthInfo{UserID: token.UserID, TokenID: token.ID, Token: tokenValue}
|
|
ctx := context.WithValue(r.Context(), authInfoKey, authInfo)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func FromContext(ctx context.Context) (AuthInfo, bool) {
|
|
v := ctx.Value(authInfoKey)
|
|
if v == nil {
|
|
return AuthInfo{}, false
|
|
}
|
|
info, ok := v.(AuthInfo)
|
|
return info, ok
|
|
}
|
|
|
|
func bearerToken(authorizationHeader string) (string, bool) {
|
|
parts := strings.Fields(strings.TrimSpace(authorizationHeader))
|
|
if len(parts) != 2 {
|
|
return "", false
|
|
}
|
|
if !strings.EqualFold(parts[0], "Bearer") {
|
|
return "", false
|
|
}
|
|
if parts[1] == "" {
|
|
return "", false
|
|
}
|
|
return parts[1], true
|
|
}
|