2.4.4 Context 上下文管理

2.4.4 Context 上下文管理 #

Context 是 Go 语言中用于处理请求范围数据、取消信号和超时的标准机制。它在并发编程中扮演着至关重要的角色,特别是在需要协调多个 goroutine、处理超时和取消操作的场景中。

Context 基础概念 #

Context 接口定义了四个方法:

type Context interface {
    // Deadline 返回context的截止时间
    Deadline() (deadline time.Time, ok bool)

    // Done 返回一个channel,当context被取消时会关闭
    Done() <-chan struct{}

    // Err 返回context被取消的原因
    Err() error

    // Value 返回与key关联的值
    Value(key interface{}) interface{}
}

Context 的创建和使用 #

基础 Context 类型 #

package main

import (
    "context"
    "fmt"
    "time"
)

func main() {
    // 1. Background Context - 通常用作根context
    bgCtx := context.Background()
    fmt.Printf("Background context: %v\n", bgCtx)

    // 2. TODO Context - 当不确定使用哪种context时的占位符
    todoCtx := context.TODO()
    fmt.Printf("TODO context: %v\n", todoCtx)

    // 3. WithCancel - 可取消的context
    cancelCtx, cancel := context.WithCancel(bgCtx)
    defer cancel() // 确保资源清理

    go func() {
        select {
        case <-cancelCtx.Done():
            fmt.Println("Context被取消:", cancelCtx.Err())
        }
    }()

    time.Sleep(time.Millisecond * 100)
    cancel() // 取消context
    time.Sleep(time.Millisecond * 100)

    // 4. WithTimeout - 带超时的context
    timeoutCtx, timeoutCancel := context.WithTimeout(bgCtx, time.Second*2)
    defer timeoutCancel()

    go func() {
        select {
        case <-timeoutCtx.Done():
            fmt.Println("Context超时:", timeoutCtx.Err())
        }
    }()

    time.Sleep(time.Second * 3) // 等待超时

    // 5. WithDeadline - 带截止时间的context
    deadline := time.Now().Add(time.Second * 1)
    deadlineCtx, deadlineCancel := context.WithDeadline(bgCtx, deadline)
    defer deadlineCancel()

    go func() {
        select {
        case <-deadlineCtx.Done():
            fmt.Println("Context到达截止时间:", deadlineCtx.Err())
        }
    }()

    time.Sleep(time.Second * 2)
}

Context 传递值 #

package main

import (
    "context"
    "fmt"
)

// 定义context key的类型,避免冲突
type contextKey string

const (
    UserIDKey    contextKey = "userID"
    RequestIDKey contextKey = "requestID"
    TraceIDKey   contextKey = "traceID"
)

// 用户信息结构
type UserInfo struct {
    ID   int
    Name string
    Role string
}

func main() {
    // 创建带值的context
    ctx := context.Background()

    // 添加用户ID
    ctx = context.WithValue(ctx, UserIDKey, 12345)

    // 添加请求ID
    ctx = context.WithValue(ctx, RequestIDKey, "req-abc-123")

    // 添加用户信息
    userInfo := &UserInfo{
        ID:   12345,
        Name: "张三",
        Role: "admin",
    }
    ctx = context.WithValue(ctx, "user", userInfo)

    // 传递context到其他函数
    processRequest(ctx)
}

func processRequest(ctx context.Context) {
    // 从context中获取值
    userID := ctx.Value(UserIDKey)
    requestID := ctx.Value(RequestIDKey)
    user := ctx.Value("user")

    fmt.Printf("处理请求 - UserID: %v, RequestID: %v\n", userID, requestID)

    if userInfo, ok := user.(*UserInfo); ok {
        fmt.Printf("用户信息 - ID: %d, Name: %s, Role: %s\n",
            userInfo.ID, userInfo.Name, userInfo.Role)
    }

    // 继续传递context
    handleDatabase(ctx)
    handleCache(ctx)
}

func handleDatabase(ctx context.Context) {
    userID := ctx.Value(UserIDKey)
    fmt.Printf("数据库操作 - UserID: %v\n", userID)
}

