3.9.3 Session 管理

3.9.3 Session 管理 #

Session(会话)是 Web 应用中维护用户状态的传统方法。与无状态的 JWT 不同,Session 在服务端存储用户状态信息,通过 Session ID 在客户端和服务端之间建立关联。

Session 基础概念 #

Session 工作原理 #

  1. 用户登录:服务器验证用户凭据后创建 Session
  2. Session 存储:服务器将 Session 数据存储在内存、数据库或缓存中
  3. Session ID:服务器生成唯一的 Session ID 并发送给客户端
  4. 客户端存储:客户端通常将 Session ID 存储在 Cookie 中
  5. 后续请求:客户端在每次请求中携带 Session ID
  6. 状态维护:服务器根据 Session ID 查找对应的 Session 数据

Session vs JWT #

特性 Session JWT
存储位置 服务端 客户端
状态性 有状态 无状态
扩展性 需要共享存储 天然支持分布式
安全性 服务端控制 依赖签名验证
撤销能力 容易撤销 难以撤销
网络开销

Go 中的 Session 实现 #

基础 Session 管理器 #

// pkg/session/manager.go
package session

import (
    "crypto/rand"
    "encoding/hex"
    "errors"
    "sync"
    "time"
)

var (
    ErrSessionNotFound = errors.New("session不存在")
    ErrSessionExpired  = errors.New("session已过期")
)

// Session 会话数据
type Session struct {
    ID        string                 `json:"id"`
    UserID    interface{}            `json:"user_id"`
    Data      map[string]interface{} `json:"data"`
    CreatedAt time.Time              `json:"created_at"`
    UpdatedAt time.Time              `json:"updated_at"`
    ExpiresAt time.Time              `json:"expires_at"`
}

// NewSession 创建新会话
func NewSession(userID interface{}, duration time.Duration) *Session {
    now := time.Now()
    return &Session{
        ID:        generateSessionID(),
        UserID:    userID,
        Data:      make(map[string]interface{}),
        CreatedAt: now,
        UpdatedAt: now,
        ExpiresAt: now.Add(duration),
    }
}

// IsExpired 检查会话是否过期
func (s *Session) IsExpired() bool {
    return time.Now().After(s.ExpiresAt)
}

// Set 设置会话数据
func (s *Session) Set(key string, value interface{}) {
    s.Data[key] = value
    s.UpdatedAt = time.Now()
}

// Get 获取会话数据
func (s *Session) Get(key string) (interface{}, bool) {
    value, exists := s.Data[key]
    return value, exists
}

// Delete 删除会话数据
func (s *Session) Delete(key string) {
    delete(s.Data, key)
    s.UpdatedAt = time.Now()
}

// Extend 延长会话有效期
func (s *Session) Extend(duration time.Duration) {
    s.ExpiresAt = time.Now().Add(duration)
    s.UpdatedAt = time.Now()
}

// Store Session存储接口
type Store interface {
    Create(session *Session) error
    Get(sessionID string) (*Session, error)
    Update(session *Session) error
    Delete(sessionID string) error
    Cleanup() error
}

// Manager Session管理器
type Manager struct {
    store    Store
    duration time.Duration
    mutex    sync.RWMutex
}

// NewManager 创建Session管理器
func NewManager(store Store, duration time.Duration) *Manager {
    manager := &Manager{
        store:    store,
        duration: duration,
    }

    // 启动清理goroutine
    go manager.startCleanup()

    return manager
}

// CreateSession 创建会话
func (m *Manager) CreateSession(userID interface{}) (*Session, error) {
    session := NewSession(userID, m.duration)

    if err := m.store.Create(session); err != nil {
        return nil, err
    }

    return session, nil
}

// GetSession 获取会话
func (m *Manager) GetSession(sessionID string) (*Session, error) {
    session, err := m.store.Get(sessionID)
    if err != nil {
        return nil, err
    }

    if session.IsExpired() {
        m.store.Delete(sessionID)
        return nil, ErrSessionExpired
    }

    return session, nil
}

// UpdateSession 更新会话
func (m *Manager) UpdateSession(session *Session) error {
    return m.store.Update(session)
}

