3.8.3 WebSocket 客户端开发

3.8.3 WebSocket 客户端开发 #

WebSocket 客户端是实时通信系统的重要组成部分,负责与服务端建立连接、发送接收消息、处理连接状态等。本节将详细介绍如何使用 Go 语言开发功能完整、稳定可靠的 WebSocket 客户端。

基础客户端实现 #

简单的 WebSocket 客户端 #

// cmd/client/main.go
package main

import (
    "bufio"
    "fmt"
    "log"
    "net/url"
    "os"
    "os/signal"
    "time"

    "github.com/gorilla/websocket"
)

func main() {
    // 解析服务器地址
    u := url.URL{Scheme: "ws", Host: "localhost:8080", Path: "/ws"}
    log.Printf("连接到 %s", u.String())

    // 建立 WebSocket 连接
    conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
    if err != nil {
        log.Fatal("连接失败:", err)
    }
    defer conn.Close()

    // 设置中断信号处理
    interrupt := make(chan os.Signal, 1)
    signal.Notify(interrupt, os.Interrupt)

    // 消息接收协程
    done := make(chan struct{})
    go func() {
        defer close(done)
        for {
            _, message, err := conn.ReadMessage()
            if err != nil {
                log.Println("读取消息失败:", err)
                return
            }
            fmt.Printf("收到消息: %s\n", message)
        }
    }()

    // 消息发送协程
    go func() {
        scanner := bufio.NewScanner(os.Stdin)
        fmt.Println("请输入消息 (输入 'quit' 退出):")

        for scanner.Scan() {
            text := scanner.Text()
            if text == "quit" {
                break
            }

            err := conn.WriteMessage(websocket.TextMessage, []byte(text))
            if err != nil {
                log.Println("发送消息失败:", err)
                return
            }
        }

        // 发送关闭消息
        err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
        if err != nil {
            log.Println("发送关闭消息失败:", err)
        }
    }()

    // 等待中断信号或连接关闭
    select {
    case <-done:
        log.Println("连接已关闭")
    case <-interrupt:
        log.Println("收到中断信号")

        // 优雅关闭连接
        err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
        if err != nil {
            log.Println("发送关闭消息失败:", err)
            return
        }

        select {
        case <-done:
        case <-time.After(time.Second):
        }
    }
}

高级客户端实现 #

客户端结构设计 #

// internal/client/client.go
package client

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "net/url"
    "sync"
    "time"

    "github.com/gorilla/websocket"
)

// 连接状态
type ConnectionState int

const (
    StateDisconnected ConnectionState = iota
    StateConnecting
    StateConnected
    StateReconnecting
)

func (s ConnectionState) String() string {
    switch s {
    case StateDisconnected:
        return "DISCONNECTED"
    case StateConnecting:
        return "CONNECTING"
    case StateConnected:
        return "CONNECTED"
    case StateReconnecting:
        return "RECONNECTING"
    default:
        return "UNKNOWN"
    }
}

// 消息类型
type Message struct {
    Type      string      `json:"type"`
    From      string      `json:"from,omitempty"`
    To        string      `json:"to,omitempty"`
    Room      string      `json:"room,omitempty"`
    Content   interface{} `json:"content"`
    Timestamp time.Time   `json:"timestamp,omitempty"`
}

// 事件处理器
type EventHandler interface {
    OnConnect()
    OnDisconnect(err error)
    OnMessage(msg *Message)
    OnError(err error)
}

// 客户端配置
type Config struct {
    URL                string
    Headers            http.Header
    ReconnectInterval  time.Duration
    MaxReconnectTries  int
    PingInterval       time.Duration
    PongTimeout        time.Duration
    WriteTimeout       time.Duration
    ReadTimeout        time.Duration
    MessageBufferSize  int
    EnableCompression  bool
}

// 默认配置
func DefaultConfig() *Config {
    return &Config{
        ReconnectInterval: 5 * time.Second,
        MaxReconnectTries: 10,
        PingInterval:      30 * time.Second,
        PongTimeout:       10 * time.Second,
        WriteTimeout:      10 * time.Second,
        ReadTimeout:       60 * time.Second,
        MessageBufferSize: 256,
        EnableCompression: true,
    }
}

