3.8.2 WebSocket 服务端实现

3.8.2 WebSocket 服务端实现 #

WebSocket 服务端是实时通信系统的核心,负责处理客户端连接、消息路由、状态管理等关键功能。本节将详细介绍如何使用 Go 语言构建高性能、可扩展的 WebSocket 服务端。

基础服务端实现 #

使用 Gorilla WebSocket #

Gorilla WebSocket 是 Go 语言中最流行的 WebSocket 库,提供了完整的 WebSocket 协议实现。

go mod init websocket-server
go get github.com/gorilla/websocket

简单的 WebSocket 服务器 #

// main.go
package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/gorilla/websocket"
)

// WebSocket 升级器配置
var upgrader = websocket.Upgrader{
    // 允许跨域请求
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
    // 读写缓冲区大小
    ReadBufferSize:  1024,
    WriteBufferSize: 1024,
}

// 处理 WebSocket 连接
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
    // 升级 HTTP 连接为 WebSocket 连接
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }
    defer conn.Close()

    log.Printf("新的 WebSocket 连接: %s", conn.RemoteAddr())

    // 消息处理循环
    for {
        // 读取消息
        messageType, message, err := conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket 错误: %v", err)
            }
            break
        }

        log.Printf("收到消息 [%d]: %s", messageType, string(message))

        // 回显消息
        err = conn.WriteMessage(messageType, message)
        if err != nil {
            log.Printf("发送消息失败: %v", err)
            break
        }
    }

    log.Printf("WebSocket 连接关闭: %s", conn.RemoteAddr())
}

