4.2.1 TCP 编程基础

4.2.1 TCP 编程基础 #

TCP(Transmission Control Protocol)是一种面向连接的、可靠的传输层协议。它提供了数据的有序传输、错误检测和重传机制,是构建可靠网络应用的基础。Go 语言通过 net 包提供了完善的 TCP 编程支持,本节将详细介绍如何在 Go 中进行 TCP 编程。

TCP 协议特性 #

面向连接 #

TCP 是面向连接的协议,在数据传输前需要建立连接,传输完成后需要关闭连接:

客户端                    服务器
   |                        |
   |-------- SYN --------->|  (建立连接请求)
   |<------- SYN+ACK ------|  (确认并请求连接)
   |-------- ACK --------->|  (确认连接建立)
   |                        |
   |<====== 数据传输 ======>|
   |                        |
   |-------- FIN --------->|  (关闭连接请求)
   |<------- ACK ----------|  (确认关闭)
   |<------- FIN ----------|  (服务器关闭请求)
   |-------- ACK --------->|  (确认关闭)

可靠传输 #

TCP 提供可靠的数据传输保证:

  • 数据完整性 - 通过校验和检测数据错误
  • 数据有序性 - 通过序列号保证数据按序到达
  • 流量控制 - 通过滑动窗口控制发送速率
  • 拥塞控制 - 根据网络状况调整传输速率

TCP 客户端编程 #

基础 TCP 客户端 #

package main

import (
    "fmt"
    "net"
    "time"
)

func basicTCPClient() {
    // 连接到服务器
    conn, err := net.Dial("tcp", "localhost:8080")
    if err != nil {
        fmt.Printf("连接失败: %v\n", err)
        return
    }
    defer conn.Close()

    fmt.Printf("连接成功: %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr())

    // 发送数据
    message := "Hello, TCP Server!"
    _, err = conn.Write([]byte(message))
    if err != nil {
        fmt.Printf("发送数据失败: %v\n", err)
        return
    }

    // 接收响应
    buffer := make([]byte, 1024)
    n, err := conn.Read(buffer)
    if err != nil {
        fmt.Printf("接收数据失败: %v\n", err)
        return
    }

    fmt.Printf("收到响应: %s\n", string(buffer[:n]))
}

带超时的 TCP 客户端 #

func timeoutTCPClient() {
    // 设置连接超时
    conn, err := net.DialTimeout("tcp", "localhost:8080", 5*time.Second)
    if err != nil {
        fmt.Printf("连接超时: %v\n", err)
        return
    }
    defer conn.Close()

    // 设置读写超时
    conn.SetReadDeadline(time.Now().Add(10 * time.Second))
    conn.SetWriteDeadline(time.Now().Add(10 * time.Second))

    // 发送数据
    _, err = conn.Write([]byte("Hello with timeout"))
    if err != nil {
        fmt.Printf("发送失败: %v\n", err)
        return
    }

    // 接收数据
    buffer := make([]byte, 1024)
    n, err := conn.Read(buffer)
    if err != nil {
        if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
            fmt.Println("读取超时")
        } else {
            fmt.Printf("读取失败: %v\n", err)
        }
        return
    }

    fmt.Printf("收到数据: %s\n", string(buffer[:n]))
}

高级 TCP 客户端 #

import (
    "bufio"
    "fmt"
    "io"
    "net"
    "time"
)

type TCPClient struct {
    conn   net.Conn
    reader *bufio.Reader
    writer *bufio.Writer
}

func NewTCPClient(address string) (*TCPClient, error) {
    conn, err := net.DialTimeout("tcp", address, 10*time.Second)
    if err != nil {
        return nil, err
    }

    return &TCPClient{
        conn:   conn,
        reader: bufio.NewReader(conn),
        writer: bufio.NewWriter(conn),
    }, nil
}

func (c *TCPClient) SendMessage(message string) error {
    // 发送消息长度(4字节)+ 消息内容
    length := len(message)

    // 写入消息长度
    err := c.writer.WriteByte(byte(length >> 24))
    if err != nil {
        return err
    }
    err = c.writer.WriteByte(byte(length >> 16))
    if err != nil {
        return err
    }
    err = c.writer.WriteByte(byte(length >> 8))
    if err != nil {
        return err
    }
    err = c.writer.WriteByte(byte(length))
    if err != nil {
        return err
    }

    // 写入消息内容
    _, err = c.writer.WriteString(message)
    if err != nil {
        return err
    }

    // 刷新缓冲区
    return c.writer.Flush()
}

