4.3.1 TCP 并发服务器

4.3.1 TCP 并发服务器 #

高并发 TCP 服务器是现代网络应用的核心组件。本节将深入探讨如何设计和实现能够处理大量并发连接的 TCP 服务器,包括连接管理、负载均衡、性能优化等关键技术。

并发模型设计 #

Goroutine-per-Connection 模型 #

这是 Go 语言中最常见的并发模型,为每个连接创建一个 goroutine:

package main

import (
    "bufio"
    "fmt"
    "io"
    "net"
    "sync"
    "time"
)

type ConcurrentServer struct {
    listener    net.Listener
    connections map[net.Conn]bool
    mutex       sync.RWMutex
    wg          sync.WaitGroup
    shutdown    chan struct{}
}

func NewConcurrentServer(address string) (*ConcurrentServer, error) {
    listener, err := net.Listen("tcp", address)
    if err != nil {
        return nil, err
    }

    return &ConcurrentServer{
        listener:    listener,
        connections: make(map[net.Conn]bool),
        shutdown:    make(chan struct{}),
    }, nil
}

func (s *ConcurrentServer) Start() error {
    fmt.Printf("TCP 并发服务器启动在 %s\n", s.listener.Addr())

    for {
        select {
        case <-s.shutdown:
            return nil
        default:
            // 设置接受连接的超时
            if tcpListener, ok := s.listener.(*net.TCPListener); ok {
                tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
            }

            conn, err := s.listener.Accept()
            if err != nil {
                if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
                    continue // 超时,继续循环
                }
                return err
            }

            s.addConnection(conn)
            s.wg.Add(1)
            go s.handleConnection(conn)
        }
    }
}

func (s *ConcurrentServer) handleConnection(conn net.Conn) {
    defer func() {
        s.removeConnection(conn)
        conn.Close()
        s.wg.Done()
    }()

    fmt.Printf("新连接: %s\n", conn.RemoteAddr())

    reader := bufio.NewReader(conn)
    writer := bufio.NewWriter(conn)

    for {
        select {
        case <-s.shutdown:
            return
        default:
            // 设置读取超时
            conn.SetReadDeadline(time.Now().Add(30 * time.Second))

            message, err := reader.ReadString('\n')
            if err != nil {
                if err == io.EOF {
                    fmt.Printf("客户端断开连接: %s\n", conn.RemoteAddr())
                } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
                    fmt.Printf("连接超时: %s\n", conn.RemoteAddr())
                } else {
                    fmt.Printf("读取错误: %v\n", err)
                }
                return
            }

            // 处理消息
            response := s.processMessage(message, conn)

            // 发送响应
            conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            _, err = writer.WriteString(response)
            if err != nil {
                fmt.Printf("写入错误: %v\n", err)
                return
            }
            writer.Flush()
        }
    }
}

func (s *ConcurrentServer) processMessage(message string, conn net.Conn) string {
    // 简单的回显处理
    return fmt.Sprintf("Echo: %s", message)
}

func (s *ConcurrentServer) addConnection(conn net.Conn) {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    s.connections[conn] = true
}

func (s *ConcurrentServer) removeConnection(conn net.Conn) {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    delete(s.connections, conn)
}

func (s *ConcurrentServer) GetConnectionCount() int {
    s.mutex.RLock()
    defer s.mutex.RUnlock()
    return len(s.connections)
}

func (s *ConcurrentServer) Shutdown() error {
    close(s.shutdown)

    // 关闭监听器
    err := s.listener.Close()

    // 关闭所有连接
    s.mutex.Lock()
    for conn := range s.connections {
        conn.Close()
    }
    s.mutex.Unlock()

    // 等待所有 goroutine 结束
    s.wg.Wait()

    return err
}

Worker Pool 模型 #

对于需要限制 goroutine 数量的场景,可以使用 worker pool 模型:

type WorkerPoolServer struct {
    listener    net.Listener
    workerCount int
    jobQueue    chan net.Conn
    workers     []*Worker
    shutdown    chan struct{}
    wg          sync.WaitGroup
}