// DestroySession 销毁会话
func (m *Manager) DestroySession(sessionID string) error {
    return m.store.Delete(sessionID)
}

// ExtendSession 延长会话
func (m *Manager) ExtendSession(sessionID string) error {
    session, err := m.GetSession(sessionID)
    if err != nil {
        return err
    }

    session.Extend(m.duration)
    return m.store.Update(session)
}

// startCleanup 启动清理过期会话
func (m *Manager) startCleanup() {
    ticker := time.NewTicker(1 * time.Hour)
    defer ticker.Stop()

    for range ticker.C {
        m.store.Cleanup()
    }
}

// generateSessionID 生成会话ID
func generateSessionID() string {
    bytes := make([]byte, 32)
    rand.Read(bytes)
    return hex.EncodeToString(bytes)
}

内存存储实现 #

// pkg/session/memory_store.go
package session

import (
    "sync"
    "time"
)

// MemoryStore 内存存储
type MemoryStore struct {
    sessions map[string]*Session
    mutex    sync.RWMutex
}

// NewMemoryStore 创建内存存储
func NewMemoryStore() *MemoryStore {
    return &MemoryStore{
        sessions: make(map[string]*Session),
    }
}

// Create 创建会话
func (s *MemoryStore) Create(session *Session) error {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    s.sessions[session.ID] = session
    return nil
}

// Get 获取会话
func (s *MemoryStore) Get(sessionID string) (*Session, error) {
    s.mutex.RLock()
    defer s.mutex.RUnlock()

    session, exists := s.sessions[sessionID]
    if !exists {
        return nil, ErrSessionNotFound
    }

    return session, nil
}

// Update 更新会话
func (s *MemoryStore) Update(session *Session) error {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    if _, exists := s.sessions[session.ID]; !exists {
        return ErrSessionNotFound
    }

    s.sessions[session.ID] = session
    return nil
}

// Delete 删除会话
func (s *MemoryStore) Delete(sessionID string) error {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    delete(s.sessions, sessionID)
    return nil
}

// Cleanup 清理过期会话
func (s *MemoryStore) Cleanup() error {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    now := time.Now()
    for id, session := range s.sessions {
        if now.After(session.ExpiresAt) {
            delete(s.sessions, id)
        }
    }

    return nil
}

Redis 存储实现 #

// pkg/session/redis_store.go
package session

import (
    "encoding/json"
    "time"

    "github.com/go-redis/redis/v8"
    "golang.org/x/net/context"
)

// RedisStore Redis存储
type RedisStore struct {
    client *redis.Client
    prefix string
}

// NewRedisStore 创建Redis存储
func NewRedisStore(client *redis.Client, prefix string) *RedisStore {
    return &RedisStore{
        client: client,
        prefix: prefix,
    }
}

// getKey 获取Redis键
func (s *RedisStore) getKey(sessionID string) string {
    return s.prefix + sessionID
}

// Create 创建会话
func (s *RedisStore) Create(session *Session) error {
    ctx := context.Background()
    key := s.getKey(session.ID)

    data, err := json.Marshal(session)
    if err != nil {
        return err
    }

    duration := time.Until(session.ExpiresAt)
    return s.client.Set(ctx, key, data, duration).Err()
}

// Get 获取会话
func (s *RedisStore) Get(sessionID string) (*Session, error) {
    ctx := context.Background()
    key := s.getKey(sessionID)

    data, err := s.client.Get(ctx, key).Result()
    if err != nil {
        if err == redis.Nil {
            return nil, ErrSessionNotFound
        }
        return nil, err
    }

    var session Session
    if err := json.Unmarshal([]byte(data), &session); err != nil {
        return nil, err
    }

    return &session, nil
}

// Update 更新会话
func (s *RedisStore) Update(session *Session) error {
    ctx := context.Background()
    key := s.getKey(session.ID)

    data, err := json.Marshal(session)
    if err != nil {
        return err
    }

    duration := time.Until(session.ExpiresAt)
    return s.client.Set(ctx, key, data, duration).Err()
}

// Delete 删除会话
func (s *RedisStore) Delete(sessionID string) error {
    ctx := context.Background()
    key := s.getKey(sessionID)

    return s.client.Del(ctx, key).Err()
}

