4.4.4 共享内存与信号量

4.4.4 共享内存与信号量 #

共享内存和信号量是进程间通信和同步的重要机制。共享内存提供了高效的数据共享方式,而信号量则用于控制对共享资源的访问。本节将详细介绍如何在 Go 语言中实现和使用这些机制。

共享内存基础 #

共享内存概念 #

共享内存是最快的进程间通信方式,它允许多个进程访问同一块物理内存区域。主要特点:

  1. 高效性:直接内存访问,无需数据拷贝
  2. 同步需求:需要额外的同步机制防止竞态条件
  3. 持久性:进程结束后共享内存可以继续存在
  4. 大小限制:受系统限制影响

System V 共享内存 #

使用 System V IPC 机制创建共享内存:

package main

import (
    "fmt"
    "syscall"
    "unsafe"
    "log"
    "time"
)

const (
    IPC_CREAT = 01000
    IPC_EXCL  = 02000
    IPC_RMID  = 0
    IPC_STAT  = 2
)

// 共享内存管理器
type SharedMemory struct {
    key  int
    id   int
    size int
    addr unsafe.Pointer
}

// 创建共享内存
func NewSharedMemory(key, size int) (*SharedMemory, error) {
    // 创建共享内存段
    id, _, errno := syscall.Syscall(syscall.SYS_SHMGET,
        uintptr(key),
        uintptr(size),
        uintptr(IPC_CREAT|0666))

    if errno != 0 {
        return nil, fmt.Errorf("创建共享内存失败: %v", errno)
    }

    return &SharedMemory{
        key:  key,
        id:   int(id),
        size: size,
    }, nil
}

// 获取已存在的共享内存
func GetSharedMemory(key int) (*SharedMemory, error) {
    // 获取共享内存段
    id, _, errno := syscall.Syscall(syscall.SYS_SHMGET,
        uintptr(key), 0, 0)

    if errno != 0 {
        return nil, fmt.Errorf("获取共享内存失败: %v", errno)
    }

    return &SharedMemory{
        key: key,
        id:  int(id),
    }, nil
}

// 附加到共享内存
func (sm *SharedMemory) Attach() error {
    addr, _, errno := syscall.Syscall(syscall.SYS_SHMAT,
        uintptr(sm.id), 0, 0)

    if errno != 0 {
        return fmt.Errorf("附加共享内存失败: %v", errno)
    }

    sm.addr = unsafe.Pointer(addr)
    return nil
}

// 分离共享内存
func (sm *SharedMemory) Detach() error {
    if sm.addr == nil {
        return nil
    }

    _, _, errno := syscall.Syscall(syscall.SYS_SHMDT,
        uintptr(sm.addr), 0, 0)

    if errno != 0 {
        return fmt.Errorf("分离共享内存失败: %v", errno)
    }

    sm.addr = nil
    return nil
}

// 删除共享内存
func (sm *SharedMemory) Remove() error {
    _, _, errno := syscall.Syscall(syscall.SYS_SHMCTL,
        uintptr(sm.id), IPC_RMID, 0)

    if errno != 0 {
        return fmt.Errorf("删除共享内存失败: %v", errno)
    }

    return nil
}

// 写入数据
func (sm *SharedMemory) Write(data []byte) error {
    if sm.addr == nil {
        return fmt.Errorf("共享内存未附加")
    }

    if len(data) > sm.size {
        return fmt.Errorf("数据大小超过共享内存大小")
    }

    // 将数据复制到共享内存
    dest := (*[1 << 30]byte)(sm.addr)[:sm.size:sm.size]
    copy(dest, data)

    return nil
}

// 读取数据
func (sm *SharedMemory) Read(size int) ([]byte, error) {
    if sm.addr == nil {
        return nil, fmt.Errorf("共享内存未附加")
    }

    if size > sm.size {
        size = sm.size
    }

    // 从共享内存读取数据
    src := (*[1 << 30]byte)(sm.addr)[:size:size]
    data := make([]byte, size)
    copy(data, src)

    return data, nil
}