// WebSocket 客户端
type Client struct {
    config          *Config
    conn            *websocket.Conn
    state           ConnectionState
    stateMu         sync.RWMutex

    // 事件处理
    eventHandler    EventHandler

    // 消息通道
    sendChan        chan []byte

    // 控制通道
    closeChan       chan struct{}
    reconnectChan   chan struct{}

    // 重连计数
    reconnectCount  int

    // 上下文
    ctx             context.Context
    cancel          context.CancelFunc

    // 等待组
    wg              sync.WaitGroup
}

// 创建新的客户端
func NewClient(config *Config, handler EventHandler) *Client {
    if config == nil {
        config = DefaultConfig()
    }

    ctx, cancel := context.WithCancel(context.Background())

    return &Client{
        config:        config,
        state:         StateDisconnected,
        eventHandler:  handler,
        sendChan:      make(chan []byte, config.MessageBufferSize),
        closeChan:     make(chan struct{}),
        reconnectChan: make(chan struct{}),
        ctx:           ctx,
        cancel:        cancel,
    }
}

// 连接到服务器
func (c *Client) Connect() error {
    c.stateMu.Lock()
    if c.state != StateDisconnected {
        c.stateMu.Unlock()
        return fmt.Errorf("客户端已连接或正在连接")
    }
    c.state = StateConnecting
    c.stateMu.Unlock()

    return c.connect()
}

// 内部连接方法
func (c *Client) connect() error {
    // 解析 URL
    u, err := url.Parse(c.config.URL)
    if err != nil {
        c.setState(StateDisconnected)
        return fmt.Errorf("无效的 URL: %v", err)
    }

    // 创建拨号器
    dialer := websocket.Dialer{
        HandshakeTimeout:  10 * time.Second,
        EnableCompression: c.config.EnableCompression,
    }

    // 建立连接
    conn, _, err := dialer.Dial(u.String(), c.config.Headers)
    if err != nil {
        c.setState(StateDisconnected)
        return fmt.Errorf("连接失败: %v", err)
    }

    c.conn = conn
    c.setState(StateConnected)
    c.reconnectCount = 0

    // 启动读写协程
    c.wg.Add(2)
    go c.readPump()
    go c.writePump()

    // 通知连接成功
    if c.eventHandler != nil {
        c.eventHandler.OnConnect()
    }

    log.Printf("WebSocket 连接已建立: %s", u.String())
    return nil
}

// 断开连接
func (c *Client) Disconnect() error {
    c.stateMu.Lock()
    if c.state == StateDisconnected {
        c.stateMu.Unlock()
        return nil
    }
    c.stateMu.Unlock()

    // 取消上下文
    c.cancel()

    // 关闭连接
    if c.conn != nil {
        c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
        c.conn.Close()
    }

    // 等待协程结束
    c.wg.Wait()

    c.setState(StateDisconnected)
    log.Println("WebSocket 连接已断开")
    return nil
}

// 发送消息
func (c *Client) SendMessage(msg *Message) error {
    if c.GetState() != StateConnected {
        return fmt.Errorf("客户端未连接")
    }

    data, err := json.Marshal(msg)
    if err != nil {
        return fmt.Errorf("消息序列化失败: %v", err)
    }

    select {
    case c.sendChan <- data:
        return nil
    case <-c.ctx.Done():
        return fmt.Errorf("客户端已关闭")
    default:
        return fmt.Errorf("发送缓冲区已满")
    }
}

// 发送文本消息
func (c *Client) SendText(text string) error {
    msg := &Message{
        Type:    "chat",
        Content: text,
    }
    return c.SendMessage(msg)
}

// 加入房间
func (c *Client) JoinRoom(room string) error {
    msg := &Message{
        Type: "join_room",
        Room: room,
    }
    return c.SendMessage(msg)
}

// 离开房间
func (c *Client) LeaveRoom(room string) error {
    msg := &Message{
        Type: "leave_room",
        Room: room,
    }
    return c.SendMessage(msg)
}

// 发送私人消息
func (c *Client) SendPrivateMessage(to, content string) error {
    msg := &Message{
        Type:    "private_message",
        To:      to,
        Content: content,
    }
    return c.SendMessage(msg)
}

// 获取连接状态
func (c *Client) GetState() ConnectionState {
    c.stateMu.RLock()
    defer c.stateMu.RUnlock()
    return c.state
}

