2.3.1 Mutex 互斥锁

2.3.1 Mutex 互斥锁 #

Mutex 概述 #

Mutex(互斥锁)是最基本的同步原语之一,它提供了对共享资源的排他性访问。在任何时刻,只有一个 Goroutine 可以持有 Mutex,其他试图获取同一个 Mutex 的 Goroutine 将被阻塞,直到 Mutex 被释放。

Mutex 的特点 #

  • 排他性:同一时刻只能有一个 Goroutine 持有锁
  • 可重入性:Go 的 Mutex 不是可重入的,同一个 Goroutine 不能重复获取同一个锁
  • 公平性:Go 的 Mutex 实现了一定程度的公平性,避免饥饿
  • 零值可用:Mutex 的零值是一个有效的、未锁定的互斥锁

Mutex 的基本使用 #

基本语法 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type Counter struct {
    mu    sync.Mutex
    value int
}

func (c *Counter) Increment() {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value++
}

func (c *Counter) Decrement() {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value--
}

func (c *Counter) Value() int {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.value
}

func main() {
    counter := &Counter{}
    var wg sync.WaitGroup

    // 启动多个 Goroutine 并发修改计数器
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 1000; j++ {
                counter.Increment()
            }
            fmt.Printf("Goroutine %d completed\n", id)
        }(i)
    }

    wg.Wait()
    fmt.Printf("Final counter value: %d\n", counter.Value())
}

不使用 Mutex 的问题 #

让我们看看不使用 Mutex 会发生什么:

package main

import (
    "fmt"
    "sync"
)

type UnsafeCounter struct {
    value int
}

func (c *UnsafeCounter) Increment() {
    c.value++ // 竞态条件!
}

func (c *UnsafeCounter) Value() int {
    return c.value // 竞态条件!
}

func demonstrateRaceCondition() {
    counter := &UnsafeCounter{}
    var wg sync.WaitGroup

    // 启动多个 Goroutine 并发修改计数器
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for j := 0; j < 1000; j++ {
                counter.Increment()
            }
        }()
    }

    wg.Wait()
    fmt.Printf("Unsafe counter value: %d (expected: 10000)\n", counter.Value())
}

func main() {
    // 运行多次观察结果的不一致性
    for i := 0; i < 5; i++ {
        demonstrateRaceCondition()
    }
}

Mutex 的内部实现原理 #

状态表示 #

Go 的 Mutex 使用一个 int32 值来表示状态:

type Mutex struct {
    state int32
    sema  uint32
}

状态位的含义:

  • 第 0 位:锁定状态(1 表示已锁定)
  • 第 1 位:唤醒状态(1 表示有 Goroutine 被唤醒)
  • 第 2 位:饥饿模式(1 表示处于饥饿模式)
  • 其余位:等待者数量

公平性机制 #

package main

import (
    "fmt"
    "runtime"
    "sync"
    "time"
)

func demonstrateFairness() {
    var mu sync.Mutex
    var wg sync.WaitGroup

    // 记录每个 Goroutine 获取锁的次数
    counts := make([]int, 5)

    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                mu.Lock()
                counts[id]++

                // 模拟一些工作
                time.Sleep(time.Microsecond)

                mu.Unlock()

                // 让出 CPU,给其他 Goroutine 机会
                runtime.Gosched()
            }
        }(i)
    }

    wg.Wait()

    fmt.Println("Lock acquisition counts:")
    for i, count := range counts {
        fmt.Printf("Goroutine %d: %d times\n", i, count)
    }
}

func main() {
    demonstrateFairness()
}

Mutex 的高级用法 #

1. 条件性加锁 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type ConditionalMutex struct {
    mu        sync.Mutex
    condition bool
}

func (cm *ConditionalMutex) LockIf(condition bool) bool {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    if cm.condition == condition {
        // 满足条件,保持锁定状态
        return true
    }

    // 不满足条件,释放锁
    return false
}

func (cm *ConditionalMutex) SetCondition(condition bool) {
    cm.mu.Lock()
    defer cm.mu.Unlock()
    cm.condition = condition
}

func (cm *ConditionalMutex) Unlock() {
    cm.mu.Unlock()
}

func conditionalLockingExample() {
    cm := &ConditionalMutex{}
    var wg sync.WaitGroup

    // 设置条件
    go func() {
        time.Sleep(1 * time.Second)
        cm.SetCondition(true)
        fmt.Println("Condition set to true")
    }()

    // 尝试条件性加锁
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            if cm.LockIf(true) {
                fmt.Printf("Goroutine %d: Successfully acquired conditional lock\n", id)
                time.Sleep(500 * time.Millisecond)
                cm.Unlock()
            } else {
                fmt.Printf("Goroutine %d: Failed to acquire conditional lock\n", id)
            }
        }(i)
    }

    wg.Wait()
}

