package Session import ( "context" "crypto/rand" "encoding/hex" "net/http" "time" ) // returns the session identified by session_id or creates a new one and returns that func (sm *sessionManager) Session(session_id string) *session { sm.mtx.Lock() s, ok := sm.sessions[session_id] if !ok { s = &session{ data: make(map[string]any), expiry: time.Now().Add(Session_Lifetime), } sm.sessions[session_id] = s sm.mtx.Unlock() } else { sm.mtx.Unlock() s.mtx.Lock() if time.Now().After(s.expiry) { // expired, clear it s.data = make(map[string]any) s.expiry = time.Now().Add(Session_Lifetime) } else { // valid, extend expiry s.expiry = time.Now().Add(Session_Lifetime) } s.mtx.Unlock() } return s } func GetSession(ctx context.Context) *session { s, _ := ctx.Value("session").(*session) return s } func (sm *sessionManager) SessionMW(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var sid string cookie, err := r.Cookie(cookie_key) if err == http.ErrNoCookie { sid = newSessionID() } else if err == nil { sid = cookie.Value } else { http.Error(w, "Invalid session", http.StatusBadRequest) return } http.SetCookie(w, &http.Cookie{ Name: cookie_key, Value: sid, Path: "/", MaxAge: int(Session_Lifetime.Seconds()), HttpOnly: true, Secure: true, }) ctx := context.WithValue(r.Context(), "session", sm.Session(sid)) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } func (sm *sessionManager) StartGC(interval time.Duration) { go func() { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { sm.cleanup() } }() } func newSessionID() string { b := make([]byte, 32) rand.Read(b) return hex.EncodeToString(b) } func (sm *sessionManager) cleanup() { now := time.Now() sm.mtx.Lock() defer sm.mtx.Unlock() for id, s := range sm.sessions { s.mtx.Lock() expired := now.After(s.expiry) s.mtx.Unlock() if expired { delete(sm.sessions, id) } } } func NewSessionManager(gc_interval time.Duration) *sessionManager { sm := &sessionManager{ sessions: make(map[string]*session), } sm.StartGC(gc_interval) return sm }