aboutsummaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/admin.go63
-rw-r--r--backend/logger.go9
-rw-r--r--backend/router.go23
3 files changed, 88 insertions, 7 deletions
diff --git a/backend/admin.go b/backend/admin.go
index 9cb5376..e5f5696 100644
--- a/backend/admin.go
+++ b/backend/admin.go
@@ -1,8 +1,13 @@
package backend
import (
+ "context"
+ "math"
"net/http"
"strconv"
+ "strings"
+ "sync"
+ "time"
"git.anhgelus.world/anhgelus/small-web/backend/storage"
"github.com/go-chi/chi/v5"
@@ -16,6 +21,64 @@ type adminData struct {
CurrentPage int
}
+type to struct {
+ n int
+ since time.Time
+}
+
+type tos struct {
+ mu sync.Mutex
+ tos map[string]*to
+}
+
+var timeouts = tos{tos: make(map[string]*to)}
+
+func handleTimeout(ctx context.Context) bool {
+ ip := ctx.Value(ipAdressKey).(string)
+ parsed := strings.Split(ip, ":")
+ ip = parsed[0]
+
+ timeouts.mu.Lock()
+ defer timeouts.mu.Unlock()
+
+ v, ok := timeouts.tos[ip]
+ if !ok {
+ timeouts.tos[ip] = &to{n: 1}
+ return false
+ }
+ dur := func() time.Duration { return time.Duration(math.Pow10(v.n/4)) * time.Second }
+ if time.Since(v.since) <= dur() {
+ return true
+ }
+ v.n++
+ if v.n%4 != 0 {
+ return false
+ }
+ v.since = time.Now()
+ GetLogger(ctx).Warn("rate limiting IP", "ip", ip, "duration", dur().String())
+ go func(v *to) {
+ time.Sleep(3 * time.Hour)
+ v.n = max(v.n-4, 0)
+ }(v)
+ return true
+}
+
+func resetTimeout(ctx context.Context) {
+ ip := ctx.Value(ipAdressKey).(string)
+ parsed := strings.Split(ip, ":")
+ ip = parsed[0]
+
+ timeouts.mu.Lock()
+ defer timeouts.mu.Unlock()
+
+ v, ok := timeouts.tos[ip]
+ if !ok {
+ return
+ }
+ v.n = 0
+ v.since = time.Unix(0, 0)
+}
+
func HandleAdmin(r *chi.Mux) {
r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
diff --git a/backend/logger.go b/backend/logger.go
index da5fadb..f3dbc3c 100644
--- a/backend/logger.go
+++ b/backend/logger.go
@@ -24,10 +24,11 @@ type customWriter struct {
}
func (c *customWriter) WriteHeader(statusCode int) {
- c.statusCode = statusCode
- if statusCode != c.statusCode {
- c.ResponseWriter.WriteHeader(statusCode)
+ if statusCode == c.statusCode {
+ return
}
+ c.statusCode = statusCode
+ c.ResponseWriter.WriteHeader(statusCode)
}
func GetStatusCode(ctx context.Context) func() int {
@@ -58,7 +59,7 @@ func SetLogger(l *slog.Logger) func(http.Handler) http.Handler {
next.ServeHTTP(ww, r.WithContext(ctx))
- if ww.statusCode == http.StatusNotFound {
+ if ww.statusCode == http.StatusNotFound || ww.statusCode == http.StatusTooManyRequests {
return
}
var lvl slog.Level
diff --git a/backend/router.go b/backend/router.go
index c742006..0a58b26 100644
--- a/backend/router.go
+++ b/backend/router.go
@@ -25,6 +25,7 @@ const (
assetsFSKey = "assets_fs"
debugKey = "debug"
loginKey = "login"
+ ipAdressKey = "ip_adress"
)
//go:embed templates
@@ -71,7 +72,15 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
})
})
// context
- setContext := func(ctx context.Context) context.Context {
+ setContext := func(ctx context.Context, r *http.Request) context.Context {
+ ip := r.Header.Get("X-Real-Ip")
+ if ip == "" {
+ ip = r.Header.Get("X-Forwarded-For")
+ }
+ if ip == "" {
+ ip = r.RemoteAddr
+ }
+ ctx = context.WithValue(ctx, ipAdressKey, ip)
ctx = context.WithValue(ctx, configKey, cfg)
ctx = context.WithValue(ctx, assetsFSKey, assets)
ctx = context.WithValue(ctx, debugKey, debug)
@@ -80,7 +89,7 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(
- setContext(r.Context()),
+ setContext(r.Context(), r),
))
})
})
@@ -90,10 +99,17 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
_, pass, ok := r.BasicAuth()
ctx := r.Context()
if ok {
+ if handleTimeout(ctx) {
+ http.Error(w, "Too many requests", http.StatusTooManyRequests)
+ return
+ }
cfg := ctx.Value(configKey).(*Config)
passHash := sha256.Sum256([]byte(pass))
rightPassHash := sha256.Sum256([]byte(cfg.AdminPassword))
ok = subtle.ConstantTimeCompare(passHash[:], rightPassHash[:]) == 1
+ if ok {
+ resetTimeout(ctx)
+ }
}
ctx = context.WithValue(ctx, loginKey, ok)
next.ServeHTTP(w, r.WithContext(ctx))
@@ -111,7 +127,8 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
- if err := storage.UpdateStats(setContext(ctx), r, cfg.Domain); err != nil {
+ err := storage.UpdateStats(setContext(ctx, r), r, cfg.Domain)
+ if err != nil {
logger.Error("updating stats", "error", err)
}
}(r.Context(), r)