4.2.4 WebSocket 服务器

4.2.4 WebSocket 服务器 #

WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,它使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。本节将详细介绍如何在 Go 语言中开发 WebSocket 服务器。

WebSocket 协议基础 #

WebSocket 握手过程 #

WebSocket 连接始于 HTTP 握手,然后升级为 WebSocket 协议:

客户端                    服务器
   |                        |
   |------ HTTP 握手 ------>|  (Upgrade: websocket)
   |<----- 握手响应 --------|  (101 Switching Protocols)
   |                        |
   |<====== WebSocket =====>|  (全双工通信)
   |                        |

WebSocket 特点 #

  • 全双工通信 - 客户端和服务器可以同时发送数据
  • 低延迟 - 没有 HTTP 请求/响应的开销
  • 持久连接 - 连接保持打开状态
  • 支持二进制和文本数据 - 可以传输各种类型的数据
  • 自动心跳检测 - 内置的 ping/pong 机制

基础 WebSocket 服务器 #

使用 gorilla/websocket #

首先安装 gorilla/websocket 包:

go mod init websocket-example
go get github.com/gorilla/websocket

简单的 WebSocket 服务器 #

package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/gorilla/websocket"
)

// 升级器配置
var upgrader = websocket.Upgrader{
    ReadBufferSize:  1024,
    WriteBufferSize: 1024,
    // 允许跨域连接
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
}

func wsHandler(w http.ResponseWriter, r *http.Request) {
    // 升级 HTTP 连接为 WebSocket
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }
    defer conn.Close()

    fmt.Printf("新的 WebSocket 连接: %s\n", conn.RemoteAddr())

    // 处理消息
    for {
        // 读取消息
        messageType, message, err := conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket 错误: %v", err)
            }
            break
        }

        fmt.Printf("收到消息: %s\n", string(message))

        // 回显消息
        err = conn.WriteMessage(messageType, message)
        if err != nil {
            log.Printf("写入消息失败: %v", err)
            break
        }
    }

    fmt.Printf("WebSocket 连接关闭: %s\n", conn.RemoteAddr())
}