// 设置连接状态
func (c *Client) setState(state ConnectionState) {
    c.stateMu.Lock()
    c.state = state
    c.stateMu.Unlock()
}

// 读取消息协程
func (c *Client) readPump() {
    defer c.wg.Done()
    defer func() {
        if c.conn != nil {
            c.conn.Close()
        }
    }()

    // 设置读取超时
    if c.config.ReadTimeout > 0 {
        c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
    }

    // 设置 Pong 处理器
    c.conn.SetPongHandler(func(string) error {
        if c.config.ReadTimeout > 0 {
            c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
        }
        return nil
    })

    for {
        select {
        case <-c.ctx.Done():
            return
        default:
        }

        _, message, err := c.conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket 读取错误: %v", err)
            }

            c.handleDisconnect(err)
            return
        }

        // 解析消息
        var msg Message
        if err := json.Unmarshal(message, &msg); err != nil {
            log.Printf("消息解析失败: %v", err)
            continue
        }

        // 处理消息
        if c.eventHandler != nil {
            c.eventHandler.OnMessage(&msg)
        }
    }
}

// 写入消息协程
func (c *Client) writePump() {
    defer c.wg.Done()
    defer func() {
        if c.conn != nil {
            c.conn.Close()
        }
    }()

    // Ping 定时器
    var pingTicker *time.Ticker
    if c.config.PingInterval > 0 {
        pingTicker = time.NewTicker(c.config.PingInterval)
        defer pingTicker.Stop()
    }

    for {
        select {
        case message := <-c.sendChan:
            if c.config.WriteTimeout > 0 {
                c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
            }

            if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
                log.Printf("发送消息失败: %v", err)
                c.handleDisconnect(err)
                return
            }

        case <-pingTicker.C:
            if c.config.WriteTimeout > 0 {
                c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
            }

            if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
                log.Printf("发送 Ping 失败: %v", err)
                c.handleDisconnect(err)
                return
            }

        case <-c.ctx.Done():
            if c.config.WriteTimeout > 0 {
                c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
            }
            c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
            return
        }
    }
}

// 处理断开连接
func (c *Client) handleDisconnect(err error) {
    c.setState(StateDisconnected)

    if c.eventHandler != nil {
        c.eventHandler.OnDisconnect(err)
    }

    // 如果不是主动关闭,尝试重连
    select {
    case <-c.ctx.Done():
        // 主动关闭,不重连
        return
    default:
        // 意外断开,尝试重连
        c.tryReconnect()
    }
}

// 尝试重连
func (c *Client) tryReconnect() {
    if c.config.MaxReconnectTries > 0 && c.reconnectCount >= c.config.MaxReconnectTries {
        log.Printf("达到最大重连次数 (%d),停止重连", c.config.MaxReconnectTries)
        return
    }

    c.setState(StateReconnecting)
    c.reconnectCount++

    log.Printf("尝试重连 (%d/%d)...", c.reconnectCount, c.config.MaxReconnectTries)

    // 等待重连间隔
    select {
    case <-time.After(c.config.ReconnectInterval):
    case <-c.ctx.Done():
        return
    }

    // 尝试连接
    if err := c.connect(); err != nil {
        log.Printf("重连失败: %v", err)
        if c.eventHandler != nil {
            c.eventHandler.OnError(err)
        }

        // 继续尝试重连
        go c.tryReconnect()
    }
}

事件处理器实现 #

// internal/client/handler.go
package client

import (
    "fmt"
    "log"
    "time"
)

// 默认事件处理器
type DefaultEventHandler struct {
    OnConnectFunc    func()
    OnDisconnectFunc func(error)
    OnMessageFunc    func(*Message)
    OnErrorFunc      func(error)
}

func (h *DefaultEventHandler) OnConnect() {
    log.Println("WebSocket 连接已建立")
    if h.OnConnectFunc != nil {
        h.OnConnectFunc()
    }
}

func (h *DefaultEventHandler) OnDisconnect(err error) {
    if err != nil {
        log.Printf("WebSocket 连接断开: %v", err)
    } else {
        log.Println("WebSocket 连接正常断开")
    }
    if h.OnDisconnectFunc != nil {
        h.OnDisconnectFunc(err)
    }
}