type Worker struct {
    id       int
    jobQueue chan net.Conn
    quit     chan struct{}
}

func NewWorkerPoolServer(address string, workerCount int) (*WorkerPoolServer, error) {
    listener, err := net.Listen("tcp", address)
    if err != nil {
        return nil, err
    }

    return &WorkerPoolServer{
        listener:    listener,
        workerCount: workerCount,
        jobQueue:    make(chan net.Conn, workerCount*2),
        shutdown:    make(chan struct{}),
    }, nil
}

func (s *WorkerPoolServer) Start() error {
    // 启动 worker
    s.workers = make([]*Worker, s.workerCount)
    for i := 0; i < s.workerCount; i++ {
        worker := &Worker{
            id:       i,
            jobQueue: s.jobQueue,
            quit:     make(chan struct{}),
        }
        s.workers[i] = worker
        s.wg.Add(1)
        go worker.start(&s.wg)
    }

    fmt.Printf("Worker Pool 服务器启动,%d 个 worker\n", s.workerCount)

    // 接受连接
    for {
        select {
        case <-s.shutdown:
            return nil
        default:
            conn, err := s.listener.Accept()
            if err != nil {
                select {
                case <-s.shutdown:
                    return nil
                default:
                    fmt.Printf("接受连接失败: %v\n", err)
                    continue
                }
            }

            // 将连接分发给 worker
            select {
            case s.jobQueue <- conn:
            case <-s.shutdown:
                conn.Close()
                return nil
            default:
                // 队列满,拒绝连接
                fmt.Printf("服务器繁忙,拒绝连接: %s\n", conn.RemoteAddr())
                conn.Close()
            }
        }
    }
}

func (w *Worker) start(wg *sync.WaitGroup) {
    defer wg.Done()

    for {
        select {
        case conn := <-w.jobQueue:
            w.handleConnection(conn)
        case <-w.quit:
            return
        }
    }
}

func (w *Worker) handleConnection(conn net.Conn) {
    defer conn.Close()

    fmt.Printf("Worker %d 处理连接: %s\n", w.id, conn.RemoteAddr())

    reader := bufio.NewReader(conn)
    writer := bufio.NewWriter(conn)

    for {
        conn.SetReadDeadline(time.Now().Add(30 * time.Second))

        message, err := reader.ReadString('\n')
        if err != nil {
            if err != io.EOF {
                fmt.Printf("Worker %d 读取错误: %v\n", w.id, err)
            }
            break
        }

        // 处理消息
        response := fmt.Sprintf("Worker %d Echo: %s", w.id, message)

        conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
        writer.WriteString(response)
        writer.Flush()
    }

    fmt.Printf("Worker %d 完成处理: %s\n", w.id, conn.RemoteAddr())
}

func (s *WorkerPoolServer) Shutdown() error {
    close(s.shutdown)

    // 停止所有 worker
    for _, worker := range s.workers {
        close(worker.quit)
    }

    // 关闭监听器
    err := s.listener.Close()

    // 等待所有 worker 结束
    s.wg.Wait()

    return err
}

连接管理 #

连接池管理 #

import (
    "context"
    "sync"
    "time"
)

type ConnectionManager struct {
    connections map[string]*ManagedConnection
    mutex       sync.RWMutex
    maxIdle     time.Duration
    cleanupTicker *time.Ticker
    ctx         context.Context
    cancel      context.CancelFunc
}

type ManagedConnection struct {
    Conn       net.Conn
    LastActive time.Time
    BytesSent  int64
    BytesRecv  int64
    State      ConnectionState
    mutex      sync.RWMutex
}

type ConnectionState int

const (
    StateActive ConnectionState = iota
    StateIdle
    StateClosing
)

func NewConnectionManager(maxIdle time.Duration) *ConnectionManager {
    ctx, cancel := context.WithCancel(context.Background())

    cm := &ConnectionManager{
        connections: make(map[string]*ManagedConnection),
        maxIdle:     maxIdle,
        ctx:         ctx,
        cancel:      cancel,
    }

    // 启动清理协程
    cm.cleanupTicker = time.NewTicker(maxIdle / 2)
    go cm.cleanupLoop()

    return cm
}