func main() {
    http.HandleFunc("/ws", wsHandler)

    // 提供静态文件服务
    http.Handle("/", http.FileServer(http.Dir("./static/")))

    fmt.Println("WebSocket 服务器启动在 :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

客户端测试页面 #

创建 static/index.html 文件:

<!DOCTYPE html>
<html>
  <head>
    <title>WebSocket 测试</title>
  </head>
  <body>
    <div id="messages"></div>
    <input type="text" id="messageInput" placeholder="输入消息..." />
    <button onclick="sendMessage()">发送</button>
    <button onclick="disconnect()">断开连接</button>

    <script>
      let ws = new WebSocket("ws://localhost:8080/ws");
      let messages = document.getElementById("messages");

      ws.onopen = function (event) {
        addMessage("连接已建立");
      };

      ws.onmessage = function (event) {
        addMessage("收到: " + event.data);
      };

      ws.onclose = function (event) {
        addMessage("连接已关闭");
      };

      ws.onerror = function (error) {
        addMessage("错误: " + error);
      };

      function sendMessage() {
        let input = document.getElementById("messageInput");
        if (input.value) {
          ws.send(input.value);
          addMessage("发送: " + input.value);
          input.value = "";
        }
      }

      function disconnect() {
        ws.close();
      }

      function addMessage(message) {
        let div = document.createElement("div");
        div.textContent = new Date().toLocaleTimeString() + " - " + message;
        messages.appendChild(div);
        messages.scrollTop = messages.scrollHeight;
      }

      // 回车发送消息
      document
        .getElementById("messageInput")
        .addEventListener("keypress", function (e) {
          if (e.key === "Enter") {
            sendMessage();
          }
        });
    </script>
  </body>
</html>

高级 WebSocket 服务器 #

连接管理器 #

import (
    "encoding/json"
    "sync"
    "time"
)

// 消息类型
type MessageType string

const (
    MessageTypeText      MessageType = "text"
    MessageTypeJoin      MessageType = "join"
    MessageTypeLeave     MessageType = "leave"
    MessageTypeBroadcast MessageType = "broadcast"
    MessageTypePrivate   MessageType = "private"
)

// 消息结构
type Message struct {
    Type      MessageType `json:"type"`
    From      string      `json:"from"`
    To        string      `json:"to,omitempty"`
    Content   string      `json:"content"`
    Timestamp time.Time   `json:"timestamp"`
}

// 客户端连接
type Client struct {
    ID       string
    Conn     *websocket.Conn
    Send     chan Message
    Hub      *Hub
    Username string
}

// 连接中心
type Hub struct {
    clients    map[*Client]bool
    broadcast  chan Message
    register   chan *Client
    unregister chan *Client
    mutex      sync.RWMutex
}

func NewHub() *Hub {
    return &Hub{
        clients:    make(map[*Client]bool),
        broadcast:  make(chan Message),
        register:   make(chan *Client),
        unregister: make(chan *Client),
    }
}

func (h *Hub) Run() {
    for {
        select {
        case client := <-h.register:
            h.mutex.Lock()
            h.clients[client] = true
            h.mutex.Unlock()

            fmt.Printf("客户端连接: %s (%s)\n", client.Username, client.ID)

            // 通知其他客户端有新用户加入
            joinMessage := Message{
                Type:      MessageTypeJoin,
                From:      "system",
                Content:   fmt.Sprintf("%s 加入了聊天室", client.Username),
                Timestamp: time.Now(),
            }
            h.broadcastToAll(joinMessage, client)

        case client := <-h.unregister:
            h.mutex.Lock()
            if _, ok := h.clients[client]; ok {
                delete(h.clients, client)
                close(client.Send)
                h.mutex.Unlock()

                fmt.Printf("客户端断开: %s (%s)\n", client.Username, client.ID)

                // 通知其他客户端有用户离开
                leaveMessage := Message{
                    Type:      MessageTypeLeave,
                    From:      "system",
                    Content:   fmt.Sprintf("%s 离开了聊天室", client.Username),
                    Timestamp: time.Now(),
                }
                h.broadcastToAll(leaveMessage, nil)
            } else {
                h.mutex.Unlock()
            }

        case message := <-h.broadcast:
            h.handleMessage(message)
        }
    }
}

func (h *Hub) handleMessage(message Message) {
    switch message.Type {
    case MessageTypeBroadcast:
        h.broadcastToAll(message, nil)
    case MessageTypePrivate:
        h.sendPrivateMessage(message)
    default:
        h.broadcastToAll(message, nil)
    }
}

func (h *Hub) broadcastToAll(message Message, exclude *Client) {
    h.mutex.RLock()
    defer h.mutex.RUnlock()

    for client := range h.clients {
        if client != exclude {
            select {
            case client.Send <- message:
            default:
                // 发送失败,关闭连接
                delete(h.clients, client)
                close(client.Send)
            }
        }
    }
}

func (h *Hub) sendPrivateMessage(message Message) {
    h.mutex.RLock()
    defer h.mutex.RUnlock()

    for client := range h.clients {
        if client.Username == message.To {
            select {
            case client.Send <- message:
                return
            default:
                // 发送失败
                return
            }
        }
    }
}

func (h *Hub) GetOnlineUsers() []string {
    h.mutex.RLock()
    defer h.mutex.RUnlock()

    users := make([]string, 0, len(h.clients))
    for client := range h.clients {
        users = append(users, client.Username)
    }
    return users
}

func (h *Hub) GetClientCount() int {
    h.mutex.RLock()
    defer h.mutex.RUnlock()
    return len(h.clients)
}

客户端处理 #

import (
    "crypto/rand"
    "encoding/hex"
    "time"
)

func (c *Client) readPump() {
    defer func() {
        c.Hub.unregister <- c
        c.Conn.Close()
    }()

    // 设置读取限制和超时
    c.Conn.SetReadLimit(512)
    c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
    c.Conn.SetPongHandler(func(string) error {
        c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
        return nil
    })

    for {
        var message Message
        err := c.Conn.ReadJSON(&message)
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                fmt.Printf("WebSocket 错误: %v", err)
            }
            break
        }

        // 设置消息元数据
        message.From = c.Username
        message.Timestamp = time.Now()

        // 发送到中心处理
        c.Hub.broadcast <- message
    }
}

func (c *Client) writePump() {
    ticker := time.NewTicker(54 * time.Second)
    defer func() {
        ticker.Stop()
        c.Conn.Close()
    }()

    for {
        select {
        case message, ok := <-c.Send:
            c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            if !ok {
                c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
                return
            }

            err := c.Conn.WriteJSON(message)
            if err != nil {
                fmt.Printf("写入消息失败: %v", err)
                return
            }

        case <-ticker.C:
            c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
            if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
                return
            }
        }
    }
}

func generateClientID() string {
    bytes := make([]byte, 16)
    rand.Read(bytes)
    return hex.EncodeToString(bytes)
}