func main() {
    // 静态文件服务
    http.Handle("/", http.FileServer(http.Dir("./static/")))

    // WebSocket 端点
    http.HandleFunc("/ws", handleWebSocket)

    fmt.Println("WebSocket 服务器启动在 :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

客户端测试页面 #

创建 static/index.html 文件:

<!DOCTYPE html>
<html>
  <head>
    <title>WebSocket 测试</title>
    <meta charset="utf-8" />
  </head>
  <body>
    <div>
      <h2>WebSocket 测试客户端</h2>
      <div>
        <input
          type="text"
          id="messageInput"
          placeholder="输入消息..."
          style="width: 300px;"
        />
        <button onclick="sendMessage()">发送</button>
        <button onclick="connect()">连接</button>
        <button onclick="disconnect()">断开</button>
      </div>
      <div>
        <h3>连接状态: <span id="status">未连接</span></h3>
        <h3>消息记录:</h3>
        <div
          id="messages"
          style="border: 1px solid #ccc; height: 300px; overflow-y: scroll; padding: 10px;"
        ></div>
      </div>
    </div>

    <script>
      let ws = null;
      const statusElement = document.getElementById("status");
      const messagesElement = document.getElementById("messages");
      const messageInput = document.getElementById("messageInput");

      function connect() {
        if (ws && ws.readyState === WebSocket.OPEN) {
          addMessage("已经连接");
          return;
        }

        ws = new WebSocket("ws://localhost:8080/ws");

        ws.onopen = function (event) {
          statusElement.textContent = "已连接";
          statusElement.style.color = "green";
          addMessage("WebSocket 连接已建立");
        };

        ws.onmessage = function (event) {
          addMessage("收到: " + event.data);
        };

        ws.onclose = function (event) {
          statusElement.textContent = "已断开";
          statusElement.style.color = "red";
          addMessage("WebSocket 连接已关闭");
        };

        ws.onerror = function (error) {
          statusElement.textContent = "错误";
          statusElement.style.color = "red";
          addMessage("WebSocket 错误: " + error);
        };
      }

      function disconnect() {
        if (ws) {
          ws.close();
        }
      }

      function sendMessage() {
        if (!ws || ws.readyState !== WebSocket.OPEN) {
          addMessage("请先连接 WebSocket");
          return;
        }

        const message = messageInput.value.trim();
        if (message) {
          ws.send(message);
          addMessage("发送: " + message);
          messageInput.value = "";
        }
      }

      function addMessage(message) {
        const div = document.createElement("div");
        div.textContent = new Date().toLocaleTimeString() + " - " + message;
        messagesElement.appendChild(div);
        messagesElement.scrollTop = messagesElement.scrollHeight;
      }

      // 回车发送消息
      messageInput.addEventListener("keypress", function (e) {
        if (e.key === "Enter") {
          sendMessage();
        }
      });

      // 页面加载时自动连接
      window.onload = function () {
        connect();
      };
    </script>
  </body>
</html>

连接管理 #

连接池实现 #

// internal/hub/hub.go
package hub

import (
    "encoding/json"
    "log"
    "sync"

    "github.com/gorilla/websocket"
)

// 客户端连接
type Client struct {
    ID     string          // 客户端唯一标识
    Conn   *websocket.Conn // WebSocket 连接
    Send   chan []byte     // 发送消息通道
    Hub    *Hub            // 所属的 Hub
    UserID string          // 用户 ID
    Rooms  map[string]bool // 加入的房间
    mu     sync.RWMutex    // 读写锁
}

// 消息类型
type Message struct {
    Type    string      `json:"type"`
    From    string      `json:"from,omitempty"`
    To      string      `json:"to,omitempty"`
    Room    string      `json:"room,omitempty"`
    Content interface{} `json:"content"`
}

// Hub 管理所有客户端连接
type Hub struct {
    // 已注册的客户端
    clients map[*Client]bool

    // 房间管理
    rooms map[string]map[*Client]bool

    // 注册新客户端
    register chan *Client

    // 注销客户端
    unregister chan *Client

    // 广播消息
    broadcast chan []byte

    // 房间消息
    roomMessage chan *RoomMessage

    // 私人消息
    privateMessage chan *PrivateMessage

    // 读写锁
    mu sync.RWMutex
}

// 房间消息
type RoomMessage struct {
    Room    string
    Message []byte
    Sender  *Client
}

// 私人消息
type PrivateMessage struct {
    To      string
    Message []byte
    Sender  *Client
}

// 创建新的 Hub
func NewHub() *Hub {
    return &Hub{
        clients:        make(map[*Client]bool),
        rooms:          make(map[string]map[*Client]bool),
        register:       make(chan *Client),
        unregister:     make(chan *Client),
        broadcast:      make(chan []byte),
        roomMessage:    make(chan *RoomMessage),
        privateMessage: make(chan *PrivateMessage),
    }
}

// 运行 Hub
func (h *Hub) Run() {
    for {
        select {
        case client := <-h.register:
            h.registerClient(client)

        case client := <-h.unregister:
            h.unregisterClient(client)

        case message := <-h.broadcast:
            h.broadcastMessage(message)

        case roomMsg := <-h.roomMessage:
            h.sendRoomMessage(roomMsg)

        case privateMsg := <-h.privateMessage:
            h.sendPrivateMessage(privateMsg)
        }
    }
}

// 注册客户端
func (h *Hub) registerClient(client *Client) {
    h.mu.Lock()
    defer h.mu.Unlock()

    h.clients[client] = true
    log.Printf("客户端注册: %s (总数: %d)", client.ID, len(h.clients))

    // 发送欢迎消息
    welcomeMsg := Message{
        Type:    "system",
        Content: "欢迎连接到 WebSocket 服务器",
    }

    if data, err := json.Marshal(welcomeMsg); err == nil {
        select {
        case client.Send <- data:
        default:
            close(client.Send)
            delete(h.clients, client)
        }
    }
}

// 注销客户端
func (h *Hub) unregisterClient(client *Client) {
    h.mu.Lock()
    defer h.mu.Unlock()

    if _, ok := h.clients[client]; ok {
        // 从所有房间中移除
        for room := range client.Rooms {
            h.leaveRoomUnsafe(client, room)
        }

        delete(h.clients, client)
        close(client.Send)
        log.Printf("客户端注销: %s (总数: %d)", client.ID, len(h.clients))
    }
}

// 广播消息给所有客户端
func (h *Hub) broadcastMessage(message []byte) {
    h.mu.RLock()
    defer h.mu.RUnlock()

    for client := range h.clients {
        select {
        case client.Send <- message:
        default:
            close(client.Send)
            delete(h.clients, client)
        }
    }
}

// 发送房间消息
func (h *Hub) sendRoomMessage(roomMsg *RoomMessage) {
    h.mu.RLock()
    defer h.mu.RUnlock()

    if clients, ok := h.rooms[roomMsg.Room]; ok {
        for client := range clients {
            // 不发送给发送者自己
            if client == roomMsg.Sender {
                continue
            }

            select {
            case client.Send <- roomMsg.Message:
            default:
                close(client.Send)
                delete(h.clients, client)
                delete(clients, client)
            }
        }
    }
}

// 发送私人消息
func (h *Hub) sendPrivateMessage(privateMsg *PrivateMessage) {
    h.mu.RLock()
    defer h.mu.RUnlock()

    for client := range h.clients {
        if client.UserID == privateMsg.To {
            select {
            case client.Send <- privateMsg.Message:
            default:
                close(client.Send)
                delete(h.clients, client)
            }
            break
        }
    }
}

// 加入房间
func (h *Hub) JoinRoom(client *Client, room string) {
    h.mu.Lock()
    defer h.mu.Unlock()

    if h.rooms[room] == nil {
        h.rooms[room] = make(map[*Client]bool)
    }

    h.rooms[room][client] = true
    client.mu.Lock()
    client.Rooms[room] = true
    client.mu.Unlock()

    log.Printf("客户端 %s 加入房间 %s", client.ID, room)

    // 通知房间内其他用户
    joinMsg := Message{
        Type:    "user_joined",
        From:    client.UserID,
        Room:    room,
        Content: client.UserID + " 加入了房间",
    }

    if data, err := json.Marshal(joinMsg); err == nil {
        h.roomMessage <- &RoomMessage{
            Room:    room,
            Message: data,
            Sender:  client,
        }
    }
}

// 离开房间
func (h *Hub) LeaveRoom(client *Client, room string) {
    h.mu.Lock()
    defer h.mu.Unlock()
    h.leaveRoomUnsafe(client, room)
}

func (h *Hub) leaveRoomUnsafe(client *Client, room string) {
    if clients, ok := h.rooms[room]; ok {
        delete(clients, client)

        if len(clients) == 0 {
            delete(h.rooms, room)
        }
    }

    client.mu.Lock()
    delete(client.Rooms, room)
    client.mu.Unlock()

    log.Printf("客户端 %s 离开房间 %s", client.ID, room)

    // 通知房间内其他用户
    leaveMsg := Message{
        Type:    "user_left",
        From:    client.UserID,
        Room:    room,
        Content: client.UserID + " 离开了房间",
    }

    if data, err := json.Marshal(leaveMsg); err == nil {
        if clients, ok := h.rooms[room]; ok {
            for client := range clients {
                select {
                case client.Send <- data:
                default:
                    close(client.Send)
                    delete(h.clients, client)
                    delete(clients, client)
                }
            }
        }
    }
}

// 获取在线用户数
func (h *Hub) GetOnlineCount() int {
    h.mu.RLock()
    defer h.mu.RUnlock()
    return len(h.clients)
}

// 获取房间用户数
func (h *Hub) GetRoomUserCount(room string) int {
    h.mu.RLock()
    defer h.mu.RUnlock()

    if clients, ok := h.rooms[room]; ok {
        return len(clients)
    }
    return 0
}

客户端处理 #

// internal/client/client.go
package client

import (
    "encoding/json"
    "log"
    "net/http"
    "time"

    "github.com/gorilla/websocket"
    "github.com/google/uuid"

    "your-project/internal/hub"
)

const (
    // 写入等待时间
    writeWait = 10 * time.Second

    // Pong 等待时间
    pongWait = 60 * time.Second

    // Ping 发送周期,必须小于 pongWait
    pingPeriod = (pongWait * 9) / 10

    // 最大消息大小
    maxMessageSize = 512
)

var upgrader = websocket.Upgrader{
    ReadBufferSize:  1024,
    WriteBufferSize: 1024,
    CheckOrigin: func(r *http.Request) bool {
        // 在生产环境中应该检查 Origin
        return true
    },
}

// 处理 WebSocket 连接
func HandleWebSocket(hub *hub.Hub, w http.ResponseWriter, r *http.Request) {
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        log.Printf("WebSocket 升级失败: %v", err)
        return
    }

    // 获取用户 ID(从查询参数或认证信息中)
    userID := r.URL.Query().Get("user_id")
    if userID == "" {
        userID = "anonymous_" + uuid.New().String()[:8]
    }

    client := &hub.Client{
        ID:     uuid.New().String(),
        Conn:   conn,
        Send:   make(chan []byte, 256),
        Hub:    hub,
        UserID: userID,
        Rooms:  make(map[string]bool),
    }

    // 注册客户端
    hub.Register <- client

    // 启动读写协程
    go client.writePump()
    go client.readPump()
}

// 读取消息
func (c *hub.Client) readPump() {
    defer func() {
        c.Hub.Unregister <- c
        c.Conn.Close()
    }()

    c.Conn.SetReadLimit(maxMessageSize)
    c.Conn.SetReadDeadline(time.Now().Add(pongWait))
    c.Conn.SetPongHandler(func(string) error {
        c.Conn.SetReadDeadline(time.Now().Add(pongWait))
        return nil
    })

    for {
        _, message, err := c.Conn.ReadMessage()
        if err != nil {
            if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
                log.Printf("WebSocket 错误: %v", err)
            }
            break
        }

        // 处理消息
        c.handleMessage(message)
    }
}