func (cm *ConnectionManager) AddConnection(conn net.Conn) *ManagedConnection {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    managed := &ManagedConnection{
        Conn:       conn,
        LastActive: time.Now(),
        State:      StateActive,
    }

    cm.connections[conn.RemoteAddr().String()] = managed
    return managed
}

func (cm *ConnectionManager) RemoveConnection(conn net.Conn) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    addr := conn.RemoteAddr().String()
    if managed, exists := cm.connections[addr]; exists {
        managed.mutex.Lock()
        managed.State = StateClosing
        managed.mutex.Unlock()
        delete(cm.connections, addr)
    }
}

func (cm *ConnectionManager) UpdateActivity(conn net.Conn, bytesSent, bytesRecv int64) {
    cm.mutex.RLock()
    managed, exists := cm.connections[conn.RemoteAddr().String()]
    cm.mutex.RUnlock()

    if exists {
        managed.mutex.Lock()
        managed.LastActive = time.Now()
        managed.BytesSent += bytesSent
        managed.BytesRecv += bytesRecv
        managed.mutex.Unlock()
    }
}

func (cm *ConnectionManager) cleanupLoop() {
    for {
        select {
        case <-cm.ctx.Done():
            return
        case <-cm.cleanupTicker.C:
            cm.cleanupIdleConnections()
        }
    }
}

func (cm *ConnectionManager) cleanupIdleConnections() {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    now := time.Now()
    var toClose []string

    for addr, managed := range cm.connections {
        managed.mutex.RLock()
        if now.Sub(managed.LastActive) > cm.maxIdle && managed.State == StateActive {
            managed.State = StateIdle
            toClose = append(toClose, addr)
        }
        managed.mutex.RUnlock()
    }

    for _, addr := range toClose {
        if managed, exists := cm.connections[addr]; exists {
            fmt.Printf("关闭空闲连接: %s\n", addr)
            managed.Conn.Close()
            delete(cm.connections, addr)
        }
    }
}

func (cm *ConnectionManager) GetStats() map[string]interface{} {
    cm.mutex.RLock()
    defer cm.mutex.RUnlock()

    stats := map[string]interface{}{
        "total_connections": len(cm.connections),
        "active_connections": 0,
        "idle_connections": 0,
        "total_bytes_sent": int64(0),
        "total_bytes_recv": int64(0),
    }

    for _, managed := range cm.connections {
        managed.mutex.RLock()
        if managed.State == StateActive {
            stats["active_connections"] = stats["active_connections"].(int) + 1
        } else if managed.State == StateIdle {
            stats["idle_connections"] = stats["idle_connections"].(int) + 1
        }
        stats["total_bytes_sent"] = stats["total_bytes_sent"].(int64) + managed.BytesSent
        stats["total_bytes_recv"] = stats["total_bytes_recv"].(int64) + managed.BytesRecv
        managed.mutex.RUnlock()
    }

    return stats
}

func (cm *ConnectionManager) Shutdown() {
    cm.cancel()
    cm.cleanupTicker.Stop()

    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    for _, managed := range cm.connections {
        managed.Conn.Close()
    }
    cm.connections = make(map[string]*ManagedConnection)
}

负载均衡 #

type LoadBalancer struct {
    servers   []*ServerInfo
    current   int
    mutex     sync.RWMutex
    algorithm LoadBalanceAlgorithm
}

type ServerInfo struct {
    Address     string
    Weight      int
    Connections int
    Active      bool
    mutex       sync.RWMutex
}

type LoadBalanceAlgorithm int

const (
    RoundRobin LoadBalanceAlgorithm = iota
    WeightedRoundRobin
    LeastConnections
)

func NewLoadBalancer(algorithm LoadBalanceAlgorithm) *LoadBalancer {
    return &LoadBalancer{
        servers:   make([]*ServerInfo, 0),
        algorithm: algorithm,
    }
}

