aboutsummaryrefslogtreecommitdiff
path: root/backend/router.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/router.go')
-rw-r--r--backend/router.go23
1 files changed, 20 insertions, 3 deletions
diff --git a/backend/router.go b/backend/router.go
index c742006..0a58b26 100644
--- a/backend/router.go
+++ b/backend/router.go
@@ -25,6 +25,7 @@ const (
assetsFSKey = "assets_fs"
debugKey = "debug"
loginKey = "login"
+ ipAdressKey = "ip_adress"
)
//go:embed templates
@@ -71,7 +72,15 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
})
})
// context
- setContext := func(ctx context.Context) context.Context {
+ setContext := func(ctx context.Context, r *http.Request) context.Context {
+ ip := r.Header.Get("X-Real-Ip")
+ if ip == "" {
+ ip = r.Header.Get("X-Forwarded-For")
+ }
+ if ip == "" {
+ ip = r.RemoteAddr
+ }
+ ctx = context.WithValue(ctx, ipAdressKey, ip)
ctx = context.WithValue(ctx, configKey, cfg)
ctx = context.WithValue(ctx, assetsFSKey, assets)
ctx = context.WithValue(ctx, debugKey, debug)
@@ -80,7 +89,7 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(
- setContext(r.Context()),
+ setContext(r.Context(), r),
))
})
})
@@ -90,10 +99,17 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
_, pass, ok := r.BasicAuth()
ctx := r.Context()
if ok {
+ if handleTimeout(ctx) {
+ http.Error(w, "Too many requests", http.StatusTooManyRequests)
+ return
+ }
cfg := ctx.Value(configKey).(*Config)
passHash := sha256.Sum256([]byte(pass))
rightPassHash := sha256.Sum256([]byte(cfg.AdminPassword))
ok = subtle.ConstantTimeCompare(passHash[:], rightPassHash[:]) == 1
+ if ok {
+ resetTimeout(ctx)
+ }
}
ctx = context.WithValue(ctx, loginKey, ok)
next.ServeHTTP(w, r.WithContext(ctx))
@@ -111,7 +127,8 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
- if err := storage.UpdateStats(setContext(ctx), r, cfg.Domain); err != nil {
+ err := storage.UpdateStats(setContext(ctx, r), r, cfg.Domain)
+ if err != nil {
logger.Error("updating stats", "error", err)
}
}(r.Context(), r)