// 共享内存示例
func sharedMemoryExample() {
    key := 12345
    size := 1024

    // 创建共享内存
    shm, err := NewSharedMemory(key, size)
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("共享内存创建成功,ID: %d\n", shm.id)

    // 附加到共享内存
    err = shm.Attach()
    if err != nil {
        log.Fatal(err)
    }
    defer shm.Detach()
    defer shm.Remove()

    // 写入数据
    message := "Hello, Shared Memory!"
    err = shm.Write([]byte(message))
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("写入数据: %s\n", message)

    // 读取数据
    data, err := shm.Read(len(message))
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("读取数据: %s\n", string(data))
}

func main() {
    fmt.Println("=== 共享内存示例 ===")
    sharedMemoryExample()
}

POSIX 共享内存 #

使用 POSIX 共享内存机制:

package main

import (
    "fmt"
    "syscall"
    "unsafe"
    "log"
    "os"
)

// POSIX 共享内存管理器
type POSIXSharedMemory struct {
    name string
    fd   int
    size int
    addr unsafe.Pointer
}

// 创建 POSIX 共享内存
func NewPOSIXSharedMemory(name string, size int) (*POSIXSharedMemory, error) {
    // 创建共享内存对象
    fd, err := syscall.Open("/dev/shm/"+name,
        syscall.O_CREAT|syscall.O_RDWR|syscall.O_EXCL, 0666)
    if err != nil {
        return nil, fmt.Errorf("创建 POSIX 共享内存失败: %v", err)
    }

    // 设置大小
    err = syscall.Ftruncate(fd, int64(size))
    if err != nil {
        syscall.Close(fd)
        return nil, fmt.Errorf("设置共享内存大小失败: %v", err)
    }

    return &POSIXSharedMemory{
        name: name,
        fd:   fd,
        size: size,
    }, nil
}

// 打开已存在的 POSIX 共享内存
func OpenPOSIXSharedMemory(name string) (*POSIXSharedMemory, error) {
    fd, err := syscall.Open("/dev/shm/"+name, syscall.O_RDWR, 0)
    if err != nil {
        return nil, fmt.Errorf("打开 POSIX 共享内存失败: %v", err)
    }

    // 获取文件大小
    var stat syscall.Stat_t
    err = syscall.Fstat(fd, &stat)
    if err != nil {
        syscall.Close(fd)
        return nil, fmt.Errorf("获取共享内存大小失败: %v", err)
    }

    return &POSIXSharedMemory{
        name: name,
        fd:   fd,
        size: int(stat.Size),
    }, nil
}