func (lb *LoadBalancer) AddServer(address string, weight int) {
    lb.mutex.Lock()
    defer lb.mutex.Unlock()

    server := &ServerInfo{
        Address: address,
        Weight:  weight,
        Active:  true,
    }

    lb.servers = append(lb.servers, server)
}

func (lb *LoadBalancer) GetNextServer() *ServerInfo {
    lb.mutex.RLock()
    defer lb.mutex.RUnlock()

    if len(lb.servers) == 0 {
        return nil
    }

    switch lb.algorithm {
    case RoundRobin:
        return lb.roundRobin()
    case WeightedRoundRobin:
        return lb.weightedRoundRobin()
    case LeastConnections:
        return lb.leastConnections()
    default:
        return lb.roundRobin()
    }
}

func (lb *LoadBalancer) roundRobin() *ServerInfo {
    activeServers := lb.getActiveServers()
    if len(activeServers) == 0 {
        return nil
    }

    server := activeServers[lb.current%len(activeServers)]
    lb.current++
    return server
}

func (lb *LoadBalancer) weightedRoundRobin() *ServerInfo {
    activeServers := lb.getActiveServers()
    if len(activeServers) == 0 {
        return nil
    }

    totalWeight := 0
    for _, server := range activeServers {
        totalWeight += server.Weight
    }

    if totalWeight == 0 {
        return lb.roundRobin()
    }

    target := lb.current % totalWeight
    lb.current++

    currentWeight := 0
    for _, server := range activeServers {
        currentWeight += server.Weight
        if target < currentWeight {
            return server
        }
    }

    return activeServers[0]
}

func (lb *LoadBalancer) leastConnections() *ServerInfo {
    activeServers := lb.getActiveServers()
    if len(activeServers) == 0 {
        return nil
    }

    var selected *ServerInfo
    minConnections := int(^uint(0) >> 1) // 最大整数

    for _, server := range activeServers {
        server.mutex.RLock()
        connections := server.Connections
        server.mutex.RUnlock()

        if connections < minConnections {
            minConnections = connections
            selected = server
        }
    }

    return selected
}

func (lb *LoadBalancer) getActiveServers() []*ServerInfo {
    var active []*ServerInfo
    for _, server := range lb.servers {
        if server.Active {
            active = append(active, server)
        }
    }
    return active
}

func (lb *LoadBalancer) UpdateServerConnections(address string, delta int) {
    lb.mutex.RLock()
    defer lb.mutex.RUnlock()

    for _, server := range lb.servers {
        if server.Address == address {
            server.mutex.Lock()
            server.Connections += delta
            if server.Connections < 0 {
                server.Connections = 0
            }
            server.mutex.Unlock()
            break
        }
    }
}

func (lb *LoadBalancer) SetServerActive(address string, active bool) {
    lb.mutex.RLock()
    defer lb.mutex.RUnlock()

    for _, server := range lb.servers {
        if server.Address == address {
            server.Active = active
            break
        }
    }
}

性能优化 #

零拷贝优化 #

import (
    "os"
    "syscall"
)

type ZeroCopyServer struct {
    listener net.Listener
}

func (s *ZeroCopyServer) sendFile(conn net.Conn, filename string) error {
    file, err := os.Open(filename)
    if err != nil {
        return err
    }
    defer file.Close()

    // 获取文件信息
    stat, err := file.Stat()
    if err != nil {
        return err
    }

    // 使用 sendfile 系统调用
    if tcpConn, ok := conn.(*net.TCPConn); ok {
        rawConn, err := tcpConn.SyscallConn()
        if err != nil {
            return err
        }

        var sendErr error
        err = rawConn.Control(func(fd uintptr) {
            _, sendErr = syscall.Sendfile(int(fd), int(file.Fd()), nil, int(stat.Size()))
        })

        if err != nil {
            return err
        }
        return sendErr
    }

    return fmt.Errorf("不支持零拷贝")
}

内存池优化 #

