2.3.2 RWMutex 读写锁

2.3.2 RWMutex 读写锁 #

RWMutex 概述 #

RWMutex(读写锁)是一种特殊的锁,它允许多个读者同时访问共享资源,但写者需要独占访问。这种设计在读多写少的场景下能显著提高性能,因为读操作之间不会相互阻塞。

RWMutex 的特点 #

  • 读者优先:多个读者可以同时持有读锁
  • 写者独占:写者需要独占访问,与所有读者和其他写者互斥
  • 读写互斥:读锁和写锁之间是互斥的
  • 写写互斥:多个写者之间也是互斥的
  • 零值可用:RWMutex 的零值是一个有效的、未锁定的读写锁

适用场景 #

  • 读多写少:大量读操作,少量写操作
  • 配置管理:频繁读取配置,偶尔更新配置
  • 缓存系统:频繁查询缓存,偶尔更新缓存
  • 统计数据:频繁读取统计信息,定期更新统计

RWMutex 的基本使用 #

基本语法 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type SafeMap struct {
    mu   sync.RWMutex
    data map[string]int
}

func NewSafeMap() *SafeMap {
    return &SafeMap{
        data: make(map[string]int),
    }
}

// 读操作:使用读锁
func (sm *SafeMap) Get(key string) (int, bool) {
    sm.mu.RLock()
    defer sm.mu.RUnlock()

    value, ok := sm.data[key]
    return value, ok
}

// 读操作:获取所有键
func (sm *SafeMap) Keys() []string {
    sm.mu.RLock()
    defer sm.mu.RUnlock()

    keys := make([]string, 0, len(sm.data))
    for k := range sm.data {
        keys = append(keys, k)
    }
    return keys
}

// 写操作:使用写锁
func (sm *SafeMap) Set(key string, value int) {
    sm.mu.Lock()
    defer sm.mu.Unlock()

    sm.data[key] = value
}

// 写操作:删除键
func (sm *SafeMap) Delete(key string) {
    sm.mu.Lock()
    defer sm.mu.Unlock()

    delete(sm.data, key)
}

// 读操作:获取长度
func (sm *SafeMap) Len() int {
    sm.mu.RLock()
    defer sm.mu.RUnlock()

    return len(sm.data)
}

func basicRWMutexExample() {
    safeMap := NewSafeMap()
    var wg sync.WaitGroup

    // 写入一些初始数据
    for i := 0; i < 10; i++ {
        safeMap.Set(fmt.Sprintf("key-%d", i), i*10)
    }

    // 启动多个读者
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                key := fmt.Sprintf("key-%d", j%10)
                if value, ok := safeMap.Get(key); ok {
                    fmt.Printf("Reader %d: %s = %d\n", id, key, value)
                }
                time.Sleep(time.Millisecond)
            }
        }(i)
    }

    // 启动少量写者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 20; j++ {
                key := fmt.Sprintf("key-%d", j%10)
                safeMap.Set(key, (id+1)*100+j)
                fmt.Printf("Writer %d: Set %s = %d\n", id, key, (id+1)*100+j)
                time.Sleep(10 * time.Millisecond)
            }
        }(i)
    }

    wg.Wait()
    fmt.Printf("Final map length: %d\n", safeMap.Len())
}

func main() {
    basicRWMutexExample()
}

RWMutex vs Mutex 性能对比 #

性能测试 #

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

type MutexCounter struct {
    mu    sync.Mutex
    value int64
}

func (c *MutexCounter) Read() int64 {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.value
}

func (c *MutexCounter) Write(delta int64) {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value += delta
}

type RWMutexCounter struct {
    mu    sync.RWMutex
    value int64
}

func (c *RWMutexCounter) Read() int64 {
    c.mu.RLock()
    defer c.mu.RUnlock()
    return c.value
}

func (c *RWMutexCounter) Write(delta int64) {
    c.mu.Lock()
    defer c.mu.Unlock()
    c.value += delta
}