func (c *TCPClient) ReceiveMessage() (string, error) {
    // 读取消息长度(4字节)
    lengthBytes := make([]byte, 4)
    _, err := io.ReadFull(c.reader, lengthBytes)
    if err != nil {
        return "", err
    }

    length := int(lengthBytes[0])<<24 | int(lengthBytes[1])<<16 |
              int(lengthBytes[2])<<8 | int(lengthBytes[3])

    // 读取消息内容
    messageBytes := make([]byte, length)
    _, err = io.ReadFull(c.reader, messageBytes)
    if err != nil {
        return "", err
    }

    return string(messageBytes), nil
}

func (c *TCPClient) Close() error {
    return c.conn.Close()
}

func demonstrateAdvancedClient() {
    client, err := NewTCPClient("localhost:8080")
    if err != nil {
        fmt.Printf("创建客户端失败: %v\n", err)
        return
    }
    defer client.Close()

    // 发送消息
    err = client.SendMessage("Hello, Advanced TCP!")
    if err != nil {
        fmt.Printf("发送消息失败: %v\n", err)
        return
    }

    // 接收响应
    response, err := client.ReceiveMessage()
    if err != nil {
        fmt.Printf("接收消息失败: %v\n", err)
        return
    }

    fmt.Printf("收到响应: %s\n", response)
}

TCP 服务器编程 #

基础 TCP 服务器 #

func basicTCPServer() {
    // 监听端口
    listener, err := net.Listen("tcp", ":8080")
    if err != nil {
        fmt.Printf("监听失败: %v\n", err)
        return
    }
    defer listener.Close()

    fmt.Println("TCP 服务器启动,监听端口 8080")

    for {
        // 接受连接
        conn, err := listener.Accept()
        if err != nil {
            fmt.Printf("接受连接失败: %v\n", err)
            continue
        }

        // 处理连接
        go handleConnection(conn)
    }
}

func handleConnection(conn net.Conn) {
    defer conn.Close()

    fmt.Printf("新连接: %s\n", conn.RemoteAddr())

    // 读取数据
    buffer := make([]byte, 1024)
    for {
        n, err := conn.Read(buffer)
        if err != nil {
            if err != io.EOF {
                fmt.Printf("读取错误: %v\n", err)
            }
            break
        }

        message := string(buffer[:n])
        fmt.Printf("收到消息: %s\n", message)

        // 回显消息
        response := fmt.Sprintf("Echo: %s", message)
        _, err = conn.Write([]byte(response))
        if err != nil {
            fmt.Printf("发送响应失败: %v\n", err)
            break
        }
    }

    fmt.Printf("连接关闭: %s\n", conn.RemoteAddr())
}

并发 TCP 服务器 #

import (
    "context"
    "fmt"
    "net"
    "sync"
    "time"
)

type TCPServer struct {
    address  string
    listener net.Listener
    clients  map[net.Conn]bool
    mutex    sync.RWMutex
    ctx      context.Context
    cancel   context.CancelFunc
}

func NewTCPServer(address string) *TCPServer {
    ctx, cancel := context.WithCancel(context.Background())
    return &TCPServer{
        address: address,
        clients: make(map[net.Conn]bool),
        ctx:     ctx,
        cancel:  cancel,
    }
}

func (s *TCPServer) Start() error {
    listener, err := net.Listen("tcp", s.address)
    if err != nil {
        return err
    }
    s.listener = listener

    fmt.Printf("TCP 服务器启动: %s\n", s.address)

    go s.acceptConnections()
    return nil
}

func (s *TCPServer) acceptConnections() {
    for {
        select {
        case <-s.ctx.Done():
            return
        default:
            // 设置接受连接的超时
            if tcpListener, ok := s.listener.(*net.TCPListener); ok {
                tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
            }

            conn, err := s.listener.Accept()
            if err != nil {
                if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
                    continue // 超时,继续循环
                }
                fmt.Printf("接受连接失败: %v\n", err)
                continue
            }

            s.addClient(conn)
            go s.handleClient(conn)
        }
    }
}

func (s *TCPServer) handleClient(conn net.Conn) {
    defer func() {
        s.removeClient(conn)
        conn.Close()
    }()

    fmt.Printf("客户端连接: %s\n", conn.RemoteAddr())

    buffer := make([]byte, 1024)
    for {
        select {
        case <-s.ctx.Done():
            return
        default:
            // 设置读取超时
            conn.SetReadDeadline(time.Now().Add(30 * time.Second))

            n, err := conn.Read(buffer)
            if err != nil {
                if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
                    continue // 读取超时,继续等待
                }
                if err != io.EOF {
                    fmt.Printf("读取错误: %v\n", err)
                }
                return
            }

            message := string(buffer[:n])
            fmt.Printf("收到消息 [%s]: %s\n", conn.RemoteAddr(), message)

            // 广播消息给所有客户端
            s.broadcast(message, conn)
        }
    }
}