type MemoryPool struct {
    small  sync.Pool // 小缓冲区 (1KB)
    medium sync.Pool // 中等缓冲区 (4KB)
    large  sync.Pool // 大缓冲区 (16KB)
}

func NewMemoryPool() *MemoryPool {
    return &MemoryPool{
        small: sync.Pool{
            New: func() interface{} {
                return make([]byte, 1024)
            },
        },
        medium: sync.Pool{
            New: func() interface{} {
                return make([]byte, 4096)
            },
        },
        large: sync.Pool{
            New: func() interface{} {
                return make([]byte, 16384)
            },
        },
    }
}

func (mp *MemoryPool) GetBuffer(size int) []byte {
    switch {
    case size <= 1024:
        return mp.small.Get().([]byte)[:size]
    case size <= 4096:
        return mp.medium.Get().([]byte)[:size]
    case size <= 16384:
        return mp.large.Get().([]byte)[:size]
    default:
        return make([]byte, size)
    }
}

func (mp *MemoryPool) PutBuffer(buf []byte) {
    capacity := cap(buf)
    switch {
    case capacity == 1024:
        mp.small.Put(buf[:1024])
    case capacity == 4096:
        mp.medium.Put(buf[:4096])
    case capacity == 16384:
        mp.large.Put(buf[:16384])
    }
    // 其他大小的缓冲区直接丢弃,让 GC 回收
}

批量处理优化 #

type BatchProcessor struct {
    conn       net.Conn
    buffer     [][]byte
    batchSize  int
    flushTimer *time.Timer
    mutex      sync.Mutex
}

func NewBatchProcessor(conn net.Conn, batchSize int) *BatchProcessor {
    return &BatchProcessor{
        conn:      conn,
        buffer:    make([][]byte, 0, batchSize),
        batchSize: batchSize,
    }
}

func (bp *BatchProcessor) Write(data []byte) error {
    bp.mutex.Lock()
    defer bp.mutex.Unlock()

    bp.buffer = append(bp.buffer, data)

    if len(bp.buffer) >= bp.batchSize {
        return bp.flush()
    }

    // 设置定时刷新
    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
    }
    bp.flushTimer = time.AfterFunc(10*time.Millisecond, func() {
        bp.mutex.Lock()
        bp.flush()
        bp.mutex.Unlock()
    })

    return nil
}

func (bp *BatchProcessor) flush() error {
    if len(bp.buffer) == 0 {
        return nil
    }

    // 计算总大小
    totalSize := 0
    for _, data := range bp.buffer {
        totalSize += len(data)
    }

    // 合并数据
    combined := make([]byte, 0, totalSize)
    for _, data := range bp.buffer {
        combined = append(combined, data...)
    }

    // 发送数据
    _, err := bp.conn.Write(combined)

    // 清空缓冲区
    bp.buffer = bp.buffer[:0]

    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
        bp.flushTimer = nil
    }

    return err
}

func (bp *BatchProcessor) Close() error {
    bp.mutex.Lock()
    defer bp.mutex.Unlock()

    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
    }
    return bp.flush()
}

高可用性设计 #

健康检查 #

type HealthChecker struct {
    servers   map[string]*ServerHealth
    interval  time.Duration
    timeout   time.Duration
    mutex     sync.RWMutex
    ctx       context.Context
    cancel    context.CancelFunc
}

type ServerHealth struct {
    Address     string
    Healthy     bool
    LastCheck   time.Time
    FailCount   int
    SuccessCount int
}

func NewHealthChecker(interval, timeout time.Duration) *HealthChecker {
    ctx, cancel := context.WithCancel(context.Background())

    hc := &HealthChecker{
        servers:  make(map[string]*ServerHealth),
        interval: interval,
        timeout:  timeout,
        ctx:      ctx,
        cancel:   cancel,
    }

    go hc.checkLoop()
    return hc
}

func (hc *HealthChecker) AddServer(address string) {
    hc.mutex.Lock()
    defer hc.mutex.Unlock()

    hc.servers[address] = &ServerHealth{
        Address: address,
        Healthy: true,
    }
}