func benchmarkMutex(readRatio float64, numGoroutines int, duration time.Duration) time.Duration {
    counter := &MutexCounter{}
    var wg sync.WaitGroup

    start := time.Now()
    stop := make(chan struct{})

    // 启动测试 Goroutine
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()

            for {
                select {
                case <-stop:
                    return
                default:
                    if rand.Float64() < readRatio {
                        counter.Read()
                    } else {
                        counter.Write(1)
                    }
                }
            }
        }()
    }

    // 运行指定时间
    time.Sleep(duration)
    close(stop)
    wg.Wait()

    return time.Since(start)
}

func benchmarkRWMutex(readRatio float64, numGoroutines int, duration time.Duration) time.Duration {
    counter := &RWMutexCounter{}
    var wg sync.WaitGroup

    start := time.Now()
    stop := make(chan struct{})

    // 启动测试 Goroutine
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()

            for {
                select {
                case <-stop:
                    return
                default:
                    if rand.Float64() < readRatio {
                        counter.Read()
                    } else {
                        counter.Write(1)
                    }
                }
            }
        }()
    }

    // 运行指定时间
    time.Sleep(duration)
    close(stop)
    wg.Wait()

    return time.Since(start)
}

func performanceComparison() {
    const duration = 2 * time.Second
    const numGoroutines = 10

    readRatios := []float64{0.5, 0.7, 0.8, 0.9, 0.95, 0.99}

    fmt.Println("Performance Comparison: Mutex vs RWMutex")
    fmt.Println("Read Ratio | Mutex Time | RWMutex Time | Improvement")
    fmt.Println("-----------|------------|--------------|------------")

    for _, ratio := range readRatios {
        mutexTime := benchmarkMutex(ratio, numGoroutines, duration)
        rwMutexTime := benchmarkRWMutex(ratio, numGoroutines, duration)

        improvement := float64(mutexTime) / float64(rwMutexTime)

        fmt.Printf("%.2f       | %-10v | %-12v | %.2fx\n",
            ratio, mutexTime, rwMutexTime, improvement)
    }
}

func main() {
    rand.Seed(time.Now().UnixNano())
    performanceComparison()
}

RWMutex 的高级用法 #

1. 配置管理器 #

package main

import (
    "encoding/json"
    "fmt"
    "sync"
    "time"
)

type Config struct {
    Database struct {
        Host     string `json:"host"`
        Port     int    `json:"port"`
        Username string `json:"username"`
        Password string `json:"password"`
    } `json:"database"`

    Server struct {
        Port    int  `json:"port"`
        Debug   bool `json:"debug"`
        Timeout int  `json:"timeout"`
    } `json:"server"`

    Features map[string]bool `json:"features"`
}

type ConfigManager struct {
    mu     sync.RWMutex
    config *Config
}

func NewConfigManager() *ConfigManager {
    return &ConfigManager{
        config: &Config{},
    }
}

// 读操作:获取数据库配置
func (cm *ConfigManager) GetDatabaseConfig() (string, int, string, string) {
    cm.mu.RLock()
    defer cm.mu.RUnlock()

    return cm.config.Database.Host,
           cm.config.Database.Port,
           cm.config.Database.Username,
           cm.config.Database.Password
}

// 读操作:获取服务器配置
func (cm *ConfigManager) GetServerConfig() (int, bool, int) {
    cm.mu.RLock()
    defer cm.mu.RUnlock()

    return cm.config.Server.Port,
           cm.config.Server.Debug,
           cm.config.Server.Timeout
}

// 读操作:检查功能是否启用
func (cm *ConfigManager) IsFeatureEnabled(feature string) bool {
    cm.mu.RLock()
    defer cm.mu.RUnlock()

    if cm.config.Features == nil {
        return false
    }

    return cm.config.Features[feature]
}

// 读操作:获取完整配置的副本
func (cm *ConfigManager) GetConfig() Config {
    cm.mu.RLock()
    defer cm.mu.RUnlock()

    // 深拷贝配置
    configCopy := *cm.config
    if cm.config.Features != nil {
        configCopy.Features = make(map[string]bool)
        for k, v := range cm.config.Features {
            configCopy.Features[k] = v
        }
    }

    return configCopy
}