// 写入消息
func (c *hub.Client) writePump() {
    ticker := time.NewTicker(pingPeriod)
    defer func() {
        ticker.Stop()
        c.Conn.Close()
    }()

    for {
        select {
        case message, ok := <-c.Send:
            c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
            if !ok {
                c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
                return
            }

            w, err := c.Conn.NextWriter(websocket.TextMessage)
            if err != nil {
                return
            }
            w.Write(message)

            // 批量发送队列中的消息
            n := len(c.Send)
            for i := 0; i < n; i++ {
                w.Write([]byte{'\n'})
                w.Write(<-c.Send)
            }

            if err := w.Close(); err != nil {
                return
            }

        case <-ticker.C:
            c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
            if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
                return
            }
        }
    }
}

// 处理接收到的消息
func (c *hub.Client) handleMessage(message []byte) {
    var msg hub.Message
    if err := json.Unmarshal(message, &msg); err != nil {
        log.Printf("消息解析失败: %v", err)
        return
    }

    // 设置发送者
    msg.From = c.UserID

    switch msg.Type {
    case "chat":
        c.handleChatMessage(&msg)
    case "join_room":
        c.handleJoinRoom(&msg)
    case "leave_room":
        c.handleLeaveRoom(&msg)
    case "private_message":
        c.handlePrivateMessage(&msg)
    case "ping":
        c.handlePing(&msg)
    default:
        log.Printf("未知消息类型: %s", msg.Type)
    }
}