func wsHandlerAdvanced(hub *Hub, w http.ResponseWriter, r *http.Request) {
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }

    // 获取用户名
    username := r.URL.Query().Get("username")
    if username == "" {
        username = "匿名用户"
    }

    client := &Client{
        ID:       generateClientID(),
        Conn:     conn,
        Send:     make(chan Message, 256),
        Hub:      hub,
        Username: username,
    }

    client.Hub.register <- client

    // 启动读写协程
    go client.writePump()
    go client.readPump()
}

聊天室实现 #

完整的聊天室服务器 #

func main() {
    hub := NewHub()
    go hub.Run()

    // WebSocket 处理器
    http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
        wsHandlerAdvanced(hub, w, r)
    })

    // API 端点
    http.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
        users := hub.GetOnlineUsers()
        w.Header().Set("Content-Type", "application/json")
        json.NewEncoder(w).Encode(map[string]interface{}{
            "users": users,
            "count": len(users),
        })
    })

    http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) {
        w.Header().Set("Content-Type", "application/json")
        json.NewEncoder(w).Encode(map[string]interface{}{
            "online_users": hub.GetClientCount(),
            "server_time":  time.Now().Format(time.RFC3339),
        })
    })

    // 静态文件服务
    http.Handle("/", http.FileServer(http.Dir("./static/")))

    fmt.Println("聊天室服务器启动在 :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

高级聊天室客户端 #

创建 static/chat.html

<!DOCTYPE html>
<html>
  <head>
    <title>WebSocket 聊天室</title>
    <style>
      body {
        font-family: Arial, sans-serif;
        margin: 20px;
      }
      #chatContainer {
        display: flex;
        height: 500px;
      }
      #messages {
        flex: 1;
        border: 1px solid #ccc;
        padding: 10px;
        overflow-y: auto;
      }
      #userList {
        width: 200px;
        border: 1px solid #ccc;
        padding: 10px;
        margin-left: 10px;
      }
      #inputContainer {
        margin-top: 10px;
      }
      #messageInput {
        width: 70%;
        padding: 5px;
      }
      #sendButton {
        padding: 5px 10px;
      }
      .message {
        margin: 5px 0;
      }
      .system {
        color: #666;
        font-style: italic;
      }
      .private {
        color: #007bff;
      }
      .timestamp {
        color: #999;
        font-size: 0.8em;
      }
    </style>
  </head>
  <body>
    <h1>WebSocket 聊天室</h1>

    <div id="loginContainer">
      <input type="text" id="usernameInput" placeholder="输入用户名" />
      <button onclick="connect()">连接</button>
    </div>

    <div id="chatContainer" style="display: none;">
      <div id="messages"></div>
      <div id="userList">
        <h3>在线用户</h3>
        <div id="users"></div>
      </div>
    </div>

    <div id="inputContainer" style="display: none;">
      <input type="text" id="messageInput" placeholder="输入消息..." />
      <button id="sendButton" onclick="sendMessage()">发送</button>
      <button onclick="disconnect()">断开连接</button>
    </div>

    <script>
      let ws = null;
      let username = "";

      function connect() {
        username = document.getElementById("usernameInput").value.trim();
        if (!username) {
          alert("请输入用户名");
          return;
        }

        ws = new WebSocket(
          `ws://localhost:8080/ws?username=${encodeURIComponent(username)}`
        );

        ws.onopen = function (event) {
          document.getElementById("loginContainer").style.display = "none";
          document.getElementById("chatContainer").style.display = "flex";
          document.getElementById("inputContainer").style.display = "block";
          addMessage("系统", "连接成功", "system");
          loadOnlineUsers();
        };

        ws.onmessage = function (event) {
          const message = JSON.parse(event.data);
          handleMessage(message);
        };

        ws.onclose = function (event) {
          addMessage("系统", "连接已关闭", "system");
          document.getElementById("loginContainer").style.display = "block";
          document.getElementById("chatContainer").style.display = "none";
          document.getElementById("inputContainer").style.display = "none";
        };

        ws.onerror = function (error) {
          addMessage("系统", "连接错误", "system");
        };
      }

      function disconnect() {
        if (ws) {
          ws.close();
        }
      }

      function sendMessage() {
        const input = document.getElementById("messageInput");
        const content = input.value.trim();
        if (!content) return;

        const message = {
          type: "broadcast",
          content: content,
        };

        ws.send(JSON.stringify(message));
        input.value = "";
      }

      function handleMessage(message) {
        switch (message.type) {
          case "join":
          case "leave":
            addMessage("系统", message.content, "system");
            loadOnlineUsers();
            break;
          case "broadcast":
          case "text":
            addMessage(message.from, message.content);
            break;
          case "private":
            addMessage(`${message.from} (私聊)`, message.content, "private");
            break;
        }
      }

      function addMessage(from, content, className = "") {
        const messages = document.getElementById("messages");
        const div = document.createElement("div");
        div.className = "message " + className;

        const timestamp = new Date().toLocaleTimeString();
        div.innerHTML = `
                <span class="timestamp">[${timestamp}]</span>
                <strong>${from}:</strong> ${content}
            `;

        messages.appendChild(div);
        messages.scrollTop = messages.scrollHeight;
      }

      function loadOnlineUsers() {
        fetch("/api/users")
          .then((response) => response.json())
          .then((data) => {
            const usersDiv = document.getElementById("users");
            usersDiv.innerHTML = "";
            data.users.forEach((user) => {
              const userDiv = document.createElement("div");
              userDiv.textContent = user;
              if (user === username) {
                userDiv.style.fontWeight = "bold";
              }
              usersDiv.appendChild(userDiv);
            });
          })
          .catch((error) => console.error("加载用户列表失败:", error));
      }

      // 回车发送消息
      document
        .getElementById("messageInput")
        .addEventListener("keypress", function (e) {
          if (e.key === "Enter") {
            sendMessage();
          }
        });

      document
        .getElementById("usernameInput")
        .addEventListener("keypress", function (e) {
          if (e.key === "Enter") {
            connect();
          }
        });
    </script>
  </body>