func (h *DefaultEventHandler) OnMessage(msg *Message) {
    log.Printf("收到消息: %+v", msg)
    if h.OnMessageFunc != nil {
        h.OnMessageFunc(msg)
    }
}

func (h *DefaultEventHandler) OnError(err error) {
    log.Printf("WebSocket 错误: %v", err)
    if h.OnErrorFunc != nil {
        h.OnErrorFunc(err)
    }
}

// 聊天事件处理器
type ChatEventHandler struct {
    UserID   string
    Username string
    Rooms    map[string]bool
}

func NewChatEventHandler(userID, username string) *ChatEventHandler {
    return &ChatEventHandler{
        UserID:   userID,
        Username: username,
        Rooms:    make(map[string]bool),
    }
}

func (h *ChatEventHandler) OnConnect() {
    fmt.Printf("✅ 连接成功!用户: %s\n", h.Username)
}

func (h *ChatEventHandler) OnDisconnect(err error) {
    if err != nil {
        fmt.Printf("❌ 连接断开: %v\n", err)
    } else {
        fmt.Println("👋 连接已关闭")
    }
}

func (h *ChatEventHandler) OnMessage(msg *Message) {
    switch msg.Type {
    case "chat":
        h.handleChatMessage(msg)
    case "system":
        h.handleSystemMessage(msg)
    case "user_joined":
        h.handleUserJoined(msg)
    case "user_left":
        h.handleUserLeft(msg)
    case "private_message":
        h.handlePrivateMessage(msg)
    case "room_joined":
        h.handleRoomJoined(msg)
    case "room_left":
        h.handleRoomLeft(msg)
    case "error":
        h.handleError(msg)
    default:
        fmt.Printf("🔍 未知消息类型: %s, 内容: %v\n", msg.Type, msg.Content)
    }
}

func (h *ChatEventHandler) OnError(err error) {
    fmt.Printf("⚠️  错误: %v\n", err)
}

func (h *ChatEventHandler) handleChatMessage(msg *Message) {
    timestamp := ""
    if !msg.Timestamp.IsZero() {
        timestamp = msg.Timestamp.Format("15:04:05")
    } else {
        timestamp = time.Now().Format("15:04:05")
    }

    if msg.Room != "" {
        fmt.Printf("[%s] [%s] %s: %v\n", timestamp, msg.Room, msg.From, msg.Content)
    } else {
        fmt.Printf("[%s] %s: %v\n", timestamp, msg.From, msg.Content)
    }
}

func (h *ChatEventHandler) handleSystemMessage(msg *Message) {
    fmt.Printf("🔔 系统消息: %v\n", msg.Content)
}

func (h *ChatEventHandler) handleUserJoined(msg *Message) {
    fmt.Printf("👤 %s 加入了房间 %s\n", msg.From, msg.Room)
}

func (h *ChatEventHandler) handleUserLeft(msg *Message) {
    fmt.Printf("👤 %s 离开了房间 %s\n", msg.From, msg.Room)
}

func (h *ChatEventHandler) handlePrivateMessage(msg *Message) {
    timestamp := time.Now().Format("15:04:05")
    fmt.Printf("[%s] 💬 %s (私聊): %v\n", timestamp, msg.From, msg.Content)
}

func (h *ChatEventHandler) handleRoomJoined(msg *Message) {
    h.Rooms[msg.Room] = true
    fmt.Printf("✅ 成功加入房间: %s\n", msg.Room)
}

func (h *ChatEventHandler) handleRoomLeft(msg *Message) {
    delete(h.Rooms, msg.Room)
    fmt.Printf("👋 已离开房间: %s\n", msg.Room)
}

func (h *ChatEventHandler) handleError(msg *Message) {
    fmt.Printf("❌ 服务器错误: %v\n", msg.Content)
}

客户端管理器 #

连接池管理 #

// internal/client/pool.go
package client

import (
    "fmt"
    "sync"
    "time"
)

// 连接池
type ConnectionPool struct {
    clients    map[string]*Client
    mu         sync.RWMutex
    maxClients int
}

// 创建连接池
func NewConnectionPool(maxClients int) *ConnectionPool {
    return &ConnectionPool{
        clients:    make(map[string]*Client),
        maxClients: maxClients,
    }
}