// Cleanup Redis会自动清理过期键
func (s *RedisStore) Cleanup() error {
    return nil
}

数据库存储实现 #

// pkg/session/db_store.go
package session

import (
    "encoding/json"
    "time"

    "gorm.io/gorm"
)

// DBSession 数据库会话模型
type DBSession struct {
    ID        string    `gorm:"primaryKey;size:64"`
    UserID    string    `gorm:"index;size:64"`
    Data      string    `gorm:"type:text"`
    CreatedAt time.Time
    UpdatedAt time.Time
    ExpiresAt time.Time `gorm:"index"`
}

// DBStore 数据库存储
type DBStore struct {
    db *gorm.DB
}

// NewDBStore 创建数据库存储
func NewDBStore(db *gorm.DB) *DBStore {
    // 自动迁移
    db.AutoMigrate(&DBSession{})

    return &DBStore{db: db}
}

// Create 创建会话
func (s *DBStore) Create(session *Session) error {
    data, err := json.Marshal(session.Data)
    if err != nil {
        return err
    }

    dbSession := &DBSession{
        ID:        session.ID,
        UserID:    session.UserID.(string),
        Data:      string(data),
        CreatedAt: session.CreatedAt,
        UpdatedAt: session.UpdatedAt,
        ExpiresAt: session.ExpiresAt,
    }

    return s.db.Create(dbSession).Error
}

// Get 获取会话
func (s *DBStore) Get(sessionID string) (*Session, error) {
    var dbSession DBSession
    if err := s.db.First(&dbSession, "id = ?", sessionID).Error; err != nil {
        if err == gorm.ErrRecordNotFound {
            return nil, ErrSessionNotFound
        }
        return nil, err
    }

    var data map[string]interface{}
    if err := json.Unmarshal([]byte(dbSession.Data), &data); err != nil {
        return nil, err
    }

    session := &Session{
        ID:        dbSession.ID,
        UserID:    dbSession.UserID,
        Data:      data,
        CreatedAt: dbSession.CreatedAt,
        UpdatedAt: dbSession.UpdatedAt,
        ExpiresAt: dbSession.ExpiresAt,
    }

    return session, nil
}

// Update 更新会话
func (s *DBStore) Update(session *Session) error {
    data, err := json.Marshal(session.Data)
    if err != nil {
        return err
    }

    return s.db.Model(&DBSession{}).
        Where("id = ?", session.ID).
        Updates(map[string]interface{}{
            "data":       string(data),
            "updated_at": session.UpdatedAt,
            "expires_at": session.ExpiresAt,
        }).Error
}

// Delete 删除会话
func (s *DBStore) Delete(sessionID string) error {
    return s.db.Delete(&DBSession{}, "id = ?", sessionID).Error
}

// Cleanup 清理过期会话
func (s *DBStore) Cleanup() error {
    return s.db.Delete(&DBSession{}, "expires_at < ?", time.Now()).Error
}

Session 中间件 #

// middleware/session.go
package middleware

import (
    "net/http"

    "github.com/gin-gonic/gin"
    "your-project/pkg/session"
)

const SessionKey = "session"

// SessionMiddleware Session中间件
func SessionMiddleware(manager *session.Manager) gin.HandlerFunc {
    return func(c *gin.Context) {
        // 从Cookie中获取Session ID
        sessionID, err := c.Cookie("session_id")
        if err != nil {
            // 没有Session ID,继续处理
            c.Next()
            return
        }

        // 获取Session
        sess, err := manager.GetSession(sessionID)
        if err != nil {
            // Session无效,清除Cookie
            c.SetCookie("session_id", "", -1, "/", "", false, true)
            c.Next()
            return
        }

        // 将Session存储到上下文
        c.Set(SessionKey, sess)
        c.Set("user_id", sess.UserID)

        c.Next()

        // 请求处理完成后,更新Session
        if updatedSess, exists := c.Get(SessionKey); exists {
            manager.UpdateSession(updatedSess.(*session.Session))
        }
    }
}

// RequireSession 要求Session的中间件
func RequireSession() gin.HandlerFunc {
    return func(c *gin.Context) {
        _, exists := c.Get(SessionKey)
        if !exists {
            c.JSON(http.StatusUnauthorized, gin.H{
                "error": "需要登录",
            })
            c.Abort()
            return
        }

        c.Next()
    }
}

