aboutsummaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/router.go22
-rw-r--r--backend/stats.go26
2 files changed, 37 insertions, 11 deletions
diff --git a/backend/router.go b/backend/router.go
index 7f9de5c..14abc06 100644
--- a/backend/router.go
+++ b/backend/router.go
@@ -22,6 +22,7 @@ const (
configKey = "config"
assetsFSKey = "assets_fs"
debugKey = "debug"
+ dbKey = "db"
)
//go:embed templates
@@ -87,20 +88,27 @@ func NewRouter(debug bool, cfg *Config, db *sql.DB, assets fs.FS) *chi.Mux {
})
})
// context
+ setContext := func(ctx context.Context) context.Context {
+ ctx = context.WithValue(ctx, configKey, cfg)
+ ctx = context.WithValue(ctx, assetsFSKey, assets)
+ ctx = context.WithValue(ctx, debugKey, debug)
+ return context.WithValue(ctx, dbKey, db)
+ }
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ctx := context.WithValue(r.Context(), configKey, cfg)
- ctx = context.WithValue(ctx, assetsFSKey, assets)
- ctx = context.WithValue(ctx, debugKey, debug)
- next.ServeHTTP(w, r.WithContext(ctx))
+ next.ServeHTTP(w, r.WithContext(setContext(r.Context())))
})
})
// stats
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if err := UpdateStats(r.Context(), db, r); err != nil {
- slog.Error("updating stats", "error", err)
- }
+ go func(r *http.Request) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+ if err := UpdateStats(setContext(ctx), r); err != nil {
+ slog.Error("updating stats", "error", err)
+ }
+ }(r)
next.ServeHTTP(w, r)
})
})
diff --git a/backend/stats.go b/backend/stats.go
index e2ad08b..2b07221 100644
--- a/backend/stats.go
+++ b/backend/stats.go
@@ -3,14 +3,20 @@ package backend
import (
"context"
"database/sql"
+ "fmt"
+ "log/slog"
"net/http"
"regexp"
"strings"
)
-var trimRefererReg = regexp.MustCompile(`https?://([a-z-0-9.]+)(:\d+)?/.*`)
+var trimRefererReg = regexp.MustCompile(`https?://([a-z-0-9.]+(:\d+)?)/.*`)
-func UpdateStats(ctx context.Context, db *sql.DB, r *http.Request) error {
+func getDB(ctx context.Context) *sql.DB {
+ return ctx.Value(dbKey).(*sql.DB)
+}
+
+func UpdateStats(ctx context.Context, r *http.Request) error {
target := r.URL.Path
if strings.HasPrefix(target, "/assets") || strings.HasPrefix(target, "/static") {
return nil
@@ -24,20 +30,32 @@ func UpdateStats(ctx context.Context, db *sql.DB, r *http.Request) error {
return nil
}
ref = subs[1]
- if ref == ctx.Value(configKey).(*Config).Domain {
+ if ref == ctx.Value(configKey).(*Config).Domain || ref == fmt.Sprintf("localhost:%d", 8000) {
ref = subs[0][strings.Index(subs[0], ref)+len(ref):]
+ if ref == target {
+ return nil
+ }
}
+ db := getDB(ctx)
rows, err := db.QueryContext(ctx, "SELECT id, visit FROM stats WHERE origin = ? AND target = ?", ref, target)
if err != nil {
return err
}
+ defer func() {
+ if err == nil {
+ slog.Debug("stats updated")
+ }
+ }()
if !rows.Next() {
_, err = db.ExecContext(ctx, "INSERT INTO stats (origin, target, visit) VALUES (?, ?, 1)", ref, target)
return err
}
var id uint
var nb uint
- rows.Scan(&id, &nb)
+ err = rows.Scan(&id, &nb)
+ if err != nil {
+ return err
+ }
err = rows.Close()
if err != nil {
return err