func (s *TCPServer) addClient(conn net.Conn) {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    s.clients[conn] = true
}

func (s *TCPServer) removeClient(conn net.Conn) {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    delete(s.clients, conn)
}

func (s *TCPServer) broadcast(message string, sender net.Conn) {
    s.mutex.RLock()
    defer s.mutex.RUnlock()

    broadcastMsg := fmt.Sprintf("广播 [%s]: %s", sender.RemoteAddr(), message)

    for client := range s.clients {
        if client != sender {
            client.SetWriteDeadline(time.Now().Add(5 * time.Second))
            _, err := client.Write([]byte(broadcastMsg))
            if err != nil {
                fmt.Printf("广播失败 [%s]: %v\n", client.RemoteAddr(), err)
            }
        }
    }
}

func (s *TCPServer) Stop() error {
    s.cancel()

    if s.listener != nil {
        return s.listener.Close()
    }
    return nil
}

func (s *TCPServer) GetClientCount() int {
    s.mutex.RLock()
    defer s.mutex.RUnlock()
    return len(s.clients)
}

func demonstrateConcurrentServer() {
    server := NewTCPServer(":8080")

    err := server.Start()
    if err != nil {
        fmt.Printf("启动服务器失败: %v\n", err)
        return
    }

    // 运行服务器
    fmt.Println("服务器运行中,按 Ctrl+C 停止")

    // 定期打印客户端数量
    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            fmt.Printf("当前客户端数量: %d\n", server.GetClientCount())
        }
    }
}

TCP 连接管理 #

连接状态监控 #

type ConnectionMonitor struct {
    connections map[string]*ConnectionInfo
    mutex       sync.RWMutex
}

type ConnectionInfo struct {
    Conn        net.Conn
    ConnectedAt time.Time
    LastActive  time.Time
    BytesSent   int64
    BytesRecv   int64
}

func NewConnectionMonitor() *ConnectionMonitor {
    return &ConnectionMonitor{
        connections: make(map[string]*ConnectionInfo),
    }
}

func (cm *ConnectionMonitor) AddConnection(conn net.Conn) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    key := conn.RemoteAddr().String()
    cm.connections[key] = &ConnectionInfo{
        Conn:        conn,
        ConnectedAt: time.Now(),
        LastActive:  time.Now(),
    }
}

func (cm *ConnectionMonitor) UpdateActivity(conn net.Conn, bytesSent, bytesRecv int64) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    key := conn.RemoteAddr().String()
    if info, exists := cm.connections[key]; exists {
        info.LastActive = time.Now()
        info.BytesSent += bytesSent
        info.BytesRecv += bytesRecv
    }
}

func (cm *ConnectionMonitor) RemoveConnection(conn net.Conn) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    key := conn.RemoteAddr().String()
    delete(cm.connections, key)
}

func (cm *ConnectionMonitor) GetStats() map[string]ConnectionInfo {
    cm.mutex.RLock()
    defer cm.mutex.RUnlock()

    stats := make(map[string]ConnectionInfo)
    for key, info := range cm.connections {
        stats[key] = *info
    }
    return stats
}

func (cm *ConnectionMonitor) CleanupIdleConnections(timeout time.Duration) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()

    now := time.Now()
    for key, info := range cm.connections {
        if now.Sub(info.LastActive) > timeout {
            fmt.Printf("关闭空闲连接: %s\n", key)
            info.Conn.Close()
            delete(cm.connections, key)
        }
    }
}

连接池实现 #

type TCPConnectionPool struct {
    address     string
    maxSize     int
    connections chan net.Conn
    factory     func() (net.Conn, error)
    mutex       sync.Mutex
    closed      bool
}

func NewTCPConnectionPool(address string, maxSize int) *TCPConnectionPool {
    pool := &TCPConnectionPool{
        address:     address,
        maxSize:     maxSize,
        connections: make(chan net.Conn, maxSize),
    }

    pool.factory = func() (net.Conn, error) {
        return net.DialTimeout("tcp", address, 10*time.Second)
    }

    return pool
}