</html>

房间管理系统 #

多房间聊天实现 #

type Room struct {
    ID      string
    Name    string
    Clients map[*Client]bool
    mutex   sync.RWMutex
}

type RoomManager struct {
    rooms map[string]*Room
    mutex sync.RWMutex
}

func NewRoomManager() *RoomManager {
    return &RoomManager{
        rooms: make(map[string]*Room),
    }
}

func (rm *RoomManager) CreateRoom(id, name string) *Room {
    rm.mutex.Lock()
    defer rm.mutex.Unlock()

    room := &Room{
        ID:      id,
        Name:    name,
        Clients: make(map[*Client]bool),
    }
    rm.rooms[id] = room
    return room
}

func (rm *RoomManager) GetRoom(id string) (*Room, bool) {
    rm.mutex.RLock()
    defer rm.mutex.RUnlock()
    room, exists := rm.rooms[id]
    return room, exists
}

func (rm *RoomManager) JoinRoom(client *Client, roomID string) error {
    room, exists := rm.GetRoom(roomID)
    if !exists {
        return fmt.Errorf("房间不存在: %s", roomID)
    }

    room.mutex.Lock()
    defer room.mutex.Unlock()

    room.Clients[client] = true
    client.Room = room

    // 通知房间内其他用户
    message := Message{
        Type:      MessageTypeJoin,
        From:      "system",
        Content:   fmt.Sprintf("%s 加入了房间", client.Username),
        Timestamp: time.Now(),
    }
    room.broadcast(message, client)

    return nil
}

func (rm *RoomManager) LeaveRoom(client *Client) {
    if client.Room == nil {
        return
    }

    room := client.Room
    room.mutex.Lock()
    defer room.mutex.Unlock()

    delete(room.Clients, client)
    client.Room = nil

    // 通知房间内其他用户
    message := Message{
        Type:      MessageTypeLeave,
        From:      "system",
        Content:   fmt.Sprintf("%s 离开了房间", client.Username),
        Timestamp: time.Now(),
    }
    room.broadcast(message, nil)
}

func (r *Room) broadcast(message Message, exclude *Client) {
    for client := range r.Clients {
        if client != exclude {
            select {
            case client.Send <- message:
            default:
                delete(r.Clients, client)
                close(client.Send)
            }
        }
    }
}

func (r *Room) GetClientCount() int {
    r.mutex.RLock()
    defer r.mutex.RUnlock()
    return len(r.Clients)
}

// 更新 Client 结构
type Client struct {
    ID       string
    Conn     *websocket.Conn
    Send     chan Message
    Hub      *Hub
    Room     *Room
    Username string
}

性能优化 #

连接池管理 #

type ConnectionPool struct {
    maxConnections int
    activeConns    int
    waitingQueue   chan *Client
    mutex          sync.Mutex
}

func NewConnectionPool(maxConnections int) *ConnectionPool {
    return &ConnectionPool{
        maxConnections: maxConnections,
        waitingQueue:   make(chan *Client, 100),
    }
}

