3.8.2 WebSocket 服务端实现 #
WebSocket 服务端是实时通信系统的核心,负责处理客户端连接、消息路由、状态管理等关键功能。本节将详细介绍如何使用 Go 语言构建高性能、可扩展的 WebSocket 服务端。
基础服务端实现 #
使用 Gorilla WebSocket #
Gorilla WebSocket 是 Go 语言中最流行的 WebSocket 库,提供了完整的 WebSocket 协议实现。
go mod init websocket-server
go get github.com/gorilla/websocket
简单的 WebSocket 服务器 #
// main.go
package main
import (
"fmt"
"log"
"net/http"
"github.com/gorilla/websocket"
)
// WebSocket 升级器配置
var upgrader = websocket.Upgrader{
// 允许跨域请求
CheckOrigin: func(r *http.Request) bool {
return true
},
// 读写缓冲区大小
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// 处理 WebSocket 连接
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
// 升级 HTTP 连接为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
defer conn.Close()
log.Printf("新的 WebSocket 连接: %s", conn.RemoteAddr())
// 消息处理循环
for {
// 读取消息
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket 错误: %v", err)
}
break
}
log.Printf("收到消息 [%d]: %s", messageType, string(message))
// 回显消息
err = conn.WriteMessage(messageType, message)
if err != nil {
log.Printf("发送消息失败: %v", err)
break
}
}
log.Printf("WebSocket 连接关闭: %s", conn.RemoteAddr())
}
func main() {
// 静态文件服务
http.Handle("/", http.FileServer(http.Dir("./static/")))
// WebSocket 端点
http.HandleFunc("/ws", handleWebSocket)
fmt.Println("WebSocket 服务器启动在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
客户端测试页面 #
创建 static/index.html
文件:
<!DOCTYPE html>
<html>
<head>
<title>WebSocket 测试</title>
<meta charset="utf-8" />
</head>
<body>
<div>
<h2>WebSocket 测试客户端</h2>
<div>
<input
type="text"
id="messageInput"
placeholder="输入消息..."
style="width: 300px;"
/>
<button onclick="sendMessage()">发送</button>
<button onclick="connect()">连接</button>
<button onclick="disconnect()">断开</button>
</div>
<div>
<h3>连接状态: <span id="status">未连接</span></h3>
<h3>消息记录:</h3>
<div
id="messages"
style="border: 1px solid #ccc; height: 300px; overflow-y: scroll; padding: 10px;"
></div>
</div>
</div>
<script>
let ws = null;
const statusElement = document.getElementById("status");
const messagesElement = document.getElementById("messages");
const messageInput = document.getElementById("messageInput");
function connect() {
if (ws && ws.readyState === WebSocket.OPEN) {
addMessage("已经连接");
return;
}
ws = new WebSocket("ws://localhost:8080/ws");
ws.onopen = function (event) {
statusElement.textContent = "已连接";
statusElement.style.color = "green";
addMessage("WebSocket 连接已建立");
};
ws.onmessage = function (event) {
addMessage("收到: " + event.data);
};
ws.onclose = function (event) {
statusElement.textContent = "已断开";
statusElement.style.color = "red";
addMessage("WebSocket 连接已关闭");
};
ws.onerror = function (error) {
statusElement.textContent = "错误";
statusElement.style.color = "red";
addMessage("WebSocket 错误: " + error);
};
}
function disconnect() {
if (ws) {
ws.close();
}
}
function sendMessage() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
addMessage("请先连接 WebSocket");
return;
}
const message = messageInput.value.trim();
if (message) {
ws.send(message);
addMessage("发送: " + message);
messageInput.value = "";
}
}
function addMessage(message) {
const div = document.createElement("div");
div.textContent = new Date().toLocaleTimeString() + " - " + message;
messagesElement.appendChild(div);
messagesElement.scrollTop = messagesElement.scrollHeight;
}
// 回车发送消息
messageInput.addEventListener("keypress", function (e) {
if (e.key === "Enter") {
sendMessage();
}
});
// 页面加载时自动连接
window.onload = function () {
connect();
};
</script>
</body>
</html>
连接管理 #
连接池实现 #
// internal/hub/hub.go
package hub
import (
"encoding/json"
"log"
"sync"
"github.com/gorilla/websocket"
)
// 客户端连接
type Client struct {
ID string // 客户端唯一标识
Conn *websocket.Conn // WebSocket 连接
Send chan []byte // 发送消息通道
Hub *Hub // 所属的 Hub
UserID string // 用户 ID
Rooms map[string]bool // 加入的房间
mu sync.RWMutex // 读写锁
}
// 消息类型
type Message struct {
Type string `json:"type"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Room string `json:"room,omitempty"`
Content interface{} `json:"content"`
}
// Hub 管理所有客户端连接
type Hub struct {
// 已注册的客户端
clients map[*Client]bool
// 房间管理
rooms map[string]map[*Client]bool
// 注册新客户端
register chan *Client
// 注销客户端
unregister chan *Client
// 广播消息
broadcast chan []byte
// 房间消息
roomMessage chan *RoomMessage
// 私人消息
privateMessage chan *PrivateMessage
// 读写锁
mu sync.RWMutex
}
// 房间消息
type RoomMessage struct {
Room string
Message []byte
Sender *Client
}
// 私人消息
type PrivateMessage struct {
To string
Message []byte
Sender *Client
}
// 创建新的 Hub
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
rooms: make(map[string]map[*Client]bool),
register: make(chan *Client),
unregister: make(chan *Client),
broadcast: make(chan []byte),
roomMessage: make(chan *RoomMessage),
privateMessage: make(chan *PrivateMessage),
}
}
// 运行 Hub
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.registerClient(client)
case client := <-h.unregister:
h.unregisterClient(client)
case message := <-h.broadcast:
h.broadcastMessage(message)
case roomMsg := <-h.roomMessage:
h.sendRoomMessage(roomMsg)
case privateMsg := <-h.privateMessage:
h.sendPrivateMessage(privateMsg)
}
}
}
// 注册客户端
func (h *Hub) registerClient(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
h.clients[client] = true
log.Printf("客户端注册: %s (总数: %d)", client.ID, len(h.clients))
// 发送欢迎消息
welcomeMsg := Message{
Type: "system",
Content: "欢迎连接到 WebSocket 服务器",
}
if data, err := json.Marshal(welcomeMsg); err == nil {
select {
case client.Send <- data:
default:
close(client.Send)
delete(h.clients, client)
}
}
}
// 注销客户端
func (h *Hub) unregisterClient(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
if _, ok := h.clients[client]; ok {
// 从所有房间中移除
for room := range client.Rooms {
h.leaveRoomUnsafe(client, room)
}
delete(h.clients, client)
close(client.Send)
log.Printf("客户端注销: %s (总数: %d)", client.ID, len(h.clients))
}
}
// 广播消息给所有客户端
func (h *Hub) broadcastMessage(message []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
for client := range h.clients {
select {
case client.Send <- message:
default:
close(client.Send)
delete(h.clients, client)
}
}
}
// 发送房间消息
func (h *Hub) sendRoomMessage(roomMsg *RoomMessage) {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.rooms[roomMsg.Room]; ok {
for client := range clients {
// 不发送给发送者自己
if client == roomMsg.Sender {
continue
}
select {
case client.Send <- roomMsg.Message:
default:
close(client.Send)
delete(h.clients, client)
delete(clients, client)
}
}
}
}
// 发送私人消息
func (h *Hub) sendPrivateMessage(privateMsg *PrivateMessage) {
h.mu.RLock()
defer h.mu.RUnlock()
for client := range h.clients {
if client.UserID == privateMsg.To {
select {
case client.Send <- privateMsg.Message:
default:
close(client.Send)
delete(h.clients, client)
}
break
}
}
}
// 加入房间
func (h *Hub) JoinRoom(client *Client, room string) {
h.mu.Lock()
defer h.mu.Unlock()
if h.rooms[room] == nil {
h.rooms[room] = make(map[*Client]bool)
}
h.rooms[room][client] = true
client.mu.Lock()
client.Rooms[room] = true
client.mu.Unlock()
log.Printf("客户端 %s 加入房间 %s", client.ID, room)
// 通知房间内其他用户
joinMsg := Message{
Type: "user_joined",
From: client.UserID,
Room: room,
Content: client.UserID + " 加入了房间",
}
if data, err := json.Marshal(joinMsg); err == nil {
h.roomMessage <- &RoomMessage{
Room: room,
Message: data,
Sender: client,
}
}
}
// 离开房间
func (h *Hub) LeaveRoom(client *Client, room string) {
h.mu.Lock()
defer h.mu.Unlock()
h.leaveRoomUnsafe(client, room)
}
func (h *Hub) leaveRoomUnsafe(client *Client, room string) {
if clients, ok := h.rooms[room]; ok {
delete(clients, client)
if len(clients) == 0 {
delete(h.rooms, room)
}
}
client.mu.Lock()
delete(client.Rooms, room)
client.mu.Unlock()
log.Printf("客户端 %s 离开房间 %s", client.ID, room)
// 通知房间内其他用户
leaveMsg := Message{
Type: "user_left",
From: client.UserID,
Room: room,
Content: client.UserID + " 离开了房间",
}
if data, err := json.Marshal(leaveMsg); err == nil {
if clients, ok := h.rooms[room]; ok {
for client := range clients {
select {
case client.Send <- data:
default:
close(client.Send)
delete(h.clients, client)
delete(clients, client)
}
}
}
}
}
// 获取在线用户数
func (h *Hub) GetOnlineCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
// 获取房间用户数
func (h *Hub) GetRoomUserCount(room string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.rooms[room]; ok {
return len(clients)
}
return 0
}
客户端处理 #
// internal/client/client.go
package client
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/google/uuid"
"your-project/internal/hub"
)
const (
// 写入等待时间
writeWait = 10 * time.Second
// Pong 等待时间
pongWait = 60 * time.Second
// Ping 发送周期,必须小于 pongWait
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 512
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// 在生产环境中应该检查 Origin
return true
},
}
// 处理 WebSocket 连接
func HandleWebSocket(hub *hub.Hub, w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
// 获取用户 ID(从查询参数或认证信息中)
userID := r.URL.Query().Get("user_id")
if userID == "" {
userID = "anonymous_" + uuid.New().String()[:8]
}
client := &hub.Client{
ID: uuid.New().String(),
Conn: conn,
Send: make(chan []byte, 256),
Hub: hub,
UserID: userID,
Rooms: make(map[string]bool),
}
// 注册客户端
hub.Register <- client
// 启动读写协程
go client.writePump()
go client.readPump()
}
// 读取消息
func (c *hub.Client) readPump() {
defer func() {
c.Hub.Unregister <- c
c.Conn.Close()
}()
c.Conn.SetReadLimit(maxMessageSize)
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket 错误: %v", err)
}
break
}
// 处理消息
c.handleMessage(message)
}
}
// 写入消息
func (c *hub.Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.Conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// 批量发送队列中的消息
n := len(c.Send)
for i := 0; i < n; i++ {
w.Write([]byte{'\n'})
w.Write(<-c.Send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// 处理接收到的消息
func (c *hub.Client) handleMessage(message []byte) {
var msg hub.Message
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("消息解析失败: %v", err)
return
}
// 设置发送者
msg.From = c.UserID
switch msg.Type {
case "chat":
c.handleChatMessage(&msg)
case "join_room":
c.handleJoinRoom(&msg)
case "leave_room":
c.handleLeaveRoom(&msg)
case "private_message":
c.handlePrivateMessage(&msg)
case "ping":
c.handlePing(&msg)
default:
log.Printf("未知消息类型: %s", msg.Type)
}
}
// 处理聊天消息
func (c *hub.Client) handleChatMessage(msg *hub.Message) {
if msg.Room != "" {
// 房间消息
data, err := json.Marshal(msg)
if err != nil {
log.Printf("消息序列化失败: %v", err)
return
}
c.Hub.RoomMessage <- &hub.RoomMessage{
Room: msg.Room,
Message: data,
Sender: c,
}
} else {
// 广播消息
data, err := json.Marshal(msg)
if err != nil {
log.Printf("消息序列化失败: %v", err)
return
}
c.Hub.Broadcast <- data
}
}
// 处理加入房间
func (c *hub.Client) handleJoinRoom(msg *hub.Message) {
if msg.Room != "" {
c.Hub.JoinRoom(c, msg.Room)
// 发送确认消息
response := hub.Message{
Type: "room_joined",
Room: msg.Room,
Content: "成功加入房间 " + msg.Room,
}
if data, err := json.Marshal(response); err == nil {
select {
case c.Send <- data:
default:
close(c.Send)
}
}
}
}
// 处理离开房间
func (c *hub.Client) handleLeaveRoom(msg *hub.Message) {
if msg.Room != "" {
c.Hub.LeaveRoom(c, msg.Room)
// 发送确认消息
response := hub.Message{
Type: "room_left",
Room: msg.Room,
Content: "已离开房间 " + msg.Room,
}
if data, err := json.Marshal(response); err == nil {
select {
case c.Send <- data:
default:
close(c.Send)
}
}
}
}
// 处理私人消息
func (c *hub.Client) handlePrivateMessage(msg *hub.Message) {
if msg.To != "" {
data, err := json.Marshal(msg)
if err != nil {
log.Printf("消息序列化失败: %v", err)
return
}
c.Hub.PrivateMessage <- &hub.PrivateMessage{
To: msg.To,
Message: data,
Sender: c,
}
}
}
// 处理 Ping 消息
func (c *hub.Client) handlePing(msg *hub.Message) {
response := hub.Message{
Type: "pong",
Content: "pong",
}
if data, err := json.Marshal(response); err == nil {
select {
case c.Send <- data:
default:
close(c.Send)
}
}
}
消息路由与处理 #
消息路由器 #
// internal/router/router.go
package router
import (
"encoding/json"
"log"
"strings"
"time"
"your-project/internal/hub"
)
// 消息路由器
type MessageRouter struct {
hub *hub.Hub
handlers map[string]MessageHandler
filters []MessageFilter
rateLimit *RateLimiter
}
// 消息处理器接口
type MessageHandler interface {
Handle(client *hub.Client, msg *hub.Message) error
}
// 消息过滤器接口
type MessageFilter interface {
Filter(client *hub.Client, msg *hub.Message) bool
}
// 速率限制器
type RateLimiter struct {
requests map[string][]time.Time
limit int
window time.Duration
}
// 创建消息路由器
func NewMessageRouter(h *hub.Hub) *MessageRouter {
router := &MessageRouter{
hub: h,
handlers: make(map[string]MessageHandler),
filters: make([]MessageFilter, 0),
rateLimit: &RateLimiter{
requests: make(map[string][]time.Time),
limit: 10, // 每分钟最多 10 条消息
window: time.Minute,
},
}
// 注册默认处理器
router.RegisterHandler("chat", &ChatHandler{})
router.RegisterHandler("join_room", &JoinRoomHandler{})
router.RegisterHandler("leave_room", &LeaveRoomHandler{})
router.RegisterHandler("private_message", &PrivateMessageHandler{})
router.RegisterHandler("get_online_users", &OnlineUsersHandler{})
// 注册过滤器
router.AddFilter(&ProfanityFilter{})
router.AddFilter(&LengthFilter{MaxLength: 1000})
return router
}
// 注册消息处理器
func (r *MessageRouter) RegisterHandler(msgType string, handler MessageHandler) {
r.handlers[msgType] = handler
}
// 添加消息过滤器
func (r *MessageRouter) AddFilter(filter MessageFilter) {
r.filters = append(r.filters, filter)
}
// 处理消息
func (r *MessageRouter) HandleMessage(client *hub.Client, message []byte) error {
var msg hub.Message
if err := json.Unmarshal(message, &msg); err != nil {
return err
}
// 设置发送者
msg.From = client.UserID
msg.Timestamp = time.Now()
// 速率限制检查
if !r.rateLimit.Allow(client.ID) {
r.sendError(client, "消息发送过于频繁,请稍后再试")
return nil
}
// 应用过滤器
for _, filter := range r.filters {
if !filter.Filter(client, &msg) {
r.sendError(client, "消息被过滤器拒绝")
return nil
}
}
// 查找处理器
handler, ok := r.handlers[msg.Type]
if !ok {
r.sendError(client, "未知的消息类型: "+msg.Type)
return nil
}
// 执行处理器
return handler.Handle(client, &msg)
}
// 发送错误消息
func (r *MessageRouter) sendError(client *hub.Client, errorMsg string) {
response := hub.Message{
Type: "error",
Content: errorMsg,
}
if data, err := json.Marshal(response); err == nil {
select {
case client.Send <- data:
default:
close(client.Send)
}
}
}
// 速率限制检查
func (rl *RateLimiter) Allow(clientID string) bool {
now := time.Now()
// 清理过期记录
if requests, ok := rl.requests[clientID]; ok {
var validRequests []time.Time
for _, reqTime := range requests {
if now.Sub(reqTime) < rl.window {
validRequests = append(validRequests, reqTime)
}
}
rl.requests[clientID] = validRequests
}
// 检查是否超过限制
if len(rl.requests[clientID]) >= rl.limit {
return false
}
// 记录新请求
rl.requests[clientID] = append(rl.requests[clientID], now)
return true
}
// 聊天消息处理器
type ChatHandler struct{}
func (h *ChatHandler) Handle(client *hub.Client, msg *hub.Message) error {
if msg.Room != "" {
// 房间消息
data, err := json.Marshal(msg)
if err != nil {
return err
}
client.Hub.RoomMessage <- &hub.RoomMessage{
Room: msg.Room,
Message: data,
Sender: client,
}
} else {
// 广播消息
data, err := json.Marshal(msg)
if err != nil {
return err
}
client.Hub.Broadcast <- data
}
return nil
}
// 加入房间处理器
type JoinRoomHandler struct{}
func (h *JoinRoomHandler) Handle(client *hub.Client, msg *hub.Message) error {
if msg.Room != "" {
client.Hub.JoinRoom(client, msg.Room)
}
return nil
}
// 离开房间处理器
type LeaveRoomHandler struct{}
func (h *LeaveRoomHandler) Handle(client *hub.Client, msg *hub.Message) error {
if msg.Room != "" {
client.Hub.LeaveRoom(client, msg.Room)
}
return nil
}
// 私人消息处理器
type PrivateMessageHandler struct{}
func (h *PrivateMessageHandler) Handle(client *hub.Client, msg *hub.Message) error {
if msg.To != "" {
data, err := json.Marshal(msg)
if err != nil {
return err
}
client.Hub.PrivateMessage <- &hub.PrivateMessage{
To: msg.To,
Message: data,
Sender: client,
}
}
return nil
}
// 在线用户处理器
type OnlineUsersHandler struct{}
func (h *OnlineUsersHandler) Handle(client *hub.Client, msg *hub.Message) error {
count := client.Hub.GetOnlineCount()
response := hub.Message{
Type: "online_users",
Content: map[string]interface{}{
"count": count,
},
}
if data, err := json.Marshal(response); err == nil {
select {
case client.Send <- data:
default:
close(client.Send)
}
}
return nil
}
// 脏话过滤器
type ProfanityFilter struct{}
func (f *ProfanityFilter) Filter(client *hub.Client, msg *hub.Message) bool {
// 简单的脏话过滤实现
profanityWords := []string{"脏话1", "脏话2", "spam"}
if content, ok := msg.Content.(string); ok {
contentLower := strings.ToLower(content)
for _, word := range profanityWords {
if strings.Contains(contentLower, word) {
log.Printf("消息被脏话过滤器拒绝: %s", content)
return false
}
}
}
return true
}
// 长度过滤器
type LengthFilter struct {
MaxLength int
}
func (f *LengthFilter) Filter(client *hub.Client, msg *hub.Message) bool {
if content, ok := msg.Content.(string); ok {
if len(content) > f.MaxLength {
log.Printf("消息被长度过滤器拒绝: 长度 %d > %d", len(content), f.MaxLength)
return false
}
}
return true
}
高级特性 #
消息持久化 #
// internal/storage/storage.go
package storage
import (
"database/sql"
"encoding/json"
"time"
_ "github.com/lib/pq"
"your-project/internal/hub"
)
// 消息存储接口
type MessageStorage interface {
SaveMessage(msg *hub.Message) error
GetRoomHistory(room string, limit int, offset int) ([]*hub.Message, error)
GetPrivateHistory(user1, user2 string, limit int, offset int) ([]*hub.Message, error)
}
// PostgreSQL 存储实现
type PostgreSQLStorage struct {
db *sql.DB
}
// 创建 PostgreSQL 存储
func NewPostgreSQLStorage(dsn string) (*PostgreSQLStorage, error) {
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, err
}
storage := &PostgreSQLStorage{db: db}
// 创建表
if err := storage.createTables(); err != nil {
return nil, err
}
return storage, nil
}
// 创建数据表
func (s *PostgreSQLStorage) createTables() error {
query := `
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
type VARCHAR(50) NOT NULL,
from_user VARCHAR(100) NOT NULL,
to_user VARCHAR(100),
room_name VARCHAR(100),
content JSONB NOT NULL,
timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_messages_room ON messages(room_name, timestamp);
CREATE INDEX IF NOT EXISTS idx_messages_private ON messages(from_user, to_user, timestamp);
CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp);
`
_, err := s.db.Exec(query)
return err
}
// 保存消息
func (s *PostgreSQLStorage) SaveMessage(msg *hub.Message) error {
contentJSON, err := json.Marshal(msg.Content)
if err != nil {
return err
}
query := `
INSERT INTO messages (type, from_user, to_user, room_name, content, timestamp)
VALUES ($1, $2, $3, $4, $5, $6)
`
_, err = s.db.Exec(query, msg.Type, msg.From, msg.To, msg.Room, contentJSON, msg.Timestamp)
return err
}
// 获取房间历史消息
func (s *PostgreSQLStorage) GetRoomHistory(room string, limit int, offset int) ([]*hub.Message, error) {
query := `
SELECT type, from_user, to_user, room_name, content, timestamp
FROM messages
WHERE room_name = $1
ORDER BY timestamp DESC
LIMIT $2 OFFSET $3
`
rows, err := s.db.Query(query, room, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []*hub.Message
for rows.Next() {
var msg hub.Message
var contentJSON []byte
err := rows.Scan(&msg.Type, &msg.From, &msg.To, &msg.Room, &contentJSON, &msg.Timestamp)
if err != nil {
continue
}
if err := json.Unmarshal(contentJSON, &msg.Content); err != nil {
continue
}
messages = append(messages, &msg)
}
return messages, nil
}
// 获取私聊历史消息
func (s *PostgreSQLStorage) GetPrivateHistory(user1, user2 string, limit int, offset int) ([]*hub.Message, error) {
query := `
SELECT type, from_user, to_user, room_name, content, timestamp
FROM messages
WHERE (from_user = $1 AND to_user = $2) OR (from_user = $2 AND to_user = $1)
ORDER BY timestamp DESC
LIMIT $3 OFFSET $4
`
rows, err := s.db.Query(query, user1, user2, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []*hub.Message
for rows.Next() {
var msg hub.Message
var contentJSON []byte
err := rows.Scan(&msg.Type, &msg.From, &msg.To, &msg.Room, &contentJSON, &msg.Timestamp)
if err != nil {
continue
}
if err := json.Unmarshal(contentJSON, &msg.Content); err != nil {
continue
}
messages = append(messages, &msg)
}
return messages, nil
}
// 关闭存储
func (s *PostgreSQLStorage) Close() error {
return s.db.Close()
}
监控和指标 #
// internal/metrics/metrics.go
package metrics
import (
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// 连接数指标
activeConnections = promauto.NewGauge(prometheus.GaugeOpts{
Name: "websocket_active_connections",
Help: "当前活跃的 WebSocket 连接数",
})
// 消息数指标
messagesTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "websocket_messages_total",
Help: "WebSocket 消息总数",
},
[]string{"type", "direction"},
)
// 消息处理时间
messageProcessingDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "websocket_message_processing_duration_seconds",
Help: "消息处理时间",
},
[]string{"type"},
)
// 房间数指标
activeRooms = promauto.NewGauge(prometheus.GaugeOpts{
Name: "websocket_active_rooms",
Help: "当前活跃的房间数",
})
)
// 指标收集器
type MetricsCollector struct {
mu sync.RWMutex
connectionCount int
roomCount int
messageStats map[string]int64
startTime time.Time
}
// 创建指标收集器
func NewMetricsCollector() *MetricsCollector {
return &MetricsCollector{
messageStats: make(map[string]int64),
startTime: time.Now(),
}
}
// 连接建立
func (mc *MetricsCollector) OnConnect() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.connectionCount++
activeConnections.Set(float64(mc.connectionCount))
}
// 连接断开
func (mc *MetricsCollector) OnDisconnect() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.connectionCount--
activeConnections.Set(float64(mc.connectionCount))
}
// 房间创建
func (mc *MetricsCollector) OnRoomCreated() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.roomCount++
activeRooms.Set(float64(mc.roomCount))
}
// 房间销毁
func (mc *MetricsCollector) OnRoomDestroyed() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.roomCount--
activeRooms.Set(float64(mc.roomCount))
}
// 消息接收
func (mc *MetricsCollector) OnMessageReceived(msgType string) {
messagesTotal.WithLabelValues(msgType, "received").Inc()
mc.mu.Lock()
mc.messageStats[msgType+"_received"]++
mc.mu.Unlock()
}
// 消息发送
func (mc *MetricsCollector) OnMessageSent(msgType string) {
messagesTotal.WithLabelValues(msgType, "sent").Inc()
mc.mu.Lock()
mc.messageStats[msgType+"_sent"]++
mc.mu.Unlock()
}
// 消息处理时间记录
func (mc *MetricsCollector) RecordMessageProcessingTime(msgType string, duration time.Duration) {
messageProcessingDuration.WithLabelValues(msgType).Observe(duration.Seconds())
}
// 获取统计信息
func (mc *MetricsCollector) GetStats() map[string]interface{} {
mc.mu.RLock()
defer mc.mu.RUnlock()
stats := make(map[string]interface{})
stats["active_connections"] = mc.connectionCount
stats["active_rooms"] = mc.roomCount
stats["uptime_seconds"] = time.Since(mc.startTime).Seconds()
// 复制消息统计
messageStats := make(map[string]int64)
for k, v := range mc.messageStats {
messageStats[k] = v
}
stats["message_stats"] = messageStats
return stats
}
完整的服务器实现 #
// cmd/server/main.go
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/prometheus/client_golang/prometheus/promhttp"
"your-project/internal/client"
"your-project/internal/hub"
"your-project/internal/metrics"
"your-project/internal/router"
"your-project/internal/storage"
)
var (
addr = flag.String("addr", ":8080", "HTTP 服务地址")
dbURL = flag.String("db", "", "数据库连接字符串")
)
func main() {
flag.Parse()
// 创建 Hub
h := hub.NewHub()
// 创建指标收集器
metricsCollector := metrics.NewMetricsCollector()
// 创建消息路由器
messageRouter := router.NewMessageRouter(h)
// 创建存储(如果提供了数据库 URL)
var messageStorage storage.MessageStorage
if *dbURL != "" {
var err error
messageStorage, err = storage.NewPostgreSQLStorage(*dbURL)
if err != nil {
log.Printf("数据库连接失败: %v", err)
} else {
log.Println("数据库连接成功")
}
}
// 启动 Hub
go h.Run()
// 设置路由
setupRoutes(h, messageRouter, metricsCollector, messageStorage)
// 优雅关闭
setupGracefulShutdown()
fmt.Printf("WebSocket 服务器启动在 %s\n", *addr)
log.Fatal(http.ListenAndServe(*addr, nil))
}
func setupRoutes(h *hub.Hub, router *router.MessageRouter, metrics *metrics.MetricsCollector, storage storage.MessageStorage) {
// 静态文件
http.Handle("/", http.FileServer(http.Dir("./static/")))
// WebSocket 端点
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
client.HandleWebSocket(h, w, r)
})
// API 端点
http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) {
stats := metrics.GetStats()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
})
// 历史消息端点
if storage != nil {
http.HandleFunc("/api/history/room", func(w http.ResponseWriter, r *http.Request) {
room := r.URL.Query().Get("room")
if room == "" {
http.Error(w, "缺少房间参数", http.StatusBadRequest)
return
}
messages, err := storage.GetRoomHistory(room, 50, 0)
if err != nil {
http.Error(w, "获取历史消息失败", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(messages)
})
}
// Prometheus 指标端点
http.Handle("/metrics", promhttp.Handler())
// 健康检查端点
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
}
func setupGracefulShutdown() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
log.Println("正在关闭服务器...")
os.Exit(0)
}()
}
小结 #
本节详细介绍了 WebSocket 服务端的实现:
- 基础实现:使用 Gorilla WebSocket 库创建基本服务器
- 连接管理:Hub 模式管理客户端连接和房间
- 消息处理:读写分离、消息路由、过滤器机制
- 高级特性:消息持久化、监控指标、速率限制
- 完整架构:模块化设计、优雅关闭、健康检查
通过这些内容,您可以构建出功能完整、性能优良的 WebSocket 服务端。在下一节中,我们将学习如何开发 WebSocket 客户端。