func (hc *HealthChecker) checkLoop() {
    ticker := time.NewTicker(hc.interval)
    defer ticker.Stop()

    for {
        select {
        case <-hc.ctx.Done():
            return
        case <-ticker.C:
            hc.checkAllServers()
        }
    }
}

func (hc *HealthChecker) checkAllServers() {
    hc.mutex.RLock()
    servers := make([]*ServerHealth, 0, len(hc.servers))
    for _, server := range hc.servers {
        servers = append(servers, server)
    }
    hc.mutex.RUnlock()

    var wg sync.WaitGroup
    for _, server := range servers {
        wg.Add(1)
        go func(s *ServerHealth) {
            defer wg.Done()
            hc.checkServer(s)
        }(server)
    }
    wg.Wait()
}

func (hc *HealthChecker) checkServer(server *ServerHealth) {
    ctx, cancel := context.WithTimeout(hc.ctx, hc.timeout)
    defer cancel()

    var d net.Dialer
    conn, err := d.DialContext(ctx, "tcp", server.Address)

    hc.mutex.Lock()
    defer hc.mutex.Unlock()

    server.LastCheck = time.Now()

    if err != nil {
        server.FailCount++
        if server.FailCount >= 3 && server.Healthy {
            server.Healthy = false
            fmt.Printf("服务器 %s 标记为不健康\n", server.Address)
        }
    } else {
        conn.Close()
        server.SuccessCount++
        if !server.Healthy && server.SuccessCount >= 2 {
            server.Healthy = true
            server.FailCount = 0
            fmt.Printf("服务器 %s 恢复健康\n", server.Address)
        }
    }
}

func (hc *HealthChecker) IsHealthy(address string) bool {
    hc.mutex.RLock()
    defer hc.mutex.RUnlock()

    if server, exists := hc.servers[address]; exists {
        return server.Healthy
    }
    return false
}

func (hc *HealthChecker) Shutdown() {
    hc.cancel()
}

故障转移 #

type FailoverManager struct {
    primary   string
    secondary []string
    current   string
    checker   *HealthChecker
    mutex     sync.RWMutex
}

func NewFailoverManager(primary string, secondary []string, checker *HealthChecker) *FailoverManager {
    return &FailoverManager{
        primary:   primary,
        secondary: secondary,
        current:   primary,
        checker:   checker,
    }
}

func (fm *FailoverManager) GetCurrentServer() string {
    fm.mutex.RLock()
    defer fm.mutex.RUnlock()

    // 检查当前服务器是否健康
    if fm.checker.IsHealthy(fm.current) {
        return fm.current
    }

    // 当前服务器不健康,尝试故障转移
    fm.mutex.RUnlock()
    fm.mutex.Lock()
    defer func() {
        fm.mutex.Unlock()
        fm.mutex.RLock()
    }()

    // 双重检查
    if fm.checker.IsHealthy(fm.current) {
        return fm.current
    }

    // 尝试主服务器
    if fm.current != fm.primary && fm.checker.IsHealthy(fm.primary) {
        fmt.Printf("故障转移到主服务器: %s\n", fm.primary)
        fm.current = fm.primary
        return fm.current
    }

    // 尝试备用服务器
    for _, server := range fm.secondary {
        if fm.checker.IsHealthy(server) {
            fmt.Printf("故障转移到备用服务器: %s\n", server)
            fm.current = server
            return fm.current
        }
    }

    // 所有服务器都不健康,返回主服务器
    fmt.Printf("所有服务器都不健康,使用主服务器: %s\n", fm.primary)
    fm.current = fm.primary
    return fm.current
}

完整示例:高性能聊天服务器 #

package main

import (
    "bufio"
    "context"
    "encoding/json"
    "fmt"
    "net"
    "sync"
    "time"
)

type ChatServer struct {
    listener    net.Listener
    clients     map[string]*Client
    rooms       map[string]*Room
    connManager *ConnectionManager
    memPool     *MemoryPool
    mutex       sync.RWMutex
    ctx         context.Context
    cancel      context.CancelFunc
}