func handleCache(ctx context.Context) {
    requestID := ctx.Value(RequestIDKey)
    fmt.Printf("缓存操作 - RequestID: %v\n", requestID)
}

Context 在并发控制中的应用 #

超时控制 #

package main

import (
    "context"
    "fmt"
    "math/rand"
    "sync"
    "time"
)

// 模拟数据库查询
func queryDatabase(ctx context.Context, query string) (string, error) {
    // 模拟查询时间
    queryTime := time.Duration(rand.Intn(3000)) * time.Millisecond

    select {
    case <-time.After(queryTime):
        return fmt.Sprintf("查询结果: %s", query), nil
    case <-ctx.Done():
        return "", ctx.Err()
    }
}

// 模拟API调用
func callExternalAPI(ctx context.Context, url string) (string, error) {
    // 模拟API调用时间
    callTime := time.Duration(rand.Intn(2000)) * time.Millisecond

    select {
    case <-time.After(callTime):
        return fmt.Sprintf("API响应: %s", url), nil
    case <-ctx.Done():
        return "", ctx.Err()
    }
}

// 并发处理多个任务,带超时控制
func processWithTimeout(timeout time.Duration) {
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel()

    var wg sync.WaitGroup
    results := make(chan string, 3)
    errors := make(chan error, 3)

    // 任务1:数据库查询
    wg.Add(1)
    go func() {
        defer wg.Done()
        result, err := queryDatabase(ctx, "SELECT * FROM users")
        if err != nil {
            errors <- fmt.Errorf("数据库查询失败: %v", err)
            return
        }
        results <- result
    }()

    // 任务2:外部API调用
    wg.Add(1)
    go func() {
        defer wg.Done()
        result, err := callExternalAPI(ctx, "https://api.example.com/data")
        if err != nil {
            errors <- fmt.Errorf("API调用失败: %v", err)
            return
        }
        results <- result
    }()

    // 任务3:另一个数据库查询
    wg.Add(1)
    go func() {
        defer wg.Done()
        result, err := queryDatabase(ctx, "SELECT * FROM orders")
        if err != nil {
            errors <- fmt.Errorf("订单查询失败: %v", err)
            return
        }
        results <- result
    }()

    // 等待所有任务完成
    go func() {
        wg.Wait()
        close(results)
        close(errors)
    }()

    // 收集结果
    fmt.Printf("开始处理,超时时间: %v\n", timeout)

    for {
        select {
        case result, ok := <-results:
            if !ok {
                results = nil
            } else {
                fmt.Printf("✓ %s\n", result)
            }
        case err, ok := <-errors:
            if !ok {
                errors = nil
            } else {
                fmt.Printf("✗ %s\n", err)
            }
        }

        if results == nil && errors == nil {
            break
        }
    }

    fmt.Println("处理完成\n")
}

func main() {
    rand.Seed(time.Now().UnixNano())

    // 测试不同的超时时间
    fmt.Println("=== 测试超时控制 ===")

    // 短超时(可能导致超时)
    processWithTimeout(time.Second * 1)

    // 中等超时
    processWithTimeout(time.Second * 2)

    // 长超时(通常能完成)
    processWithTimeout(time.Second * 5)
}

取消传播 #

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// Worker 工作者结构
type Worker struct {
    id   int
    name string
}

// 工作者执行任务
func (w *Worker) doWork(ctx context.Context, wg *sync.WaitGroup) {
    defer wg.Done()

    fmt.Printf("Worker %d (%s) 开始工作\n", w.id, w.name)

    for i := 0; i < 10; i++ {
        select {
        case <-ctx.Done():
            fmt.Printf("Worker %d (%s) 收到取消信号: %v\n", w.id, w.name, ctx.Err())
            return
        default:
            fmt.Printf("Worker %d (%s) 执行任务 %d\n", w.id, w.name, i+1)
            time.Sleep(time.Millisecond * 500)
        }
    }

    fmt.Printf("Worker %d (%s) 完成所有任务\n", w.id, w.name)
}