func (cp *ConnectionPool) TryAccept(client *Client) bool {
    cp.mutex.Lock()
    defer cp.mutex.Unlock()

    if cp.activeConns < cp.maxConnections {
        cp.activeConns++
        return true
    }

    // 连接数已满,加入等待队列
    select {
    case cp.waitingQueue <- client:
        return false
    default:
        // 等待队列也满了,拒绝连接
        return false
    }
}

func (cp *ConnectionPool) Release() {
    cp.mutex.Lock()
    defer cp.mutex.Unlock()

    cp.activeConns--

    // 处理等待队列
    select {
    case client := <-cp.waitingQueue:
        cp.activeConns++
        go func() {
            // 处理等待的客户端
            client.Hub.register <- client
        }()
    default:
        // 没有等待的连接
    }
}

消息压缩 #

import (
    "compress/flate"
)

var upgraderWithCompression = websocket.Upgrader{
    ReadBufferSize:    1024,
    WriteBufferSize:   1024,
    EnableCompression: true,
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
}

func wsHandlerWithCompression(w http.ResponseWriter, r *http.Request) {
    conn, err := upgraderWithCompression.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }

    // 设置压缩级别
    conn.EnableWriteCompression(true)
    conn.SetCompressionLevel(flate.BestSpeed)

    // 处理连接...
}

内存优化 #

import (
    "sync"
)

// 消息对象池
var messagePool = sync.Pool{
    New: func() interface{} {
        return &Message{}
    },
}

func getMessageFromPool() *Message {
    return messagePool.Get().(*Message)
}

func putMessageToPool(msg *Message) {
    // 重置消息内容
    msg.Type = ""
    msg.From = ""
    msg.To = ""
    msg.Content = ""
    msg.Timestamp = time.Time{}

    messagePool.Put(msg)
}

// 在处理消息时使用对象池
func (c *Client) readPumpOptimized() {
    defer func() {
        c.Hub.unregister <- c
        c.Conn.Close()
    }()

    for {
        message := getMessageFromPool()
        err := c.Conn.ReadJSON(message)
        if err != nil {
            putMessageToPool(message)
            break
        }

        // 处理消息
        message.From = c.Username
        message.Timestamp = time.Now()
        c.Hub.broadcast <- *message

        putMessageToPool(message)
    }
}

监控和调试 #

WebSocket 监控 #

type WSMonitor struct {
    totalConnections    int64
    activeConnections   int64
    messagesSent        int64
    messagesReceived    int64
    errors              int64
    mutex               sync.RWMutex
}

func (m *WSMonitor) RecordConnection() {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    m.totalConnections++
    m.activeConnections++
}

func (m *WSMonitor) RecordDisconnection() {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    m.activeConnections--
}

func (m *WSMonitor) RecordMessageSent() {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    m.messagesSent++
}

func (m *WSMonitor) RecordMessageReceived() {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    m.messagesReceived++
}

func (m *WSMonitor) RecordError() {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    m.errors++
}

func (m *WSMonitor) GetStats() map[string]int64 {
    m.mutex.RLock()
    defer m.mutex.RUnlock()

    return map[string]int64{
        "total_connections":  m.totalConnections,
        "active_connections": m.activeConnections,
        "messages_sent":      m.messagesSent,
        "messages_received":  m.messagesReceived,
        "errors":             m.errors,
    }
}

// 全局监控实例
var wsMonitor = &WSMonitor{}

// 监控端点
func monitorHandler(w http.ResponseWriter, r *http.Request) {
    stats := wsMonitor.GetStats()
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(stats)
}

健康检查 #

func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
    stats := wsMonitor.GetStats()

    health := map[string]interface{}{
        "status": "healthy",
        "timestamp": time.Now().Format(time.RFC3339),
        "websocket": map[string]interface{}{
            "active_connections": stats["active_connections"],
            "total_connections":  stats["total_connections"],
        },
    }

    // 检查连接数是否过高
    if stats["active_connections"] > 1000 {
        health["status"] = "warning"
        health["message"] = "高连接数"
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(health)
}

小结 #

本节详细介绍了 Go 语言中的 WebSocket 服务器开发,包括:

  1. WebSocket 基础 - 协议特点和握手过程
  2. 基础服务器 - 简单的 WebSocket 服务器实现
  3. 高级服务器 - 连接管理、消息处理和广播机制
  4. 聊天室实现 - 完整的实时聊天应用
  5. 房间管理 - 多房间聊天系统
  6. 性能优化 - 连接池、压缩、内存优化等技术
  7. 监控调试 - 连接监控和健康检查

掌握这些 WebSocket 开发技术后,你就能够构建高性能的实时通信应用。至此,我们完成了网络编程基础章节的学习,为后续的高级网络编程和系统编程打下了坚实的基础。