type Client struct {
    ID       string
    Conn     net.Conn
    Username string
    Room     string
    Writer   *bufio.Writer
    LastSeen time.Time
}

type Room struct {
    Name    string
    Clients map[string]*Client
    mutex   sync.RWMutex
}

type Message struct {
    Type      string    `json:"type"`
    From      string    `json:"from"`
    To        string    `json:"to,omitempty"`
    Room      string    `json:"room,omitempty"`
    Content   string    `json:"content"`
    Timestamp time.Time `json:"timestamp"`
}

func NewChatServer(address string) (*ChatServer, error) {
    listener, err := net.Listen("tcp", address)
    if err != nil {
        return nil, err
    }

    ctx, cancel := context.WithCancel(context.Background())

    return &ChatServer{
        listener:    listener,
        clients:     make(map[string]*Client),
        rooms:       make(map[string]*Room),
        connManager: NewConnectionManager(5 * time.Minute),
        memPool:     NewMemoryPool(),
        ctx:         ctx,
        cancel:      cancel,
    }, nil
}

func (s *ChatServer) Start() error {
    fmt.Printf("聊天服务器启动在 %s\n", s.listener.Addr())

    for {
        select {
        case <-s.ctx.Done():
            return nil
        default:
            conn, err := s.listener.Accept()
            if err != nil {
                select {
                case <-s.ctx.Done():
                    return nil
                default:
                    fmt.Printf("接受连接失败: %v\n", err)
                    continue
                }
            }

            go s.handleClient(conn)
        }
    }
}

func (s *ChatServer) handleClient(conn net.Conn) {
    defer conn.Close()

    // 添加到连接管理器
    managed := s.connManager.AddConnection(conn)
    defer s.connManager.RemoveConnection(conn)

    client := &Client{
        ID:       conn.RemoteAddr().String(),
        Conn:     conn,
        Writer:   bufio.NewWriter(conn),
        LastSeen: time.Now(),
    }

    reader := bufio.NewReader(conn)

    for {
        select {
        case <-s.ctx.Done():
            return
        default:
            conn.SetReadDeadline(time.Now().Add(30 * time.Second))

            line, err := reader.ReadString('\n')
            if err != nil {
                s.removeClient(client)
                return
            }

            // 更新活动时间
            client.LastSeen = time.Now()
            s.connManager.UpdateActivity(conn, 0, int64(len(line)))

            var msg Message
            if err := json.Unmarshal([]byte(line), &msg); err != nil {
                s.sendError(client, "无效的消息格式")
                continue
            }

            s.handleMessage(client, &msg)
        }
    }
}

func (s *ChatServer) handleMessage(client *Client, msg *Message) {
    msg.From = client.Username
    msg.Timestamp = time.Now()

    switch msg.Type {
    case "join":
        s.handleJoin(client, msg)
    case "leave":
        s.handleLeave(client, msg)
    case "chat":
        s.handleChat(client, msg)
    case "private":
        s.handlePrivateMessage(client, msg)
    default:
        s.sendError(client, "未知的消息类型")
    }
}

func (s *ChatServer) handleJoin(client *Client, msg *Message) {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    client.Username = msg.Content
    client.Room = msg.Room

    // 添加到客户端列表
    s.clients[client.ID] = client

    // 添加到房间
    room, exists := s.rooms[msg.Room]
    if !exists {
        room = &Room{
            Name:    msg.Room,
            Clients: make(map[string]*Client),
        }
        s.rooms[msg.Room] = room
    }

    room.mutex.Lock()
    room.Clients[client.ID] = client
    room.mutex.Unlock()

    // 通知房间内其他用户
    notification := &Message{
        Type:      "notification",
        Content:   fmt.Sprintf("%s 加入了房间", client.Username),
        Room:      msg.Room,
        Timestamp: time.Now(),
    }

    s.broadcastToRoom(msg.Room, notification, client.ID)
}

