2.3.3 WaitGroup 与 Once

2.3.3 WaitGroup 与 Once #

WaitGroup 概述 #

WaitGroup 是 Go 语言中用于等待一组 Goroutine 完成的同步原语。它提供了一种简单而有效的方式来协调多个并发任务的执行,确保主程序在所有子任务完成后再继续执行。

WaitGroup 的特点 #

  • 计数器机制:内部维护一个计数器,记录待完成的任务数量
  • 阻塞等待:Wait() 方法会阻塞直到计数器归零
  • 线程安全:所有操作都是原子性的,可以安全地在多个 Goroutine 中使用
  • 零值可用:WaitGroup 的零值是一个有效的、计数为零的等待组

基本方法 #

  • Add(delta int):增加计数器的值
  • Done():减少计数器的值(等价于 Add(-1))
  • Wait():阻塞等待直到计数器为零

WaitGroup 的基本使用 #

基本语法 #

package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done() // 确保在函数结束时调用 Done()

    fmt.Printf("Worker %d starting\n", id)

    // 模拟工作
    time.Sleep(time.Duration(id) * time.Second)

    fmt.Printf("Worker %d completed\n", id)
}

func basicWaitGroupExample() {
    var wg sync.WaitGroup

    // 启动5个工作者
    for i := 1; i <= 5; i++ {
        wg.Add(1) // 增加计数器
        go worker(i, &wg)
    }

    fmt.Println("Waiting for all workers to complete...")
    wg.Wait() // 等待所有工作者完成
    fmt.Println("All workers completed!")
}

func main() {
    basicWaitGroupExample()
}

批量添加任务 #

package main

import (
    "fmt"
    "sync"
    "time"
)

func batchWorker(id int, jobs <-chan int, wg *sync.WaitGroup) {
    defer wg.Done()

    for job := range jobs {
        fmt.Printf("Worker %d processing job %d\n", id, job)
        time.Sleep(100 * time.Millisecond)
    }

    fmt.Printf("Worker %d finished\n", id)
}

func batchWaitGroupExample() {
    const numWorkers = 3
    const numJobs = 10

    jobs := make(chan int, numJobs)
    var wg sync.WaitGroup

    // 启动工作者
    wg.Add(numWorkers) // 一次性添加所有工作者
    for i := 1; i <= numWorkers; i++ {
        go batchWorker(i, jobs, &wg)
    }

    // 发送任务
    for i := 1; i <= numJobs; i++ {
        jobs <- i
    }
    close(jobs)

    // 等待所有工作者完成
    wg.Wait()
    fmt.Println("All batch workers completed!")
}

func main() {
    batchWaitGroupExample()
}

WaitGroup 的高级用法 #

1. 嵌套 WaitGroup #

package main

import (
    "fmt"
    "sync"
    "time"
)

func subTask(taskID, subTaskID int, wg *sync.WaitGroup) {
    defer wg.Done()

    fmt.Printf("Task %d - SubTask %d starting\n", taskID, subTaskID)
    time.Sleep(500 * time.Millisecond)
    fmt.Printf("Task %d - SubTask %d completed\n", taskID, subTaskID)
}

func mainTask(taskID int, wg *sync.WaitGroup) {
    defer wg.Done()

    fmt.Printf("Main Task %d starting\n", taskID)

    // 创建子任务的 WaitGroup
    var subWG sync.WaitGroup

    // 启动3个子任务
    for i := 1; i <= 3; i++ {
        subWG.Add(1)
        go subTask(taskID, i, &subWG)
    }

    // 等待所有子任务完成
    subWG.Wait()

    fmt.Printf("Main Task %d completed (all subtasks done)\n", taskID)
}

func nestedWaitGroupExample() {
    var mainWG sync.WaitGroup

    // 启动3个主任务
    for i := 1; i <= 3; i++ {
        mainWG.Add(1)
        go mainTask(i, &mainWG)
    }

    // 等待所有主任务完成
    mainWG.Wait()
    fmt.Println("All main tasks completed!")
}