// 处理聊天消息
func (c *hub.Client) handleChatMessage(msg *hub.Message) {
    if msg.Room != "" {
        // 房间消息
        data, err := json.Marshal(msg)
        if err != nil {
            log.Printf("消息序列化失败: %v", err)
            return
        }

        c.Hub.RoomMessage <- &hub.RoomMessage{
            Room:    msg.Room,
            Message: data,
            Sender:  c,
        }
    } else {
        // 广播消息
        data, err := json.Marshal(msg)
        if err != nil {
            log.Printf("消息序列化失败: %v", err)
            return
        }

        c.Hub.Broadcast <- data
    }
}

// 处理加入房间
func (c *hub.Client) handleJoinRoom(msg *hub.Message) {
    if msg.Room != "" {
        c.Hub.JoinRoom(c, msg.Room)

        // 发送确认消息
        response := hub.Message{
            Type:    "room_joined",
            Room:    msg.Room,
            Content: "成功加入房间 " + msg.Room,
        }

        if data, err := json.Marshal(response); err == nil {
            select {
            case c.Send <- data:
            default:
                close(c.Send)
            }
        }
    }
}

// 处理离开房间
func (c *hub.Client) handleLeaveRoom(msg *hub.Message) {
    if msg.Room != "" {
        c.Hub.LeaveRoom(c, msg.Room)

        // 发送确认消息
        response := hub.Message{
            Type:    "room_left",
            Room:    msg.Room,
            Content: "已离开房间 " + msg.Room,
        }

        if data, err := json.Marshal(response); err == nil {
            select {
            case c.Send <- data:
            default:
                close(c.Send)
            }
        }
    }
}

// 处理私人消息
func (c *hub.Client) handlePrivateMessage(msg *hub.Message) {
    if msg.To != "" {
        data, err := json.Marshal(msg)
        if err != nil {
            log.Printf("消息序列化失败: %v", err)
            return
        }

        c.Hub.PrivateMessage <- &hub.PrivateMessage{
            To:      msg.To,
            Message: data,
            Sender:  c,
        }
    }
}

// 处理 Ping 消息
func (c *hub.Client) handlePing(msg *hub.Message) {
    response := hub.Message{
        Type:    "pong",
        Content: "pong",
    }

    if data, err := json.Marshal(response); err == nil {
        select {
        case c.Send <- data:
        default:
            close(c.Send)
        }
    }
}

消息路由与处理 #

消息路由器 #

// internal/router/router.go
package router

import (
    "encoding/json"
    "log"
    "strings"
    "time"

    "your-project/internal/hub"
)

// 消息路由器
type MessageRouter struct {
    hub       *hub.Hub
    handlers  map[string]MessageHandler
    filters   []MessageFilter
    rateLimit *RateLimiter
}

