diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/admin.go | 63 | ||||
| -rw-r--r-- | backend/logger.go | 9 | ||||
| -rw-r--r-- | backend/router.go | 23 |
3 files changed, 88 insertions, 7 deletions
diff --git a/backend/admin.go b/backend/admin.go index 9cb5376..e5f5696 100644 --- a/backend/admin.go +++ b/backend/admin.go @@ -1,8 +1,13 @@ package backend import ( + "context" + "math" "net/http" "strconv" + "strings" + "sync" + "time" "git.anhgelus.world/anhgelus/small-web/backend/storage" "github.com/go-chi/chi/v5" @@ -16,6 +21,64 @@ type adminData struct { CurrentPage int } +type to struct { + n int + since time.Time +} + +type tos struct { + mu sync.Mutex + tos map[string]*to +} + +var timeouts = tos{tos: make(map[string]*to)} + +func handleTimeout(ctx context.Context) bool { + ip := ctx.Value(ipAdressKey).(string) + parsed := strings.Split(ip, ":") + ip = parsed[0] + + timeouts.mu.Lock() + defer timeouts.mu.Unlock() + + v, ok := timeouts.tos[ip] + if !ok { + timeouts.tos[ip] = &to{n: 1} + return false + } + dur := func() time.Duration { return time.Duration(math.Pow10(v.n/4)) * time.Second } + if time.Since(v.since) <= dur() { + return true + } + v.n++ + if v.n%4 != 0 { + return false + } + v.since = time.Now() + GetLogger(ctx).Warn("rate limiting IP", "ip", ip, "duration", dur().String()) + go func(v *to) { + time.Sleep(3 * time.Hour) + v.n = max(v.n-4, 0) + }(v) + return true +} + +func resetTimeout(ctx context.Context) { + ip := ctx.Value(ipAdressKey).(string) + parsed := strings.Split(ip, ":") + ip = parsed[0] + + timeouts.mu.Lock() + defer timeouts.mu.Unlock() + + v, ok := timeouts.tos[ip] + if !ok { + return + } + v.n = 0 + v.since = time.Unix(0, 0) +} + func HandleAdmin(r *chi.Mux) { r.Get("/admin", func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/backend/logger.go b/backend/logger.go index da5fadb..f3dbc3c 100644 --- a/backend/logger.go +++ b/backend/logger.go @@ -24,10 +24,11 @@ type customWriter struct { } func (c *customWriter) WriteHeader(statusCode int) { - c.statusCode = statusCode - if statusCode != c.statusCode { - c.ResponseWriter.WriteHeader(statusCode) + if statusCode == c.statusCode { + return } + c.statusCode = statusCode + c.ResponseWriter.WriteHeader(statusCode) } func GetStatusCode(ctx context.Context) func() int { @@ -58,7 +59,7 @@ func SetLogger(l *slog.Logger) func(http.Handler) http.Handler { next.ServeHTTP(ww, r.WithContext(ctx)) - if ww.statusCode == http.StatusNotFound { + if ww.statusCode == http.StatusNotFound || ww.statusCode == http.StatusTooManyRequests { return } var lvl slog.Level diff --git a/backend/router.go b/backend/router.go index c742006..0a58b26 100644 --- a/backend/router.go +++ b/backend/router.go @@ -25,6 +25,7 @@ const ( assetsFSKey = "assets_fs" debugKey = "debug" loginKey = "login" + ipAdressKey = "ip_adress" ) //go:embed templates @@ -71,7 +72,15 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { }) }) // context - setContext := func(ctx context.Context) context.Context { + setContext := func(ctx context.Context, r *http.Request) context.Context { + ip := r.Header.Get("X-Real-Ip") + if ip == "" { + ip = r.Header.Get("X-Forwarded-For") + } + if ip == "" { + ip = r.RemoteAddr + } + ctx = context.WithValue(ctx, ipAdressKey, ip) ctx = context.WithValue(ctx, configKey, cfg) ctx = context.WithValue(ctx, assetsFSKey, assets) ctx = context.WithValue(ctx, debugKey, debug) @@ -80,7 +89,7 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r.WithContext( - setContext(r.Context()), + setContext(r.Context(), r), )) }) }) @@ -90,10 +99,17 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { _, pass, ok := r.BasicAuth() ctx := r.Context() if ok { + if handleTimeout(ctx) { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } cfg := ctx.Value(configKey).(*Config) passHash := sha256.Sum256([]byte(pass)) rightPassHash := sha256.Sum256([]byte(cfg.AdminPassword)) ok = subtle.ConstantTimeCompare(passHash[:], rightPassHash[:]) == 1 + if ok { + resetTimeout(ctx) + } } ctx = context.WithValue(ctx, loginKey, ok) next.ServeHTTP(w, r.WithContext(ctx)) @@ -111,7 +127,8 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux { } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - if err := storage.UpdateStats(setContext(ctx), r, cfg.Domain); err != nil { + err := storage.UpdateStats(setContext(ctx, r), r, cfg.Domain) + if err != nil { logger.Error("updating stats", "error", err) } }(r.Context(), r) |