// 映射内存
func (psm *POSIXSharedMemory) Map() error {
    addr, err := syscall.Mmap(psm.fd, 0, psm.size,
        syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
    if err != nil {
        return fmt.Errorf("映射内存失败: %v", err)
    }

    psm.addr = unsafe.Pointer(&addr[0])
    return nil
}

// 解除映射
func (psm *POSIXSharedMemory) Unmap() error {
    if psm.addr == nil {
        return nil
    }

    data := (*[1 << 30]byte)(psm.addr)[:psm.size:psm.size]
    err := syscall.Munmap(data)
    if err != nil {
        return fmt.Errorf("解除映射失败: %v", err)
    }

    psm.addr = nil
    return nil
}

// 关闭共享内存
func (psm *POSIXSharedMemory) Close() error {
    if psm.addr != nil {
        psm.Unmap()
    }

    err := syscall.Close(psm.fd)
    if err != nil {
        return fmt.Errorf("关闭共享内存失败: %v", err)
    }

    return nil
}

// 删除共享内存
func (psm *POSIXSharedMemory) Unlink() error {
    return os.Remove("/dev/shm/" + psm.name)
}

// 写入数据
func (psm *POSIXSharedMemory) Write(data []byte) error {
    if psm.addr == nil {
        return fmt.Errorf("内存未映射")
    }

    if len(data) > psm.size {
        return fmt.Errorf("数据大小超过共享内存大小")
    }

    dest := (*[1 << 30]byte)(psm.addr)[:psm.size:psm.size]
    copy(dest, data)

    return nil
}

// 读取数据
func (psm *POSIXSharedMemory) Read(size int) ([]byte, error) {
    if psm.addr == nil {
        return nil, fmt.Errorf("内存未映射")
    }

    if size > psm.size {
        size = psm.size
    }

    src := (*[1 << 30]byte)(psm.addr)[:size:size]
    data := make([]byte, size)
    copy(data, src)

    return data, nil
}

func posixSharedMemoryExample() {
    name := "test_shm"
    size := 1024

    // 创建 POSIX 共享内存
    shm, err := NewPOSIXSharedMemory(name, size)
    if err != nil {
        log.Fatal(err)
    }
    defer shm.Close()
    defer shm.Unlink()

    // 映射内存
    err = shm.Map()
    if err != nil {
        log.Fatal(err)
    }
    defer shm.Unmap()

    // 写入数据
    message := "Hello, POSIX Shared Memory!"
    err = shm.Write([]byte(message))
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("写入数据: %s\n", message)

    // 读取数据
    data, err := shm.Read(len(message))
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("读取数据: %s\n", string(data))
}

func main() {
    fmt.Println("=== POSIX 共享内存示例 ===")
    posixSharedMemoryExample()
}

信号量机制 #

System V 信号量 #

信号量用于控制对共享资源的访问:

package main

import (
    "fmt"
    "syscall"
    "unsafe"
    "log"
    "time"
)

// 信号量操作结构
type SemBuf struct {
    Num uint16
    Op  int16
    Flg int16
}

// 信号量管理器
type Semaphore struct {
    key int
    id  int
}

// 创建信号量
func NewSemaphore(key, nsems int, initVal int) (*Semaphore, error) {
    // 创建信号量集
    id, _, errno := syscall.Syscall(syscall.SYS_SEMGET,
        uintptr(key),
        uintptr(nsems),
        uintptr(IPC_CREAT|0666))

    if errno != 0 {
        return nil, fmt.Errorf("创建信号量失败: %v", errno)
    }

    sem := &Semaphore{
        key: key,
        id:  int(id),
    }

    // 初始化信号量值
    if initVal >= 0 {
        err := sem.SetValue(0, initVal)
        if err != nil {
            return nil, err
        }
    }

    return sem, nil
}

// 获取已存在的信号量
func GetSemaphore(key int) (*Semaphore, error) {
    id, _, errno := syscall.Syscall(syscall.SYS_SEMGET,
        uintptr(key), 0, 0)

    if errno != 0 {
        return nil, fmt.Errorf("获取信号量失败: %v", errno)
    }

    return &Semaphore{
        key: key,
        id:  int(id),
    }, nil
}

// 设置信号量值
func (s *Semaphore) SetValue(semnum, value int) error {
    // 使用 SETVAL 命令设置信号量值
    _, _, errno := syscall.Syscall6(syscall.SYS_SEMCTL,
        uintptr(s.id),
        uintptr(semnum),
        16, // SETVAL
        uintptr(value), 0, 0)

    if errno != 0 {
        return fmt.Errorf("设置信号量值失败: %v", errno)
    }

    return nil
}

// 获取信号量值
func (s *Semaphore) GetValue(semnum int) (int, error) {
    val, _, errno := syscall.Syscall(syscall.SYS_SEMCTL,
        uintptr(s.id),
        uintptr(semnum),
        12) // GETVAL

    if errno != 0 {
        return 0, fmt.Errorf("获取信号量值失败: %v", errno)
    }

    return int(val), nil
}

// P 操作(等待/减少)
func (s *Semaphore) Wait(semnum int) error {
    sembuf := SemBuf{
        Num: uint16(semnum),
        Op:  -1,
        Flg: 0,
    }

    _, _, errno := syscall.Syscall(syscall.SYS_SEMOP,
        uintptr(s.id),
        uintptr(unsafe.Pointer(&sembuf)),
        1)

    if errno != 0 {
        return fmt.Errorf("信号量等待失败: %v", errno)
    }

    return nil
}

// V 操作(信号/增加)
func (s *Semaphore) Signal(semnum int) error {
    sembuf := SemBuf{
        Num: uint16(semnum),
        Op:  1,
        Flg: 0,
    }

    _, _, errno := syscall.Syscall(syscall.SYS_SEMOP,
        uintptr(s.id),
        uintptr(unsafe.Pointer(&sembuf)),
        1)

    if errno != 0 {
        return fmt.Errorf("信号量信号失败: %v", errno)
    }

    return nil
}

// 删除信号量
func (s *Semaphore) Remove() error {
    _, _, errno := syscall.Syscall(syscall.SYS_SEMCTL,
        uintptr(s.id), 0, IPC_RMID)

    if errno != 0 {
        return fmt.Errorf("删除信号量失败: %v", errno)
    }

    return nil
}

// 信号量示例
func semaphoreExample() {
    key := 54321

    // 创建信号量(初始值为 1,用作互斥锁)
    sem, err := NewSemaphore(key, 1, 1)
    if err != nil {
        log.Fatal(err)
    }
    defer sem.Remove()

    fmt.Printf("信号量创建成功,ID: %d\n", sem.id)

    // 获取初始值
    val, err := sem.GetValue(0)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("信号量初始值: %d\n", val)

    // P 操作(获取锁)
    fmt.Println("执行 P 操作(获取锁)")
    err = sem.Wait(0)
    if err != nil {
        log.Fatal(err)
    }

    val, _ = sem.GetValue(0)
    fmt.Printf("P 操作后信号量值: %d\n", val)

    // 模拟临界区操作
    fmt.Println("在临界区中...")
    time.Sleep(2 * time.Second)

    // V 操作(释放锁)
    fmt.Println("执行 V 操作(释放锁)")
    err = sem.Signal(0)
    if err != nil {
        log.Fatal(err)
    }

    val, _ = sem.GetValue(0)
    fmt.Printf("V 操作后信号量值: %d\n", val)
}

func main() {
    fmt.Println("=== 信号量示例 ===")
    semaphoreExample()
}

生产者-消费者问题 #

使用共享内存和信号量实现经典的生产者-消费者问题:

package main

import (
    "fmt"
    "log"
    "time"
    "sync"
    "math/rand"
)

// 缓冲区结构
type Buffer struct {
    data  []int
    size  int
    in    int
    out   int
    count int
}

// 生产者-消费者系统
type ProducerConsumerSystem struct {
    shm       *SharedMemory
    buffer    *Buffer
    mutex     *Semaphore  // 互斥信号量
    empty     *Semaphore  // 空槽信号量
    full      *Semaphore  // 满槽信号量
    bufferSize int
}

// 创建生产者-消费者系统
func NewProducerConsumerSystem(shmKey, semKey, bufferSize int) (*ProducerConsumerSystem, error) {
    // 创建共享内存
    shm, err := NewSharedMemory(shmKey, 1024)
    if err != nil {
        return nil, err
    }

    err = shm.Attach()
    if err != nil {
        return nil, err
    }

    // 创建信号量
    mutex, err := NewSemaphore(semKey, 1, 1)      // 互斥锁
    if err != nil {
        return nil, err
    }

    empty, err := NewSemaphore(semKey+1, 1, bufferSize) // 空槽数量
    if err != nil {
        return nil, err
    }

    full, err := NewSemaphore(semKey+2, 1, 0)     // 满槽数量
    if err != nil {
        return nil, err
    }

    // 初始化缓冲区
    buffer := &Buffer{
        data: make([]int, bufferSize),
        size: bufferSize,
        in:   0,
        out:  0,
        count: 0,
    }

    return &ProducerConsumerSystem{
        shm:        shm,
        buffer:     buffer,
        mutex:      mutex,
        empty:      empty,
        full:       full,
        bufferSize: bufferSize,
    }, nil
}

// 清理资源
func (pcs *ProducerConsumerSystem) Cleanup() {
    if pcs.shm != nil {
        pcs.shm.Detach()
        pcs.shm.Remove()
    }
    if pcs.mutex != nil {
        pcs.mutex.Remove()
    }
    if pcs.empty != nil {
        pcs.empty.Remove()
    }
    if pcs.full != nil {
        pcs.full.Remove()
    }
}

// 生产者
func (pcs *ProducerConsumerSystem) Producer(id int, itemCount int, wg *sync.WaitGroup) {
    defer wg.Done()

    for i := 0; i < itemCount; i++ {
        item := rand.Intn(100)

        // 等待空槽
        err := pcs.empty.Wait(0)
        if err != nil {
            log.Printf("生产者 %d 等待空槽失败: %v", id, err)
            continue
        }

        // 获取互斥锁
        err = pcs.mutex.Wait(0)
        if err != nil {
            log.Printf("生产者 %d 获取锁失败: %v", id, err)
            continue
        }

        // 生产物品
        pcs.buffer.data[pcs.buffer.in] = item
        pcs.buffer.in = (pcs.buffer.in + 1) % pcs.buffer.size
        pcs.buffer.count++

        fmt.Printf("生产者 %d 生产物品 %d,缓冲区数量: %d\n", id, item, pcs.buffer.count)

        // 释放互斥锁
        pcs.mutex.Signal(0)

        // 增加满槽数量
        pcs.full.Signal(0)

        // 模拟生产时间
        time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
    }

    fmt.Printf("生产者 %d 完成\n", id)
}

// 消费者
func (pcs *ProducerConsumerSystem) Consumer(id int, wg *sync.WaitGroup) {
    defer wg.Done()

    consumedCount := 0

    for {
        // 等待满槽
        err := pcs.full.Wait(0)
        if err != nil {
            log.Printf("消费者 %d 等待满槽失败: %v", id, err)
            break
        }

        // 获取互斥锁
        err = pcs.mutex.Wait(0)
        if err != nil {
            log.Printf("消费者 %d 获取锁失败: %v", id, err)
            continue
        }

        // 消费物品
        item := pcs.buffer.data[pcs.buffer.out]
        pcs.buffer.out = (pcs.buffer.out + 1) % pcs.buffer.size
        pcs.buffer.count--
        consumedCount++

        fmt.Printf("消费者 %d 消费物品 %d,缓冲区数量: %d\n", id, item, pcs.buffer.count)

        // 释放互斥锁
        pcs.mutex.Signal(0)

        // 增加空槽数量
        pcs.empty.Signal(0)

        // 模拟消费时间
        time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond)

        // 简单的退出条件
        if consumedCount >= 5 {
            break
        }
    }

    fmt.Printf("消费者 %d 完成,消费了 %d 个物品\n", id, consumedCount)
}