func main() {
    nestedWaitGroupExample()
}

2. 带超时的 WaitGroup #

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// 带超时的 WaitGroup 包装器
type TimeoutWaitGroup struct {
    wg sync.WaitGroup
}

func (twg *TimeoutWaitGroup) Add(delta int) {
    twg.wg.Add(delta)
}

func (twg *TimeoutWaitGroup) Done() {
    twg.wg.Done()
}

func (twg *TimeoutWaitGroup) Wait() {
    twg.wg.Wait()
}

func (twg *TimeoutWaitGroup) WaitWithTimeout(timeout time.Duration) bool {
    done := make(chan struct{})

    go func() {
        defer close(done)
        twg.wg.Wait()
    }()

    select {
    case <-done:
        return true // 正常完成
    case <-time.After(timeout):
        return false // 超时
    }
}

func (twg *TimeoutWaitGroup) WaitWithContext(ctx context.Context) bool {
    done := make(chan struct{})

    go func() {
        defer close(done)
        twg.wg.Wait()
    }()

    select {
    case <-done:
        return true // 正常完成
    case <-ctx.Done():
        return false // 上下文取消
    }
}

func slowWorker(id int, duration time.Duration, wg *TimeoutWaitGroup) {
    defer wg.Done()

    fmt.Printf("Slow worker %d starting (will take %v)\n", id, duration)
    time.Sleep(duration)
    fmt.Printf("Slow worker %d completed\n", id)
}

func timeoutWaitGroupExample() {
    var twg TimeoutWaitGroup

    // 启动一些快速任务和一些慢速任务
    twg.Add(2)
    go slowWorker(1, 1*time.Second, &twg)  // 快速任务
    go slowWorker(2, 5*time.Second, &twg)  // 慢速任务

    // 尝试在3秒内完成
    fmt.Println("Waiting with 3 second timeout...")
    if twg.WaitWithTimeout(3 * time.Second) {
        fmt.Println("All workers completed within timeout")
    } else {
        fmt.Println("Timeout occurred, some workers may still be running")
    }

    // 使用 context 的例子
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()

    var twg2 TimeoutWaitGroup
    twg2.Add(1)
    go slowWorker(3, 4*time.Second, &twg2)

    fmt.Println("Waiting with context timeout...")
    if twg2.WaitWithContext(ctx) {
        fmt.Println("Worker completed within context timeout")
    } else {
        fmt.Println("Context timeout occurred")
    }
}

func main() {
    timeoutWaitGroupExample()
}

3. 动态任务管理 #

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

type TaskManager struct {
    wg       sync.WaitGroup
    mu       sync.Mutex
    tasks    map[int]string
    results  map[int]string
    taskID   int
}

func NewTaskManager() *TaskManager {
    return &TaskManager{
        tasks:   make(map[int]string),
        results: make(map[int]string),
    }
}

func (tm *TaskManager) AddTask(description string) int {
    tm.mu.Lock()
    defer tm.mu.Unlock()

    tm.taskID++
    taskID := tm.taskID
    tm.tasks[taskID] = description

    tm.wg.Add(1)

    go tm.executeTask(taskID, description)

    return taskID
}

func (tm *TaskManager) executeTask(taskID int, description string) {
    defer tm.wg.Done()

    fmt.Printf("Task %d started: %s\n", taskID, description)

    // 模拟任务执行时间
    duration := time.Duration(rand.Intn(3)+1) * time.Second
    time.Sleep(duration)

    // 模拟任务结果
    result := fmt.Sprintf("Result of task %d (took %v)", taskID, duration)

    tm.mu.Lock()
    tm.results[taskID] = result
    tm.mu.Unlock()

    fmt.Printf("Task %d completed: %s\n", taskID, result)

    // 随机生成子任务
    if rand.Float32() < 0.3 { // 30% 概率生成子任务
        subTaskDesc := fmt.Sprintf("Subtask of %d", taskID)
        subTaskID := tm.AddTask(subTaskDesc)
        fmt.Printf("Task %d spawned subtask %d\n", taskID, subTaskID)
    }
}