// 获取或创建客户端
func (p *ConnectionPool) GetClient(id string, config *Config, handler EventHandler) (*Client, error) {
    p.mu.Lock()
    defer p.mu.Unlock()

    // 检查是否已存在
    if client, exists := p.clients[id]; exists {
        return client, nil
    }

    // 检查连接数限制
    if len(p.clients) >= p.maxClients {
        return nil, fmt.Errorf("连接池已满,最大连接数: %d", p.maxClients)
    }

    // 创建新客户端
    client := NewClient(config, handler)
    p.clients[id] = client

    return client, nil
}

// 移除客户端
func (p *ConnectionPool) RemoveClient(id string) {
    p.mu.Lock()
    defer p.mu.Unlock()

    if client, exists := p.clients[id]; exists {
        client.Disconnect()
        delete(p.clients, id)
    }
}

// 获取所有客户端
func (p *ConnectionPool) GetAllClients() map[string]*Client {
    p.mu.RLock()
    defer p.mu.RUnlock()

    clients := make(map[string]*Client)
    for id, client := range p.clients {
        clients[id] = client
    }

    return clients
}

// 广播消息
func (p *ConnectionPool) Broadcast(msg *Message) {
    p.mu.RLock()
    defer p.mu.RUnlock()

    for _, client := range p.clients {
        if client.GetState() == StateConnected {
            client.SendMessage(msg)
        }
    }
}

// 关闭所有连接
func (p *ConnectionPool) CloseAll() {
    p.mu.Lock()
    defer p.mu.Unlock()

    for id, client := range p.clients {
        client.Disconnect()
        delete(p.clients, id)
    }
}

// 获取统计信息
func (p *ConnectionPool) GetStats() map[string]interface{} {
    p.mu.RLock()
    defer p.mu.RUnlock()

    stats := make(map[string]interface{})
    stats["total_clients"] = len(p.clients)
    stats["max_clients"] = p.maxClients

    stateCount := make(map[string]int)
    for _, client := range p.clients {
        state := client.GetState().String()
        stateCount[state]++
    }
    stats["state_count"] = stateCount

    return stats
}

重连策略 #

// internal/client/reconnect.go
package client

import (
    "math"
    "math/rand"
    "time"
)

// 重连策略接口
type ReconnectStrategy interface {
    NextDelay(attempt int) time.Duration
    ShouldReconnect(attempt int) bool
}

// 固定间隔重连策略
type FixedIntervalStrategy struct {
    Interval   time.Duration
    MaxRetries int
}

func (s *FixedIntervalStrategy) NextDelay(attempt int) time.Duration {
    return s.Interval
}

func (s *FixedIntervalStrategy) ShouldReconnect(attempt int) bool {
    return s.MaxRetries <= 0 || attempt < s.MaxRetries
}

// 指数退避重连策略
type ExponentialBackoffStrategy struct {
    InitialDelay time.Duration
    MaxDelay     time.Duration
    Multiplier   float64
    MaxRetries   int
    Jitter       bool
}

func (s *ExponentialBackoffStrategy) NextDelay(attempt int) time.Duration {
    delay := float64(s.InitialDelay) * math.Pow(s.Multiplier, float64(attempt))

    if delay > float64(s.MaxDelay) {
        delay = float64(s.MaxDelay)
    }

    // 添加随机抖动
    if s.Jitter {
        jitter := rand.Float64() * 0.1 * delay // 10% 抖动
        delay += jitter
    }

    return time.Duration(delay)
}

func (s *ExponentialBackoffStrategy) ShouldReconnect(attempt int) bool {
    return s.MaxRetries <= 0 || attempt < s.MaxRetries
}

// 线性退避重连策略
type LinearBackoffStrategy struct {
    InitialDelay time.Duration
    MaxDelay     time.Duration
    Increment    time.Duration
    MaxRetries   int
}

func (s *LinearBackoffStrategy) NextDelay(attempt int) time.Duration {
    delay := s.InitialDelay + time.Duration(attempt)*s.Increment

    if delay > s.MaxDelay {
        delay = s.MaxDelay
    }

    return delay
}

func (s *LinearBackoffStrategy) ShouldReconnect(attempt int) bool {
    return s.MaxRetries <= 0 || attempt < s.MaxRetries
}

// 带重连策略的客户端
type ReconnectableClient struct {
    *Client
    strategy ReconnectStrategy
}