// 写操作:更新配置
func (cm *ConfigManager) UpdateConfig(newConfig *Config) {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    cm.config = newConfig
    fmt.Println("Configuration updated")
}

// 写操作:更新单个功能
func (cm *ConfigManager) SetFeature(feature string, enabled bool) {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    if cm.config.Features == nil {
        cm.config.Features = make(map[string]bool)
    }

    cm.config.Features[feature] = enabled
    fmt.Printf("Feature %s set to %v\n", feature, enabled)
}

// 写操作:从 JSON 加载配置
func (cm *ConfigManager) LoadFromJSON(jsonData []byte) error {
    var newConfig Config
    if err := json.Unmarshal(jsonData, &newConfig); err != nil {
        return err
    }

    cm.mu.Lock()
    defer cm.mu.Unlock()

    cm.config = &newConfig
    fmt.Println("Configuration loaded from JSON")
    return nil
}

func configManagerExample() {
    cm := NewConfigManager()
    var wg sync.WaitGroup

    // 初始化配置
    initialConfig := &Config{}
    initialConfig.Database.Host = "localhost"
    initialConfig.Database.Port = 5432
    initialConfig.Database.Username = "user"
    initialConfig.Database.Password = "password"
    initialConfig.Server.Port = 8080
    initialConfig.Server.Debug = true
    initialConfig.Server.Timeout = 30
    initialConfig.Features = map[string]bool{
        "feature_a": true,
        "feature_b": false,
        "feature_c": true,
    }

    cm.UpdateConfig(initialConfig)

    // 启动多个读者
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 50; j++ {
                // 读取各种配置
                host, port, _, _ := cm.GetDatabaseConfig()
                fmt.Printf("Reader %d: DB %s:%d\n", id, host, port)

                if cm.IsFeatureEnabled("feature_a") {
                    fmt.Printf("Reader %d: Feature A is enabled\n", id)
                }

                time.Sleep(10 * time.Millisecond)
            }
        }(i)
    }

    // 启动配置更新者
    wg.Add(1)
    go func() {
        defer wg.Done()

        for i := 0; i < 5; i++ {
            time.Sleep(200 * time.Millisecond)

            // 切换功能状态
            cm.SetFeature("feature_b", i%2 == 0)

            // 更新服务器端口
            config := cm.GetConfig()
            config.Server.Port = 8080 + i
            cm.UpdateConfig(&config)
        }
    }()

    wg.Wait()
}

func main() {
    configManagerExample()
}

2. 缓存系统 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type CacheItem struct {
    Value      interface{}
    Expiration time.Time
}

func (item *CacheItem) IsExpired() bool {
    return time.Now().After(item.Expiration)
}

type Cache struct {
    mu    sync.RWMutex
    items map[string]*CacheItem
}

func NewCache() *Cache {
    cache := &Cache{
        items: make(map[string]*CacheItem),
    }

    // 启动清理 Goroutine
    go cache.cleanup()

    return cache
}

// 读操作:获取缓存项
func (c *Cache) Get(key string) (interface{}, bool) {
    c.mu.RLock()
    defer c.mu.RUnlock()

    item, exists := c.items[key]
    if !exists {
        return nil, false
    }

    if item.IsExpired() {
        return nil, false
    }

    return item.Value, true
}

// 读操作:检查键是否存在
func (c *Cache) Exists(key string) bool {
    c.mu.RLock()
    defer c.mu.RUnlock()

    item, exists := c.items[key]
    if !exists {
        return false
    }

    return !item.IsExpired()
}

// 读操作:获取所有有效的键
func (c *Cache) Keys() []string {
    c.mu.RLock()
    defer c.mu.RUnlock()

    var keys []string
    now := time.Now()

    for key, item := range c.items {
        if now.Before(item.Expiration) {
            keys = append(keys, key)
        }
    }

    return keys
}