// 消息处理器接口
type MessageHandler interface {
    Handle(client *hub.Client, msg *hub.Message) error
}

// 消息过滤器接口
type MessageFilter interface {
    Filter(client *hub.Client, msg *hub.Message) bool
}

// 速率限制器
type RateLimiter struct {
    requests map[string][]time.Time
    limit    int
    window   time.Duration
}

// 创建消息路由器
func NewMessageRouter(h *hub.Hub) *MessageRouter {
    router := &MessageRouter{
        hub:      h,
        handlers: make(map[string]MessageHandler),
        filters:  make([]MessageFilter, 0),
        rateLimit: &RateLimiter{
            requests: make(map[string][]time.Time),
            limit:    10, // 每分钟最多 10 条消息
            window:   time.Minute,
        },
    }

    // 注册默认处理器
    router.RegisterHandler("chat", &ChatHandler{})
    router.RegisterHandler("join_room", &JoinRoomHandler{})
    router.RegisterHandler("leave_room", &LeaveRoomHandler{})
    router.RegisterHandler("private_message", &PrivateMessageHandler{})
    router.RegisterHandler("get_online_users", &OnlineUsersHandler{})

    // 注册过滤器
    router.AddFilter(&ProfanityFilter{})
    router.AddFilter(&LengthFilter{MaxLength: 1000})

    return router
}

// 注册消息处理器
func (r *MessageRouter) RegisterHandler(msgType string, handler MessageHandler) {
    r.handlers[msgType] = handler
}

// 添加消息过滤器
func (r *MessageRouter) AddFilter(filter MessageFilter) {
    r.filters = append(r.filters, filter)
}

// 处理消息
func (r *MessageRouter) HandleMessage(client *hub.Client, message []byte) error {
    var msg hub.Message
    if err := json.Unmarshal(message, &msg); err != nil {
        return err
    }

    // 设置发送者
    msg.From = client.UserID
    msg.Timestamp = time.Now()

    // 速率限制检查
    if !r.rateLimit.Allow(client.ID) {
        r.sendError(client, "消息发送过于频繁,请稍后再试")
        return nil
    }

    // 应用过滤器
    for _, filter := range r.filters {
        if !filter.Filter(client, &msg) {
            r.sendError(client, "消息被过滤器拒绝")
            return nil
        }
    }

    // 查找处理器
    handler, ok := r.handlers[msg.Type]
    if !ok {
        r.sendError(client, "未知的消息类型: "+msg.Type)
        return nil
    }

    // 执行处理器
    return handler.Handle(client, &msg)
}

// 发送错误消息
func (r *MessageRouter) sendError(client *hub.Client, errorMsg string) {
    response := hub.Message{
        Type:    "error",
        Content: errorMsg,
    }

    if data, err := json.Marshal(response); err == nil {
        select {
        case client.Send <- data:
        default:
            close(client.Send)
        }
    }
}

// 速率限制检查
func (rl *RateLimiter) Allow(clientID string) bool {
    now := time.Now()

    // 清理过期记录
    if requests, ok := rl.requests[clientID]; ok {
        var validRequests []time.Time
        for _, reqTime := range requests {
            if now.Sub(reqTime) < rl.window {
                validRequests = append(validRequests, reqTime)
            }
        }
        rl.requests[clientID] = validRequests
    }

    // 检查是否超过限制
    if len(rl.requests[clientID]) >= rl.limit {
        return false
    }

    // 记录新请求
    rl.requests[clientID] = append(rl.requests[clientID], now)
    return true
}

// 聊天消息处理器
type ChatHandler struct{}

func (h *ChatHandler) Handle(client *hub.Client, msg *hub.Message) error {
    if msg.Room != "" {
        // 房间消息
        data, err := json.Marshal(msg)
        if err != nil {
            return err
        }

        client.Hub.RoomMessage <- &hub.RoomMessage{
            Room:    msg.Room,
            Message: data,
            Sender:  client,
        }
    } else {
        // 广播消息
        data, err := json.Marshal(msg)
        if err != nil {
            return err
        }

        client.Hub.Broadcast <- data
    }

    return nil
}

// 加入房间处理器
type JoinRoomHandler struct{}

func (h *JoinRoomHandler) Handle(client *hub.Client, msg *hub.Message) error {
    if msg.Room != "" {
        client.Hub.JoinRoom(client, msg.Room)
    }
    return nil
}