// 创建带重连策略的客户端
func NewReconnectableClient(config *Config, handler EventHandler, strategy ReconnectStrategy) *ReconnectableClient {
    client := NewClient(config, handler)
    return &ReconnectableClient{
        Client:   client,
        strategy: strategy,
    }
}

// 重写重连逻辑
func (rc *ReconnectableClient) tryReconnectWithStrategy() {
    attempt := 0

    for rc.strategy.ShouldReconnect(attempt) {
        delay := rc.strategy.NextDelay(attempt)

        select {
        case <-time.After(delay):
        case <-rc.ctx.Done():
            return
        }

        if err := rc.connect(); err != nil {
            attempt++
            if rc.eventHandler != nil {
                rc.eventHandler.OnError(fmt.Errorf("重连失败 (尝试 %d): %v", attempt, err))
            }
            continue
        }

        // 重连成功
        return
    }

    // 达到最大重连次数
    if rc.eventHandler != nil {
        rc.eventHandler.OnError(fmt.Errorf("达到最大重连次数,停止重连"))
    }
}

实用工具和扩展 #

消息队列 #

// internal/client/queue.go
package client

import (
    "container/list"
    "sync"
    "time"
)

// 消息队列项
type QueueItem struct {
    Message   *Message
    Timestamp time.Time
    Retries   int
}

// 消息队列
type MessageQueue struct {
    queue    *list.List
    mu       sync.Mutex
    maxSize  int
    maxAge   time.Duration
}

// 创建消息队列
func NewMessageQueue(maxSize int, maxAge time.Duration) *MessageQueue {
    mq := &MessageQueue{
        queue:   list.New(),
        maxSize: maxSize,
        maxAge:  maxAge,
    }

    // 启动清理协程
    go mq.cleanup()

    return mq
}

// 添加消息到队列
func (mq *MessageQueue) Enqueue(msg *Message) bool {
    mq.mu.Lock()
    defer mq.mu.Unlock()

    // 检查队列大小
    if mq.queue.Len() >= mq.maxSize {
        // 移除最旧的消息
        oldest := mq.queue.Front()
        if oldest != nil {
            mq.queue.Remove(oldest)
        }
    }

    // 添加新消息
    item := &QueueItem{
        Message:   msg,
        Timestamp: time.Now(),
        Retries:   0,
    }

    mq.queue.PushBack(item)
    return true
}

// 从队列获取消息
func (mq *MessageQueue) Dequeue() *QueueItem {
    mq.mu.Lock()
    defer mq.mu.Unlock()

    front := mq.queue.Front()
    if front == nil {
        return nil
    }

    mq.queue.Remove(front)
    return front.Value.(*QueueItem)
}

// 获取队列大小
func (mq *MessageQueue) Size() int {
    mq.mu.Lock()
    defer mq.mu.Unlock()
    return mq.queue.Len()
}

// 清理过期消息
func (mq *MessageQueue) cleanup() {
    ticker := time.NewTicker(time.Minute)
    defer ticker.Stop()

    for range ticker.C {
        mq.mu.Lock()

        now := time.Now()
        for e := mq.queue.Front(); e != nil; {
            item := e.Value.(*QueueItem)
            if now.Sub(item.Timestamp) > mq.maxAge {
                next := e.Next()
                mq.queue.Remove(e)
                e = next
            } else {
                break // 队列是按时间顺序的,后面的都不会过期
            }
        }

        mq.mu.Unlock()
    }
}

// 带队列的客户端
type QueuedClient struct {
    *Client
    queue *MessageQueue
}

// 创建带队列的客户端
func NewQueuedClient(config *Config, handler EventHandler, queueSize int, maxAge time.Duration) *QueuedClient {
    client := NewClient(config, handler)
    queue := NewMessageQueue(queueSize, maxAge)

    qc := &QueuedClient{
        Client: client,
        queue:  queue,
    }

    // 启动队列处理协程
    go qc.processQueue()

    return qc
}

// 发送消息(通过队列)
func (qc *QueuedClient) SendMessageQueued(msg *Message) error {
    if qc.GetState() == StateConnected {
        // 直接发送
        return qc.Client.SendMessage(msg)
    } else {
        // 添加到队列
        qc.queue.Enqueue(msg)
        return nil
    }
}