func main() {
    conditionalLockingExample()
}

2. 超时锁 #

虽然 Go 的标准 Mutex 不支持超时,但我们可以使用 Channel 实现:

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

type TimeoutMutex struct {
    ch chan struct{}
}

func NewTimeoutMutex() *TimeoutMutex {
    ch := make(chan struct{}, 1)
    ch <- struct{}{} // 初始状态为可用
    return &TimeoutMutex{ch: ch}
}

func (tm *TimeoutMutex) Lock() {
    <-tm.ch
}

func (tm *TimeoutMutex) TryLock() bool {
    select {
    case <-tm.ch:
        return true
    default:
        return false
    }
}

func (tm *TimeoutMutex) LockWithTimeout(timeout time.Duration) bool {
    select {
    case <-tm.ch:
        return true
    case <-time.After(timeout):
        return false
    }
}

func (tm *TimeoutMutex) LockWithContext(ctx context.Context) bool {
    select {
    case <-tm.ch:
        return true
    case <-ctx.Done():
        return false
    }
}

func (tm *TimeoutMutex) Unlock() {
    select {
    case tm.ch <- struct{}{}:
    default:
        panic("unlock of unlocked mutex")
    }
}

func timeoutMutexExample() {
    tm := NewTimeoutMutex()
    var wg sync.WaitGroup

    // 长时间持有锁的 Goroutine
    wg.Add(1)
    go func() {
        defer wg.Done()
        tm.Lock()
        fmt.Println("Long-running task started")
        time.Sleep(3 * time.Second)
        fmt.Println("Long-running task completed")
        tm.Unlock()
    }()

    time.Sleep(500 * time.Millisecond)

    // 尝试获取锁的 Goroutine
    wg.Add(1)
    go func() {
        defer wg.Done()

        // 尝试立即获取锁
        if tm.TryLock() {
            fmt.Println("Immediately acquired lock")
            tm.Unlock()
        } else {
            fmt.Println("Failed to acquire lock immediately")
        }

        // 尝试带超时获取锁
        if tm.LockWithTimeout(1 * time.Second) {
            fmt.Println("Acquired lock with timeout")
            tm.Unlock()
        } else {
            fmt.Println("Failed to acquire lock within timeout")
        }

        // 使用 context 获取锁
        ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
        defer cancel()

        if tm.LockWithContext(ctx) {
            fmt.Println("Acquired lock with context")
            tm.Unlock()
        } else {
            fmt.Println("Failed to acquire lock with context")
        }
    }()

    wg.Wait()
}

func main() {
    timeoutMutexExample()
}

3. 可重入锁 #

Go 的标准 Mutex 不是可重入的,但我们可以实现一个:

package main

import (
    "fmt"
    "runtime"
    "sync"
)

type ReentrantMutex struct {
    mu       sync.Mutex
    owner    int64  // Goroutine ID
    recursion int32 // 递归深度
}

func (rm *ReentrantMutex) Lock() {
    gid := getGoroutineID()

    rm.mu.Lock()
    defer rm.mu.Unlock()

    if rm.owner == gid {
        rm.recursion++
        return
    }

    // 等待直到可以获取锁
    for rm.owner != 0 {
        rm.mu.Unlock()
        runtime.Gosched()
        rm.mu.Lock()
    }

    rm.owner = gid
    rm.recursion = 1
}

func (rm *ReentrantMutex) Unlock() {
    gid := getGoroutineID()

    rm.mu.Lock()
    defer rm.mu.Unlock()

    if rm.owner != gid {
        panic("unlock of mutex not owned by current goroutine")
    }

    rm.recursion--
    if rm.recursion == 0 {
        rm.owner = 0
    }
}

// 获取当前 Goroutine ID(简化实现)
func getGoroutineID() int64 {
    var buf [64]byte
    n := runtime.Stack(buf[:], false)
    // 解析 stack trace 获取 goroutine ID
    // 这里使用简化的实现
    return int64(runtime.NumGoroutine()) // 简化实现,实际应该解析 stack
}

func reentrantMutexExample() {
    rm := &ReentrantMutex{}

    var recursiveFunction func(int)
    recursiveFunction = func(depth int) {
        rm.Lock()
        defer rm.Unlock()

        fmt.Printf("Depth: %d\n", depth)

        if depth > 0 {
            recursiveFunction(depth - 1)
        }
    }

    recursiveFunction(3)
}

func main() {
    reentrantMutexExample()
}

性能优化技巧 #

1. 减少锁的粒度 #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 粗粒度锁:整个结构体一个锁
type CoarseGrainedCache struct {
    mu   sync.Mutex
    data map[string]string
}

