4.2.4 WebSocket 服务器 #
WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,它使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。本节将详细介绍如何在 Go 语言中开发 WebSocket 服务器。
WebSocket 协议基础 #
WebSocket 握手过程 #
WebSocket 连接始于 HTTP 握手,然后升级为 WebSocket 协议:
客户端 服务器
| |
|------ HTTP 握手 ------>| (Upgrade: websocket)
|<----- 握手响应 --------| (101 Switching Protocols)
| |
|<====== WebSocket =====>| (全双工通信)
| |
WebSocket 特点 #
- 全双工通信 - 客户端和服务器可以同时发送数据
- 低延迟 - 没有 HTTP 请求/响应的开销
- 持久连接 - 连接保持打开状态
- 支持二进制和文本数据 - 可以传输各种类型的数据
- 自动心跳检测 - 内置的 ping/pong 机制
基础 WebSocket 服务器 #
使用 gorilla/websocket #
首先安装 gorilla/websocket 包:
go mod init websocket-example
go get github.com/gorilla/websocket
简单的 WebSocket 服务器 #
package main
import (
"fmt"
"log"
"net/http"
"github.com/gorilla/websocket"
)
// 升级器配置
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许跨域连接
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func wsHandler(w http.ResponseWriter, r *http.Request) {
// 升级 HTTP 连接为 WebSocket
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
defer conn.Close()
fmt.Printf("新的 WebSocket 连接: %s\n", conn.RemoteAddr())
// 处理消息
for {
// 读取消息
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket 错误: %v", err)
}
break
}
fmt.Printf("收到消息: %s\n", string(message))
// 回显消息
err = conn.WriteMessage(messageType, message)
if err != nil {
log.Printf("写入消息失败: %v", err)
break
}
}
fmt.Printf("WebSocket 连接关闭: %s\n", conn.RemoteAddr())
}
func main() {
http.HandleFunc("/ws", wsHandler)
// 提供静态文件服务
http.Handle("/", http.FileServer(http.Dir("./static/")))
fmt.Println("WebSocket 服务器启动在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
客户端测试页面 #
创建 static/index.html
文件:
<!DOCTYPE html>
<html>
<head>
<title>WebSocket 测试</title>
</head>
<body>
<div id="messages"></div>
<input type="text" id="messageInput" placeholder="输入消息..." />
<button onclick="sendMessage()">发送</button>
<button onclick="disconnect()">断开连接</button>
<script>
let ws = new WebSocket("ws://localhost:8080/ws");
let messages = document.getElementById("messages");
ws.onopen = function (event) {
addMessage("连接已建立");
};
ws.onmessage = function (event) {
addMessage("收到: " + event.data);
};
ws.onclose = function (event) {
addMessage("连接已关闭");
};
ws.onerror = function (error) {
addMessage("错误: " + error);
};
function sendMessage() {
let input = document.getElementById("messageInput");
if (input.value) {
ws.send(input.value);
addMessage("发送: " + input.value);
input.value = "";
}
}
function disconnect() {
ws.close();
}
function addMessage(message) {
let div = document.createElement("div");
div.textContent = new Date().toLocaleTimeString() + " - " + message;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
// 回车发送消息
document
.getElementById("messageInput")
.addEventListener("keypress", function (e) {
if (e.key === "Enter") {
sendMessage();
}
});
</script>
</body>
</html>
高级 WebSocket 服务器 #
连接管理器 #
import (
"encoding/json"
"sync"
"time"
)
// 消息类型
type MessageType string
const (
MessageTypeText MessageType = "text"
MessageTypeJoin MessageType = "join"
MessageTypeLeave MessageType = "leave"
MessageTypeBroadcast MessageType = "broadcast"
MessageTypePrivate MessageType = "private"
)
// 消息结构
type Message struct {
Type MessageType `json:"type"`
From string `json:"from"`
To string `json:"to,omitempty"`
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
}
// 客户端连接
type Client struct {
ID string
Conn *websocket.Conn
Send chan Message
Hub *Hub
Username string
}
// 连接中心
type Hub struct {
clients map[*Client]bool
broadcast chan Message
register chan *Client
unregister chan *Client
mutex sync.RWMutex
}
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan Message),
register: make(chan *Client),
unregister: make(chan *Client),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mutex.Lock()
h.clients[client] = true
h.mutex.Unlock()
fmt.Printf("客户端连接: %s (%s)\n", client.Username, client.ID)
// 通知其他客户端有新用户加入
joinMessage := Message{
Type: MessageTypeJoin,
From: "system",
Content: fmt.Sprintf("%s 加入了聊天室", client.Username),
Timestamp: time.Now(),
}
h.broadcastToAll(joinMessage, client)
case client := <-h.unregister:
h.mutex.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.Send)
h.mutex.Unlock()
fmt.Printf("客户端断开: %s (%s)\n", client.Username, client.ID)
// 通知其他客户端有用户离开
leaveMessage := Message{
Type: MessageTypeLeave,
From: "system",
Content: fmt.Sprintf("%s 离开了聊天室", client.Username),
Timestamp: time.Now(),
}
h.broadcastToAll(leaveMessage, nil)
} else {
h.mutex.Unlock()
}
case message := <-h.broadcast:
h.handleMessage(message)
}
}
}
func (h *Hub) handleMessage(message Message) {
switch message.Type {
case MessageTypeBroadcast:
h.broadcastToAll(message, nil)
case MessageTypePrivate:
h.sendPrivateMessage(message)
default:
h.broadcastToAll(message, nil)
}
}
func (h *Hub) broadcastToAll(message Message, exclude *Client) {
h.mutex.RLock()
defer h.mutex.RUnlock()
for client := range h.clients {
if client != exclude {
select {
case client.Send <- message:
default:
// 发送失败,关闭连接
delete(h.clients, client)
close(client.Send)
}
}
}
}
func (h *Hub) sendPrivateMessage(message Message) {
h.mutex.RLock()
defer h.mutex.RUnlock()
for client := range h.clients {
if client.Username == message.To {
select {
case client.Send <- message:
return
default:
// 发送失败
return
}
}
}
}
func (h *Hub) GetOnlineUsers() []string {
h.mutex.RLock()
defer h.mutex.RUnlock()
users := make([]string, 0, len(h.clients))
for client := range h.clients {
users = append(users, client.Username)
}
return users
}
func (h *Hub) GetClientCount() int {
h.mutex.RLock()
defer h.mutex.RUnlock()
return len(h.clients)
}
客户端处理 #
import (
"crypto/rand"
"encoding/hex"
"time"
)
func (c *Client) readPump() {
defer func() {
c.Hub.unregister <- c
c.Conn.Close()
}()
// 设置读取限制和超时
c.Conn.SetReadLimit(512)
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
var message Message
err := c.Conn.ReadJSON(&message)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
fmt.Printf("WebSocket 错误: %v", err)
}
break
}
// 设置消息元数据
message.From = c.Username
message.Timestamp = time.Now()
// 发送到中心处理
c.Hub.broadcast <- message
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
err := c.Conn.WriteJSON(message)
if err != nil {
fmt.Printf("写入消息失败: %v", err)
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func generateClientID() string {
bytes := make([]byte, 16)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
func wsHandlerAdvanced(hub *Hub, w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
// 获取用户名
username := r.URL.Query().Get("username")
if username == "" {
username = "匿名用户"
}
client := &Client{
ID: generateClientID(),
Conn: conn,
Send: make(chan Message, 256),
Hub: hub,
Username: username,
}
client.Hub.register <- client
// 启动读写协程
go client.writePump()
go client.readPump()
}
聊天室实现 #
完整的聊天室服务器 #
func main() {
hub := NewHub()
go hub.Run()
// WebSocket 处理器
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
wsHandlerAdvanced(hub, w, r)
})
// API 端点
http.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
users := hub.GetOnlineUsers()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"users": users,
"count": len(users),
})
})
http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"online_users": hub.GetClientCount(),
"server_time": time.Now().Format(time.RFC3339),
})
})
// 静态文件服务
http.Handle("/", http.FileServer(http.Dir("./static/")))
fmt.Println("聊天室服务器启动在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
高级聊天室客户端 #
创建 static/chat.html
:
<!DOCTYPE html>
<html>
<head>
<title>WebSocket 聊天室</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 20px;
}
#chatContainer {
display: flex;
height: 500px;
}
#messages {
flex: 1;
border: 1px solid #ccc;
padding: 10px;
overflow-y: auto;
}
#userList {
width: 200px;
border: 1px solid #ccc;
padding: 10px;
margin-left: 10px;
}
#inputContainer {
margin-top: 10px;
}
#messageInput {
width: 70%;
padding: 5px;
}
#sendButton {
padding: 5px 10px;
}
.message {
margin: 5px 0;
}
.system {
color: #666;
font-style: italic;
}
.private {
color: #007bff;
}
.timestamp {
color: #999;
font-size: 0.8em;
}
</style>
</head>
<body>
<h1>WebSocket 聊天室</h1>
<div id="loginContainer">
<input type="text" id="usernameInput" placeholder="输入用户名" />
<button onclick="connect()">连接</button>
</div>
<div id="chatContainer" style="display: none;">
<div id="messages"></div>
<div id="userList">
<h3>在线用户</h3>
<div id="users"></div>
</div>
</div>
<div id="inputContainer" style="display: none;">
<input type="text" id="messageInput" placeholder="输入消息..." />
<button id="sendButton" onclick="sendMessage()">发送</button>
<button onclick="disconnect()">断开连接</button>
</div>
<script>
let ws = null;
let username = "";
function connect() {
username = document.getElementById("usernameInput").value.trim();
if (!username) {
alert("请输入用户名");
return;
}
ws = new WebSocket(
`ws://localhost:8080/ws?username=${encodeURIComponent(username)}`
);
ws.onopen = function (event) {
document.getElementById("loginContainer").style.display = "none";
document.getElementById("chatContainer").style.display = "flex";
document.getElementById("inputContainer").style.display = "block";
addMessage("系统", "连接成功", "system");
loadOnlineUsers();
};
ws.onmessage = function (event) {
const message = JSON.parse(event.data);
handleMessage(message);
};
ws.onclose = function (event) {
addMessage("系统", "连接已关闭", "system");
document.getElementById("loginContainer").style.display = "block";
document.getElementById("chatContainer").style.display = "none";
document.getElementById("inputContainer").style.display = "none";
};
ws.onerror = function (error) {
addMessage("系统", "连接错误", "system");
};
}
function disconnect() {
if (ws) {
ws.close();
}
}
function sendMessage() {
const input = document.getElementById("messageInput");
const content = input.value.trim();
if (!content) return;
const message = {
type: "broadcast",
content: content,
};
ws.send(JSON.stringify(message));
input.value = "";
}
function handleMessage(message) {
switch (message.type) {
case "join":
case "leave":
addMessage("系统", message.content, "system");
loadOnlineUsers();
break;
case "broadcast":
case "text":
addMessage(message.from, message.content);
break;
case "private":
addMessage(`${message.from} (私聊)`, message.content, "private");
break;
}
}
function addMessage(from, content, className = "") {
const messages = document.getElementById("messages");
const div = document.createElement("div");
div.className = "message " + className;
const timestamp = new Date().toLocaleTimeString();
div.innerHTML = `
<span class="timestamp">[${timestamp}]</span>
<strong>${from}:</strong> ${content}
`;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function loadOnlineUsers() {
fetch("/api/users")
.then((response) => response.json())
.then((data) => {
const usersDiv = document.getElementById("users");
usersDiv.innerHTML = "";
data.users.forEach((user) => {
const userDiv = document.createElement("div");
userDiv.textContent = user;
if (user === username) {
userDiv.style.fontWeight = "bold";
}
usersDiv.appendChild(userDiv);
});
})
.catch((error) => console.error("加载用户列表失败:", error));
}
// 回车发送消息
document
.getElementById("messageInput")
.addEventListener("keypress", function (e) {
if (e.key === "Enter") {
sendMessage();
}
});
document
.getElementById("usernameInput")
.addEventListener("keypress", function (e) {
if (e.key === "Enter") {
connect();
}
});
</script>
</body>
</html>
房间管理系统 #
多房间聊天实现 #
type Room struct {
ID string
Name string
Clients map[*Client]bool
mutex sync.RWMutex
}
type RoomManager struct {
rooms map[string]*Room
mutex sync.RWMutex
}
func NewRoomManager() *RoomManager {
return &RoomManager{
rooms: make(map[string]*Room),
}
}
func (rm *RoomManager) CreateRoom(id, name string) *Room {
rm.mutex.Lock()
defer rm.mutex.Unlock()
room := &Room{
ID: id,
Name: name,
Clients: make(map[*Client]bool),
}
rm.rooms[id] = room
return room
}
func (rm *RoomManager) GetRoom(id string) (*Room, bool) {
rm.mutex.RLock()
defer rm.mutex.RUnlock()
room, exists := rm.rooms[id]
return room, exists
}
func (rm *RoomManager) JoinRoom(client *Client, roomID string) error {
room, exists := rm.GetRoom(roomID)
if !exists {
return fmt.Errorf("房间不存在: %s", roomID)
}
room.mutex.Lock()
defer room.mutex.Unlock()
room.Clients[client] = true
client.Room = room
// 通知房间内其他用户
message := Message{
Type: MessageTypeJoin,
From: "system",
Content: fmt.Sprintf("%s 加入了房间", client.Username),
Timestamp: time.Now(),
}
room.broadcast(message, client)
return nil
}
func (rm *RoomManager) LeaveRoom(client *Client) {
if client.Room == nil {
return
}
room := client.Room
room.mutex.Lock()
defer room.mutex.Unlock()
delete(room.Clients, client)
client.Room = nil
// 通知房间内其他用户
message := Message{
Type: MessageTypeLeave,
From: "system",
Content: fmt.Sprintf("%s 离开了房间", client.Username),
Timestamp: time.Now(),
}
room.broadcast(message, nil)
}
func (r *Room) broadcast(message Message, exclude *Client) {
for client := range r.Clients {
if client != exclude {
select {
case client.Send <- message:
default:
delete(r.Clients, client)
close(client.Send)
}
}
}
}
func (r *Room) GetClientCount() int {
r.mutex.RLock()
defer r.mutex.RUnlock()
return len(r.Clients)
}
// 更新 Client 结构
type Client struct {
ID string
Conn *websocket.Conn
Send chan Message
Hub *Hub
Room *Room
Username string
}
性能优化 #
连接池管理 #
type ConnectionPool struct {
maxConnections int
activeConns int
waitingQueue chan *Client
mutex sync.Mutex
}
func NewConnectionPool(maxConnections int) *ConnectionPool {
return &ConnectionPool{
maxConnections: maxConnections,
waitingQueue: make(chan *Client, 100),
}
}
func (cp *ConnectionPool) TryAccept(client *Client) bool {
cp.mutex.Lock()
defer cp.mutex.Unlock()
if cp.activeConns < cp.maxConnections {
cp.activeConns++
return true
}
// 连接数已满,加入等待队列
select {
case cp.waitingQueue <- client:
return false
default:
// 等待队列也满了,拒绝连接
return false
}
}
func (cp *ConnectionPool) Release() {
cp.mutex.Lock()
defer cp.mutex.Unlock()
cp.activeConns--
// 处理等待队列
select {
case client := <-cp.waitingQueue:
cp.activeConns++
go func() {
// 处理等待的客户端
client.Hub.register <- client
}()
default:
// 没有等待的连接
}
}
消息压缩 #
import (
"compress/flate"
)
var upgraderWithCompression = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
EnableCompression: true,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func wsHandlerWithCompression(w http.ResponseWriter, r *http.Request) {
conn, err := upgraderWithCompression.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
// 设置压缩级别
conn.EnableWriteCompression(true)
conn.SetCompressionLevel(flate.BestSpeed)
// 处理连接...
}
内存优化 #
import (
"sync"
)
// 消息对象池
var messagePool = sync.Pool{
New: func() interface{} {
return &Message{}
},
}
func getMessageFromPool() *Message {
return messagePool.Get().(*Message)
}
func putMessageToPool(msg *Message) {
// 重置消息内容
msg.Type = ""
msg.From = ""
msg.To = ""
msg.Content = ""
msg.Timestamp = time.Time{}
messagePool.Put(msg)
}
// 在处理消息时使用对象池
func (c *Client) readPumpOptimized() {
defer func() {
c.Hub.unregister <- c
c.Conn.Close()
}()
for {
message := getMessageFromPool()
err := c.Conn.ReadJSON(message)
if err != nil {
putMessageToPool(message)
break
}
// 处理消息
message.From = c.Username
message.Timestamp = time.Now()
c.Hub.broadcast <- *message
putMessageToPool(message)
}
}
监控和调试 #
WebSocket 监控 #
type WSMonitor struct {
totalConnections int64
activeConnections int64
messagesSent int64
messagesReceived int64
errors int64
mutex sync.RWMutex
}
func (m *WSMonitor) RecordConnection() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.totalConnections++
m.activeConnections++
}
func (m *WSMonitor) RecordDisconnection() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.activeConnections--
}
func (m *WSMonitor) RecordMessageSent() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.messagesSent++
}
func (m *WSMonitor) RecordMessageReceived() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.messagesReceived++
}
func (m *WSMonitor) RecordError() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.errors++
}
func (m *WSMonitor) GetStats() map[string]int64 {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]int64{
"total_connections": m.totalConnections,
"active_connections": m.activeConnections,
"messages_sent": m.messagesSent,
"messages_received": m.messagesReceived,
"errors": m.errors,
}
}
// 全局监控实例
var wsMonitor = &WSMonitor{}
// 监控端点
func monitorHandler(w http.ResponseWriter, r *http.Request) {
stats := wsMonitor.GetStats()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
健康检查 #
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
stats := wsMonitor.GetStats()
health := map[string]interface{}{
"status": "healthy",
"timestamp": time.Now().Format(time.RFC3339),
"websocket": map[string]interface{}{
"active_connections": stats["active_connections"],
"total_connections": stats["total_connections"],
},
}
// 检查连接数是否过高
if stats["active_connections"] > 1000 {
health["status"] = "warning"
health["message"] = "高连接数"
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
}
小结 #
本节详细介绍了 Go 语言中的 WebSocket 服务器开发,包括:
- WebSocket 基础 - 协议特点和握手过程
- 基础服务器 - 简单的 WebSocket 服务器实现
- 高级服务器 - 连接管理、消息处理和广播机制
- 聊天室实现 - 完整的实时聊天应用
- 房间管理 - 多房间聊天系统
- 性能优化 - 连接池、压缩、内存优化等技术
- 监控调试 - 连接监控和健康检查
掌握这些 WebSocket 开发技术后,你就能够构建高性能的实时通信应用。至此,我们完成了网络编程基础章节的学习,为后续的高级网络编程和系统编程打下了坚实的基础。