// GetSession 从上下文获取Session
func GetSession(c *gin.Context) (*session.Session, bool) {
    sess, exists := c.Get(SessionKey)
    if !exists {
        return nil, false
    }
    return sess.(*session.Session), true
}

认证控制器 #

// controllers/session_auth.go
package controllers

import (
    "net/http"
    "time"

    "github.com/gin-gonic/gin"
    "your-project/middleware"
    "your-project/models"
    "your-project/pkg/session"
)

type SessionAuthController struct {
    userService    *models.UserService
    sessionManager *session.Manager
}

func NewSessionAuthController(userService *models.UserService, sessionManager *session.Manager) *SessionAuthController {
    return &SessionAuthController{
        userService:    userService,
        sessionManager: sessionManager,
    }
}

// LoginRequest 登录请求
type LoginRequest struct {
    Username   string `json:"username" binding:"required"`
    Password   string `json:"password" binding:"required"`
    RememberMe bool   `json:"remember_me"`
}

// Login 用户登录
func (ctrl *SessionAuthController) Login(c *gin.Context) {
    var req LoginRequest
    if err := c.ShouldBindJSON(&req); err != nil {
        c.JSON(http.StatusBadRequest, gin.H{
            "error": err.Error(),
        })
        return
    }

    // 验证用户
    user, err := ctrl.userService.GetUserByUsername(req.Username)
    if err != nil {
        c.JSON(http.StatusUnauthorized, gin.H{
            "error": "用户名或密码错误",
        })
        return
    }

    if !user.CheckPassword(req.Password) {
        c.JSON(http.StatusUnauthorized, gin.H{
            "error": "用户名或密码错误",
        })
        return
    }

    // 创建Session
    sess, err := ctrl.sessionManager.CreateSession(user.ID)
    if err != nil {
        c.JSON(http.StatusInternalServerError, gin.H{
            "error": "创建会话失败",
        })
        return
    }

    // 存储用户信息到Session
    sess.Set("username", user.Username)
    sess.Set("role", user.Role)
    sess.Set("login_time", time.Now())

    // 设置Cookie
    maxAge := int(24 * time.Hour / time.Second) // 24小时
    if req.RememberMe {
        maxAge = int(30 * 24 * time.Hour / time.Second) // 30天
        sess.Extend(30 * 24 * time.Hour)
    }

    c.SetCookie("session_id", sess.ID, maxAge, "/", "", false, true)

    // 更新Session
    ctrl.sessionManager.UpdateSession(sess)

    c.JSON(http.StatusOK, gin.H{
        "message": "登录成功",
        "user":    user,
    })
}

// Logout 用户登出
func (ctrl *SessionAuthController) Logout(c *gin.Context) {
    sess, exists := middleware.GetSession(c)
    if !exists {
        c.JSON(http.StatusOK, gin.H{
            "message": "已登出",
        })
        return
    }

    // 销毁Session
    ctrl.sessionManager.DestroySession(sess.ID)

    // 清除Cookie
    c.SetCookie("session_id", "", -1, "/", "", false, true)

    c.JSON(http.StatusOK, gin.H{
        "message": "登出成功",
    })
}

// GetProfile 获取用户信息
func (ctrl *SessionAuthController) GetProfile(c *gin.Context) {
    sess, exists := middleware.GetSession(c)
    if !exists {
        c.JSON(http.StatusUnauthorized, gin.H{
            "error": "未登录",
        })
        return
    }

    userID := sess.UserID.(uint)
    user, err := ctrl.userService.GetUserByID(userID)
    if err != nil {
        c.JSON(http.StatusNotFound, gin.H{
            "error": "用户不存在",
        })
        return
    }

    // 获取Session中的额外信息
    username, _ := sess.Get("username")
    role, _ := sess.Get("role")
    loginTime, _ := sess.Get("login_time")

    c.JSON(http.StatusOK, gin.H{
        "user": user,
        "session_info": gin.H{
            "username":   username,
            "role":       role,
            "login_time": loginTime,
            "session_id": sess.ID,
        },
    })
}