func (tm *TaskManager) WaitAll() {
    tm.wg.Wait()
}

func (tm *TaskManager) GetResults() map[int]string {
    tm.mu.Lock()
    defer tm.mu.Unlock()

    results := make(map[int]string)
    for k, v := range tm.results {
        results[k] = v
    }
    return results
}

func dynamicTaskExample() {
    rand.Seed(time.Now().UnixNano())

    tm := NewTaskManager()

    // 添加初始任务
    initialTasks := []string{
        "Process data file A",
        "Generate report B",
        "Send notifications C",
        "Update database D",
    }

    for _, task := range initialTasks {
        taskID := tm.AddTask(task)
        fmt.Printf("Added initial task %d: %s\n", taskID, task)
    }

    // 等待所有任务完成(包括动态生成的子任务)
    fmt.Println("Waiting for all tasks to complete...")
    tm.WaitAll()

    // 显示结果
    fmt.Println("\n=== Task Results ===")
    results := tm.GetResults()
    for taskID, result := range results {
        fmt.Printf("Task %d: %s\n", taskID, result)
    }

    fmt.Printf("Total tasks completed: %d\n", len(results))
}

func main() {
    dynamicTaskExample()
}

Once 概述 #

Once 是 Go 语言中用于确保某个操作只执行一次的同步原语。无论有多少个 Goroutine 调用 Once.Do(),传入的函数只会被执行一次,这在初始化场景中非常有用。

Once 的特点 #

  • 单次执行:确保函数只被执行一次
  • 线程安全:可以安全地在多个 Goroutine 中使用
  • 阻塞等待:如果函数正在执行,其他调用会等待其完成
  • 零值可用:Once 的零值是一个有效的、未执行的 Once

Once 的基本使用 #

基本语法 #

package main

import (
    "fmt"
    "sync"
    "time"
)

var (
    once     sync.Once
    instance *Singleton
)

type Singleton struct {
    data string
}

func GetInstance() *Singleton {
    once.Do(func() {
        fmt.Println("Creating singleton instance...")
        time.Sleep(100 * time.Millisecond) // 模拟初始化时间
        instance = &Singleton{
            data: "I am a singleton",
        }
        fmt.Println("Singleton instance created")
    })

    return instance
}

func basicOnceExample() {
    var wg sync.WaitGroup

    // 启动多个 Goroutine 尝试获取单例
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            fmt.Printf("Goroutine %d requesting singleton...\n", id)
            singleton := GetInstance()
            fmt.Printf("Goroutine %d got singleton: %s\n", id, singleton.data)
        }(i)
    }

    wg.Wait()
    fmt.Println("All goroutines completed")
}

func main() {
    basicOnceExample()
}

配置初始化 #

package main

import (
    "encoding/json"
    "fmt"
    "sync"
)

type Config struct {
    Database struct {
        Host     string `json:"host"`
        Port     int    `json:"port"`
        Username string `json:"username"`
    } `json:"database"`

    Server struct {
        Port int `json:"port"`
    } `json:"server"`
}

var (
    configOnce sync.Once
    config     *Config
    configErr  error
)

func loadConfig() (*Config, error) {
    configOnce.Do(func() {
        fmt.Println("Loading configuration...")

        // 模拟从文件或环境变量加载配置
        configData := `{
            "database": {
                "host": "localhost",
                "port": 5432,
                "username": "admin"
            },
            "server": {
                "port": 8080
            }
        }`

        config = &Config{}
        configErr = json.Unmarshal([]byte(configData), config)

        if configErr != nil {
            fmt.Printf("Failed to load config: %v\n", configErr)
        } else {
            fmt.Println("Configuration loaded successfully")
        }
    })

    return config, configErr
}