// 管理器协调多个工作者
func manager(ctx context.Context) {
    // 创建子context,可以独立取消
    managerCtx, managerCancel := context.WithCancel(ctx)
    defer managerCancel()

    var wg sync.WaitGroup

    // 创建多个工作者
    workers := []Worker{
        {1, "数据处理器"},
        {2, "文件上传器"},
        {3, "邮件发送器"},
        {4, "日志记录器"},
    }

    // 启动所有工作者
    for _, worker := range workers {
        wg.Add(1)
        go worker.doWork(managerCtx, &wg)
    }

    // 模拟管理器在3秒后决定取消所有任务
    go func() {
        time.Sleep(time.Second * 3)
        fmt.Println("管理器决定取消所有任务")
        managerCancel()
    }()

    // 等待所有工作者完成或被取消
    wg.Wait()
    fmt.Println("管理器:所有工作者已停止")
}

// 监控系统
func monitoringSystem(ctx context.Context) {
    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()

    fmt.Println("监控系统启动")

    for {
        select {
        case <-ticker.C:
            fmt.Println("监控系统:系统运行正常")
        case <-ctx.Done():
            fmt.Printf("监控系统收到停止信号: %v\n", ctx.Err())
            return
        }
    }
}

func main() {
    // 创建根context
    rootCtx, rootCancel := context.WithCancel(context.Background())

    var wg sync.WaitGroup

    // 启动管理器
    wg.Add(1)
    go func() {
        defer wg.Done()
        manager(rootCtx)
    }()

    // 启动监控系统
    wg.Add(1)
    go func() {
        defer wg.Done()
        monitoringSystem(rootCtx)
    }()

    // 模拟系统运行5秒后关闭
    time.Sleep(time.Second * 5)
    fmt.Println("主程序决定关闭系统")
    rootCancel()

    // 等待所有组件停止
    wg.Wait()
    fmt.Println("系统已完全关闭")
}

Context 最佳实践 #

1. Context 传递规范 #

package main

import (
    "context"
    "fmt"
    "net/http"
    "time"
)

// 正确的函数签名:context作为第一个参数
func processRequest(ctx context.Context, userID int, data string) error {
    // 检查context是否已取消
    if err := ctx.Err(); err != nil {
        return err
    }

    // 创建子context用于数据库操作
    dbCtx, cancel := context.WithTimeout(ctx, time.Second*5)
    defer cancel()

    return queryUserData(dbCtx, userID, data)
}