// UpdateProfile 更新用户信息
func (ctrl *SessionAuthController) UpdateProfile(c *gin.Context) {
    sess, exists := middleware.GetSession(c)
    if !exists {
        c.JSON(http.StatusUnauthorized, gin.H{
            "error": "未登录",
        })
        return
    }

    var req struct {
        Email string `json:"email" binding:"required,email"`
        Name  string `json:"name" binding:"required"`
    }

    if err := c.ShouldBindJSON(&req); err != nil {
        c.JSON(http.StatusBadRequest, gin.H{
            "error": err.Error(),
        })
        return
    }

    // 更新用户信息
    userID := sess.UserID.(uint)
    user, err := ctrl.userService.GetUserByID(userID)
    if err != nil {
        c.JSON(http.StatusNotFound, gin.H{
            "error": "用户不存在",
        })
        return
    }

    user.Email = req.Email
    // user.Name = req.Name // 假设User模型有Name字段

    if err := ctrl.userService.UpdateUser(user); err != nil {
        c.JSON(http.StatusInternalServerError, gin.H{
            "error": "更新失败",
        })
        return
    }

    // 更新Session中的信息
    sess.Set("last_update", time.Now())

    c.JSON(http.StatusOK, gin.H{
        "message": "更新成功",
        "user":    user,
    })
}

高级 Session 功能 #

Session 安全增强 #

// pkg/session/security.go
package session

import (
    "crypto/hmac"
    "crypto/sha256"
    "encoding/hex"
    "net"
    "net/http"
)

// SecurityConfig 安全配置
type SecurityConfig struct {
    CheckIP       bool   // 检查IP地址
    CheckUserAgent bool  // 检查User-Agent
    SecretKey     string // 签名密钥
}

// SecureSession 安全会话
type SecureSession struct {
    *Session
    IPAddress string `json:"ip_address"`
    UserAgent string `json:"user_agent"`
    Signature string `json:"signature"`
}

// NewSecureSession 创建安全会话
func NewSecureSession(session *Session, req *http.Request, config *SecurityConfig) *SecureSession {
    secureSession := &SecureSession{
        Session:   session,
        IPAddress: getClientIP(req),
        UserAgent: req.UserAgent(),
    }

    if config.SecretKey != "" {
        secureSession.Signature = secureSession.generateSignature(config.SecretKey)
    }

    return secureSession
}

// Validate 验证会话安全性
func (s *SecureSession) Validate(req *http.Request, config *SecurityConfig) bool {
    if config.CheckIP && s.IPAddress != getClientIP(req) {
        return false
    }

    if config.CheckUserAgent && s.UserAgent != req.UserAgent() {
        return false
    }

    if config.SecretKey != "" {
        expectedSignature := s.generateSignature(config.SecretKey)
        if s.Signature != expectedSignature {
            return false
        }
    }

    return true
}

// generateSignature 生成签名
func (s *SecureSession) generateSignature(secretKey string) string {
    data := s.ID + s.IPAddress + s.UserAgent
    h := hmac.New(sha256.New, []byte(secretKey))
    h.Write([]byte(data))
    return hex.EncodeToString(h.Sum(nil))
}

// getClientIP 获取客户端IP
func getClientIP(req *http.Request) string {
    // 检查X-Forwarded-For头
    if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
        return xff
    }

    // 检查X-Real-IP头
    if xri := req.Header.Get("X-Real-IP"); xri != "" {
        return xri
    }

    // 使用RemoteAddr
    ip, _, _ := net.SplitHostPort(req.RemoteAddr)
    return ip
}

Session 监控和统计 #

// pkg/session/monitor.go
package session

import (
    "sync"
    "time"
)

// SessionStats 会话统计
type SessionStats struct {
    TotalSessions   int64     `json:"total_sessions"`
    ActiveSessions  int64     `json:"active_sessions"`
    ExpiredSessions int64     `json:"expired_sessions"`
    LastCleanup     time.Time `json:"last_cleanup"`
    mutex           sync.RWMutex
}

// Monitor 会话监控器
type Monitor struct {
    stats *SessionStats
    store Store
}

// NewMonitor 创建监控器
func NewMonitor(store Store) *Monitor {
    return &Monitor{
        stats: &SessionStats{},
        store: store,
    }
}

