3.8.3 WebSocket 客户端开发 #
WebSocket 客户端是实时通信系统的重要组成部分,负责与服务端建立连接、发送接收消息、处理连接状态等。本节将详细介绍如何使用 Go 语言开发功能完整、稳定可靠的 WebSocket 客户端。
基础客户端实现 #
简单的 WebSocket 客户端 #
// cmd/client/main.go
package main
import (
"bufio"
"fmt"
"log"
"net/url"
"os"
"os/signal"
"time"
"github.com/gorilla/websocket"
)
func main() {
// 解析服务器地址
u := url.URL{Scheme: "ws", Host: "localhost:8080", Path: "/ws"}
log.Printf("连接到 %s", u.String())
// 建立 WebSocket 连接
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
log.Fatal("连接失败:", err)
}
defer conn.Close()
// 设置中断信号处理
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
// 消息接收协程
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取消息失败:", err)
return
}
fmt.Printf("收到消息: %s\n", message)
}
}()
// 消息发送协程
go func() {
scanner := bufio.NewScanner(os.Stdin)
fmt.Println("请输入消息 (输入 'quit' 退出):")
for scanner.Scan() {
text := scanner.Text()
if text == "quit" {
break
}
err := conn.WriteMessage(websocket.TextMessage, []byte(text))
if err != nil {
log.Println("发送消息失败:", err)
return
}
}
// 发送关闭消息
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("发送关闭消息失败:", err)
}
}()
// 等待中断信号或连接关闭
select {
case <-done:
log.Println("连接已关闭")
case <-interrupt:
log.Println("收到中断信号")
// 优雅关闭连接
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("发送关闭消息失败:", err)
return
}
select {
case <-done:
case <-time.After(time.Second):
}
}
}
高级客户端实现 #
客户端结构设计 #
// internal/client/client.go
package client
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
)
// 连接状态
type ConnectionState int
const (
StateDisconnected ConnectionState = iota
StateConnecting
StateConnected
StateReconnecting
)
func (s ConnectionState) String() string {
switch s {
case StateDisconnected:
return "DISCONNECTED"
case StateConnecting:
return "CONNECTING"
case StateConnected:
return "CONNECTED"
case StateReconnecting:
return "RECONNECTING"
default:
return "UNKNOWN"
}
}
// 消息类型
type Message struct {
Type string `json:"type"`
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Room string `json:"room,omitempty"`
Content interface{} `json:"content"`
Timestamp time.Time `json:"timestamp,omitempty"`
}
// 事件处理器
type EventHandler interface {
OnConnect()
OnDisconnect(err error)
OnMessage(msg *Message)
OnError(err error)
}
// 客户端配置
type Config struct {
URL string
Headers http.Header
ReconnectInterval time.Duration
MaxReconnectTries int
PingInterval time.Duration
PongTimeout time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
MessageBufferSize int
EnableCompression bool
}
// 默认配置
func DefaultConfig() *Config {
return &Config{
ReconnectInterval: 5 * time.Second,
MaxReconnectTries: 10,
PingInterval: 30 * time.Second,
PongTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ReadTimeout: 60 * time.Second,
MessageBufferSize: 256,
EnableCompression: true,
}
}
// WebSocket 客户端
type Client struct {
config *Config
conn *websocket.Conn
state ConnectionState
stateMu sync.RWMutex
// 事件处理
eventHandler EventHandler
// 消息通道
sendChan chan []byte
// 控制通道
closeChan chan struct{}
reconnectChan chan struct{}
// 重连计数
reconnectCount int
// 上下文
ctx context.Context
cancel context.CancelFunc
// 等待组
wg sync.WaitGroup
}
// 创建新的客户端
func NewClient(config *Config, handler EventHandler) *Client {
if config == nil {
config = DefaultConfig()
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
config: config,
state: StateDisconnected,
eventHandler: handler,
sendChan: make(chan []byte, config.MessageBufferSize),
closeChan: make(chan struct{}),
reconnectChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
}
// 连接到服务器
func (c *Client) Connect() error {
c.stateMu.Lock()
if c.state != StateDisconnected {
c.stateMu.Unlock()
return fmt.Errorf("客户端已连接或正在连接")
}
c.state = StateConnecting
c.stateMu.Unlock()
return c.connect()
}
// 内部连接方法
func (c *Client) connect() error {
// 解析 URL
u, err := url.Parse(c.config.URL)
if err != nil {
c.setState(StateDisconnected)
return fmt.Errorf("无效的 URL: %v", err)
}
// 创建拨号器
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
EnableCompression: c.config.EnableCompression,
}
// 建立连接
conn, _, err := dialer.Dial(u.String(), c.config.Headers)
if err != nil {
c.setState(StateDisconnected)
return fmt.Errorf("连接失败: %v", err)
}
c.conn = conn
c.setState(StateConnected)
c.reconnectCount = 0
// 启动读写协程
c.wg.Add(2)
go c.readPump()
go c.writePump()
// 通知连接成功
if c.eventHandler != nil {
c.eventHandler.OnConnect()
}
log.Printf("WebSocket 连接已建立: %s", u.String())
return nil
}
// 断开连接
func (c *Client) Disconnect() error {
c.stateMu.Lock()
if c.state == StateDisconnected {
c.stateMu.Unlock()
return nil
}
c.stateMu.Unlock()
// 取消上下文
c.cancel()
// 关闭连接
if c.conn != nil {
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.conn.Close()
}
// 等待协程结束
c.wg.Wait()
c.setState(StateDisconnected)
log.Println("WebSocket 连接已断开")
return nil
}
// 发送消息
func (c *Client) SendMessage(msg *Message) error {
if c.GetState() != StateConnected {
return fmt.Errorf("客户端未连接")
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("消息序列化失败: %v", err)
}
select {
case c.sendChan <- data:
return nil
case <-c.ctx.Done():
return fmt.Errorf("客户端已关闭")
default:
return fmt.Errorf("发送缓冲区已满")
}
}
// 发送文本消息
func (c *Client) SendText(text string) error {
msg := &Message{
Type: "chat",
Content: text,
}
return c.SendMessage(msg)
}
// 加入房间
func (c *Client) JoinRoom(room string) error {
msg := &Message{
Type: "join_room",
Room: room,
}
return c.SendMessage(msg)
}
// 离开房间
func (c *Client) LeaveRoom(room string) error {
msg := &Message{
Type: "leave_room",
Room: room,
}
return c.SendMessage(msg)
}
// 发送私人消息
func (c *Client) SendPrivateMessage(to, content string) error {
msg := &Message{
Type: "private_message",
To: to,
Content: content,
}
return c.SendMessage(msg)
}
// 获取连接状态
func (c *Client) GetState() ConnectionState {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.state
}
// 设置连接状态
func (c *Client) setState(state ConnectionState) {
c.stateMu.Lock()
c.state = state
c.stateMu.Unlock()
}
// 读取消息协程
func (c *Client) readPump() {
defer c.wg.Done()
defer func() {
if c.conn != nil {
c.conn.Close()
}
}()
// 设置读取超时
if c.config.ReadTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
// 设置 Pong 处理器
c.conn.SetPongHandler(func(string) error {
if c.config.ReadTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
return nil
})
for {
select {
case <-c.ctx.Done():
return
default:
}
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket 读取错误: %v", err)
}
c.handleDisconnect(err)
return
}
// 解析消息
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("消息解析失败: %v", err)
continue
}
// 处理消息
if c.eventHandler != nil {
c.eventHandler.OnMessage(&msg)
}
}
}
// 写入消息协程
func (c *Client) writePump() {
defer c.wg.Done()
defer func() {
if c.conn != nil {
c.conn.Close()
}
}()
// Ping 定时器
var pingTicker *time.Ticker
if c.config.PingInterval > 0 {
pingTicker = time.NewTicker(c.config.PingInterval)
defer pingTicker.Stop()
}
for {
select {
case message := <-c.sendChan:
if c.config.WriteTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
log.Printf("发送消息失败: %v", err)
c.handleDisconnect(err)
return
}
case <-pingTicker.C:
if c.config.WriteTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
}
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("发送 Ping 失败: %v", err)
c.handleDisconnect(err)
return
}
case <-c.ctx.Done():
if c.config.WriteTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
}
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
return
}
}
}
// 处理断开连接
func (c *Client) handleDisconnect(err error) {
c.setState(StateDisconnected)
if c.eventHandler != nil {
c.eventHandler.OnDisconnect(err)
}
// 如果不是主动关闭,尝试重连
select {
case <-c.ctx.Done():
// 主动关闭,不重连
return
default:
// 意外断开,尝试重连
c.tryReconnect()
}
}
// 尝试重连
func (c *Client) tryReconnect() {
if c.config.MaxReconnectTries > 0 && c.reconnectCount >= c.config.MaxReconnectTries {
log.Printf("达到最大重连次数 (%d),停止重连", c.config.MaxReconnectTries)
return
}
c.setState(StateReconnecting)
c.reconnectCount++
log.Printf("尝试重连 (%d/%d)...", c.reconnectCount, c.config.MaxReconnectTries)
// 等待重连间隔
select {
case <-time.After(c.config.ReconnectInterval):
case <-c.ctx.Done():
return
}
// 尝试连接
if err := c.connect(); err != nil {
log.Printf("重连失败: %v", err)
if c.eventHandler != nil {
c.eventHandler.OnError(err)
}
// 继续尝试重连
go c.tryReconnect()
}
}
事件处理器实现 #
// internal/client/handler.go
package client
import (
"fmt"
"log"
"time"
)
// 默认事件处理器
type DefaultEventHandler struct {
OnConnectFunc func()
OnDisconnectFunc func(error)
OnMessageFunc func(*Message)
OnErrorFunc func(error)
}
func (h *DefaultEventHandler) OnConnect() {
log.Println("WebSocket 连接已建立")
if h.OnConnectFunc != nil {
h.OnConnectFunc()
}
}
func (h *DefaultEventHandler) OnDisconnect(err error) {
if err != nil {
log.Printf("WebSocket 连接断开: %v", err)
} else {
log.Println("WebSocket 连接正常断开")
}
if h.OnDisconnectFunc != nil {
h.OnDisconnectFunc(err)
}
}
func (h *DefaultEventHandler) OnMessage(msg *Message) {
log.Printf("收到消息: %+v", msg)
if h.OnMessageFunc != nil {
h.OnMessageFunc(msg)
}
}
func (h *DefaultEventHandler) OnError(err error) {
log.Printf("WebSocket 错误: %v", err)
if h.OnErrorFunc != nil {
h.OnErrorFunc(err)
}
}
// 聊天事件处理器
type ChatEventHandler struct {
UserID string
Username string
Rooms map[string]bool
}
func NewChatEventHandler(userID, username string) *ChatEventHandler {
return &ChatEventHandler{
UserID: userID,
Username: username,
Rooms: make(map[string]bool),
}
}
func (h *ChatEventHandler) OnConnect() {
fmt.Printf("✅ 连接成功!用户: %s\n", h.Username)
}
func (h *ChatEventHandler) OnDisconnect(err error) {
if err != nil {
fmt.Printf("❌ 连接断开: %v\n", err)
} else {
fmt.Println("👋 连接已关闭")
}
}
func (h *ChatEventHandler) OnMessage(msg *Message) {
switch msg.Type {
case "chat":
h.handleChatMessage(msg)
case "system":
h.handleSystemMessage(msg)
case "user_joined":
h.handleUserJoined(msg)
case "user_left":
h.handleUserLeft(msg)
case "private_message":
h.handlePrivateMessage(msg)
case "room_joined":
h.handleRoomJoined(msg)
case "room_left":
h.handleRoomLeft(msg)
case "error":
h.handleError(msg)
default:
fmt.Printf("🔍 未知消息类型: %s, 内容: %v\n", msg.Type, msg.Content)
}
}
func (h *ChatEventHandler) OnError(err error) {
fmt.Printf("⚠️ 错误: %v\n", err)
}
func (h *ChatEventHandler) handleChatMessage(msg *Message) {
timestamp := ""
if !msg.Timestamp.IsZero() {
timestamp = msg.Timestamp.Format("15:04:05")
} else {
timestamp = time.Now().Format("15:04:05")
}
if msg.Room != "" {
fmt.Printf("[%s] [%s] %s: %v\n", timestamp, msg.Room, msg.From, msg.Content)
} else {
fmt.Printf("[%s] %s: %v\n", timestamp, msg.From, msg.Content)
}
}
func (h *ChatEventHandler) handleSystemMessage(msg *Message) {
fmt.Printf("🔔 系统消息: %v\n", msg.Content)
}
func (h *ChatEventHandler) handleUserJoined(msg *Message) {
fmt.Printf("👤 %s 加入了房间 %s\n", msg.From, msg.Room)
}
func (h *ChatEventHandler) handleUserLeft(msg *Message) {
fmt.Printf("👤 %s 离开了房间 %s\n", msg.From, msg.Room)
}
func (h *ChatEventHandler) handlePrivateMessage(msg *Message) {
timestamp := time.Now().Format("15:04:05")
fmt.Printf("[%s] 💬 %s (私聊): %v\n", timestamp, msg.From, msg.Content)
}
func (h *ChatEventHandler) handleRoomJoined(msg *Message) {
h.Rooms[msg.Room] = true
fmt.Printf("✅ 成功加入房间: %s\n", msg.Room)
}
func (h *ChatEventHandler) handleRoomLeft(msg *Message) {
delete(h.Rooms, msg.Room)
fmt.Printf("👋 已离开房间: %s\n", msg.Room)
}
func (h *ChatEventHandler) handleError(msg *Message) {
fmt.Printf("❌ 服务器错误: %v\n", msg.Content)
}
客户端管理器 #
连接池管理 #
// internal/client/pool.go
package client
import (
"fmt"
"sync"
"time"
)
// 连接池
type ConnectionPool struct {
clients map[string]*Client
mu sync.RWMutex
maxClients int
}
// 创建连接池
func NewConnectionPool(maxClients int) *ConnectionPool {
return &ConnectionPool{
clients: make(map[string]*Client),
maxClients: maxClients,
}
}
// 获取或创建客户端
func (p *ConnectionPool) GetClient(id string, config *Config, handler EventHandler) (*Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
// 检查是否已存在
if client, exists := p.clients[id]; exists {
return client, nil
}
// 检查连接数限制
if len(p.clients) >= p.maxClients {
return nil, fmt.Errorf("连接池已满,最大连接数: %d", p.maxClients)
}
// 创建新客户端
client := NewClient(config, handler)
p.clients[id] = client
return client, nil
}
// 移除客户端
func (p *ConnectionPool) RemoveClient(id string) {
p.mu.Lock()
defer p.mu.Unlock()
if client, exists := p.clients[id]; exists {
client.Disconnect()
delete(p.clients, id)
}
}
// 获取所有客户端
func (p *ConnectionPool) GetAllClients() map[string]*Client {
p.mu.RLock()
defer p.mu.RUnlock()
clients := make(map[string]*Client)
for id, client := range p.clients {
clients[id] = client
}
return clients
}
// 广播消息
func (p *ConnectionPool) Broadcast(msg *Message) {
p.mu.RLock()
defer p.mu.RUnlock()
for _, client := range p.clients {
if client.GetState() == StateConnected {
client.SendMessage(msg)
}
}
}
// 关闭所有连接
func (p *ConnectionPool) CloseAll() {
p.mu.Lock()
defer p.mu.Unlock()
for id, client := range p.clients {
client.Disconnect()
delete(p.clients, id)
}
}
// 获取统计信息
func (p *ConnectionPool) GetStats() map[string]interface{} {
p.mu.RLock()
defer p.mu.RUnlock()
stats := make(map[string]interface{})
stats["total_clients"] = len(p.clients)
stats["max_clients"] = p.maxClients
stateCount := make(map[string]int)
for _, client := range p.clients {
state := client.GetState().String()
stateCount[state]++
}
stats["state_count"] = stateCount
return stats
}
重连策略 #
// internal/client/reconnect.go
package client
import (
"math"
"math/rand"
"time"
)
// 重连策略接口
type ReconnectStrategy interface {
NextDelay(attempt int) time.Duration
ShouldReconnect(attempt int) bool
}
// 固定间隔重连策略
type FixedIntervalStrategy struct {
Interval time.Duration
MaxRetries int
}
func (s *FixedIntervalStrategy) NextDelay(attempt int) time.Duration {
return s.Interval
}
func (s *FixedIntervalStrategy) ShouldReconnect(attempt int) bool {
return s.MaxRetries <= 0 || attempt < s.MaxRetries
}
// 指数退避重连策略
type ExponentialBackoffStrategy struct {
InitialDelay time.Duration
MaxDelay time.Duration
Multiplier float64
MaxRetries int
Jitter bool
}
func (s *ExponentialBackoffStrategy) NextDelay(attempt int) time.Duration {
delay := float64(s.InitialDelay) * math.Pow(s.Multiplier, float64(attempt))
if delay > float64(s.MaxDelay) {
delay = float64(s.MaxDelay)
}
// 添加随机抖动
if s.Jitter {
jitter := rand.Float64() * 0.1 * delay // 10% 抖动
delay += jitter
}
return time.Duration(delay)
}
func (s *ExponentialBackoffStrategy) ShouldReconnect(attempt int) bool {
return s.MaxRetries <= 0 || attempt < s.MaxRetries
}
// 线性退避重连策略
type LinearBackoffStrategy struct {
InitialDelay time.Duration
MaxDelay time.Duration
Increment time.Duration
MaxRetries int
}
func (s *LinearBackoffStrategy) NextDelay(attempt int) time.Duration {
delay := s.InitialDelay + time.Duration(attempt)*s.Increment
if delay > s.MaxDelay {
delay = s.MaxDelay
}
return delay
}
func (s *LinearBackoffStrategy) ShouldReconnect(attempt int) bool {
return s.MaxRetries <= 0 || attempt < s.MaxRetries
}
// 带重连策略的客户端
type ReconnectableClient struct {
*Client
strategy ReconnectStrategy
}
// 创建带重连策略的客户端
func NewReconnectableClient(config *Config, handler EventHandler, strategy ReconnectStrategy) *ReconnectableClient {
client := NewClient(config, handler)
return &ReconnectableClient{
Client: client,
strategy: strategy,
}
}
// 重写重连逻辑
func (rc *ReconnectableClient) tryReconnectWithStrategy() {
attempt := 0
for rc.strategy.ShouldReconnect(attempt) {
delay := rc.strategy.NextDelay(attempt)
select {
case <-time.After(delay):
case <-rc.ctx.Done():
return
}
if err := rc.connect(); err != nil {
attempt++
if rc.eventHandler != nil {
rc.eventHandler.OnError(fmt.Errorf("重连失败 (尝试 %d): %v", attempt, err))
}
continue
}
// 重连成功
return
}
// 达到最大重连次数
if rc.eventHandler != nil {
rc.eventHandler.OnError(fmt.Errorf("达到最大重连次数,停止重连"))
}
}
实用工具和扩展 #
消息队列 #
// internal/client/queue.go
package client
import (
"container/list"
"sync"
"time"
)
// 消息队列项
type QueueItem struct {
Message *Message
Timestamp time.Time
Retries int
}
// 消息队列
type MessageQueue struct {
queue *list.List
mu sync.Mutex
maxSize int
maxAge time.Duration
}
// 创建消息队列
func NewMessageQueue(maxSize int, maxAge time.Duration) *MessageQueue {
mq := &MessageQueue{
queue: list.New(),
maxSize: maxSize,
maxAge: maxAge,
}
// 启动清理协程
go mq.cleanup()
return mq
}
// 添加消息到队列
func (mq *MessageQueue) Enqueue(msg *Message) bool {
mq.mu.Lock()
defer mq.mu.Unlock()
// 检查队列大小
if mq.queue.Len() >= mq.maxSize {
// 移除最旧的消息
oldest := mq.queue.Front()
if oldest != nil {
mq.queue.Remove(oldest)
}
}
// 添加新消息
item := &QueueItem{
Message: msg,
Timestamp: time.Now(),
Retries: 0,
}
mq.queue.PushBack(item)
return true
}
// 从队列获取消息
func (mq *MessageQueue) Dequeue() *QueueItem {
mq.mu.Lock()
defer mq.mu.Unlock()
front := mq.queue.Front()
if front == nil {
return nil
}
mq.queue.Remove(front)
return front.Value.(*QueueItem)
}
// 获取队列大小
func (mq *MessageQueue) Size() int {
mq.mu.Lock()
defer mq.mu.Unlock()
return mq.queue.Len()
}
// 清理过期消息
func (mq *MessageQueue) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
mq.mu.Lock()
now := time.Now()
for e := mq.queue.Front(); e != nil; {
item := e.Value.(*QueueItem)
if now.Sub(item.Timestamp) > mq.maxAge {
next := e.Next()
mq.queue.Remove(e)
e = next
} else {
break // 队列是按时间顺序的,后面的都不会过期
}
}
mq.mu.Unlock()
}
}
// 带队列的客户端
type QueuedClient struct {
*Client
queue *MessageQueue
}
// 创建带队列的客户端
func NewQueuedClient(config *Config, handler EventHandler, queueSize int, maxAge time.Duration) *QueuedClient {
client := NewClient(config, handler)
queue := NewMessageQueue(queueSize, maxAge)
qc := &QueuedClient{
Client: client,
queue: queue,
}
// 启动队列处理协程
go qc.processQueue()
return qc
}
// 发送消息(通过队列)
func (qc *QueuedClient) SendMessageQueued(msg *Message) error {
if qc.GetState() == StateConnected {
// 直接发送
return qc.Client.SendMessage(msg)
} else {
// 添加到队列
qc.queue.Enqueue(msg)
return nil
}
}
// 处理队列中的消息
func (qc *QueuedClient) processQueue() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if qc.GetState() == StateConnected {
// 处理队列中的消息
for qc.queue.Size() > 0 {
item := qc.queue.Dequeue()
if item == nil {
break
}
if err := qc.Client.SendMessage(item.Message); err != nil {
// 发送失败,重新入队(如果重试次数未超限)
item.Retries++
if item.Retries < 3 {
qc.queue.Enqueue(item.Message)
}
}
}
}
case <-qc.ctx.Done():
return
}
}
}
性能监控 #
// internal/client/metrics.go
package client
import (
"sync"
"time"
)
// 客户端指标
type ClientMetrics struct {
mu sync.RWMutex
connectTime time.Time
totalMessagesSent int64
totalMessagesReceived int64
totalReconnects int64
lastMessageTime time.Time
connectionDuration time.Duration
averageLatency time.Duration
latencyMeasurements []time.Duration
}
// 创建指标收集器
func NewClientMetrics() *ClientMetrics {
return &ClientMetrics{
latencyMeasurements: make([]time.Duration, 0, 100),
}
}
// 记录连接建立
func (m *ClientMetrics) OnConnect() {
m.mu.Lock()
defer m.mu.Unlock()
m.connectTime = time.Now()
}
// 记录连接断开
func (m *ClientMetrics) OnDisconnect() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.connectTime.IsZero() {
m.connectionDuration += time.Since(m.connectTime)
}
}
// 记录重连
func (m *ClientMetrics) OnReconnect() {
m.mu.Lock()
defer m.mu.Unlock()
m.totalReconnects++
}
// 记录发送消息
func (m *ClientMetrics) OnMessageSent() {
m.mu.Lock()
defer m.mu.Unlock()
m.totalMessagesSent++
}
// 记录接收消息
func (m *ClientMetrics) OnMessageReceived() {
m.mu.Lock()
defer m.mu.Unlock()
m.totalMessagesReceived++
m.lastMessageTime = time.Now()
}
// 记录延迟
func (m *ClientMetrics) RecordLatency(latency time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
m.latencyMeasurements = append(m.latencyMeasurements, latency)
// 保持最近 100 个测量值
if len(m.latencyMeasurements) > 100 {
m.latencyMeasurements = m.latencyMeasurements[1:]
}
// 计算平均延迟
var total time.Duration
for _, l := range m.latencyMeasurements {
total += l
}
m.averageLatency = total / time.Duration(len(m.latencyMeasurements))
}
// 获取统计信息
func (m *ClientMetrics) GetStats() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
stats := make(map[string]interface{})
stats["total_messages_sent"] = m.totalMessagesSent
stats["total_messages_received"] = m.totalMessagesReceived
stats["total_reconnects"] = m.totalReconnects
stats["connection_duration_seconds"] = m.connectionDuration.Seconds()
stats["average_latency_ms"] = m.averageLatency.Milliseconds()
if !m.lastMessageTime.IsZero() {
stats["last_message_ago_seconds"] = time.Since(m.lastMessageTime).Seconds()
}
return stats
}
// 带指标的客户端
type MetricsClient struct {
*Client
metrics *ClientMetrics
}
// 创建带指标的客户端
func NewMetricsClient(config *Config, handler EventHandler) *MetricsClient {
metrics := NewClientMetrics()
// 包装事件处理器
wrappedHandler := &metricsEventHandler{
EventHandler: handler,
metrics: metrics,
}
client := NewClient(config, wrappedHandler)
return &MetricsClient{
Client: client,
metrics: metrics,
}
}
// 获取指标
func (mc *MetricsClient) GetMetrics() *ClientMetrics {
return mc.metrics
}
// 指标事件处理器包装器
type metricsEventHandler struct {
EventHandler
metrics *ClientMetrics
}
func (h *metricsEventHandler) OnConnect() {
h.metrics.OnConnect()
if h.EventHandler != nil {
h.EventHandler.OnConnect()
}
}
func (h *metricsEventHandler) OnDisconnect(err error) {
h.metrics.OnDisconnect()
if h.EventHandler != nil {
h.EventHandler.OnDisconnect(err)
}
}
func (h *metricsEventHandler) OnMessage(msg *Message) {
h.metrics.OnMessageReceived()
if h.EventHandler != nil {
h.EventHandler.OnMessage(msg)
}
}
func (h *metricsEventHandler) OnError(err error) {
if h.EventHandler != nil {
h.EventHandler.OnError(err)
}
}
完整的客户端示例 #
// cmd/chat-client/main.go
package main
import (
"bufio"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
"your-project/internal/client"
)
var (
serverURL = flag.String("url", "ws://localhost:8080/ws", "WebSocket 服务器地址")
userID = flag.String("user", "", "用户 ID")
username = flag.String("name", "", "用户名")
)
func main() {
flag.Parse()
if *userID == "" {
*userID = fmt.Sprintf("user_%d", time.Now().Unix())
}
if *username == "" {
*username = *userID
}
// 创建配置
config := client.DefaultConfig()
config.URL = *serverURL
// 创建事件处理器
handler := client.NewChatEventHandler(*userID, *username)
// 创建客户端
c := client.NewMetricsClient(config, handler)
// 连接到服务器
if err := c.Connect(); err != nil {
log.Fatalf("连接失败: %v", err)
}
// 启动命令行界面
go startCLI(c)
// 等待用户输入退出命令
fmt.Println("聊天客户端已启动!")
fmt.Println("命令:")
fmt.Println(" /join <room> - 加入房间")
fmt.Println(" /leave <room> - 离开房间")
fmt.Println(" /pm <user> <msg> - 发送私信")
fmt.Println(" /stats - 显示统计信息")
fmt.Println(" /quit - 退出")
fmt.Println()
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
input := strings.TrimSpace(scanner.Text())
if input == "" {
continue
}
if strings.HasPrefix(input, "/") {
handleCommand(c, input)
} else {
// 发送聊天消息
c.SendText(input)
}
}
}
func handleCommand(c *client.MetricsClient, command string) {
parts := strings.Fields(command)
if len(parts) == 0 {
return
}
switch parts[0] {
case "/join":
if len(parts) < 2 {
fmt.Println("用法: /join <room>")
return
}
c.JoinRoom(parts[1])
case "/leave":
if len(parts) < 2 {
fmt.Println("用法: /leave <room>")
return
}
c.LeaveRoom(parts[1])
case "/pm":
if len(parts) < 3 {
fmt.Println("用法: /pm <user> <message>")
return
}
message := strings.Join(parts[2:], " ")
c.SendPrivateMessage(parts[1], message)
case "/stats":
stats := c.GetMetrics().GetStats()
fmt.Println("📊 客户端统计:")
for key, value := range stats {
fmt.Printf(" %s: %v\n", key, value)
}
case "/quit":
fmt.Println("正在断开连接...")
c.Disconnect()
os.Exit(0)
default:
fmt.Printf("未知命令: %s\n", parts[0])
}
}
func startCLI(c *client.MetricsClient) {
// 这里可以添加更复杂的 CLI 界面
// 比如使用 termui 或其他 TUI 库
}
小结 #
本节详细介绍了 WebSocket 客户端的开发:
- 基础实现:简单的 WebSocket 客户端连接和消息处理
- 高级客户端:状态管理、事件处理、自动重连
- 连接管理:连接池、重连策略、消息队列
- 扩展功能:性能监控、指标收集、错误处理
- 实用工具:命令行界面、配置管理、日志记录
通过这些内容,您可以构建出功能完整、稳定可靠的 WebSocket 客户端应用。在下一节中,我们将通过构建一个完整的实时聊天系统来综合运用所学知识。