func configInitExample() {
    var wg sync.WaitGroup

    // 多个 Goroutine 同时请求配置
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            fmt.Printf("Service %d requesting config...\n", id)
            cfg, err := loadConfig()
            if err != nil {
                fmt.Printf("Service %d failed to get config: %v\n", id, err)
                return
            }

            fmt.Printf("Service %d got config - DB: %s:%d, Server: %d\n",
                id, cfg.Database.Host, cfg.Database.Port, cfg.Server.Port)
        }(i)
    }

    wg.Wait()
}

func main() {
    configInitExample()
}

Once 的高级用法 #

1. 重置 Once #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 可重置的 Once
type ResettableOnce struct {
    mu   sync.Mutex
    done bool
}

func (ro *ResettableOnce) Do(f func()) {
    ro.mu.Lock()
    defer ro.mu.Unlock()

    if !ro.done {
        f()
        ro.done = true
    }
}

func (ro *ResettableOnce) Reset() {
    ro.mu.Lock()
    defer ro.mu.Unlock()
    ro.done = false
}

func (ro *ResettableOnce) IsDone() bool {
    ro.mu.Lock()
    defer ro.mu.Unlock()
    return ro.done
}

var (
    resettableOnce ResettableOnce
    counter        int
)

func incrementCounter() {
    counter++
    fmt.Printf("Counter incremented to: %d\n", counter)
}

func resettableOnceExample() {
    var wg sync.WaitGroup

    // 第一轮:多个 Goroutine 尝试执行
    fmt.Println("=== First Round ===")
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            fmt.Printf("Goroutine %d calling Do()\n", id)
            resettableOnce.Do(incrementCounter)
        }(i)
    }
    wg.Wait()

    fmt.Printf("After first round, done: %v\n", resettableOnce.IsDone())

    // 重置
    fmt.Println("\n=== Resetting ===")
    resettableOnce.Reset()
    fmt.Printf("After reset, done: %v\n", resettableOnce.IsDone())

    // 第二轮:再次尝试执行
    fmt.Println("\n=== Second Round ===")
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            fmt.Printf("Goroutine %d calling Do()\n", id)
            resettableOnce.Do(incrementCounter)
        }(i)
    }
    wg.Wait()

    fmt.Printf("Final counter value: %d\n", counter)
}

func main() {
    resettableOnceExample()
}

2. 条件 Once #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 条件 Once:只有满足条件时才执行
type ConditionalOnce struct {
    mu        sync.Mutex
    done      bool
    condition func() bool
}

func NewConditionalOnce(condition func() bool) *ConditionalOnce {
    return &ConditionalOnce{
        condition: condition,
    }
}

func (co *ConditionalOnce) Do(f func()) bool {
    co.mu.Lock()
    defer co.mu.Unlock()

    if co.done {
        return false // 已经执行过
    }

    if co.condition != nil && !co.condition() {
        return false // 条件不满足
    }

    f()
    co.done = true
    return true
}

func (co *ConditionalOnce) IsDone() bool {
    co.mu.Lock()
    defer co.mu.Unlock()
    return co.done
}

var (
    startTime = time.Now()
    condOnce  = NewConditionalOnce(func() bool {
        // 只有在程序运行超过2秒后才允许执行
        return time.Since(startTime) > 2*time.Second
    })
)

func expensiveOperation() {
    fmt.Println("Executing expensive operation...")
    time.Sleep(500 * time.Millisecond)
    fmt.Println("Expensive operation completed")
}

func conditionalOnceExample() {
    var wg sync.WaitGroup

    // 启动多个 Goroutine,在不同时间尝试执行
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            // 等待不同的时间
            time.Sleep(time.Duration(id) * 500 * time.Millisecond)

            fmt.Printf("Goroutine %d attempting execution at %v\n",
                id, time.Since(startTime))

            if condOnce.Do(expensiveOperation) {
                fmt.Printf("Goroutine %d successfully executed the operation\n", id)
            } else {
                if condOnce.IsDone() {
                    fmt.Printf("Goroutine %d: operation already done\n", id)
                } else {
                    fmt.Printf("Goroutine %d: condition not met\n", id)
                }
            }
        }(i)
    }

    wg.Wait()
}

