2.4.4 Context 上下文管理 #
Context 是 Go 语言中用于处理请求范围数据、取消信号和超时的标准机制。它在并发编程中扮演着至关重要的角色,特别是在需要协调多个 goroutine、处理超时和取消操作的场景中。
Context 基础概念 #
Context 接口定义了四个方法:
type Context interface {
// Deadline 返回context的截止时间
Deadline() (deadline time.Time, ok bool)
// Done 返回一个channel,当context被取消时会关闭
Done() <-chan struct{}
// Err 返回context被取消的原因
Err() error
// Value 返回与key关联的值
Value(key interface{}) interface{}
}
Context 的创建和使用 #
基础 Context 类型 #
package main
import (
"context"
"fmt"
"time"
)
func main() {
// 1. Background Context - 通常用作根context
bgCtx := context.Background()
fmt.Printf("Background context: %v\n", bgCtx)
// 2. TODO Context - 当不确定使用哪种context时的占位符
todoCtx := context.TODO()
fmt.Printf("TODO context: %v\n", todoCtx)
// 3. WithCancel - 可取消的context
cancelCtx, cancel := context.WithCancel(bgCtx)
defer cancel() // 确保资源清理
go func() {
select {
case <-cancelCtx.Done():
fmt.Println("Context被取消:", cancelCtx.Err())
}
}()
time.Sleep(time.Millisecond * 100)
cancel() // 取消context
time.Sleep(time.Millisecond * 100)
// 4. WithTimeout - 带超时的context
timeoutCtx, timeoutCancel := context.WithTimeout(bgCtx, time.Second*2)
defer timeoutCancel()
go func() {
select {
case <-timeoutCtx.Done():
fmt.Println("Context超时:", timeoutCtx.Err())
}
}()
time.Sleep(time.Second * 3) // 等待超时
// 5. WithDeadline - 带截止时间的context
deadline := time.Now().Add(time.Second * 1)
deadlineCtx, deadlineCancel := context.WithDeadline(bgCtx, deadline)
defer deadlineCancel()
go func() {
select {
case <-deadlineCtx.Done():
fmt.Println("Context到达截止时间:", deadlineCtx.Err())
}
}()
time.Sleep(time.Second * 2)
}
Context 传递值 #
package main
import (
"context"
"fmt"
)
// 定义context key的类型,避免冲突
type contextKey string
const (
UserIDKey contextKey = "userID"
RequestIDKey contextKey = "requestID"
TraceIDKey contextKey = "traceID"
)
// 用户信息结构
type UserInfo struct {
ID int
Name string
Role string
}
func main() {
// 创建带值的context
ctx := context.Background()
// 添加用户ID
ctx = context.WithValue(ctx, UserIDKey, 12345)
// 添加请求ID
ctx = context.WithValue(ctx, RequestIDKey, "req-abc-123")
// 添加用户信息
userInfo := &UserInfo{
ID: 12345,
Name: "张三",
Role: "admin",
}
ctx = context.WithValue(ctx, "user", userInfo)
// 传递context到其他函数
processRequest(ctx)
}
func processRequest(ctx context.Context) {
// 从context中获取值
userID := ctx.Value(UserIDKey)
requestID := ctx.Value(RequestIDKey)
user := ctx.Value("user")
fmt.Printf("处理请求 - UserID: %v, RequestID: %v\n", userID, requestID)
if userInfo, ok := user.(*UserInfo); ok {
fmt.Printf("用户信息 - ID: %d, Name: %s, Role: %s\n",
userInfo.ID, userInfo.Name, userInfo.Role)
}
// 继续传递context
handleDatabase(ctx)
handleCache(ctx)
}
func handleDatabase(ctx context.Context) {
userID := ctx.Value(UserIDKey)
fmt.Printf("数据库操作 - UserID: %v\n", userID)
}
func handleCache(ctx context.Context) {
requestID := ctx.Value(RequestIDKey)
fmt.Printf("缓存操作 - RequestID: %v\n", requestID)
}
Context 在并发控制中的应用 #
超时控制 #
package main
import (
"context"
"fmt"
"math/rand"
"sync"
"time"
)
// 模拟数据库查询
func queryDatabase(ctx context.Context, query string) (string, error) {
// 模拟查询时间
queryTime := time.Duration(rand.Intn(3000)) * time.Millisecond
select {
case <-time.After(queryTime):
return fmt.Sprintf("查询结果: %s", query), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// 模拟API调用
func callExternalAPI(ctx context.Context, url string) (string, error) {
// 模拟API调用时间
callTime := time.Duration(rand.Intn(2000)) * time.Millisecond
select {
case <-time.After(callTime):
return fmt.Sprintf("API响应: %s", url), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// 并发处理多个任务,带超时控制
func processWithTimeout(timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
results := make(chan string, 3)
errors := make(chan error, 3)
// 任务1:数据库查询
wg.Add(1)
go func() {
defer wg.Done()
result, err := queryDatabase(ctx, "SELECT * FROM users")
if err != nil {
errors <- fmt.Errorf("数据库查询失败: %v", err)
return
}
results <- result
}()
// 任务2:外部API调用
wg.Add(1)
go func() {
defer wg.Done()
result, err := callExternalAPI(ctx, "https://api.example.com/data")
if err != nil {
errors <- fmt.Errorf("API调用失败: %v", err)
return
}
results <- result
}()
// 任务3:另一个数据库查询
wg.Add(1)
go func() {
defer wg.Done()
result, err := queryDatabase(ctx, "SELECT * FROM orders")
if err != nil {
errors <- fmt.Errorf("订单查询失败: %v", err)
return
}
results <- result
}()
// 等待所有任务完成
go func() {
wg.Wait()
close(results)
close(errors)
}()
// 收集结果
fmt.Printf("开始处理,超时时间: %v\n", timeout)
for {
select {
case result, ok := <-results:
if !ok {
results = nil
} else {
fmt.Printf("✓ %s\n", result)
}
case err, ok := <-errors:
if !ok {
errors = nil
} else {
fmt.Printf("✗ %s\n", err)
}
}
if results == nil && errors == nil {
break
}
}
fmt.Println("处理完成\n")
}
func main() {
rand.Seed(time.Now().UnixNano())
// 测试不同的超时时间
fmt.Println("=== 测试超时控制 ===")
// 短超时(可能导致超时)
processWithTimeout(time.Second * 1)
// 中等超时
processWithTimeout(time.Second * 2)
// 长超时(通常能完成)
processWithTimeout(time.Second * 5)
}
取消传播 #
package main
import (
"context"
"fmt"
"sync"
"time"
)
// Worker 工作者结构
type Worker struct {
id int
name string
}
// 工作者执行任务
func (w *Worker) doWork(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
fmt.Printf("Worker %d (%s) 开始工作\n", w.id, w.name)
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
fmt.Printf("Worker %d (%s) 收到取消信号: %v\n", w.id, w.name, ctx.Err())
return
default:
fmt.Printf("Worker %d (%s) 执行任务 %d\n", w.id, w.name, i+1)
time.Sleep(time.Millisecond * 500)
}
}
fmt.Printf("Worker %d (%s) 完成所有任务\n", w.id, w.name)
}
// 管理器协调多个工作者
func manager(ctx context.Context) {
// 创建子context,可以独立取消
managerCtx, managerCancel := context.WithCancel(ctx)
defer managerCancel()
var wg sync.WaitGroup
// 创建多个工作者
workers := []Worker{
{1, "数据处理器"},
{2, "文件上传器"},
{3, "邮件发送器"},
{4, "日志记录器"},
}
// 启动所有工作者
for _, worker := range workers {
wg.Add(1)
go worker.doWork(managerCtx, &wg)
}
// 模拟管理器在3秒后决定取消所有任务
go func() {
time.Sleep(time.Second * 3)
fmt.Println("管理器决定取消所有任务")
managerCancel()
}()
// 等待所有工作者完成或被取消
wg.Wait()
fmt.Println("管理器:所有工作者已停止")
}
// 监控系统
func monitoringSystem(ctx context.Context) {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
fmt.Println("监控系统启动")
for {
select {
case <-ticker.C:
fmt.Println("监控系统:系统运行正常")
case <-ctx.Done():
fmt.Printf("监控系统收到停止信号: %v\n", ctx.Err())
return
}
}
}
func main() {
// 创建根context
rootCtx, rootCancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
// 启动管理器
wg.Add(1)
go func() {
defer wg.Done()
manager(rootCtx)
}()
// 启动监控系统
wg.Add(1)
go func() {
defer wg.Done()
monitoringSystem(rootCtx)
}()
// 模拟系统运行5秒后关闭
time.Sleep(time.Second * 5)
fmt.Println("主程序决定关闭系统")
rootCancel()
// 等待所有组件停止
wg.Wait()
fmt.Println("系统已完全关闭")
}
Context 最佳实践 #
1. Context 传递规范 #
package main
import (
"context"
"fmt"
"net/http"
"time"
)
// 正确的函数签名:context作为第一个参数
func processRequest(ctx context.Context, userID int, data string) error {
// 检查context是否已取消
if err := ctx.Err(); err != nil {
return err
}
// 创建子context用于数据库操作
dbCtx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
return queryUserData(dbCtx, userID, data)
}
func queryUserData(ctx context.Context, userID int, data string) error {
// 模拟数据库查询
select {
case <-time.After(time.Second * 2):
fmt.Printf("查询用户 %d 的数据: %s\n", userID, data)
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// HTTP处理器中使用context
func httpHandler(w http.ResponseWriter, r *http.Request) {
// 从请求中获取context
ctx := r.Context()
// 添加请求特定的值
ctx = context.WithValue(ctx, "requestID", generateRequestID())
// 设置请求超时
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
// 处理请求
if err := processRequest(ctx, 123, "user data"); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("请求处理成功"))
}
func generateRequestID() string {
return fmt.Sprintf("req-%d", time.Now().UnixNano())
}
// 错误示例:不要这样做
type BadService struct {
ctx context.Context // 不要在结构体中存储context
}
// 正确示例:context作为方法参数
type GoodService struct {
config Config
}
type Config struct {
Timeout time.Duration
}
func (s *GoodService) ProcessData(ctx context.Context, data string) error {
// 使用传入的context
ctx, cancel := context.WithTimeout(ctx, s.config.Timeout)
defer cancel()
return s.doProcess(ctx, data)
}
func (s *GoodService) doProcess(ctx context.Context, data string) error {
// 实际处理逻辑
select {
case <-time.After(time.Second):
fmt.Printf("处理数据: %s\n", data)
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func main() {
service := &GoodService{
config: Config{Timeout: time.Second * 5},
}
ctx := context.Background()
if err := service.ProcessData(ctx, "test data"); err != nil {
fmt.Printf("处理失败: %v\n", err)
}
}
2. Context 值的安全使用 #
package main
import (
"context"
"fmt"
)
// 定义类型安全的context key
type contextKey int
const (
userContextKey contextKey = iota
requestContextKey
traceContextKey
)
// 用户信息
type User struct {
ID int
Name string
}
// 请求信息
type RequestInfo struct {
ID string
Method string
Path string
StartTime time.Time
}
// 链路追踪信息
type TraceInfo struct {
TraceID string
SpanID string
}
// Context辅助函数
func WithUser(ctx context.Context, user *User) context.Context {
return context.WithValue(ctx, userContextKey, user)
}
func GetUser(ctx context.Context) (*User, bool) {
user, ok := ctx.Value(userContextKey).(*User)
return user, ok
}
func WithRequestInfo(ctx context.Context, req *RequestInfo) context.Context {
return context.WithValue(ctx, requestContextKey, req)
}
func GetRequestInfo(ctx context.Context) (*RequestInfo, bool) {
req, ok := ctx.Value(requestContextKey).(*RequestInfo)
return req, ok
}
func WithTrace(ctx context.Context, trace *TraceInfo) context.Context {
return context.WithValue(ctx, traceContextKey, trace)
}
func GetTrace(ctx context.Context) (*TraceInfo, bool) {
trace, ok := ctx.Value(traceContextKey).(*TraceInfo)
return trace, ok
}
// 业务逻辑函数
func businessLogic(ctx context.Context) {
// 安全地获取context中的值
if user, ok := GetUser(ctx); ok {
fmt.Printf("当前用户: %s (ID: %d)\n", user.Name, user.ID)
}
if req, ok := GetRequestInfo(ctx); ok {
fmt.Printf("请求信息: %s %s (ID: %s)\n", req.Method, req.Path, req.ID)
}
if trace, ok := GetTrace(ctx); ok {
fmt.Printf("链路追踪: TraceID=%s, SpanID=%s\n", trace.TraceID, trace.SpanID)
}
}
func main() {
// 构建context
ctx := context.Background()
// 添加用户信息
user := &User{ID: 123, Name: "张三"}
ctx = WithUser(ctx, user)
// 添加请求信息
reqInfo := &RequestInfo{
ID: "req-001",
Method: "GET",
Path: "/api/users",
StartTime: time.Now(),
}
ctx = WithRequestInfo(ctx, reqInfo)
// 添加链路追踪信息
trace := &TraceInfo{
TraceID: "trace-abc-123",
SpanID: "span-def-456",
}
ctx = WithTrace(ctx, trace)
// 调用业务逻辑
businessLogic(ctx)
}
3. Context 超时和取消的优雅处理 #
package main
import (
"context"
"fmt"
"sync"
"time"
)
// 任务结果
type TaskResult struct {
ID string
Result interface{}
Error error
}
// 任务管理器
type TaskManager struct {
tasks map[string]context.CancelFunc
results chan TaskResult
mu sync.RWMutex
}
func NewTaskManager() *TaskManager {
return &TaskManager{
tasks: make(map[string]context.CancelFunc),
results: make(chan TaskResult, 100),
}
}
// 提交任务
func (tm *TaskManager) SubmitTask(taskID string, timeout time.Duration, task func(context.Context) (interface{}, error)) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
tm.mu.Lock()
tm.tasks[taskID] = cancel
tm.mu.Unlock()
go func() {
defer func() {
tm.mu.Lock()
delete(tm.tasks, taskID)
tm.mu.Unlock()
}()
result, err := task(ctx)
tm.results <- TaskResult{
ID: taskID,
Result: result,
Error: err,
}
}()
}
// 取消任务
func (tm *TaskManager) CancelTask(taskID string) bool {
tm.mu.RLock()
cancel, exists := tm.tasks[taskID]
tm.mu.RUnlock()
if exists {
cancel()
return true
}
return false
}
// 获取结果
func (tm *TaskManager) GetResults() <-chan TaskResult {
return tm.results
}
// 关闭任务管理器
func (tm *TaskManager) Close() {
tm.mu.Lock()
defer tm.mu.Unlock()
// 取消所有未完成的任务
for _, cancel := range tm.tasks {
cancel()
}
close(tm.results)
}
// 示例任务函数
func longRunningTask(ctx context.Context, taskName string, duration time.Duration) (interface{}, error) {
fmt.Printf("任务 %s 开始执行\n", taskName)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
startTime := time.Now()
for {
select {
case <-ticker.C:
elapsed := time.Since(startTime)
fmt.Printf("任务 %s 执行中... (已执行 %.1f 秒)\n", taskName, elapsed.Seconds())
if elapsed >= duration {
result := fmt.Sprintf("任务 %s 完成,耗时 %.1f 秒", taskName, elapsed.Seconds())
fmt.Println(result)
return result, nil
}
case <-ctx.Done():
elapsed := time.Since(startTime)
err := fmt.Errorf("任务 %s 被取消,已执行 %.1f 秒,原因: %v",
taskName, elapsed.Seconds(), ctx.Err())
fmt.Println(err)
return nil, err
}
}
}
func main() {
tm := NewTaskManager()
defer tm.Close()
// 提交多个任务
tm.SubmitTask("task1", time.Second*3, func(ctx context.Context) (interface{}, error) {
return longRunningTask(ctx, "快速任务", time.Second*2)
})
tm.SubmitTask("task2", time.Second*5, func(ctx context.Context) (interface{}, error) {
return longRunningTask(ctx, "中等任务", time.Second*4)
})
tm.SubmitTask("task3", time.Second*3, func(ctx context.Context) (interface{}, error) {
return longRunningTask(ctx, "长时间任务", time.Second*6)
})
// 2秒后取消task2
go func() {
time.Sleep(time.Second * 2)
if tm.CancelTask("task2") {
fmt.Println("手动取消了 task2")
}
}()
// 收集结果
resultCount := 0
timeout := time.After(time.Second * 8)
for {
select {
case result := <-tm.GetResults():
resultCount++
if result.Error != nil {
fmt.Printf("任务 %s 失败: %v\n", result.ID, result.Error)
} else {
fmt.Printf("任务 %s 成功: %v\n", result.ID, result.Result)
}
if resultCount >= 3 {
fmt.Println("所有任务处理完成")
return
}
case <-timeout:
fmt.Println("等待结果超时")
return
}
}
}
Context 性能考虑 #
1. 避免过度使用 WithValue #
// 不好的做法:频繁创建context
func badExample(ctx context.Context) {
for i := 0; i < 1000; i++ {
newCtx := context.WithValue(ctx, fmt.Sprintf("key%d", i), i)
// 使用newCtx...
_ = newCtx
}
}
// 好的做法:使用结构体传递数据
type RequestData struct {
UserID int
RequestID string
Values map[string]interface{}
}
func goodExample(ctx context.Context, data *RequestData) {
// 直接使用结构体中的数据
fmt.Printf("UserID: %d, RequestID: %s\n", data.UserID, data.RequestID)
}
2. 合理使用超时 #
// 根据操作类型设置合理的超时时间
func setAppropriateTimeout(ctx context.Context, operationType string) (context.Context, context.CancelFunc) {
var timeout time.Duration
switch operationType {
case "database":
timeout = time.Second * 5
case "cache":
timeout = time.Millisecond * 100
case "external_api":
timeout = time.Second * 10
default:
timeout = time.Second * 3
}
return context.WithTimeout(ctx, timeout)
}
Context 是 Go 并发编程的重要工具,正确使用 Context 可以让程序更加健壮和高效。在实际开发中,要遵循 Context 的最佳实践,合理设置超时时间,正确处理取消信号,并避免在 Context 中存储过多的值。