// 读操作:获取缓存统计信息
func (c *Cache) Stats() (int, int) {
    c.mu.RLock()
    defer c.mu.RUnlock()

    total := len(c.items)
    expired := 0
    now := time.Now()

    for _, item := range c.items {
        if now.After(item.Expiration) {
            expired++
        }
    }

    return total, expired
}

// 写操作:设置缓存项
func (c *Cache) Set(key string, value interface{}, ttl time.Duration) {
    c.mu.Lock()
    defer c.mu.Unlock()

    c.items[key] = &CacheItem{
        Value:      value,
        Expiration: time.Now().Add(ttl),
    }
}

// 写操作:删除缓存项
func (c *Cache) Delete(key string) {
    c.mu.Lock()
    defer c.mu.Unlock()

    delete(c.items, key)
}

// 写操作:清理过期项
func (c *Cache) cleanup() {
    ticker := time.NewTicker(1 * time.Minute)
    defer ticker.Stop()

    for range ticker.C {
        c.mu.Lock()
        now := time.Now()

        for key, item := range c.items {
            if now.After(item.Expiration) {
                delete(c.items, key)
            }
        }

        c.mu.Unlock()
    }
}

// 写操作:清空所有缓存
func (c *Cache) Clear() {
    c.mu.Lock()
    defer c.mu.Unlock()

    c.items = make(map[string]*CacheItem)
}

func cacheExample() {
    cache := NewCache()
    var wg sync.WaitGroup

    // 启动多个读者
    for i := 0; i < 8; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                key := fmt.Sprintf("key-%d", j%20)

                if value, ok := cache.Get(key); ok {
                    fmt.Printf("Reader %d: Got %s = %v\n", id, key, value)
                } else {
                    fmt.Printf("Reader %d: Miss %s\n", id, key)
                }

                time.Sleep(10 * time.Millisecond)
            }
        }(i)
    }

    // 启动写者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 50; j++ {
                key := fmt.Sprintf("key-%d", j%20)
                value := fmt.Sprintf("value-%d-%d", id, j)

                cache.Set(key, value, 5*time.Second)
                fmt.Printf("Writer %d: Set %s = %s\n", id, key, value)

                time.Sleep(50 * time.Millisecond)
            }
        }(i)
    }

    // 定期打印统计信息
    wg.Add(1)
    go func() {
        defer wg.Done()

        for i := 0; i < 10; i++ {
            time.Sleep(500 * time.Millisecond)
            total, expired := cache.Stats()
            fmt.Printf("Stats: Total=%d, Expired=%d\n", total, expired)
        }
    }()

    wg.Wait()
}

func main() {
    cacheExample()
}

3. 统计计数器 #

package main

import (
    "fmt"
    "sync"
    "time"
)

type StatsCounter struct {
    mu     sync.RWMutex
    counts map[string]int64
    totals map[string]int64
}

func NewStatsCounter() *StatsCounter {
    return &StatsCounter{
        counts: make(map[string]int64),
        totals: make(map[string]int64),
    }
}

// 写操作:增加计数
func (sc *StatsCounter) Increment(key string, delta int64) {
    sc.mu.Lock()
    defer sc.mu.Unlock()

    sc.counts[key] += delta
    sc.totals[key] += delta
}

// 读操作:获取当前计数
func (sc *StatsCounter) Get(key string) int64 {
    sc.mu.RLock()
    defer sc.mu.RUnlock()

    return sc.counts[key]
}

// 读操作:获取总计数
func (sc *StatsCounter) GetTotal(key string) int64 {
    sc.mu.RLock()
    defer sc.mu.RUnlock()

    return sc.totals[key]
}

// 读操作:获取所有当前计数
func (sc *StatsCounter) GetAll() map[string]int64 {
    sc.mu.RLock()
    defer sc.mu.RUnlock()

    result := make(map[string]int64)
    for k, v := range sc.counts {
        result[k] = v
    }
    return result
}