func main() {
    conditionalOnceExample()
}

3. 多阶段初始化 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type MultiStageInitializer struct {
    stage1Once sync.Once
    stage2Once sync.Once
    stage3Once sync.Once

    stage1Done bool
    stage2Done bool
    stage3Done bool

    mu sync.RWMutex
}

func (msi *MultiStageInitializer) InitStage1() {
    msi.stage1Once.Do(func() {
        fmt.Println("Initializing Stage 1...")
        time.Sleep(500 * time.Millisecond)

        msi.mu.Lock()
        msi.stage1Done = true
        msi.mu.Unlock()

        fmt.Println("Stage 1 initialization completed")
    })
}

func (msi *MultiStageInitializer) InitStage2() {
    // Stage 2 依赖于 Stage 1
    msi.InitStage1()

    msi.stage2Once.Do(func() {
        fmt.Println("Initializing Stage 2...")
        time.Sleep(300 * time.Millisecond)

        msi.mu.Lock()
        msi.stage2Done = true
        msi.mu.Unlock()

        fmt.Println("Stage 2 initialization completed")
    })
}

func (msi *MultiStageInitializer) InitStage3() {
    // Stage 3 依赖于 Stage 2
    msi.InitStage2()

    msi.stage3Once.Do(func() {
        fmt.Println("Initializing Stage 3...")
        time.Sleep(200 * time.Millisecond)

        msi.mu.Lock()
        msi.stage3Done = true
        msi.mu.Unlock()

        fmt.Println("Stage 3 initialization completed")
    })
}

func (msi *MultiStageInitializer) GetStatus() (bool, bool, bool) {
    msi.mu.RLock()
    defer msi.mu.RUnlock()
    return msi.stage1Done, msi.stage2Done, msi.stage3Done
}

func multiStageExample() {
    initializer := &MultiStageInitializer{}
    var wg sync.WaitGroup

    // 不同的 Goroutine 请求不同阶段的初始化
    scenarios := []struct {
        name  string
        stage int
    }{
        {"Service A", 1},
        {"Service B", 2},
        {"Service C", 3},
        {"Service D", 2},
        {"Service E", 3},
        {"Service F", 1},
    }

    for _, scenario := range scenarios {
        wg.Add(1)
        go func(name string, stage int) {
            defer wg.Done()

            fmt.Printf("%s requesting stage %d initialization\n", name, stage)

            switch stage {
            case 1:
                initializer.InitStage1()
            case 2:
                initializer.InitStage2()
            case 3:
                initializer.InitStage3()
            }

            s1, s2, s3 := initializer.GetStatus()
            fmt.Printf("%s completed - Status: Stage1=%v, Stage2=%v, Stage3=%v\n",
                name, s1, s2, s3)
        }(scenario.name, scenario.stage)
    }

    wg.Wait()

    s1, s2, s3 := initializer.GetStatus()
    fmt.Printf("Final status: Stage1=%v, Stage2=%v, Stage3=%v\n", s1, s2, s3)
}

func main() {
    multiStageExample()
}

常见陷阱和最佳实践 #

1. WaitGroup 计数器管理 #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 错误示例:计数器不匹配
func badWaitGroupExample() {
    var wg sync.WaitGroup

    // 错误:Add 和 Done 不匹配
    wg.Add(3)

    go func() {
        defer wg.Done()
        fmt.Println("Worker 1")
    }()

    go func() {
        defer wg.Done()
        fmt.Println("Worker 2")
        // 忘记调用 Done() 或者多调用了 Done()
    }()

    // 只启动了2个 Goroutine,但 Add 了3
    // 这会导致 Wait() 永远阻塞

    // wg.Wait() // 这会永远等待

    fmt.Println("This example shows incorrect usage")
}