func (p *TCPConnectionPool) Get() (net.Conn, error) {
    p.mutex.Lock()
    if p.closed {
        p.mutex.Unlock()
        return nil, fmt.Errorf("连接池已关闭")
    }
    p.mutex.Unlock()

    select {
    case conn := <-p.connections:
        // 检查连接是否仍然有效
        if p.isConnValid(conn) {
            return conn, nil
        }
        // 连接无效,创建新连接
        return p.factory()
    default:
        // 池中没有可用连接,创建新连接
        return p.factory()
    }
}

func (p *TCPConnectionPool) Put(conn net.Conn) {
    if conn == nil {
        return
    }

    p.mutex.Lock()
    if p.closed {
        p.mutex.Unlock()
        conn.Close()
        return
    }
    p.mutex.Unlock()

    select {
    case p.connections <- conn:
        // 成功放回池中
    default:
        // 池已满,关闭连接
        conn.Close()
    }
}

func (p *TCPConnectionPool) isConnValid(conn net.Conn) bool {
    // 设置很短的超时来测试连接
    conn.SetReadDeadline(time.Now().Add(1 * time.Millisecond))
    defer conn.SetReadDeadline(time.Time{})

    buffer := make([]byte, 1)
    _, err := conn.Read(buffer)

    // 如果是超时错误,说明连接正常但没有数据
    if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
        return true
    }

    // 其他错误说明连接有问题
    return false
}

func (p *TCPConnectionPool) Close() {
    p.mutex.Lock()
    defer p.mutex.Unlock()

    if p.closed {
        return
    }

    p.closed = true
    close(p.connections)

    // 关闭池中所有连接
    for conn := range p.connections {
        conn.Close()
    }
}

func demonstrateConnectionPool() {
    pool := NewTCPConnectionPool("localhost:8080", 10)
    defer pool.Close()

    // 使用连接池
    for i := 0; i < 5; i++ {
        go func(id int) {
            conn, err := pool.Get()
            if err != nil {
                fmt.Printf("获取连接失败 [%d]: %v\n", id, err)
                return
            }

            // 使用连接
            message := fmt.Sprintf("Hello from client %d", id)
            conn.Write([]byte(message))

            buffer := make([]byte, 1024)
            n, err := conn.Read(buffer)
            if err != nil {
                fmt.Printf("读取失败 [%d]: %v\n", id, err)
                conn.Close()
                return
            }

            fmt.Printf("客户端 %d 收到: %s\n", id, string(buffer[:n]))

            // 归还连接到池中
            pool.Put(conn)
        }(i)
    }

    time.Sleep(5 * time.Second)
}

错误处理和重连机制 #

自动重连客户端 #

type ReconnectClient struct {
    address        string
    conn           net.Conn
    maxRetries     int
    retryInterval  time.Duration
    connected      bool
    mutex          sync.RWMutex
    stopChan       chan struct{}
    reconnectChan  chan struct{}
}

func NewReconnectClient(address string) *ReconnectClient {
    return &ReconnectClient{
        address:       address,
        maxRetries:    5,
        retryInterval: 2 * time.Second,
        stopChan:      make(chan struct{}),
        reconnectChan: make(chan struct{}, 1),
    }
}

func (rc *ReconnectClient) Connect() error {
    conn, err := net.DialTimeout("tcp", rc.address, 10*time.Second)
    if err != nil {
        return err
    }

    rc.mutex.Lock()
    rc.conn = conn
    rc.connected = true
    rc.mutex.Unlock()

    fmt.Printf("连接成功: %s\n", rc.address)

    // 启动重连监控
    go rc.monitorConnection()

    return nil
}

func (rc *ReconnectClient) monitorConnection() {
    for {
        select {
        case <-rc.stopChan:
            return
        case <-rc.reconnectChan:
            rc.attemptReconnect()
        }
    }
}

func (rc *ReconnectClient) attemptReconnect() {
    rc.mutex.Lock()
    rc.connected = false
    if rc.conn != nil {
        rc.conn.Close()
        rc.conn = nil
    }
    rc.mutex.Unlock()

    fmt.Println("开始重连...")

    for attempt := 1; attempt <= rc.maxRetries; attempt++ {
        select {
        case <-rc.stopChan:
            return
        default:
            fmt.Printf("重连尝试 %d/%d\n", attempt, rc.maxRetries)

            conn, err := net.DialTimeout("tcp", rc.address, 10*time.Second)
            if err != nil {
                fmt.Printf("重连失败: %v\n", err)
                if attempt < rc.maxRetries {
                    time.Sleep(rc.retryInterval)
                }
                continue
            }

            rc.mutex.Lock()
            rc.conn = conn
            rc.connected = true
            rc.mutex.Unlock()

            fmt.Println("重连成功")
            return
        }
    }

    fmt.Println("重连失败,已达到最大重试次数")
}