func producerConsumerExample() {
    rand.Seed(time.Now().UnixNano())

    // 创建系统
    system, err := NewProducerConsumerSystem(11111, 22222, 5)
    if err != nil {
        log.Fatal(err)
    }
    defer system.Cleanup()

    var wg sync.WaitGroup

    // 启动生产者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go system.Producer(i, 3, &wg)
    }

    // 启动消费者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go system.Consumer(i, &wg)
    }

    wg.Wait()
    fmt.Println("生产者-消费者系统运行完成")
}

func main() {
    fmt.Println("=== 生产者-消费者问题示例 ===")
    producerConsumerExample()
}

读者-写者问题 #

使用信号量解决读者-写者问题:

package main

import (
    "fmt"
    "log"
    "time"
    "sync"
    "math/rand"
)

// 读者-写者系统
type ReaderWriterSystem struct {
    shm         *SharedMemory
    data        []byte
    readCount   int
    mutex       *Semaphore  // 保护读者计数
    writeMutex  *Semaphore  // 写者互斥
}

// 创建读者-写者系统
func NewReaderWriterSystem(shmKey, semKey int) (*ReaderWriterSystem, error) {
    // 创建共享内存
    shm, err := NewSharedMemory(shmKey, 1024)
    if err != nil {
        return nil, err
    }

    err = shm.Attach()
    if err != nil {
        return nil, err
    }

    // 初始化数据
    initialData := "Initial shared data"
    err = shm.Write([]byte(initialData))
    if err != nil {
        return nil, err
    }

    // 创建信号量
    mutex, err := NewSemaphore(semKey, 1, 1)
    if err != nil {
        return nil, err
    }

    writeMutex, err := NewSemaphore(semKey+1, 1, 1)
    if err != nil {
        return nil, err
    }

    return &ReaderWriterSystem{
        shm:        shm,
        data:       make([]byte, 1024),
        readCount:  0,
        mutex:      mutex,
        writeMutex: writeMutex,
    }, nil
}