// 正确示例:确保计数器匹配
func goodWaitGroupExample() {
    var wg sync.WaitGroup

    tasks := []string{"Task A", "Task B", "Task C"}

    // 根据实际任务数量设置计数器
    wg.Add(len(tasks))

    for i, task := range tasks {
        go func(id int, taskName string) {
            defer wg.Done() // 确保每个 Goroutine 都调用 Done()

            fmt.Printf("Processing %s\n", taskName)
            time.Sleep(time.Duration(id+1) * 100 * time.Millisecond)
            fmt.Printf("Completed %s\n", taskName)
        }(i, task)
    }

    wg.Wait()
    fmt.Println("All tasks completed correctly")
}

func main() {
    fmt.Println("=== Bad Example ===")
    badWaitGroupExample()

    fmt.Println("\n=== Good Example ===")
    goodWaitGroupExample()
}

2. Once 的错误处理 #

package main

import (
    "fmt"
    "sync"
)

var (
    initOnce   sync.Once
    initResult string
    initError  error
)

// 错误示例:Once 中的错误处理
func badInitialization() (string, error) {
    initOnce.Do(func() {
        // 模拟可能失败的初始化
        if false { // 假设这里是某种条件
            initResult = "Success"
            initError = nil
        } else {
            initResult = ""
            initError = fmt.Errorf("initialization failed")
        }
    })

    return initResult, initError
}

// 正确示例:使用可重置的 Once 处理错误
type ErrorHandlingOnce struct {
    mu      sync.Mutex
    done    bool
    result  string
    err     error
}

func (eho *ErrorHandlingOnce) Do(f func() (string, error)) (string, error) {
    eho.mu.Lock()
    defer eho.mu.Unlock()

    if eho.done && eho.err == nil {
        return eho.result, eho.err
    }

    eho.result, eho.err = f()
    if eho.err == nil {
        eho.done = true
    }

    return eho.result, eho.err
}

var goodInitOnce ErrorHandlingOnce

func goodInitialization() (string, error) {
    return goodInitOnce.Do(func() (string, error) {
        // 模拟初始化逻辑
        fmt.Println("Attempting initialization...")

        // 这里可以有复杂的初始化逻辑
        success := true // 假设这是某种条件检查

        if success {
            return "Initialization successful", nil
        }
        return "", fmt.Errorf("initialization failed")
    })
}

func onceErrorHandlingExample() {
    fmt.Println("=== Bad Error Handling ===")
    result, err := badInitialization()
    fmt.Printf("Result: %s, Error: %v\n", result, err)

    // 再次调用,错误状态被"记住"了
    result, err = badInitialization()
    fmt.Printf("Second call - Result: %s, Error: %v\n", result, err)

    fmt.Println("\n=== Good Error Handling ===")
    result, err = goodInitialization()
    fmt.Printf("Result: %s, Error: %v\n", result, err)

    // 再次调用
    result, err = goodInitialization()
    fmt.Printf("Second call - Result: %s, Error: %v\n", result, err)
}

func main() {
    onceErrorHandlingExample()
}

3. 组合使用 WaitGroup 和 Once #

package main

import (
    "fmt"
    "sync"
    "time"
)

type ServiceManager struct {
    initOnce    sync.Once
    shutdownOnce sync.Once

    services map[string]*Service
    mu       sync.RWMutex

    initialized bool
    shutdown    bool
}

type Service struct {
    name    string
    running bool
    mu      sync.Mutex
}

func (s *Service) Start() error {
    s.mu.Lock()
    defer s.mu.Unlock()

    if s.running {
        return fmt.Errorf("service %s is already running", s.name)
    }

    fmt.Printf("Starting service: %s\n", s.name)
    time.Sleep(100 * time.Millisecond) // 模拟启动时间
    s.running = true
    fmt.Printf("Service %s started\n", s.name)

    return nil
}