func (rc *ReconnectClient) Send(data []byte) error {
    rc.mutex.RLock()
    conn := rc.conn
    connected := rc.connected
    rc.mutex.RUnlock()

    if !connected || conn == nil {
        return fmt.Errorf("连接未建立")
    }

    _, err := conn.Write(data)
    if err != nil {
        // 触发重连
        select {
        case rc.reconnectChan <- struct{}{}:
        default:
        }
        return err
    }

    return nil
}

func (rc *ReconnectClient) Receive() ([]byte, error) {
    rc.mutex.RLock()
    conn := rc.conn
    connected := rc.connected
    rc.mutex.RUnlock()

    if !connected || conn == nil {
        return nil, fmt.Errorf("连接未建立")
    }

    buffer := make([]byte, 1024)
    n, err := conn.Read(buffer)
    if err != nil {
        // 触发重连
        select {
        case rc.reconnectChan <- struct{}{}:
        default:
        }
        return nil, err
    }

    return buffer[:n], nil
}

func (rc *ReconnectClient) Close() {
    close(rc.stopChan)

    rc.mutex.Lock()
    defer rc.mutex.Unlock()

    if rc.conn != nil {
        rc.conn.Close()
        rc.conn = nil
    }
    rc.connected = false
}

func demonstrateReconnectClient() {
    client := NewReconnectClient("localhost:8080")
    defer client.Close()

    err := client.Connect()
    if err != nil {
        fmt.Printf("初始连接失败: %v\n", err)
        return
    }

    // 发送消息
    for i := 0; i < 10; i++ {
        message := fmt.Sprintf("Message %d", i)
        err := client.Send([]byte(message))
        if err != nil {
            fmt.Printf("发送失败: %v\n", err)
        } else {
            fmt.Printf("发送成功: %s\n", message)
        }

        time.Sleep(2 * time.Second)
    }
}

性能优化 #

零拷贝优化 #

import (
    "net"
    "os"
    "syscall"
)

func sendFileZeroCopy(conn net.Conn, filename string) error {
    file, err := os.Open(filename)
    if err != nil {
        return err
    }
    defer file.Close()

    // 获取文件信息
    stat, err := file.Stat()
    if err != nil {
        return err
    }

    // 使用 sendfile 系统调用进行零拷贝传输
    if tcpConn, ok := conn.(*net.TCPConn); ok {
        rawConn, err := tcpConn.SyscallConn()
        if err != nil {
            return err
        }

        var sendErr error
        err = rawConn.Control(func(fd uintptr) {
            _, sendErr = syscall.Sendfile(int(fd), int(file.Fd()), nil, int(stat.Size()))
        })

        if err != nil {
            return err
        }
        return sendErr
    }

    return fmt.Errorf("不支持零拷贝传输")
}

批量处理优化 #

type BatchProcessor struct {
    conn       net.Conn
    buffer     []byte
    batchSize  int
    flushTimer *time.Timer
}

func NewBatchProcessor(conn net.Conn, batchSize int) *BatchProcessor {
    return &BatchProcessor{
        conn:      conn,
        buffer:    make([]byte, 0, batchSize),
        batchSize: batchSize,
    }
}

func (bp *BatchProcessor) Write(data []byte) error {
    bp.buffer = append(bp.buffer, data...)

    if len(bp.buffer) >= bp.batchSize {
        return bp.flush()
    }

    // 设置定时刷新
    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
    }
    bp.flushTimer = time.AfterFunc(100*time.Millisecond, func() {
        bp.flush()
    })

    return nil
}

func (bp *BatchProcessor) flush() error {
    if len(bp.buffer) == 0 {
        return nil
    }

    _, err := bp.conn.Write(bp.buffer)
    bp.buffer = bp.buffer[:0] // 重置缓冲区

    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
        bp.flushTimer = nil
    }

    return err
}

func (bp *BatchProcessor) Close() error {
    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
    }
    return bp.flush()
}

小结 #

本节详细介绍了 Go 语言中的 TCP 编程基础,包括:

  1. TCP 协议特性 - 面向连接、可靠传输的特点
  2. TCP 客户端 - 基础客户端、超时控制、高级客户端实现
  3. TCP 服务器 - 基础服务器、并发服务器、连接管理
  4. 连接管理 - 连接监控、连接池、状态管理
  5. 错误处理 - 重连机制、错误恢复策略
  6. 性能优化 - 零拷贝、批量处理等优化技术

掌握这些 TCP 编程技术后,你就能够构建可靠、高性能的网络应用程序。在下一节中,我们将学习 UDP 编程的相关知识。