// 清理资源
func (rws *ReaderWriterSystem) Cleanup() {
    if rws.shm != nil {
        rws.shm.Detach()
        rws.shm.Remove()
    }
    if rws.mutex != nil {
        rws.mutex.Remove()
    }
    if rws.writeMutex != nil {
        rws.writeMutex.Remove()
    }
}

// 读者
func (rws *ReaderWriterSystem) Reader(id int, wg *sync.WaitGroup) {
    defer wg.Done()

    for i := 0; i < 3; i++ {
        // 获取读者计数锁
        err := rws.mutex.Wait(0)
        if err != nil {
            log.Printf("读者 %d 获取计数锁失败: %v", id, err)
            continue
        }

        rws.readCount++
        if rws.readCount == 1 {
            // 第一个读者需要获取写锁
            err = rws.writeMutex.Wait(0)
            if err != nil {
                log.Printf("读者 %d 获取写锁失败: %v", id, err)
                rws.readCount--
                rws.mutex.Signal(0)
                continue
            }
        }

        // 释放读者计数锁
        rws.mutex.Signal(0)

        // 读取数据
        data, err := rws.shm.Read(100)
        if err != nil {
            log.Printf("读者 %d 读取数据失败: %v", id, err)
        } else {
            // 找到字符串结束位置
            end := 0
            for i, b := range data {
                if b == 0 {
                    end = i
                    break
                }
            }
            if end == 0 {
                end = len(data)
            }

            fmt.Printf("读者 %d 读取数据: %s\n", id, string(data[:end]))
        }

        // 模拟读取时间
        time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)

        // 获取读者计数锁
        err = rws.mutex.Wait(0)
        if err != nil {
            log.Printf("读者 %d 获取计数锁失败: %v", id, err)
            continue
        }

        rws.readCount--
        if rws.readCount == 0 {
            // 最后一个读者释放写锁
            rws.writeMutex.Signal(0)
        }

        // 释放读者计数锁
        rws.mutex.Signal(0)

        // 读者间隔
        time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
    }

    fmt.Printf("读者 %d 完成\n", id)
}