// 读操作:获取所有总计数
func (sc *StatsCounter) GetAllTotals() map[string]int64 {
    sc.mu.RLock()
    defer sc.mu.RUnlock()

    result := make(map[string]int64)
    for k, v := range sc.totals {
        result[k] = v
    }
    return result
}

// 写操作:重置当前计数(保留总计数)
func (sc *StatsCounter) Reset() {
    sc.mu.Lock()
    defer sc.mu.Unlock()

    sc.counts = make(map[string]int64)
}

// 写操作:重置指定键的当前计数
func (sc *StatsCounter) ResetKey(key string) {
    sc.mu.Lock()
    defer sc.mu.Unlock()

    sc.counts[key] = 0
}

// 读操作:计算速率(每秒)
func (sc *StatsCounter) GetRate(key string, duration time.Duration) float64 {
    sc.mu.RLock()
    defer sc.mu.RUnlock()

    count := sc.counts[key]
    seconds := duration.Seconds()

    if seconds == 0 {
        return 0
    }

    return float64(count) / seconds
}

func statsCounterExample() {
    counter := NewStatsCounter()
    var wg sync.WaitGroup

    // 启动多个统计读者
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 50; j++ {
                // 读取各种统计信息
                requests := counter.Get("requests")
                errors := counter.Get("errors")

                fmt.Printf("Reader %d: Requests=%d, Errors=%d\n", id, requests, errors)

                // 计算错误率
                if requests > 0 {
                    errorRate := float64(errors) / float64(requests) * 100
                    fmt.Printf("Reader %d: Error rate=%.2f%%\n", id, errorRate)
                }

                time.Sleep(20 * time.Millisecond)
            }
        }(i)
    }

    // 启动事件生成器
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                // 模拟请求
                counter.Increment("requests", 1)

                // 模拟偶尔的错误
                if j%10 == 0 {
                    counter.Increment("errors", 1)
                }

                // 模拟不同类型的事件
                eventType := fmt.Sprintf("event_type_%d", j%3)
                counter.Increment(eventType, 1)

                time.Sleep(10 * time.Millisecond)
            }
        }(i)
    }

    // 定期报告统计信息
    wg.Add(1)
    go func() {
        defer wg.Done()

        for i := 0; i < 10; i++ {
            time.Sleep(1 * time.Second)

            fmt.Println("\n=== Stats Report ===")
            all := counter.GetAll()
            totals := counter.GetAllTotals()

            for key, current := range all {
                total := totals[key]
                fmt.Printf("%s: Current=%d, Total=%d\n", key, current, total)
            }

            // 重置当前计数
            if i%3 == 2 {
                counter.Reset()
                fmt.Println("Current counters reset")
            }
        }
    }()

    wg.Wait()
}

func main() {
    statsCounterExample()
}

RWMutex 的内部实现 #

实现原理 #

RWMutex 内部使用了多个字段来管理状态:

type RWMutex struct {
    w           Mutex  // 写者互斥锁
    writerSem   uint32 // 写者信号量
    readerSem   uint32 // 读者信号量
    readerCount int32  // 读者计数
    readerWait  int32  // 等待的读者数量
}

锁的获取过程 #

package main

import (
    "fmt"
    "runtime"
    "sync"
    "sync/atomic"
    "time"
)

// 模拟 RWMutex 的内部状态监控
type RWMutexMonitor struct {
    mu          sync.RWMutex
    readerCount int32
    writerCount int32
}

func (m *RWMutexMonitor) RLock() {
    atomic.AddInt32(&m.readerCount, 1)
    m.mu.RLock()
    fmt.Printf("Reader acquired lock (readers: %d, writers: %d)\n",
        atomic.LoadInt32(&m.readerCount), atomic.LoadInt32(&m.writerCount))
}

func (m *RWMutexMonitor) RUnlock() {
    m.mu.RUnlock()
    atomic.AddInt32(&m.readerCount, -1)
    fmt.Printf("Reader released lock (readers: %d, writers: %d)\n",
        atomic.LoadInt32(&m.readerCount), atomic.LoadInt32(&m.writerCount))
}

