3.9.3 Session 管理 #
Session(会话)是 Web 应用中维护用户状态的传统方法。与无状态的 JWT 不同,Session 在服务端存储用户状态信息,通过 Session ID 在客户端和服务端之间建立关联。
Session 基础概念 #
Session 工作原理 #
- 用户登录:服务器验证用户凭据后创建 Session
- Session 存储:服务器将 Session 数据存储在内存、数据库或缓存中
- Session ID:服务器生成唯一的 Session ID 并发送给客户端
- 客户端存储:客户端通常将 Session ID 存储在 Cookie 中
- 后续请求:客户端在每次请求中携带 Session ID
- 状态维护:服务器根据 Session ID 查找对应的 Session 数据
Session vs JWT #
特性 | Session | JWT |
---|---|---|
存储位置 | 服务端 | 客户端 |
状态性 | 有状态 | 无状态 |
扩展性 | 需要共享存储 | 天然支持分布式 |
安全性 | 服务端控制 | 依赖签名验证 |
撤销能力 | 容易撤销 | 难以撤销 |
网络开销 | 小 | 大 |
Go 中的 Session 实现 #
基础 Session 管理器 #
// pkg/session/manager.go
package session
import (
"crypto/rand"
"encoding/hex"
"errors"
"sync"
"time"
)
var (
ErrSessionNotFound = errors.New("session不存在")
ErrSessionExpired = errors.New("session已过期")
)
// Session 会话数据
type Session struct {
ID string `json:"id"`
UserID interface{} `json:"user_id"`
Data map[string]interface{} `json:"data"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// NewSession 创建新会话
func NewSession(userID interface{}, duration time.Duration) *Session {
now := time.Now()
return &Session{
ID: generateSessionID(),
UserID: userID,
Data: make(map[string]interface{}),
CreatedAt: now,
UpdatedAt: now,
ExpiresAt: now.Add(duration),
}
}
// IsExpired 检查会话是否过期
func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// Set 设置会话数据
func (s *Session) Set(key string, value interface{}) {
s.Data[key] = value
s.UpdatedAt = time.Now()
}
// Get 获取会话数据
func (s *Session) Get(key string) (interface{}, bool) {
value, exists := s.Data[key]
return value, exists
}
// Delete 删除会话数据
func (s *Session) Delete(key string) {
delete(s.Data, key)
s.UpdatedAt = time.Now()
}
// Extend 延长会话有效期
func (s *Session) Extend(duration time.Duration) {
s.ExpiresAt = time.Now().Add(duration)
s.UpdatedAt = time.Now()
}
// Store Session存储接口
type Store interface {
Create(session *Session) error
Get(sessionID string) (*Session, error)
Update(session *Session) error
Delete(sessionID string) error
Cleanup() error
}
// Manager Session管理器
type Manager struct {
store Store
duration time.Duration
mutex sync.RWMutex
}
// NewManager 创建Session管理器
func NewManager(store Store, duration time.Duration) *Manager {
manager := &Manager{
store: store,
duration: duration,
}
// 启动清理goroutine
go manager.startCleanup()
return manager
}
// CreateSession 创建会话
func (m *Manager) CreateSession(userID interface{}) (*Session, error) {
session := NewSession(userID, m.duration)
if err := m.store.Create(session); err != nil {
return nil, err
}
return session, nil
}
// GetSession 获取会话
func (m *Manager) GetSession(sessionID string) (*Session, error) {
session, err := m.store.Get(sessionID)
if err != nil {
return nil, err
}
if session.IsExpired() {
m.store.Delete(sessionID)
return nil, ErrSessionExpired
}
return session, nil
}
// UpdateSession 更新会话
func (m *Manager) UpdateSession(session *Session) error {
return m.store.Update(session)
}
// DestroySession 销毁会话
func (m *Manager) DestroySession(sessionID string) error {
return m.store.Delete(sessionID)
}
// ExtendSession 延长会话
func (m *Manager) ExtendSession(sessionID string) error {
session, err := m.GetSession(sessionID)
if err != nil {
return err
}
session.Extend(m.duration)
return m.store.Update(session)
}
// startCleanup 启动清理过期会话
func (m *Manager) startCleanup() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
m.store.Cleanup()
}
}
// generateSessionID 生成会话ID
func generateSessionID() string {
bytes := make([]byte, 32)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
内存存储实现 #
// pkg/session/memory_store.go
package session
import (
"sync"
"time"
)
// MemoryStore 内存存储
type MemoryStore struct {
sessions map[string]*Session
mutex sync.RWMutex
}
// NewMemoryStore 创建内存存储
func NewMemoryStore() *MemoryStore {
return &MemoryStore{
sessions: make(map[string]*Session),
}
}
// Create 创建会话
func (s *MemoryStore) Create(session *Session) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.sessions[session.ID] = session
return nil
}
// Get 获取会话
func (s *MemoryStore) Get(sessionID string) (*Session, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, ErrSessionNotFound
}
return session, nil
}
// Update 更新会话
func (s *MemoryStore) Update(session *Session) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if _, exists := s.sessions[session.ID]; !exists {
return ErrSessionNotFound
}
s.sessions[session.ID] = session
return nil
}
// Delete 删除会话
func (s *MemoryStore) Delete(sessionID string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.sessions, sessionID)
return nil
}
// Cleanup 清理过期会话
func (s *MemoryStore) Cleanup() error {
s.mutex.Lock()
defer s.mutex.Unlock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
return nil
}
Redis 存储实现 #
// pkg/session/redis_store.go
package session
import (
"encoding/json"
"time"
"github.com/go-redis/redis/v8"
"golang.org/x/net/context"
)
// RedisStore Redis存储
type RedisStore struct {
client *redis.Client
prefix string
}
// NewRedisStore 创建Redis存储
func NewRedisStore(client *redis.Client, prefix string) *RedisStore {
return &RedisStore{
client: client,
prefix: prefix,
}
}
// getKey 获取Redis键
func (s *RedisStore) getKey(sessionID string) string {
return s.prefix + sessionID
}
// Create 创建会话
func (s *RedisStore) Create(session *Session) error {
ctx := context.Background()
key := s.getKey(session.ID)
data, err := json.Marshal(session)
if err != nil {
return err
}
duration := time.Until(session.ExpiresAt)
return s.client.Set(ctx, key, data, duration).Err()
}
// Get 获取会话
func (s *RedisStore) Get(sessionID string) (*Session, error) {
ctx := context.Background()
key := s.getKey(sessionID)
data, err := s.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, ErrSessionNotFound
}
return nil, err
}
var session Session
if err := json.Unmarshal([]byte(data), &session); err != nil {
return nil, err
}
return &session, nil
}
// Update 更新会话
func (s *RedisStore) Update(session *Session) error {
ctx := context.Background()
key := s.getKey(session.ID)
data, err := json.Marshal(session)
if err != nil {
return err
}
duration := time.Until(session.ExpiresAt)
return s.client.Set(ctx, key, data, duration).Err()
}
// Delete 删除会话
func (s *RedisStore) Delete(sessionID string) error {
ctx := context.Background()
key := s.getKey(sessionID)
return s.client.Del(ctx, key).Err()
}
// Cleanup Redis会自动清理过期键
func (s *RedisStore) Cleanup() error {
return nil
}
数据库存储实现 #
// pkg/session/db_store.go
package session
import (
"encoding/json"
"time"
"gorm.io/gorm"
)
// DBSession 数据库会话模型
type DBSession struct {
ID string `gorm:"primaryKey;size:64"`
UserID string `gorm:"index;size:64"`
Data string `gorm:"type:text"`
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time `gorm:"index"`
}
// DBStore 数据库存储
type DBStore struct {
db *gorm.DB
}
// NewDBStore 创建数据库存储
func NewDBStore(db *gorm.DB) *DBStore {
// 自动迁移
db.AutoMigrate(&DBSession{})
return &DBStore{db: db}
}
// Create 创建会话
func (s *DBStore) Create(session *Session) error {
data, err := json.Marshal(session.Data)
if err != nil {
return err
}
dbSession := &DBSession{
ID: session.ID,
UserID: session.UserID.(string),
Data: string(data),
CreatedAt: session.CreatedAt,
UpdatedAt: session.UpdatedAt,
ExpiresAt: session.ExpiresAt,
}
return s.db.Create(dbSession).Error
}
// Get 获取会话
func (s *DBStore) Get(sessionID string) (*Session, error) {
var dbSession DBSession
if err := s.db.First(&dbSession, "id = ?", sessionID).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrSessionNotFound
}
return nil, err
}
var data map[string]interface{}
if err := json.Unmarshal([]byte(dbSession.Data), &data); err != nil {
return nil, err
}
session := &Session{
ID: dbSession.ID,
UserID: dbSession.UserID,
Data: data,
CreatedAt: dbSession.CreatedAt,
UpdatedAt: dbSession.UpdatedAt,
ExpiresAt: dbSession.ExpiresAt,
}
return session, nil
}
// Update 更新会话
func (s *DBStore) Update(session *Session) error {
data, err := json.Marshal(session.Data)
if err != nil {
return err
}
return s.db.Model(&DBSession{}).
Where("id = ?", session.ID).
Updates(map[string]interface{}{
"data": string(data),
"updated_at": session.UpdatedAt,
"expires_at": session.ExpiresAt,
}).Error
}
// Delete 删除会话
func (s *DBStore) Delete(sessionID string) error {
return s.db.Delete(&DBSession{}, "id = ?", sessionID).Error
}
// Cleanup 清理过期会话
func (s *DBStore) Cleanup() error {
return s.db.Delete(&DBSession{}, "expires_at < ?", time.Now()).Error
}
Session 中间件 #
// middleware/session.go
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"your-project/pkg/session"
)
const SessionKey = "session"
// SessionMiddleware Session中间件
func SessionMiddleware(manager *session.Manager) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Cookie中获取Session ID
sessionID, err := c.Cookie("session_id")
if err != nil {
// 没有Session ID,继续处理
c.Next()
return
}
// 获取Session
sess, err := manager.GetSession(sessionID)
if err != nil {
// Session无效,清除Cookie
c.SetCookie("session_id", "", -1, "/", "", false, true)
c.Next()
return
}
// 将Session存储到上下文
c.Set(SessionKey, sess)
c.Set("user_id", sess.UserID)
c.Next()
// 请求处理完成后,更新Session
if updatedSess, exists := c.Get(SessionKey); exists {
manager.UpdateSession(updatedSess.(*session.Session))
}
}
}
// RequireSession 要求Session的中间件
func RequireSession() gin.HandlerFunc {
return func(c *gin.Context) {
_, exists := c.Get(SessionKey)
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "需要登录",
})
c.Abort()
return
}
c.Next()
}
}
// GetSession 从上下文获取Session
func GetSession(c *gin.Context) (*session.Session, bool) {
sess, exists := c.Get(SessionKey)
if !exists {
return nil, false
}
return sess.(*session.Session), true
}
认证控制器 #
// controllers/session_auth.go
package controllers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"your-project/middleware"
"your-project/models"
"your-project/pkg/session"
)
type SessionAuthController struct {
userService *models.UserService
sessionManager *session.Manager
}
func NewSessionAuthController(userService *models.UserService, sessionManager *session.Manager) *SessionAuthController {
return &SessionAuthController{
userService: userService,
sessionManager: sessionManager,
}
}
// LoginRequest 登录请求
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
RememberMe bool `json:"remember_me"`
}
// Login 用户登录
func (ctrl *SessionAuthController) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
})
return
}
// 验证用户
user, err := ctrl.userService.GetUserByUsername(req.Username)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "用户名或密码错误",
})
return
}
if !user.CheckPassword(req.Password) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "用户名或密码错误",
})
return
}
// 创建Session
sess, err := ctrl.sessionManager.CreateSession(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建会话失败",
})
return
}
// 存储用户信息到Session
sess.Set("username", user.Username)
sess.Set("role", user.Role)
sess.Set("login_time", time.Now())
// 设置Cookie
maxAge := int(24 * time.Hour / time.Second) // 24小时
if req.RememberMe {
maxAge = int(30 * 24 * time.Hour / time.Second) // 30天
sess.Extend(30 * 24 * time.Hour)
}
c.SetCookie("session_id", sess.ID, maxAge, "/", "", false, true)
// 更新Session
ctrl.sessionManager.UpdateSession(sess)
c.JSON(http.StatusOK, gin.H{
"message": "登录成功",
"user": user,
})
}
// Logout 用户登出
func (ctrl *SessionAuthController) Logout(c *gin.Context) {
sess, exists := middleware.GetSession(c)
if !exists {
c.JSON(http.StatusOK, gin.H{
"message": "已登出",
})
return
}
// 销毁Session
ctrl.sessionManager.DestroySession(sess.ID)
// 清除Cookie
c.SetCookie("session_id", "", -1, "/", "", false, true)
c.JSON(http.StatusOK, gin.H{
"message": "登出成功",
})
}
// GetProfile 获取用户信息
func (ctrl *SessionAuthController) GetProfile(c *gin.Context) {
sess, exists := middleware.GetSession(c)
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "未登录",
})
return
}
userID := sess.UserID.(uint)
user, err := ctrl.userService.GetUserByID(userID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "用户不存在",
})
return
}
// 获取Session中的额外信息
username, _ := sess.Get("username")
role, _ := sess.Get("role")
loginTime, _ := sess.Get("login_time")
c.JSON(http.StatusOK, gin.H{
"user": user,
"session_info": gin.H{
"username": username,
"role": role,
"login_time": loginTime,
"session_id": sess.ID,
},
})
}
// UpdateProfile 更新用户信息
func (ctrl *SessionAuthController) UpdateProfile(c *gin.Context) {
sess, exists := middleware.GetSession(c)
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "未登录",
})
return
}
var req struct {
Email string `json:"email" binding:"required,email"`
Name string `json:"name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
})
return
}
// 更新用户信息
userID := sess.UserID.(uint)
user, err := ctrl.userService.GetUserByID(userID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "用户不存在",
})
return
}
user.Email = req.Email
// user.Name = req.Name // 假设User模型有Name字段
if err := ctrl.userService.UpdateUser(user); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新失败",
})
return
}
// 更新Session中的信息
sess.Set("last_update", time.Now())
c.JSON(http.StatusOK, gin.H{
"message": "更新成功",
"user": user,
})
}
高级 Session 功能 #
Session 安全增强 #
// pkg/session/security.go
package session
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"net"
"net/http"
)
// SecurityConfig 安全配置
type SecurityConfig struct {
CheckIP bool // 检查IP地址
CheckUserAgent bool // 检查User-Agent
SecretKey string // 签名密钥
}
// SecureSession 安全会话
type SecureSession struct {
*Session
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
Signature string `json:"signature"`
}
// NewSecureSession 创建安全会话
func NewSecureSession(session *Session, req *http.Request, config *SecurityConfig) *SecureSession {
secureSession := &SecureSession{
Session: session,
IPAddress: getClientIP(req),
UserAgent: req.UserAgent(),
}
if config.SecretKey != "" {
secureSession.Signature = secureSession.generateSignature(config.SecretKey)
}
return secureSession
}
// Validate 验证会话安全性
func (s *SecureSession) Validate(req *http.Request, config *SecurityConfig) bool {
if config.CheckIP && s.IPAddress != getClientIP(req) {
return false
}
if config.CheckUserAgent && s.UserAgent != req.UserAgent() {
return false
}
if config.SecretKey != "" {
expectedSignature := s.generateSignature(config.SecretKey)
if s.Signature != expectedSignature {
return false
}
}
return true
}
// generateSignature 生成签名
func (s *SecureSession) generateSignature(secretKey string) string {
data := s.ID + s.IPAddress + s.UserAgent
h := hmac.New(sha256.New, []byte(secretKey))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
// getClientIP 获取客户端IP
func getClientIP(req *http.Request) string {
// 检查X-Forwarded-For头
if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
return xff
}
// 检查X-Real-IP头
if xri := req.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// 使用RemoteAddr
ip, _, _ := net.SplitHostPort(req.RemoteAddr)
return ip
}
Session 监控和统计 #
// pkg/session/monitor.go
package session
import (
"sync"
"time"
)
// SessionStats 会话统计
type SessionStats struct {
TotalSessions int64 `json:"total_sessions"`
ActiveSessions int64 `json:"active_sessions"`
ExpiredSessions int64 `json:"expired_sessions"`
LastCleanup time.Time `json:"last_cleanup"`
mutex sync.RWMutex
}
// Monitor 会话监控器
type Monitor struct {
stats *SessionStats
store Store
}
// NewMonitor 创建监控器
func NewMonitor(store Store) *Monitor {
return &Monitor{
stats: &SessionStats{},
store: store,
}
}
// GetStats 获取统计信息
func (m *Monitor) GetStats() *SessionStats {
m.stats.mutex.RLock()
defer m.stats.mutex.RUnlock()
// 复制统计信息
stats := *m.stats
return &stats
}
// IncrementTotal 增加总会话数
func (m *Monitor) IncrementTotal() {
m.stats.mutex.Lock()
defer m.stats.mutex.Unlock()
m.stats.TotalSessions++
m.stats.ActiveSessions++
}
// DecrementActive 减少活跃会话数
func (m *Monitor) DecrementActive() {
m.stats.mutex.Lock()
defer m.stats.mutex.Unlock()
m.stats.ActiveSessions--
}
// IncrementExpired 增加过期会话数
func (m *Monitor) IncrementExpired() {
m.stats.mutex.Lock()
defer m.stats.mutex.Unlock()
m.stats.ExpiredSessions++
}
// UpdateCleanupTime 更新清理时间
func (m *Monitor) UpdateCleanupTime() {
m.stats.mutex.Lock()
defer m.stats.mutex.Unlock()
m.stats.LastCleanup = time.Now()
}
完整示例 #
// main.go
package main
import (
"log"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"your-project/controllers"
"your-project/middleware"
"your-project/models"
"your-project/pkg/session"
)
func main() {
// 初始化数据库
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
if err != nil {
log.Fatal("数据库连接失败:", err)
}
// 初始化Redis客户端
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
// 选择Session存储方式
// 1. 内存存储(开发环境)
// store := session.NewMemoryStore()
// 2. Redis存储(推荐)
store := session.NewRedisStore(rdb, "session:")
// 3. 数据库存储
// store := session.NewDBStore(db)
// 初始化Session管理器
sessionManager := session.NewManager(store, 24*time.Hour)
// 初始化服务
userService := models.NewUserService(db)
// 初始化控制器
authController := controllers.NewSessionAuthController(userService, sessionManager)
// 初始化路由
r := gin.Default()
// 使用Session中间件
r.Use(middleware.SessionMiddleware(sessionManager))
// 公开路由
public := r.Group("/api/v1")
{
public.POST("/login", authController.Login)
public.POST("/logout", authController.Logout)
}
// 需要登录的路由
protected := r.Group("/api/v1")
protected.Use(middleware.RequireSession())
{
protected.GET("/profile", authController.GetProfile)
protected.PUT("/profile", authController.UpdateProfile)
// 管理员路由
admin := protected.Group("/admin")
admin.Use(func(c *gin.Context) {
sess, _ := middleware.GetSession(c)
role, _ := sess.Get("role")
if role != "admin" {
c.JSON(403, gin.H{"error": "权限不足"})
c.Abort()
return
}
c.Next()
})
{
admin.GET("/users", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "用户列表"})
})
}
}
log.Println("服务器启动在 :8080")
r.Run(":8080")
}
Session 最佳实践 #
1. 安全考虑 #
- HTTPS 传输:在生产环境中必须使用 HTTPS
- HttpOnly Cookie:防止 XSS 攻击
- Secure Cookie:仅在 HTTPS 下传输
- SameSite 属性:防止 CSRF 攻击
2. 性能优化 #
- 选择合适的存储:Redis 适合分布式环境
- 设置合理的过期时间:平衡安全性和用户体验
- 定期清理过期 Session:避免存储空间浪费
3. 扩展性考虑 #
- 分布式存储:使用 Redis 或数据库存储
- Session 复制:在集群环境中同步 Session
- 负载均衡:使用粘性会话或共享存储
Session 管理为 Web 应用提供了传统而可靠的状态管理方案,通过合理的设计和实现,可以构建出安全、高效的会话管理系统。