// 写者
func (rws *ReaderWriterSystem) Writer(id int, wg *sync.WaitGroup) {
    defer wg.Done()

    for i := 0; i < 2; i++ {
        // 获取写锁
        err := rws.writeMutex.Wait(0)
        if err != nil {
            log.Printf("写者 %d 获取写锁失败: %v", id, err)
            continue
        }

        // 写入数据
        newData := fmt.Sprintf("Data written by writer %d at %d", id, time.Now().Unix())
        err = rws.shm.Write([]byte(newData))
        if err != nil {
            log.Printf("写者 %d 写入数据失败: %v", id, err)
        } else {
            fmt.Printf("写者 %d 写入数据: %s\n", id, newData)
        }

        // 模拟写入时间
        time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond)

        // 释放写锁
        rws.writeMutex.Signal(0)

        // 写者间隔
        time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
    }

    fmt.Printf("写者 %d 完成\n", id)
}

func readerWriterExample() {
    rand.Seed(time.Now().UnixNano())

    // 创建系统
    system, err := NewReaderWriterSystem(33333, 44444)
    if err != nil {
        log.Fatal(err)
    }
    defer system.Cleanup()

    var wg sync.WaitGroup

    // 启动读者
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go system.Reader(i, &wg)
    }

    // 启动写者
    for i := 0; i < 2; i++ {
        wg.Add(1)
        go system.Writer(i, &wg)
    }

    wg.Wait()
    fmt.Println("读者-写者系统运行完成")
}

func main() {
    fmt.Println("=== 读者-写者问题示例 ===")
    readerWriterExample()
}

实践练习 #

练习 1:分布式计数器 #

使用共享内存和信号量实现一个分布式计数器:

package main

import (
    "fmt"
    "log"
    "time"
    "sync"
    "math/rand"
    "unsafe"
)

// 分布式计数器
type DistributedCounter struct {
    shm     *SharedMemory
    mutex   *Semaphore
    counter *int64
}