// 离开房间处理器
type LeaveRoomHandler struct{}

func (h *LeaveRoomHandler) Handle(client *hub.Client, msg *hub.Message) error {
    if msg.Room != "" {
        client.Hub.LeaveRoom(client, msg.Room)
    }
    return nil
}

// 私人消息处理器
type PrivateMessageHandler struct{}

func (h *PrivateMessageHandler) Handle(client *hub.Client, msg *hub.Message) error {
    if msg.To != "" {
        data, err := json.Marshal(msg)
        if err != nil {
            return err
        }

        client.Hub.PrivateMessage <- &hub.PrivateMessage{
            To:      msg.To,
            Message: data,
            Sender:  client,
        }
    }
    return nil
}

// 在线用户处理器
type OnlineUsersHandler struct{}

func (h *OnlineUsersHandler) Handle(client *hub.Client, msg *hub.Message) error {
    count := client.Hub.GetOnlineCount()

    response := hub.Message{
        Type:    "online_users",
        Content: map[string]interface{}{
            "count": count,
        },
    }

    if data, err := json.Marshal(response); err == nil {
        select {
        case client.Send <- data:
        default:
            close(client.Send)
        }
    }

    return nil
}

// 脏话过滤器
type ProfanityFilter struct{}

func (f *ProfanityFilter) Filter(client *hub.Client, msg *hub.Message) bool {
    // 简单的脏话过滤实现
    profanityWords := []string{"脏话1", "脏话2", "spam"}

    if content, ok := msg.Content.(string); ok {
        contentLower := strings.ToLower(content)
        for _, word := range profanityWords {
            if strings.Contains(contentLower, word) {
                log.Printf("消息被脏话过滤器拒绝: %s", content)
                return false
            }
        }
    }

    return true
}

// 长度过滤器
type LengthFilter struct {
    MaxLength int
}

func (f *LengthFilter) Filter(client *hub.Client, msg *hub.Message) bool {
    if content, ok := msg.Content.(string); ok {
        if len(content) > f.MaxLength {
            log.Printf("消息被长度过滤器拒绝: 长度 %d > %d", len(content), f.MaxLength)
            return false
        }
    }

    return true
}

高级特性 #

消息持久化 #

// internal/storage/storage.go
package storage

import (
    "database/sql"
    "encoding/json"
    "time"

    _ "github.com/lib/pq"
    "your-project/internal/hub"
)

// 消息存储接口
type MessageStorage interface {
    SaveMessage(msg *hub.Message) error
    GetRoomHistory(room string, limit int, offset int) ([]*hub.Message, error)
    GetPrivateHistory(user1, user2 string, limit int, offset int) ([]*hub.Message, error)
}

// PostgreSQL 存储实现
type PostgreSQLStorage struct {
    db *sql.DB
}

// 创建 PostgreSQL 存储
func NewPostgreSQLStorage(dsn string) (*PostgreSQLStorage, error) {
    db, err := sql.Open("postgres", dsn)
    if err != nil {
        return nil, err
    }

    storage := &PostgreSQLStorage{db: db}

    // 创建表
    if err := storage.createTables(); err != nil {
        return nil, err
    }

    return storage, nil
}

// 创建数据表
func (s *PostgreSQLStorage) createTables() error {
    query := `
    CREATE TABLE IF NOT EXISTS messages (
        id SERIAL PRIMARY KEY,
        type VARCHAR(50) NOT NULL,
        from_user VARCHAR(100) NOT NULL,
        to_user VARCHAR(100),
        room_name VARCHAR(100),
        content JSONB NOT NULL,
        timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
        created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
    );

    CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(room_name, timestamp);
    CREATE INDEX IF NOT EXISTS idx_messages_private ON messages(from_user, to_user, timestamp);
    CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp);
    `

    _, err := s.db.Exec(query)
    return err
}

// 保存消息
func (s *PostgreSQLStorage) SaveMessage(msg *hub.Message) error {
    contentJSON, err := json.Marshal(msg.Content)
    if err != nil {
        return err
    }

    query := `
    INSERT INTO messages (type, from_user, to_user, room_name, content, timestamp)
    VALUES ($1, $2, $3, $4, $5, $6)
    `

    _, err = s.db.Exec(query, msg.Type, msg.From, msg.To, msg.Room, contentJSON, msg.Timestamp)
    return err
}