func (c *CoarseGrainedCache) Get(key string) (string, bool) {
    c.mu.Lock()
    defer c.mu.Unlock()
    value, ok := c.data[key]
    return value, ok
}

func (c *CoarseGrainedCache) Set(key, value string) {
    c.mu.Lock()
    defer c.mu.Unlock()
    if c.data == nil {
        c.data = make(map[string]string)
    }
    c.data[key] = value
}

// 细粒度锁:每个桶一个锁
type FineGrainedCache struct {
    buckets []bucket
    size    int
}

type bucket struct {
    mu   sync.Mutex
    data map[string]string
}

func NewFineGrainedCache(size int) *FineGrainedCache {
    buckets := make([]bucket, size)
    for i := range buckets {
        buckets[i].data = make(map[string]string)
    }
    return &FineGrainedCache{
        buckets: buckets,
        size:    size,
    }
}

func (c *FineGrainedCache) hash(key string) int {
    h := 0
    for _, b := range []byte(key) {
        h = h*31 + int(b)
    }
    return h % c.size
}

func (c *FineGrainedCache) Get(key string) (string, bool) {
    bucket := &c.buckets[c.hash(key)]
    bucket.mu.Lock()
    defer bucket.mu.Unlock()
    value, ok := bucket.data[key]
    return value, ok
}

func (c *FineGrainedCache) Set(key, value string) {
    bucket := &c.buckets[c.hash(key)]
    bucket.mu.Lock()
    defer bucket.mu.Unlock()
    bucket.data[key] = value
}

func benchmarkCache(name string, cache interface{}) {
    var wg sync.WaitGroup
    start := time.Now()

    // 根据缓存类型执行操作
    switch c := cache.(type) {
    case *CoarseGrainedCache:
        for i := 0; i < 100; i++ {
            wg.Add(1)
            go func(id int) {
                defer wg.Done()
                for j := 0; j < 100; j++ {
                    key := fmt.Sprintf("key-%d-%d", id, j)
                    c.Set(key, fmt.Sprintf("value-%d-%d", id, j))
                    c.Get(key)
                }
            }(i)
        }
    case *FineGrainedCache:
        for i := 0; i < 100; i++ {
            wg.Add(1)
            go func(id int) {
                defer wg.Done()
                for j := 0; j < 100; j++ {
                    key := fmt.Sprintf("key-%d-%d", id, j)
                    c.Set(key, fmt.Sprintf("value-%d-%d", id, j))
                    c.Get(key)
                }
            }(i)
        }
    }

    wg.Wait()
    duration := time.Since(start)
    fmt.Printf("%s took: %v\n", name, duration)
}

func lockGranularityExample() {
    coarseCache := &CoarseGrainedCache{}
    fineCache := NewFineGrainedCache(16)

    benchmarkCache("Coarse-grained cache", coarseCache)
    benchmarkCache("Fine-grained cache", fineCache)
}

func main() {
    lockGranularityExample()
}

2. 使用读写分离 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type Stats struct {
    mu     sync.Mutex
    counts map[string]int64
}

func (s *Stats) Increment(key string) {
    s.mu.Lock()
    defer s.mu.Unlock()

    if s.counts == nil {
        s.counts = make(map[string]int64)
    }
    s.counts[key]++
}

func (s *Stats) Get(key string) int64 {
    s.mu.Lock()
    defer s.mu.Unlock()
    return s.counts[key]
}

func (s *Stats) GetAll() map[string]int64 {
    s.mu.Lock()
    defer s.mu.Unlock()

    result := make(map[string]int64)
    for k, v := range s.counts {
        result[k] = v
    }
    return result
}

func statsExample() {
    stats := &Stats{}
    var wg sync.WaitGroup

    // 写操作
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 1000; j++ {
                stats.Increment(fmt.Sprintf("counter-%d", id%3))
            }
        }(i)
    }

    // 读操作
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for j := 0; j < 100; j++ {
                stats.Get(fmt.Sprintf("counter-%d", id%3))
                time.Sleep(time.Millisecond)
            }
        }(i)
    }

    wg.Wait()

    fmt.Println("Final stats:", stats.GetAll())
}

func main() {
    statsExample()
}

常见陷阱和最佳实践 #

1. 避免死锁 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type Account struct {
    mu      sync.Mutex
    id      int
    balance int
}

// 错误的转账实现:可能导致死锁
func badTransfer(from, to *Account, amount int) {
    from.mu.Lock()
    defer from.mu.Unlock()

    to.mu.Lock() // 可能死锁
    defer to.mu.Unlock()

    from.balance -= amount
    to.balance += amount
}