func queryUserData(ctx context.Context, userID int, data string) error {
    // 模拟数据库查询
    select {
    case <-time.After(time.Second * 2):
        fmt.Printf("查询用户 %d 的数据: %s\n", userID, data)
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

// HTTP处理器中使用context
func httpHandler(w http.ResponseWriter, r *http.Request) {
    // 从请求中获取context
    ctx := r.Context()

    // 添加请求特定的值
    ctx = context.WithValue(ctx, "requestID", generateRequestID())

    // 设置请求超时
    ctx, cancel := context.WithTimeout(ctx, time.Second*10)
    defer cancel()

    // 处理请求
    if err := processRequest(ctx, 123, "user data"); err != nil {
        http.Error(w, err.Error(), http.StatusInternalServerError)
        return
    }

    w.WriteHeader(http.StatusOK)
    w.Write([]byte("请求处理成功"))
}

func generateRequestID() string {
    return fmt.Sprintf("req-%d", time.Now().UnixNano())
}

// 错误示例:不要这样做
type BadService struct {
    ctx context.Context // 不要在结构体中存储context
}

// 正确示例:context作为方法参数
type GoodService struct {
    config Config
}

type Config struct {
    Timeout time.Duration
}

func (s *GoodService) ProcessData(ctx context.Context, data string) error {
    // 使用传入的context
    ctx, cancel := context.WithTimeout(ctx, s.config.Timeout)
    defer cancel()

    return s.doProcess(ctx, data)
}

func (s *GoodService) doProcess(ctx context.Context, data string) error {
    // 实际处理逻辑
    select {
    case <-time.After(time.Second):
        fmt.Printf("处理数据: %s\n", data)
        return nil
    case <-ctx.Done():
        return ctx.Err()
    }
}

func main() {
    service := &GoodService{
        config: Config{Timeout: time.Second * 5},
    }

    ctx := context.Background()
    if err := service.ProcessData(ctx, "test data"); err != nil {
        fmt.Printf("处理失败: %v\n", err)
    }
}

2. Context 值的安全使用 #

package main

import (
    "context"
    "fmt"
)

// 定义类型安全的context key
type contextKey int

const (
    userContextKey contextKey = iota
    requestContextKey
    traceContextKey
)

// 用户信息
type User struct {
    ID   int
    Name string
}

// 请求信息
type RequestInfo struct {
    ID        string
    Method    string
    Path      string
    StartTime time.Time
}

// 链路追踪信息
type TraceInfo struct {
    TraceID string
    SpanID  string
}

// Context辅助函数
func WithUser(ctx context.Context, user *User) context.Context {
    return context.WithValue(ctx, userContextKey, user)
}

func GetUser(ctx context.Context) (*User, bool) {
    user, ok := ctx.Value(userContextKey).(*User)
    return user, ok
}

func WithRequestInfo(ctx context.Context, req *RequestInfo) context.Context {
    return context.WithValue(ctx, requestContextKey, req)
}

func GetRequestInfo(ctx context.Context) (*RequestInfo, bool) {
    req, ok := ctx.Value(requestContextKey).(*RequestInfo)
    return req, ok
}

func WithTrace(ctx context.Context, trace *TraceInfo) context.Context {
    return context.WithValue(ctx, traceContextKey, trace)
}

func GetTrace(ctx context.Context) (*TraceInfo, bool) {
    trace, ok := ctx.Value(traceContextKey).(*TraceInfo)
    return trace, ok
}

// 业务逻辑函数
func businessLogic(ctx context.Context) {
    // 安全地获取context中的值
    if user, ok := GetUser(ctx); ok {
        fmt.Printf("当前用户: %s (ID: %d)\n", user.Name, user.ID)
    }

    if req, ok := GetRequestInfo(ctx); ok {
        fmt.Printf("请求信息: %s %s (ID: %s)\n", req.Method, req.Path, req.ID)
    }

    if trace, ok := GetTrace(ctx); ok {
        fmt.Printf("链路追踪: TraceID=%s, SpanID=%s\n", trace.TraceID, trace.SpanID)
    }
}

func main() {
    // 构建context
    ctx := context.Background()

    // 添加用户信息
    user := &User{ID: 123, Name: "张三"}
    ctx = WithUser(ctx, user)

    // 添加请求信息
    reqInfo := &RequestInfo{
        ID:        "req-001",
        Method:    "GET",
        Path:      "/api/users",
        StartTime: time.Now(),
    }
    ctx = WithRequestInfo(ctx, reqInfo)

    // 添加链路追踪信息
    trace := &TraceInfo{
        TraceID: "trace-abc-123",
        SpanID:  "span-def-456",
    }
    ctx = WithTrace(ctx, trace)

    // 调用业务逻辑
    businessLogic(ctx)
}

3. Context 超时和取消的优雅处理 #

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// 任务结果
type TaskResult struct {
    ID     string
    Result interface{}
    Error  error
}

// 任务管理器
type TaskManager struct {
    tasks   map[string]context.CancelFunc
    results chan TaskResult
    mu      sync.RWMutex
}

func NewTaskManager() *TaskManager {
    return &TaskManager{
        tasks:   make(map[string]context.CancelFunc),
        results: make(chan TaskResult, 100),
    }
}

// 提交任务
func (tm *TaskManager) SubmitTask(taskID string, timeout time.Duration, task func(context.Context) (interface{}, error)) {
    ctx, cancel := context.WithTimeout(context.Background(), timeout)

    tm.mu.Lock()
    tm.tasks[taskID] = cancel
    tm.mu.Unlock()

    go func() {
        defer func() {
            tm.mu.Lock()
            delete(tm.tasks, taskID)
            tm.mu.Unlock()
        }()

        result, err := task(ctx)

        tm.results <- TaskResult{
            ID:     taskID,
            Result: result,
            Error:  err,
        }
    }()
}

// 取消任务
func (tm *TaskManager) CancelTask(taskID string) bool {
    tm.mu.RLock()
    cancel, exists := tm.tasks[taskID]
    tm.mu.RUnlock()

    if exists {
        cancel()
        return true
    }
    return false
}

// 获取结果
func (tm *TaskManager) GetResults() <-chan TaskResult {
    return tm.results
}

// 关闭任务管理器
func (tm *TaskManager) Close() {
    tm.mu.Lock()
    defer tm.mu.Unlock()

    // 取消所有未完成的任务
    for _, cancel := range tm.tasks {
        cancel()
    }

    close(tm.results)
}

// 示例任务函数
func longRunningTask(ctx context.Context, taskName string, duration time.Duration) (interface{}, error) {
    fmt.Printf("任务 %s 开始执行\n", taskName)

    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()

    startTime := time.Now()

    for {
        select {
        case <-ticker.C:
            elapsed := time.Since(startTime)
            fmt.Printf("任务 %s 执行中... (已执行 %.1f 秒)\n", taskName, elapsed.Seconds())

            if elapsed >= duration {
                result := fmt.Sprintf("任务 %s 完成,耗时 %.1f 秒", taskName, elapsed.Seconds())
                fmt.Println(result)
                return result, nil
            }

        case <-ctx.Done():
            elapsed := time.Since(startTime)
            err := fmt.Errorf("任务 %s 被取消,已执行 %.1f 秒,原因: %v",
                taskName, elapsed.Seconds(), ctx.Err())
            fmt.Println(err)
            return nil, err
        }
    }
}

func main() {
    tm := NewTaskManager()
    defer tm.Close()

    // 提交多个任务
    tm.SubmitTask("task1", time.Second*3, func(ctx context.Context) (interface{}, error) {
        return longRunningTask(ctx, "快速任务", time.Second*2)
    })

    tm.SubmitTask("task2", time.Second*5, func(ctx context.Context) (interface{}, error) {
        return longRunningTask(ctx, "中等任务", time.Second*4)
    })

    tm.SubmitTask("task3", time.Second*3, func(ctx context.Context) (interface{}, error) {
        return longRunningTask(ctx, "长时间任务", time.Second*6)
    })

    // 2秒后取消task2
    go func() {
        time.Sleep(time.Second * 2)
        if tm.CancelTask("task2") {
            fmt.Println("手动取消了 task2")
        }
    }()

    // 收集结果
    resultCount := 0
    timeout := time.After(time.Second * 8)

    for {
        select {
        case result := <-tm.GetResults():
            resultCount++
            if result.Error != nil {
                fmt.Printf("任务 %s 失败: %v\n", result.ID, result.Error)
            } else {
                fmt.Printf("任务 %s 成功: %v\n", result.ID, result.Result)
            }

            if resultCount >= 3 {
                fmt.Println("所有任务处理完成")
                return
            }

        case <-timeout:
            fmt.Println("等待结果超时")
            return
        }
    }
}

Context 性能考虑 #

1. 避免过度使用 WithValue #

// 不好的做法:频繁创建context
func badExample(ctx context.Context) {
    for i := 0; i < 1000; i++ {
        newCtx := context.WithValue(ctx, fmt.Sprintf("key%d", i), i)
        // 使用newCtx...
        _ = newCtx
    }
}

// 好的做法:使用结构体传递数据
type RequestData struct {
    UserID    int
    RequestID string
    Values    map[string]interface{}
}

func goodExample(ctx context.Context, data *RequestData) {
    // 直接使用结构体中的数据
    fmt.Printf("UserID: %d, RequestID: %s\n", data.UserID, data.RequestID)
}

2. 合理使用超时 #

// 根据操作类型设置合理的超时时间
func setAppropriateTimeout(ctx context.Context, operationType string) (context.Context, context.CancelFunc) {
    var timeout time.Duration

    switch operationType {
    case "database":
        timeout = time.Second * 5
    case "cache":
        timeout = time.Millisecond * 100
    case "external_api":
        timeout = time.Second * 10
    default:
        timeout = time.Second * 3
    }

    return context.WithTimeout(ctx, timeout)
}

Context 是 Go 并发编程的重要工具,正确使用 Context 可以让程序更加健壮和高效。在实际开发中,要遵循 Context 的最佳实践,合理设置超时时间,正确处理取消信号,并避免在 Context 中存储过多的值。