// 处理队列中的消息
func (qc *QueuedClient) processQueue() {
    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            if qc.GetState() == StateConnected {
                // 处理队列中的消息
                for qc.queue.Size() > 0 {
                    item := qc.queue.Dequeue()
                    if item == nil {
                        break
                    }

                    if err := qc.Client.SendMessage(item.Message); err != nil {
                        // 发送失败,重新入队(如果重试次数未超限)
                        item.Retries++
                        if item.Retries < 3 {
                            qc.queue.Enqueue(item.Message)
                        }
                    }
                }
            }
        case <-qc.ctx.Done():
            return
        }
    }
}

性能监控 #

// internal/client/metrics.go
package client

import (
    "sync"
    "time"
)

// 客户端指标
type ClientMetrics struct {
    mu                    sync.RWMutex
    connectTime           time.Time
    totalMessagesSent     int64
    totalMessagesReceived int64
    totalReconnects       int64
    lastMessageTime       time.Time
    connectionDuration    time.Duration
    averageLatency        time.Duration
    latencyMeasurements   []time.Duration
}

// 创建指标收集器
func NewClientMetrics() *ClientMetrics {
    return &ClientMetrics{
        latencyMeasurements: make([]time.Duration, 0, 100),
    }
}

// 记录连接建立
func (m *ClientMetrics) OnConnect() {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.connectTime = time.Now()
}

// 记录连接断开
func (m *ClientMetrics) OnDisconnect() {
    m.mu.Lock()
    defer m.mu.Unlock()
    if !m.connectTime.IsZero() {
        m.connectionDuration += time.Since(m.connectTime)
    }
}

// 记录重连
func (m *ClientMetrics) OnReconnect() {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.totalReconnects++
}

// 记录发送消息
func (m *ClientMetrics) OnMessageSent() {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.totalMessagesSent++
}

// 记录接收消息
func (m *ClientMetrics) OnMessageReceived() {
    m.mu.Lock()
    defer m.mu.Unlock()
    m.totalMessagesReceived++
    m.lastMessageTime = time.Now()
}

// 记录延迟
func (m *ClientMetrics) RecordLatency(latency time.Duration) {
    m.mu.Lock()
    defer m.mu.Unlock()

    m.latencyMeasurements = append(m.latencyMeasurements, latency)

    // 保持最近 100 个测量值
    if len(m.latencyMeasurements) > 100 {
        m.latencyMeasurements = m.latencyMeasurements[1:]
    }

    // 计算平均延迟
    var total time.Duration
    for _, l := range m.latencyMeasurements {
        total += l
    }
    m.averageLatency = total / time.Duration(len(m.latencyMeasurements))
}

// 获取统计信息
func (m *ClientMetrics) GetStats() map[string]interface{} {
    m.mu.RLock()
    defer m.mu.RUnlock()

    stats := make(map[string]interface{})
    stats["total_messages_sent"] = m.totalMessagesSent
    stats["total_messages_received"] = m.totalMessagesReceived
    stats["total_reconnects"] = m.totalReconnects
    stats["connection_duration_seconds"] = m.connectionDuration.Seconds()
    stats["average_latency_ms"] = m.averageLatency.Milliseconds()

    if !m.lastMessageTime.IsZero() {
        stats["last_message_ago_seconds"] = time.Since(m.lastMessageTime).Seconds()
    }

    return stats
}

// 带指标的客户端
type MetricsClient struct {
    *Client
    metrics *ClientMetrics
}

// 创建带指标的客户端
func NewMetricsClient(config *Config, handler EventHandler) *MetricsClient {
    metrics := NewClientMetrics()

    // 包装事件处理器
    wrappedHandler := &metricsEventHandler{
        EventHandler: handler,
        metrics:      metrics,
    }

    client := NewClient(config, wrappedHandler)

    return &MetricsClient{
        Client:  client,
        metrics: metrics,
    }
}

// 获取指标
func (mc *MetricsClient) GetMetrics() *ClientMetrics {
    return mc.metrics
}

// 指标事件处理器包装器
type metricsEventHandler struct {
    EventHandler
    metrics *ClientMetrics
}

func (h *metricsEventHandler) OnConnect() {
    h.metrics.OnConnect()
    if h.EventHandler != nil {
        h.EventHandler.OnConnect()
    }
}