// 正确的转账实现:按 ID 排序获取锁
func goodTransfer(from, to *Account, amount int) {
    if from.id == to.id {
        return // 同一账户,无需转账
    }

    // 按 ID 排序,确保锁的获取顺序一致
    first, second := from, to
    if from.id > to.id {
        first, second = to, from
    }

    first.mu.Lock()
    defer first.mu.Unlock()

    second.mu.Lock()
    defer second.mu.Unlock()

    from.balance -= amount
    to.balance += amount
}

func deadlockExample() {
    account1 := &Account{id: 1, balance: 1000}
    account2 := &Account{id: 2, balance: 1000}

    var wg sync.WaitGroup

    // 使用正确的转账方法
    for i := 0; i < 10; i++ {
        wg.Add(2)

        go func() {
            defer wg.Done()
            goodTransfer(account1, account2, 10)
        }()

        go func() {
            defer wg.Done()
            goodTransfer(account2, account1, 10)
        }()
    }

    wg.Wait()

    fmt.Printf("Account 1 balance: %d\n", account1.balance)
    fmt.Printf("Account 2 balance: %d\n", account2.balance)
}

func main() {
    deadlockExample()
}

2. 正确使用 defer #

package main

import (
    "fmt"
    "sync"
)

type SafeMap struct {
    mu   sync.Mutex
    data map[string]int
}

func (sm *SafeMap) Get(key string) (int, bool) {
    sm.mu.Lock()
    defer sm.mu.Unlock() // 确保解锁

    if sm.data == nil {
        return 0, false
    }

    value, ok := sm.data[key]
    return value, ok
}

func (sm *SafeMap) Set(key string, value int) {
    sm.mu.Lock()
    defer sm.mu.Unlock() // 确保解锁

    if sm.data == nil {
        sm.data = make(map[string]int)
    }

    sm.data[key] = value
}

// 错误示例:忘记解锁
func (sm *SafeMap) BadMethod(key string) int {
    sm.mu.Lock()

    if sm.data == nil {
        return 0 // 忘记解锁!
    }

    value := sm.data[key]
    sm.mu.Unlock()
    return value
}

func deferExample() {
    sm := &SafeMap{}
    var wg sync.WaitGroup

    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            key := fmt.Sprintf("key-%d", id)
            sm.Set(key, id*10)

            if value, ok := sm.Get(key); ok {
                fmt.Printf("Key: %s, Value: %d\n", key, value)
            }
        }(i)
    }

    wg.Wait()
}

func main() {
    deferExample()
}

3. 避免长时间持有锁 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type DataProcessor struct {
    mu   sync.Mutex
    data []int
}

// 错误示例:长时间持有锁
func (dp *DataProcessor) BadProcess() {
    dp.mu.Lock()
    defer dp.mu.Unlock()

    // 长时间的处理逻辑
    for i := range dp.data {
        time.Sleep(time.Millisecond) // 模拟耗时操作
        dp.data[i] *= 2
    }
}

// 正确示例:最小化锁持有时间
func (dp *DataProcessor) GoodProcess() {
    // 先复制数据
    dp.mu.Lock()
    dataCopy := make([]int, len(dp.data))
    copy(dataCopy, dp.data)
    dp.mu.Unlock()

    // 在锁外进行处理
    for i := range dataCopy {
        time.Sleep(time.Millisecond) // 模拟耗时操作
        dataCopy[i] *= 2
    }

    // 写回结果
    dp.mu.Lock()
    dp.data = dataCopy
    dp.mu.Unlock()
}

func lockHoldingExample() {
    dp := &DataProcessor{
        data: []int{1, 2, 3, 4, 5},
    }

    start := time.Now()
    dp.GoodProcess()
    fmt.Printf("Good process took: %v\n", time.Since(start))

    fmt.Printf("Processed data: %v\n", dp.data)
}

func main() {
    lockHoldingExample()
}

小结 #

在本节中,我们深入学习了:

  1. Mutex 基础:基本用法、内部实现和公平性机制
  2. 高级用法:条件性加锁、超时锁、可重入锁
  3. 性能优化:减少锁粒度、读写分离
  4. 最佳实践:避免死锁、正确使用 defer、最小化锁持有时间

Mutex 是并发编程中最基础的同步原语,正确使用 Mutex 对于编写安全、高效的并发程序至关重要。在下一节中,我们将学习读写锁 RWMutex,它在读多写少的场景下提供了更好的性能。

练习题 #

  1. 实现一个线程安全的 LRU 缓存,使用 Mutex 保护内部数据结构
  2. 设计一个支持超时的互斥锁,并实现相应的测试用例
  3. 创建一个银行转账系统,确保在高并发情况下不会出现死锁和数据不一致