func (s *ChatServer) handleChat(client *Client, msg *Message) {
    if client.Room == "" {
        s.sendError(client, "请先加入房间")
        return
    }

    msg.Room = client.Room
    s.broadcastToRoom(client.Room, msg, "")
}

func (s *ChatServer) broadcastToRoom(roomName string, msg *Message, excludeID string) {
    s.mutex.RLock()
    room, exists := s.rooms[roomName]
    s.mutex.RUnlock()

    if !exists {
        return
    }

    room.mutex.RLock()
    defer room.mutex.RUnlock()

    data, _ := json.Marshal(msg)
    data = append(data, '\n')

    for clientID, client := range room.Clients {
        if clientID != excludeID {
            s.sendToClient(client, data)
        }
    }
}

func (s *ChatServer) sendToClient(client *Client, data []byte) {
    client.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
    client.Writer.Write(data)
    client.Writer.Flush()

    s.connManager.UpdateActivity(client.Conn, int64(len(data)), 0)
}

func (s *ChatServer) sendError(client *Client, message string) {
    errorMsg := &Message{
        Type:      "error",
        Content:   message,
        Timestamp: time.Now(),
    }

    data, _ := json.Marshal(errorMsg)
    data = append(data, '\n')
    s.sendToClient(client, data)
}

func (s *ChatServer) removeClient(client *Client) {
    s.mutex.Lock()
    defer s.mutex.Unlock()

    delete(s.clients, client.ID)

    if client.Room != "" {
        if room, exists := s.rooms[client.Room]; exists {
            room.mutex.Lock()
            delete(room.Clients, client.ID)
            room.mutex.Unlock()

            // 通知房间内其他用户
            if client.Username != "" {
                notification := &Message{
                    Type:      "notification",
                    Content:   fmt.Sprintf("%s 离开了房间", client.Username),
                    Room:      client.Room,
                    Timestamp: time.Now(),
                }
                s.broadcastToRoom(client.Room, notification, client.ID)
            }
        }
    }
}

func (s *ChatServer) GetStats() map[string]interface{} {
    s.mutex.RLock()
    defer s.mutex.RUnlock()

    stats := map[string]interface{}{
        "total_clients": len(s.clients),
        "total_rooms":   len(s.rooms),
        "rooms":         make(map[string]int),
    }

    for roomName, room := range s.rooms {
        room.mutex.RLock()
        stats["rooms"].(map[string]int)[roomName] = len(room.Clients)
        room.mutex.RUnlock()
    }

    // 合并连接管理器统计
    connStats := s.connManager.GetStats()
    for k, v := range connStats {
        stats[k] = v
    }

    return stats
}

func (s *ChatServer) Shutdown() error {
    s.cancel()
    s.connManager.Shutdown()
    return s.listener.Close()
}

func main() {
    server, err := NewChatServer(":8080")
    if err != nil {
        fmt.Printf("创建服务器失败: %v\n", err)
        return
    }

    // 启动统计协程
    go func() {
        ticker := time.NewTicker(30 * time.Second)
        defer ticker.Stop()

        for range ticker.C {
            stats := server.GetStats()
            fmt.Printf("服务器统计: %+v\n", stats)
        }
    }()

    fmt.Println("高性能聊天服务器启动...")
    if err := server.Start(); err != nil {
        fmt.Printf("服务器运行失败: %v\n", err)
    }
}

小结 #

本节详细介绍了 TCP 并发服务器的设计与实现,包括:

  1. 并发模型 - Goroutine-per-Connection 和 Worker Pool 模型
  2. 连接管理 - 连接池管理和生命周期控制
  3. 负载均衡 - 多种负载均衡算法的实现
  4. 性能优化 - 零拷贝、内存池、批量处理等优化技术
  5. 高可用性 - 健康检查和故障转移机制
  6. 完整示例 - 高性能聊天服务器的实现

掌握这些技术后,你就能够构建高性能、高可用的 TCP 并发服务器。在下一节中,我们将学习 UDP 广播与组播技术。