func (m *RWMutexMonitor) Lock() {
    atomic.AddInt32(&m.writerCount, 1)
    fmt.Printf("Writer waiting for lock (readers: %d, writers: %d)\n",
        atomic.LoadInt32(&m.readerCount), atomic.LoadInt32(&m.writerCount))

    m.mu.Lock()
    fmt.Printf("Writer acquired lock (readers: %d, writers: %d)\n",
        atomic.LoadInt32(&m.readerCount), atomic.LoadInt32(&m.writerCount))
}

func (m *RWMutexMonitor) Unlock() {
    m.mu.Unlock()
    atomic.AddInt32(&m.writerCount, -1)
    fmt.Printf("Writer released lock (readers: %d, writers: %d)\n",
        atomic.LoadInt32(&m.readerCount), atomic.LoadInt32(&m.writerCount))
}

func rwMutexInternalsExample() {
    monitor := &RWMutexMonitor{}
    var wg sync.WaitGroup

    // 启动多个读者
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            monitor.RLock()
            fmt.Printf("Reader %d working...\n", id)
            time.Sleep(2 * time.Second)
            monitor.RUnlock()
        }(i)
    }

    time.Sleep(500 * time.Millisecond)

    // 启动写者
    wg.Add(1)
    go func() {
        defer wg.Done()

        monitor.Lock()
        fmt.Println("Writer working...")
        time.Sleep(1 * time.Second)
        monitor.Unlock()
    }()

    time.Sleep(500 * time.Millisecond)

    // 再启动一个读者
    wg.Add(1)
    go func() {
        defer wg.Done()

        monitor.RLock()
        fmt.Println("Late reader working...")
        time.Sleep(1 * time.Second)
        monitor.RUnlock()
    }()

    wg.Wait()
}

func main() {
    rwMutexInternalsExample()
}

常见陷阱和最佳实践 #

1. 避免写者饥饿 #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 演示写者饥饿问题
func writerStarvationDemo() {
    var mu sync.RWMutex
    var wg sync.WaitGroup

    // 启动大量持续的读者
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                mu.RLock()
                fmt.Printf("Reader %d reading...\n", id)
                time.Sleep(10 * time.Millisecond)
                mu.RUnlock()

                // 很短的间隔,几乎连续读取
                time.Sleep(time.Millisecond)
            }
        }(i)
    }

    time.Sleep(100 * time.Millisecond)

    // 启动写者
    wg.Add(1)
    go func() {
        defer wg.Done()

        start := time.Now()
        mu.Lock()
        duration := time.Since(start)

        fmt.Printf("Writer acquired lock after waiting %v\n", duration)
        fmt.Println("Writer writing...")
        time.Sleep(100 * time.Millisecond)
        mu.Unlock()
        fmt.Println("Writer finished")
    }()

    wg.Wait()
}

// 缓解写者饥饿的策略
type FairRWMutex struct {
    mu          sync.RWMutex
    writerQueue chan struct{}
    writerCount int32
}

func NewFairRWMutex() *FairRWMutex {
    return &FairRWMutex{
        writerQueue: make(chan struct{}, 1),
    }
}

func (frw *FairRWMutex) RLock() {
    // 检查是否有写者在等待
    select {
    case <-frw.writerQueue:
        // 有写者等待,让写者先执行
        frw.writerQueue <- struct{}{}
        time.Sleep(time.Millisecond) // 短暂让出
    default:
    }

    frw.mu.RLock()
}

func (frw *FairRWMutex) RUnlock() {
    frw.mu.RUnlock()
}

func (frw *FairRWMutex) Lock() {
    // 标记有写者等待
    select {
    case frw.writerQueue <- struct{}{}:
    default:
    }

    frw.mu.Lock()
}

func (frw *FairRWMutex) Unlock() {
    frw.mu.Unlock()

    // 清除写者等待标记
    select {
    case <-frw.writerQueue:
    default:
    }
}