// 创建分布式计数器
func NewDistributedCounter(shmKey, semKey int) (*DistributedCounter, error) {
    // 创建共享内存
    shm, err := NewSharedMemory(shmKey, 64)
    if err != nil {
        return nil, err
    }

    err = shm.Attach()
    if err != nil {
        return nil, err
    }

    // 创建互斥信号量
    mutex, err := NewSemaphore(semKey, 1, 1)
    if err != nil {
        return nil, err
    }

    // 初始化计数器
    counter := (*int64)(shm.addr)
    *counter = 0

    return &DistributedCounter{
        shm:     shm,
        mutex:   mutex,
        counter: counter,
    }, nil
}

// 清理资源
func (dc *DistributedCounter) Cleanup() {
    if dc.shm != nil {
        dc.shm.Detach()
        dc.shm.Remove()
    }
    if dc.mutex != nil {
        dc.mutex.Remove()
    }
}

// 增加计数
func (dc *DistributedCounter) Increment(processID int) error {
    // 获取锁
    err := dc.mutex.Wait(0)
    if err != nil {
        return err
    }
    defer dc.mutex.Signal(0)

    // 读取当前值
    oldValue := *dc.counter

    // 模拟一些处理时间
    time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)

    // 增加计数
    *dc.counter = oldValue + 1

    fmt.Printf("进程 %d: %d -> %d\n", processID, oldValue, *dc.counter)

    return nil
}

// 获取当前计数
func (dc *DistributedCounter) GetValue() int64 {
    // 获取锁
    err := dc.mutex.Wait(0)
    if err != nil {
        return -1
    }
    defer dc.mutex.Signal(0)

    return *dc.counter
}

// 工作进程
func worker(id int, counter *DistributedCounter, iterations int, wg *sync.WaitGroup) {
    defer wg.Done()

    fmt.Printf("工作进程 %d 启动\n", id)

    for i := 0; i < iterations; i++ {
        err := counter.Increment(id)
        if err != nil {
            log.Printf("进程 %d 增加计数失败: %v", id, err)
        }

        // 随机等待
        time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond)
    }

    fmt.Printf("工作进程 %d 完成\n", id)
}

func distributedCounterExample() {
    rand.Seed(time.Now().UnixNano())

    // 创建分布式计数器
    counter, err := NewDistributedCounter(55555, 66666)
    if err != nil {
        log.Fatal(err)
    }
    defer counter.Cleanup()

    fmt.Printf("初始计数值: %d\n", counter.GetValue())

    var wg sync.WaitGroup
    processCount := 5
    iterationsPerProcess := 10

    // 启动多个工作进程
    for i := 0; i < processCount; i++ {
        wg.Add(1)
        go worker(i, counter, iterationsPerProcess, &wg)
    }

    wg.Wait()

    finalValue := counter.GetValue()
    expectedValue := int64(processCount * iterationsPerProcess)

    fmt.Printf("最终计数值: %d\n", finalValue)
    fmt.Printf("期望计数值: %d\n", expectedValue)

    if finalValue == expectedValue {
        fmt.Println("✓ 分布式计数器工作正常")
    } else {
        fmt.Println("✗ 分布式计数器存在问题")
    }
}

func main() {
    fmt.Println("=== 分布式计数器示例 ===")
    distributedCounterExample()
}

总结 #

本节详细介绍了共享内存和信号量的使用方法:

  1. 共享内存机制:学会了使用 System V 和 POSIX 共享内存进行高效的进程间数据共享
  2. 信号量同步:掌握了使用信号量进行进程间同步和互斥控制
  3. 经典问题:实现了生产者-消费者和读者-写者等经典并发问题的解决方案
  4. 实践应用:构建了分布式计数器等实际应用场景

这些机制为构建高性能的多进程系统提供了强大的基础工具。通过合理使用共享内存和信号量,可以实现高效的进程间协作和数据共享。

至此,第 4 章第 4 节"进程与系统调用"的内容已经完成。我们学习了进程创建与管理、进程间通信、系统调用详解以及共享内存与信号量等重要概念和实践技能。这些知识为深入理解操作系统和进行系统级编程奠定了坚实的基础。