// GetStats 获取统计信息
func (m *Monitor) GetStats() *SessionStats {
    m.stats.mutex.RLock()
    defer m.stats.mutex.RUnlock()

    // 复制统计信息
    stats := *m.stats
    return &stats
}

// IncrementTotal 增加总会话数
func (m *Monitor) IncrementTotal() {
    m.stats.mutex.Lock()
    defer m.stats.mutex.Unlock()

    m.stats.TotalSessions++
    m.stats.ActiveSessions++
}

// DecrementActive 减少活跃会话数
func (m *Monitor) DecrementActive() {
    m.stats.mutex.Lock()
    defer m.stats.mutex.Unlock()

    m.stats.ActiveSessions--
}

// IncrementExpired 增加过期会话数
func (m *Monitor) IncrementExpired() {
    m.stats.mutex.Lock()
    defer m.stats.mutex.Unlock()

    m.stats.ExpiredSessions++
}

// UpdateCleanupTime 更新清理时间
func (m *Monitor) UpdateCleanupTime() {
    m.stats.mutex.Lock()
    defer m.stats.mutex.Unlock()

    m.stats.LastCleanup = time.Now()
}

完整示例 #

// main.go
package main

import (
    "log"
    "time"

    "github.com/gin-gonic/gin"
    "github.com/go-redis/redis/v8"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"

    "your-project/controllers"
    "your-project/middleware"
    "your-project/models"
    "your-project/pkg/session"
)

func main() {
    // 初始化数据库
    db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
    if err != nil {
        log.Fatal("数据库连接失败:", err)
    }

    // 初始化Redis客户端
    rdb := redis.NewClient(&redis.Options{
        Addr: "localhost:6379",
    })

    // 选择Session存储方式
    // 1. 内存存储(开发环境)
    // store := session.NewMemoryStore()

    // 2. Redis存储(推荐)
    store := session.NewRedisStore(rdb, "session:")

    // 3. 数据库存储
    // store := session.NewDBStore(db)

    // 初始化Session管理器
    sessionManager := session.NewManager(store, 24*time.Hour)

    // 初始化服务
    userService := models.NewUserService(db)

    // 初始化控制器
    authController := controllers.NewSessionAuthController(userService, sessionManager)

    // 初始化路由
    r := gin.Default()

    // 使用Session中间件
    r.Use(middleware.SessionMiddleware(sessionManager))

    // 公开路由
    public := r.Group("/api/v1")
    {
        public.POST("/login", authController.Login)
        public.POST("/logout", authController.Logout)
    }

    // 需要登录的路由
    protected := r.Group("/api/v1")
    protected.Use(middleware.RequireSession())
    {
        protected.GET("/profile", authController.GetProfile)
        protected.PUT("/profile", authController.UpdateProfile)

        // 管理员路由
        admin := protected.Group("/admin")
        admin.Use(func(c *gin.Context) {
            sess, _ := middleware.GetSession(c)
            role, _ := sess.Get("role")
            if role != "admin" {
                c.JSON(403, gin.H{"error": "权限不足"})
                c.Abort()
                return
            }
            c.Next()
        })
        {
            admin.GET("/users", func(c *gin.Context) {
                c.JSON(200, gin.H{"message": "用户列表"})
            })
        }
    }

    log.Println("服务器启动在 :8080")
    r.Run(":8080")
}

Session 最佳实践 #

1. 安全考虑 #

  • HTTPS 传输:在生产环境中必须使用 HTTPS
  • HttpOnly Cookie:防止 XSS 攻击
  • Secure Cookie:仅在 HTTPS 下传输
  • SameSite 属性:防止 CSRF 攻击

2. 性能优化 #

  • 选择合适的存储:Redis 适合分布式环境
  • 设置合理的过期时间:平衡安全性和用户体验
  • 定期清理过期 Session:避免存储空间浪费

3. 扩展性考虑 #

  • 分布式存储:使用 Redis 或数据库存储
  • Session 复制:在集群环境中同步 Session
  • 负载均衡:使用粘性会话或共享存储

Session 管理为 Web 应用提供了传统而可靠的状态管理方案,通过合理的设计和实现,可以构建出安全、高效的会话管理系统。