// 获取房间历史消息
func (s *PostgreSQLStorage) GetRoomHistory(room string, limit int, offset int) ([]*hub.Message, error) {
    query := `
    SELECT type, from_user, to_user, room_name, content, timestamp
    FROM messages
    WHERE room_name = $1
    ORDER BY timestamp DESC
    LIMIT $2 OFFSET $3
    `

    rows, err := s.db.Query(query, room, limit, offset)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    var messages []*hub.Message
    for rows.Next() {
        var msg hub.Message
        var contentJSON []byte

        err := rows.Scan(&msg.Type, &msg.From, &msg.To, &msg.Room, &contentJSON, &msg.Timestamp)
        if err != nil {
            continue
        }

        if err := json.Unmarshal(contentJSON, &msg.Content); err != nil {
            continue
        }

        messages = append(messages, &msg)
    }

    return messages, nil
}

// 获取私聊历史消息
func (s *PostgreSQLStorage) GetPrivateHistory(user1, user2 string, limit int, offset int) ([]*hub.Message, error) {
    query := `
    SELECT type, from_user, to_user, room_name, content, timestamp
    FROM messages
    WHERE (from_user = $1 AND to_user = $2) OR (from_user = $2 AND to_user = $1)
    ORDER BY timestamp DESC
    LIMIT $3 OFFSET $4
    `

    rows, err := s.db.Query(query, user1, user2, limit, offset)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    var messages []*hub.Message
    for rows.Next() {
        var msg hub.Message
        var contentJSON []byte

        err := rows.Scan(&msg.Type, &msg.From, &msg.To, &msg.Room, &contentJSON, &msg.Timestamp)
        if err != nil {
            continue
        }

        if err := json.Unmarshal(contentJSON, &msg.Content); err != nil {
            continue
        }

        messages = append(messages, &msg)
    }

    return messages, nil
}

// 关闭存储
func (s *PostgreSQLStorage) Close() error {
    return s.db.Close()
}

监控和指标 #

// internal/metrics/metrics.go
package metrics

import (
    "sync"
    "time"

    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

var (
    // 连接数指标
    activeConnections = promauto.NewGauge(prometheus.GaugeOpts{
        Name: "websocket_active_connections",
        Help: "当前活跃的 WebSocket 连接数",
    })

    // 消息数指标
    messagesTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "websocket_messages_total",
            Help: "WebSocket 消息总数",
        },
        []string{"type", "direction"},
    )

    // 消息处理时间
    messageProcessingDuration = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name: "websocket_message_processing_duration_seconds",
            Help: "消息处理时间",
        },
        []string{"type"},
    )

    // 房间数指标
    activeRooms = promauto.NewGauge(prometheus.GaugeOpts{
        Name: "websocket_active_rooms",
        Help: "当前活跃的房间数",
    })
)

// 指标收集器
type MetricsCollector struct {
    mu                sync.RWMutex
    connectionCount   int
    roomCount         int
    messageStats      map[string]int64
    startTime         time.Time
}

// 创建指标收集器
func NewMetricsCollector() *MetricsCollector {
    return &MetricsCollector{
        messageStats: make(map[string]int64),
        startTime:    time.Now(),
    }
}

// 连接建立
func (mc *MetricsCollector) OnConnect() {
    mc.mu.Lock()
    defer mc.mu.Unlock()

    mc.connectionCount++
    activeConnections.Set(float64(mc.connectionCount))
}

// 连接断开
func (mc *MetricsCollector) OnDisconnect() {
    mc.mu.Lock()
    defer mc.mu.Unlock()

    mc.connectionCount--
    activeConnections.Set(float64(mc.connectionCount))
}

// 房间创建
func (mc *MetricsCollector) OnRoomCreated() {
    mc.mu.Lock()
    defer mc.mu.Unlock()

    mc.roomCount++
    activeRooms.Set(float64(mc.roomCount))
}

// 房间销毁
func (mc *MetricsCollector) OnRoomDestroyed() {
    mc.mu.Lock()
    defer mc.mu.Unlock()

    mc.roomCount--
    activeRooms.Set(float64(mc.roomCount))
}

// 消息接收
func (mc *MetricsCollector) OnMessageReceived(msgType string) {
    messagesTotal.WithLabelValues(msgType, "received").Inc()

    mc.mu.Lock()
    mc.messageStats[msgType+"_received"]++
    mc.mu.Unlock()
}

// 消息发送
func (mc *MetricsCollector) OnMessageSent(msgType string) {
    messagesTotal.WithLabelValues(msgType, "sent").Inc()

    mc.mu.Lock()
    mc.messageStats[msgType+"_sent"]++
    mc.mu.Unlock()
}