func main() {
    fmt.Println("Demonstrating writer starvation:")
    writerStarvationDemo()
}

2. 正确的锁升级 #

package main

import (
    "fmt"
    "sync"
)

type UpgradeableRWMutex struct {
    mu sync.RWMutex
}

// 错误的锁升级方式
func (u *UpgradeableRWMutex) BadUpgrade() {
    u.mu.RLock()
    defer u.mu.RUnlock()

    // 这里尝试获取写锁会导致死锁
    // u.mu.Lock() // 死锁!

    fmt.Println("This would deadlock if we tried to upgrade")
}

// 正确的锁升级方式
func (u *UpgradeableRWMutex) GoodUpgrade() bool {
    u.mu.RLock()

    // 检查是否需要升级
    needUpgrade := true // 某种条件

    if needUpgrade {
        u.mu.RUnlock() // 先释放读锁
        u.mu.Lock()    // 再获取写锁
        defer u.mu.Unlock()

        // 重新检查条件,因为在释放读锁和获取写锁之间
        // 其他 Goroutine 可能已经修改了状态
        if true { // 重新检查条件
            fmt.Println("Successfully upgraded to write lock")
            return true
        }

        // 如果不需要写操作,降级为读锁
        u.mu.Unlock()
        u.mu.RLock()
    }

    defer u.mu.RUnlock()
    fmt.Println("Using read lock")
    return false
}

func lockUpgradeExample() {
    mutex := &UpgradeableRWMutex{}

    mutex.BadUpgrade()
    mutex.GoodUpgrade()
}

func main() {
    lockUpgradeExample()
}

3. 避免递归锁定 #

package main

import (
    "fmt"
    "sync"
)

type RecursiveRWMutex struct {
    mu sync.RWMutex
}

// 错误:递归读锁定
func (r *RecursiveRWMutex) BadRecursiveRead() {
    r.mu.RLock()
    defer r.mu.RUnlock()

    fmt.Println("First level read")

    // 这在同一个 Goroutine 中是安全的(读锁可以递归)
    r.mu.RLock()
    defer r.mu.RUnlock()

    fmt.Println("Second level read - this is OK")
}

// 错误:在持有读锁时尝试获取写锁
func (r *RecursiveRWMutex) BadReadToWrite() {
    r.mu.RLock()
    defer r.mu.RUnlock()

    fmt.Println("Holding read lock")

    // 这会导致死锁
    // r.mu.Lock() // 死锁!

    fmt.Println("This would deadlock")
}

// 正确:避免锁的嵌套
func (r *RecursiveRWMutex) GoodSeparateOperations() {
    // 读操作
    r.mu.RLock()
    fmt.Println("Reading data")
    data := "some data"
    r.mu.RUnlock()

    // 处理数据(不持有锁)
    processedData := "processed " + data

    // 写操作
    r.mu.Lock()
    fmt.Printf("Writing: %s\n", processedData)
    r.mu.Unlock()
}

func recursiveLockExample() {
    mutex := &RecursiveRWMutex{}

    mutex.BadRecursiveRead()
    mutex.BadReadToWrite()
    mutex.GoodSeparateOperations()
}

func main() {
    recursiveLockExample()
}

小结 #

在本节中,我们深入学习了:

  1. RWMutex 基础:读写锁的特点和基本用法
  2. 性能优势:在读多写少场景下的性能提升
  3. 实际应用:配置管理、缓存系统、统计计数器
  4. 内部实现:RWMutex 的工作原理和状态管理
  5. 最佳实践:避免写者饥饿、正确的锁升级、避免递归锁定

RWMutex 是处理读多写少场景的理想选择,但需要注意其使用陷阱。在下一节中,我们将学习 WaitGroup 和 Once,它们提供了不同的同步机制。

练习题 #

  1. 实现一个支持过期时间的读写安全缓存系统
  2. 设计一个配置热加载系统,支持配置文件的动态更新
  3. 创建一个多级缓存系统,使用 RWMutex 保护不同级别的缓存数据