func (s *Service) Stop() error {
    s.mu.Lock()
    defer s.mu.Unlock()

    if !s.running {
        return fmt.Errorf("service %s is not running", s.name)
    }

    fmt.Printf("Stopping service: %s\n", s.name)
    time.Sleep(50 * time.Millisecond) // 模拟停止时间
    s.running = false
    fmt.Printf("Service %s stopped\n", s.name)

    return nil
}

func NewServiceManager() *ServiceManager {
    return &ServiceManager{
        services: make(map[string]*Service),
    }
}

func (sm *ServiceManager) Initialize() {
    sm.initOnce.Do(func() {
        fmt.Println("Initializing service manager...")

        // 创建服务
        serviceNames := []string{"Database", "Cache", "Logger", "Monitor"}
        for _, name := range serviceNames {
            sm.services[name] = &Service{name: name}
        }

        sm.initialized = true
        fmt.Println("Service manager initialized")
    })
}

func (sm *ServiceManager) StartAll() error {
    sm.Initialize()

    sm.mu.RLock()
    defer sm.mu.RUnlock()

    if sm.shutdown {
        return fmt.Errorf("service manager is shutdown")
    }

    var wg sync.WaitGroup
    errors := make(chan error, len(sm.services))

    for _, service := range sm.services {
        wg.Add(1)
        go func(svc *Service) {
            defer wg.Done()
            if err := svc.Start(); err != nil {
                errors <- err
            }
        }(service)
    }

    wg.Wait()
    close(errors)

    // 检查是否有错误
    for err := range errors {
        if err != nil {
            return err
        }
    }

    return nil
}

func (sm *ServiceManager) Shutdown() {
    sm.shutdownOnce.Do(func() {
        fmt.Println("Shutting down service manager...")

        sm.mu.Lock()
        sm.shutdown = true
        sm.mu.Unlock()

        var wg sync.WaitGroup

        sm.mu.RLock()
        for _, service := range sm.services {
            wg.Add(1)
            go func(svc *Service) {
                defer wg.Done()
                svc.Stop()
            }(service)
        }
        sm.mu.RUnlock()

        wg.Wait()
        fmt.Println("Service manager shutdown completed")
    })
}

func combinedExample() {
    sm := NewServiceManager()
    var wg sync.WaitGroup

    // 多个 Goroutine 尝试启动服务
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            fmt.Printf("Client %d starting services...\n", id)
            if err := sm.StartAll(); err != nil {
                fmt.Printf("Client %d failed to start services: %v\n", id, err)
            } else {
                fmt.Printf("Client %d successfully started services\n", id)
            }
        }(i)
    }

    wg.Wait()

    // 等待一段时间后关闭
    time.Sleep(1 * time.Second)

    // 多个 Goroutine 尝试关闭服务
    for i := 1; i <= 2; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            fmt.Printf("Client %d shutting down services...\n", id)
            sm.Shutdown()
        }(i)
    }

    wg.Wait()
}

func main() {
    combinedExample()
}

小结 #

在本节中,我们深入学习了:

  1. WaitGroup:等待多个 Goroutine 完成的同步机制

    • 基本用法:Add、Done、Wait
    • 高级用法:嵌套 WaitGroup、带超时的等待、动态任务管理
  2. Once:确保函数只执行一次的同步原语

    • 基本用法:单例模式、配置初始化
    • 高级用法:可重置 Once、条件 Once、多阶段初始化
  3. 最佳实践

    • 确保 WaitGroup 计数器匹配
    • 正确处理 Once 中的错误
    • 组合使用不同的同步原语

WaitGroup 和 Once 是 Go 语言中非常实用的同步原语,它们解决了并发编程中的常见问题。在下一节中,我们将学习 Cond 条件变量,它提供了更复杂的等待和通知机制。

练习题 #

  1. 实现一个并发安全的任务队列,使用 WaitGroup 等待所有任务完成
  2. 设计一个资源池管理器,使用 Once 确保资源只初始化一次
  3. 创建一个分阶段的应用启动器,确保各个组件按正确顺序初始化