diff options
| author | Anhgelus Morhtuuzh <william@herges.fr> | 2026-01-02 15:42:05 +0100 |
|---|---|---|
| committer | Anhgelus Morhtuuzh <william@herges.fr> | 2026-01-02 15:42:05 +0100 |
| commit | e840a9baf47f47bd533fca96ae341b0f4b1196cf (patch) | |
| tree | e9c2f6253c86113a55e93d3a1f76425c154ae54c /backend | |
| parent | 20a69a3f84efde6219798f7db81a5aadca03fba1 (diff) | |
feat(backend): clean rate limit
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/admin.go | 26 | ||||
| -rw-r--r-- | backend/router.go | 16 |
2 files changed, 31 insertions, 11 deletions
diff --git a/backend/admin.go b/backend/admin.go index b5cd695..6fc61eb 100644 --- a/backend/admin.go +++ b/backend/admin.go @@ -32,7 +32,24 @@ type tos struct { var timeouts = tos{tos: make(map[string]*to)} -func handleTimeout(ctx context.Context) bool { +func rateLimitDuration(n int) time.Duration { + return time.Duration(math.Pow10(n/4)) * time.Second +} + +func isRateLimited(ctx context.Context) bool { + ip := ctx.Value(storage.IPAddressKey).(string) + + timeouts.mu.Lock() + defer timeouts.mu.Unlock() + + v, ok := timeouts.tos[ip] + if !ok { + return false + } + return time.Since(v.since) <= rateLimitDuration(v.n) +} + +func rateLimit(ctx context.Context) bool { ip := ctx.Value(storage.IPAddressKey).(string) timeouts.mu.Lock() @@ -43,8 +60,7 @@ func handleTimeout(ctx context.Context) bool { timeouts.tos[ip] = &to{n: 1} return false } - dur := func() time.Duration { return time.Duration(math.Pow10(v.n/4)) * time.Second } - if time.Since(v.since) <= dur() { + if time.Since(v.since) <= rateLimitDuration(v.n) { return true } v.n++ @@ -52,7 +68,7 @@ func handleTimeout(ctx context.Context) bool { return false } v.since = time.Now() - GetLogger(ctx).Warn("rate limiting IP", "ip", ip, "duration", dur().String()) + GetLogger(ctx).Warn("rate limiting IP", "ip", ip, "duration", rateLimitDuration(v.n).String()) go func(v *to, ip string) { time.Sleep(3 * time.Hour) v.n = max(v.n-4, 0) @@ -65,7 +81,7 @@ func handleTimeout(ctx context.Context) bool { return true } -func resetTimeout(ctx context.Context) { +func resetRateLimit(ctx context.Context) { ip := ctx.Value(storage.IPAddressKey).(string) timeouts.mu.Lock() diff --git a/backend/router.go b/backend/router.go index 3928119..60223d3 100644 --- a/backend/router.go +++ b/backend/router.go @@ -98,25 +98,29 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { // login r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, pass, ok := r.BasicAuth() ctx := r.Context() + if isRateLimited(ctx) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + _, pass, ok := r.BasicAuth() if ok { - if handleTimeout(ctx) { - http.Error(w, "Too many requests", http.StatusTooManyRequests) - return - } cfg := ctx.Value(configKey).(*Config) passHash := sha256.Sum256([]byte(pass)) rightPassHash := sha256.Sum256([]byte(cfg.AdminPassword)) ok = subtle.ConstantTimeCompare(passHash[:], rightPassHash[:]) == 1 if ok { - resetTimeout(ctx) + resetRateLimit(ctx) + } else if rateLimit(ctx) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return } } ctx = context.WithValue(ctx, loginKey, ok) next.ServeHTTP(w, r.WithContext(ctx)) }) }) + // stats r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) |