func (h *metricsEventHandler) OnDisconnect(err error) {
    h.metrics.OnDisconnect()
    if h.EventHandler != nil {
        h.EventHandler.OnDisconnect(err)
    }
}

func (h *metricsEventHandler) OnMessage(msg *Message) {
    h.metrics.OnMessageReceived()
    if h.EventHandler != nil {
        h.EventHandler.OnMessage(msg)
    }
}

func (h *metricsEventHandler) OnError(err error) {
    if h.EventHandler != nil {
        h.EventHandler.OnError(err)
    }
}

完整的客户端示例 #

// cmd/chat-client/main.go
package main

import (
    "bufio"
    "flag"
    "fmt"
    "log"
    "os"
    "strings"
    "time"

    "your-project/internal/client"
)

var (
    serverURL = flag.String("url", "ws://localhost:8080/ws", "WebSocket 服务器地址")
    userID    = flag.String("user", "", "用户 ID")
    username  = flag.String("name", "", "用户名")
)

func main() {
    flag.Parse()

    if *userID == "" {
        *userID = fmt.Sprintf("user_%d", time.Now().Unix())
    }
    if *username == "" {
        *username = *userID
    }

    // 创建配置
    config := client.DefaultConfig()
    config.URL = *serverURL

    // 创建事件处理器
    handler := client.NewChatEventHandler(*userID, *username)

    // 创建客户端
    c := client.NewMetricsClient(config, handler)

    // 连接到服务器
    if err := c.Connect(); err != nil {
        log.Fatalf("连接失败: %v", err)
    }

    // 启动命令行界面
    go startCLI(c)

    // 等待用户输入退出命令
    fmt.Println("聊天客户端已启动!")
    fmt.Println("命令:")
    fmt.Println("  /join <room>     - 加入房间")
    fmt.Println("  /leave <room>    - 离开房间")
    fmt.Println("  /pm <user> <msg> - 发送私信")
    fmt.Println("  /stats           - 显示统计信息")
    fmt.Println("  /quit            - 退出")
    fmt.Println()

    scanner := bufio.NewScanner(os.Stdin)
    for scanner.Scan() {
        input := strings.TrimSpace(scanner.Text())
        if input == "" {
            continue
        }

        if strings.HasPrefix(input, "/") {
            handleCommand(c, input)
        } else {
            // 发送聊天消息
            c.SendText(input)
        }
    }
}

func handleCommand(c *client.MetricsClient, command string) {
    parts := strings.Fields(command)
    if len(parts) == 0 {
        return
    }

    switch parts[0] {
    case "/join":
        if len(parts) < 2 {
            fmt.Println("用法: /join <room>")
            return
        }
        c.JoinRoom(parts[1])

    case "/leave":
        if len(parts) < 2 {
            fmt.Println("用法: /leave <room>")
            return
        }
        c.LeaveRoom(parts[1])

    case "/pm":
        if len(parts) < 3 {
            fmt.Println("用法: /pm <user> <message>")
            return
        }
        message := strings.Join(parts[2:], " ")
        c.SendPrivateMessage(parts[1], message)

    case "/stats":
        stats := c.GetMetrics().GetStats()
        fmt.Println("📊 客户端统计:")
        for key, value := range stats {
            fmt.Printf("  %s: %v\n", key, value)
        }

    case "/quit":
        fmt.Println("正在断开连接...")
        c.Disconnect()
        os.Exit(0)

    default:
        fmt.Printf("未知命令: %s\n", parts[0])
    }
}

func startCLI(c *client.MetricsClient) {
    // 这里可以添加更复杂的 CLI 界面
    // 比如使用 termui 或其他 TUI 库
}

小结 #

本节详细介绍了 WebSocket 客户端的开发:

  1. 基础实现:简单的 WebSocket 客户端连接和消息处理
  2. 高级客户端:状态管理、事件处理、自动重连
  3. 连接管理:连接池、重连策略、消息队列
  4. 扩展功能:性能监控、指标收集、错误处理
  5. 实用工具:命令行界面、配置管理、日志记录

通过这些内容,您可以构建出功能完整、稳定可靠的 WebSocket 客户端应用。在下一节中,我们将通过构建一个完整的实时聊天系统来综合运用所学知识。