diff options
| author | Anhgelus Morhtuuzh <william@herges.fr> | 2025-12-26 23:15:06 +0100 |
|---|---|---|
| committer | Anhgelus Morhtuuzh <william@herges.fr> | 2025-12-26 23:15:06 +0100 |
| commit | 563494cb7779a840bd650f0cf215b6e6ae7080ed (patch) | |
| tree | 8a2630444cbfbbe6db9c073450467eeb01e028cd /backend/router.go | |
| parent | af11793ca48244eafd7dcdf66ac1dff83995a775 (diff) | |
feat(backend): introduce rate limit to protect auth
Diffstat (limited to 'backend/router.go')
| -rw-r--r-- | backend/router.go | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/backend/router.go b/backend/router.go index c742006..0a58b26 100644 --- a/backend/router.go +++ b/backend/router.go @@ -25,6 +25,7 @@ const ( assetsFSKey = "assets_fs" debugKey = "debug" loginKey = "login" + ipAdressKey = "ip_adress" ) //go:embed templates @@ -71,7 +72,15 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { }) }) // context - setContext := func(ctx context.Context) context.Context { + setContext := func(ctx context.Context, r *http.Request) context.Context { + ip := r.Header.Get("X-Real-Ip") + if ip == "" { + ip = r.Header.Get("X-Forwarded-For") + } + if ip == "" { + ip = r.RemoteAddr + } + ctx = context.WithValue(ctx, ipAdressKey, ip) ctx = context.WithValue(ctx, configKey, cfg) ctx = context.WithValue(ctx, assetsFSKey, assets) ctx = context.WithValue(ctx, debugKey, debug) @@ -80,7 +89,7 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r.WithContext( - setContext(r.Context()), + setContext(r.Context(), r), )) }) }) @@ -90,10 +99,17 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { _, pass, ok := r.BasicAuth() ctx := r.Context() if ok { + if handleTimeout(ctx) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } cfg := ctx.Value(configKey).(*Config) passHash := sha256.Sum256([]byte(pass)) rightPassHash := sha256.Sum256([]byte(cfg.AdminPassword)) ok = subtle.ConstantTimeCompare(passHash[:], rightPassHash[:]) == 1 + if ok { + resetTimeout(ctx) + } } ctx = context.WithValue(ctx, loginKey, ok) next.ServeHTTP(w, r.WithContext(ctx)) @@ -111,7 +127,8 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - if err := storage.UpdateStats(setContext(ctx), r, cfg.Domain); err != nil { + err := storage.UpdateStats(setContext(ctx, r), r, cfg.Domain) + if err != nil { logger.Error("updating stats", "error", err) } }(r.Context(), r) |