// 消息处理时间记录
func (mc *MetricsCollector) RecordMessageProcessingTime(msgType string, duration time.Duration) {
    messageProcessingDuration.WithLabelValues(msgType).Observe(duration.Seconds())
}

// 获取统计信息
func (mc *MetricsCollector) GetStats() map[string]interface{} {
    mc.mu.RLock()
    defer mc.mu.RUnlock()

    stats := make(map[string]interface{})
    stats["active_connections"] = mc.connectionCount
    stats["active_rooms"] = mc.roomCount
    stats["uptime_seconds"] = time.Since(mc.startTime).Seconds()

    // 复制消息统计
    messageStats := make(map[string]int64)
    for k, v := range mc.messageStats {
        messageStats[k] = v
    }
    stats["message_stats"] = messageStats

    return stats
}

完整的服务器实现 #

// cmd/server/main.go
package main

import (
    "encoding/json"
    "flag"
    "fmt"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"

    "github.com/prometheus/client_golang/prometheus/promhttp"

    "your-project/internal/client"
    "your-project/internal/hub"
    "your-project/internal/metrics"
    "your-project/internal/router"
    "your-project/internal/storage"
)

var (
    addr    = flag.String("addr", ":8080", "HTTP 服务地址")
    dbURL   = flag.String("db", "", "数据库连接字符串")
)

func main() {
    flag.Parse()

    // 创建 Hub
    h := hub.NewHub()

    // 创建指标收集器
    metricsCollector := metrics.NewMetricsCollector()

    // 创建消息路由器
    messageRouter := router.NewMessageRouter(h)

    // 创建存储(如果提供了数据库 URL)
    var messageStorage storage.MessageStorage
    if *dbURL != "" {
        var err error
        messageStorage, err = storage.NewPostgreSQLStorage(*dbURL)
        if err != nil {
            log.Printf("数据库连接失败: %v", err)
        } else {
            log.Println("数据库连接成功")
        }
    }

    // 启动 Hub
    go h.Run()

    // 设置路由
    setupRoutes(h, messageRouter, metricsCollector, messageStorage)

    // 优雅关闭
    setupGracefulShutdown()

    fmt.Printf("WebSocket 服务器启动在 %s\n", *addr)
    log.Fatal(http.ListenAndServe(*addr, nil))
}

func setupRoutes(h *hub.Hub, router *router.MessageRouter, metrics *metrics.MetricsCollector, storage storage.MessageStorage) {
    // 静态文件
    http.Handle("/", http.FileServer(http.Dir("./static/")))

    // WebSocket 端点
    http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
        client.HandleWebSocket(h, w, r)
    })

    // API 端点
    http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) {
        stats := metrics.GetStats()
        w.Header().Set("Content-Type", "application/json")
        json.NewEncoder(w).Encode(stats)
    })

    // 历史消息端点
    if storage != nil {
        http.HandleFunc("/api/history/room", func(w http.ResponseWriter, r *http.Request) {
            room := r.URL.Query().Get("room")
            if room == "" {
                http.Error(w, "缺少房间参数", http.StatusBadRequest)
                return
            }

            messages, err := storage.GetRoomHistory(room, 50, 0)
            if err != nil {
                http.Error(w, "获取历史消息失败", http.StatusInternalServerError)
                return
            }

            w.Header().Set("Content-Type", "application/json")
            json.NewEncoder(w).Encode(messages)
        })
    }

    // Prometheus 指标端点
    http.Handle("/metrics", promhttp.Handler())

    // 健康检查端点
    http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
        w.WriteHeader(http.StatusOK)
        w.Write([]byte("OK"))
    })
}

func setupGracefulShutdown() {
    c := make(chan os.Signal, 1)
    signal.Notify(c, os.Interrupt, syscall.SIGTERM)

    go func() {
        <-c
        log.Println("正在关闭服务器...")
        os.Exit(0)
    }()
}

小结 #

本节详细介绍了 WebSocket 服务端的实现:

  1. 基础实现:使用 Gorilla WebSocket 库创建基本服务器
  2. 连接管理:Hub 模式管理客户端连接和房间
  3. 消息处理:读写分离、消息路由、过滤器机制
  4. 高级特性:消息持久化、监控指标、速率限制
  5. 完整架构:模块化设计、优雅关闭、健康检查

通过这些内容,您可以构建出功能完整、性能优良的 WebSocket 服务端。在下一节中,我们将学习如何开发 WebSocket 客户端。