3.8.1 WebSocket 协议详解

3.8.1 WebSocket 协议详解 #

WebSocket 是一种基于 TCP 的应用层通信协议,它在 2011 年被 IETF 标准化为 RFC 6455。WebSocket 协议使得客户端和服务器之间可以建立持久的全双工通信连接,极大地提升了实时应用的性能和用户体验。

WebSocket 协议概述 #

协议特点 #

1. 全双工通信 WebSocket 连接建立后,客户端和服务端都可以主动发送数据,无需等待对方的请求。

2. 低开销 相比传统的 HTTP 轮询,WebSocket 只需要一次握手,后续通信无需重复的 HTTP 头部信息。

3. 实时性 消息可以即时传递,无需轮询等待,大大降低了通信延迟。

4. 协议升级 WebSocket 通过 HTTP 升级机制建立连接,兼容现有的网络基础设施。

与 HTTP 的比较 #

特性 HTTP WebSocket
通信模式 请求-响应 全双工
连接持久性 短连接 长连接
服务端推送 不支持 原生支持
协议开销 每次请求都有头部 握手后开销很小
实时性 需要轮询 即时通信

WebSocket 握手过程 #

客户端握手请求 #

WebSocket 连接始于一个标准的 HTTP 请求,包含特殊的升级头部:

GET /chat HTTP/1.1
Host: example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Sec-WebSocket-Protocol: chat, superchat
Origin: http://example.com

关键头部说明:

  • Upgrade: websocket:请求协议升级到 WebSocket
  • Connection: Upgrade:表示这是一个升级请求
  • Sec-WebSocket-Key:客户端生成的随机字符串,用于验证
  • Sec-WebSocket-Version:WebSocket 协议版本,当前为 13
  • Sec-WebSocket-Protocol:可选的子协议列表
  • Origin:请求的来源,用于安全验证

服务端握手响应 #

服务端验证请求后,返回升级确认响应:

HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat

关键头部说明:

  • 101 Switching Protocols:表示协议切换成功
  • Sec-WebSocket-Accept:服务端根据客户端的 Key 计算得出的确认值
  • Sec-WebSocket-Protocol:服务端选择的子协议

握手验证算法 #

服务端需要验证 Sec-WebSocket-Accept 的计算:

package main

import (
    "crypto/sha1"
    "encoding/base64"
    "fmt"
)

const websocketMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

func generateAcceptKey(clientKey string) string {
    // 将客户端密钥与魔术字符串连接
    combined := clientKey + websocketMagicString

    // 计算 SHA-1 哈希
    hash := sha1.Sum([]byte(combined))

    // Base64 编码
    return base64.StdEncoding.EncodeToString(hash[:])
}

func main() {
    clientKey := "dGhlIHNhbXBsZSBub25jZQ=="
    acceptKey := generateAcceptKey(clientKey)
    fmt.Printf("Sec-WebSocket-Accept: %s\n", acceptKey)
    // 输出: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
}

WebSocket 数据帧格式 #

帧结构 #

WebSocket 数据以帧(Frame)的形式传输,每个帧包含以下结构:

 0                   1                   2                   3
 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
|N|V|V|V|       |S|             |   (if payload len==126/127)   |
| |1|2|3|       |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |          Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+

字段说明 #

1. FIN (1 bit)

  • 表示这是否为消息的最后一个片段
  • 1 表示最后一个片段,0 表示还有后续片段

2. RSV1, RSV2, RSV3 (各 1 bit)

  • 保留位,必须为 0,除非扩展定义了非零值的含义

3. Opcode (4 bits)

  • 定义帧的类型:
    • 0x0:继续帧
    • 0x1:文本帧
    • 0x2:二进制帧
    • 0x8:连接关闭
    • 0x9:Ping
    • 0xA:Pong

4. MASK (1 bit)

  • 表示载荷数据是否被掩码
  • 客户端发送的帧必须设置为 1
  • 服务端发送的帧必须设置为 0

5. Payload Length (7 bits, 7+16 bits, 或 7+64 bits)

  • 载荷数据的长度
  • 如果值为 0-125,则为实际长度
  • 如果值为 126,则后续 16 位为实际长度
  • 如果值为 127,则后续 64 位为实际长度

帧类型详解 #

package main

import (
    "fmt"
)

// WebSocket 操作码常量
const (
    OpcodeContinuation = 0x0
    OpcodeText         = 0x1
    OpcodeBinary       = 0x2
    OpcodeClose        = 0x8
    OpcodePing         = 0x9
    OpcodePong         = 0xA
)

// 帧类型描述
func getOpcodeDescription(opcode byte) string {
    switch opcode {
    case OpcodeContinuation:
        return "继续帧 - 分片消息的后续部分"
    case OpcodeText:
        return "文本帧 - UTF-8 编码的文本数据"
    case OpcodeBinary:
        return "二进制帧 - 任意二进制数据"
    case OpcodeClose:
        return "关闭帧 - 连接关闭请求"
    case OpcodePing:
        return "Ping 帧 - 心跳检测"
    case OpcodePong:
        return "Pong 帧 - 心跳响应"
    default:
        return "未知帧类型"
    }
}

func main() {
    opcodes := []byte{0x0, 0x1, 0x2, 0x8, 0x9, 0xA}

    for _, opcode := range opcodes {
        fmt.Printf("Opcode 0x%X: %s\n", opcode, getOpcodeDescription(opcode))
    }
}

消息分片机制 #

分片原理 #

当消息过大时,WebSocket 支持将消息分割成多个帧传输:

package main

import (
    "fmt"
    "strings"
)

// 模拟消息分片
type Frame struct {
    FIN    bool   // 是否为最后一帧
    Opcode byte   // 操作码
    Data   string // 数据
}

func fragmentMessage(message string, chunkSize int) []Frame {
    var frames []Frame

    for i := 0; i < len(message); i += chunkSize {
        end := i + chunkSize
        if end > len(message) {
            end = len(message)
        }

        chunk := message[i:end]

        frame := Frame{
            Data: chunk,
        }

        if i == 0 {
            // 第一帧使用文本操作码
            frame.Opcode = OpcodeText
        } else {
            // 后续帧使用继续操作码
            frame.Opcode = OpcodeContinuation
        }

        // 最后一帧设置 FIN 标志
        if end == len(message) {
            frame.FIN = true
        }

        frames = append(frames, frame)
    }

    return frames
}

func reassembleMessage(frames []Frame) string {
    var parts []string

    for _, frame := range frames {
        parts = append(parts, frame.Data)
    }

    return strings.Join(parts, "")
}

func main() {
    message := "这是一个很长的消息,需要分片传输以避免单个帧过大"
    chunkSize := 10

    // 分片
    frames := fragmentMessage(message, chunkSize)

    fmt.Printf("原始消息: %s\n", message)
    fmt.Printf("分片数量: %d\n\n", len(frames))

    for i, frame := range frames {
        fmt.Printf("帧 %d:\n", i+1)
        fmt.Printf("  FIN: %t\n", frame.FIN)
        fmt.Printf("  Opcode: 0x%X (%s)\n", frame.Opcode, getOpcodeDescription(frame.Opcode))
        fmt.Printf("  Data: %s\n\n", frame.Data)
    }

    // 重组
    reassembled := reassembleMessage(frames)
    fmt.Printf("重组消息: %s\n", reassembled)
    fmt.Printf("消息完整性: %t\n", message == reassembled)
}

连接管理 #

连接状态 #

WebSocket 连接具有以下状态:

package main

import (
    "fmt"
    "time"
)

// WebSocket 连接状态
type ConnectionState int

const (
    StateConnecting ConnectionState = iota
    StateOpen
    StateClosing
    StateClosed
)

func (s ConnectionState) String() string {
    switch s {
    case StateConnecting:
        return "CONNECTING"
    case StateOpen:
        return "OPEN"
    case StateClosing:
        return "CLOSING"
    case StateClosed:
        return "CLOSED"
    default:
        return "UNKNOWN"
    }
}

// WebSocket 连接模拟
type WebSocketConnection struct {
    state     ConnectionState
    createdAt time.Time
}

func NewWebSocketConnection() *WebSocketConnection {
    return &WebSocketConnection{
        state:     StateConnecting,
        createdAt: time.Now(),
    }
}

func (conn *WebSocketConnection) GetState() ConnectionState {
    return conn.state
}

func (conn *WebSocketConnection) Open() {
    if conn.state == StateConnecting {
        conn.state = StateOpen
        fmt.Printf("连接已建立,状态: %s\n", conn.state)
    }
}

func (conn *WebSocketConnection) Close() {
    if conn.state == StateOpen {
        conn.state = StateClosing
        fmt.Printf("开始关闭连接,状态: %s\n", conn.state)

        // 模拟关闭过程
        time.Sleep(100 * time.Millisecond)

        conn.state = StateClosed
        fmt.Printf("连接已关闭,状态: %s\n", conn.state)
    }
}

func (conn *WebSocketConnection) IsOpen() bool {
    return conn.state == StateOpen
}

func main() {
    conn := NewWebSocketConnection()

    fmt.Printf("初始状态: %s\n", conn.GetState())

    // 建立连接
    conn.Open()

    // 检查连接状态
    if conn.IsOpen() {
        fmt.Println("连接可用,可以发送消息")
    }

    // 关闭连接
    conn.Close()

    fmt.Printf("连接存活时间: %v\n", time.Since(conn.createdAt))
}

心跳机制 #

WebSocket 提供了 Ping/Pong 帧来实现心跳检测:

package main

import (
    "fmt"
    "time"
)

// 心跳管理器
type HeartbeatManager struct {
    interval    time.Duration
    timeout     time.Duration
    lastPong    time.Time
    ticker      *time.Ticker
    stopChan    chan struct{}
    isActive    bool
}

func NewHeartbeatManager(interval, timeout time.Duration) *HeartbeatManager {
    return &HeartbeatManager{
        interval: interval,
        timeout:  timeout,
        lastPong: time.Now(),
        stopChan: make(chan struct{}),
    }
}

func (hm *HeartbeatManager) Start() {
    if hm.isActive {
        return
    }

    hm.isActive = true
    hm.ticker = time.NewTicker(hm.interval)

    go func() {
        for {
            select {
            case <-hm.ticker.C:
                // 检查是否超时
                if time.Since(hm.lastPong) > hm.timeout {
                    fmt.Println("心跳超时,连接可能已断开")
                    hm.Stop()
                    return
                }

                // 发送 Ping
                fmt.Printf("发送 Ping 帧 - %s\n", time.Now().Format("15:04:05"))

                // 模拟发送 Ping 帧的逻辑
                go hm.simulatePong()

            case <-hm.stopChan:
                return
            }
        }
    }()

    fmt.Printf("心跳管理器已启动,间隔: %v,超时: %v\n", hm.interval, hm.timeout)
}

func (hm *HeartbeatManager) Stop() {
    if !hm.isActive {
        return
    }

    hm.isActive = false
    if hm.ticker != nil {
        hm.ticker.Stop()
    }

    close(hm.stopChan)
    fmt.Println("心跳管理器已停止")
}

func (hm *HeartbeatManager) OnPong() {
    hm.lastPong = time.Now()
    fmt.Printf("收到 Pong 帧 - %s\n", hm.lastPong.Format("15:04:05"))
}

// 模拟接收 Pong 响应
func (hm *HeartbeatManager) simulatePong() {
    // 模拟网络延迟
    time.Sleep(50 * time.Millisecond)
    hm.OnPong()
}

func main() {
    // 创建心跳管理器:每 3 秒发送一次 Ping,超时时间 10 秒
    heartbeat := NewHeartbeatManager(3*time.Second, 10*time.Second)

    // 启动心跳
    heartbeat.Start()

    // 运行 15 秒
    time.Sleep(15 * time.Second)

    // 停止心跳
    heartbeat.Stop()
}

关闭握手 #

正常关闭流程 #

WebSocket 连接的正常关闭需要经过握手过程:

package main

import (
    "fmt"
    "time"
)

// 关闭状态码
const (
    CloseNormalClosure           = 1000
    CloseGoingAway              = 1001
    CloseProtocolError          = 1002
    CloseUnsupportedData        = 1003
    CloseNoStatusReceived       = 1005
    CloseAbnormalClosure        = 1006
    CloseInvalidFramePayloadData = 1007
    ClosePolicyViolation        = 1008
    CloseMessageTooBig          = 1009
    CloseMandatoryExtension     = 1010
    CloseInternalServerErr      = 1011
    CloseServiceRestart         = 1012
    CloseTryAgainLater          = 1013
    CloseBadGateway             = 1014
    CloseTLSHandshake           = 1015
)

// 关闭帧
type CloseFrame struct {
    Code   uint16
    Reason string
}

func getCloseCodeDescription(code uint16) string {
    switch code {
    case CloseNormalClosure:
        return "正常关闭"
    case CloseGoingAway:
        return "端点离开"
    case CloseProtocolError:
        return "协议错误"
    case CloseUnsupportedData:
        return "不支持的数据"
    case CloseNoStatusReceived:
        return "未收到状态码"
    case CloseAbnormalClosure:
        return "异常关闭"
    case CloseInvalidFramePayloadData:
        return "无效的帧载荷数据"
    case ClosePolicyViolation:
        return "策略违规"
    case CloseMessageTooBig:
        return "消息过大"
    case CloseMandatoryExtension:
        return "强制扩展"
    case CloseInternalServerErr:
        return "内部服务器错误"
    case CloseServiceRestart:
        return "服务重启"
    case CloseTryAgainLater:
        return "稍后重试"
    case CloseBadGateway:
        return "网关错误"
    case CloseTLSHandshake:
        return "TLS 握手失败"
    default:
        return "未知错误"
    }
}

// WebSocket 关闭管理器
type CloseManager struct {
    closeInitiated bool
    closeReceived  bool
    closeFrame     *CloseFrame
}

func NewCloseManager() *CloseManager {
    return &CloseManager{}
}

// 发起关闭
func (cm *CloseManager) InitiateClose(code uint16, reason string) {
    if cm.closeInitiated {
        return
    }

    cm.closeInitiated = true
    cm.closeFrame = &CloseFrame{
        Code:   code,
        Reason: reason,
    }

    fmt.Printf("发起关闭握手:\n")
    fmt.Printf("  状态码: %d (%s)\n", code, getCloseCodeDescription(code))
    fmt.Printf("  原因: %s\n", reason)

    // 模拟发送关闭帧
    go cm.simulateCloseResponse()
}

// 接收关闭帧
func (cm *CloseManager) ReceiveClose(code uint16, reason string) {
    cm.closeReceived = true

    fmt.Printf("收到关闭帧:\n")
    fmt.Printf("  状态码: %d (%s)\n", code, getCloseCodeDescription(code))
    fmt.Printf("  原因: %s\n", reason)

    if !cm.closeInitiated {
        // 如果我们没有发起关闭,需要回复关闭帧
        fmt.Println("回复关闭帧确认")
        cm.closeInitiated = true
    }

    fmt.Println("关闭握手完成,连接将被关闭")
}

// 模拟接收关闭响应
func (cm *CloseManager) simulateCloseResponse() {
    // 模拟网络延迟
    time.Sleep(100 * time.Millisecond)

    // 模拟对方回复关闭帧
    cm.ReceiveClose(cm.closeFrame.Code, "确认关闭")
}

func main() {
    closeManager := NewCloseManager()

    // 模拟正常关闭
    fmt.Println("=== 正常关闭流程 ===")
    closeManager.InitiateClose(CloseNormalClosure, "用户主动关闭")

    time.Sleep(200 * time.Millisecond)

    // 模拟异常关闭
    fmt.Println("\n=== 异常关闭流程 ===")
    closeManager2 := NewCloseManager()
    closeManager2.InitiateClose(CloseInternalServerErr, "服务器内部错误")

    time.Sleep(200 * time.Millisecond)
}

安全考虑 #

同源策略 #

WebSocket 连接受到同源策略的限制:

package main

import (
    "fmt"
    "net/url"
    "strings"
)

// 同源检查器
type OriginChecker struct {
    allowedOrigins []string
    allowAll       bool
}

func NewOriginChecker(allowedOrigins []string) *OriginChecker {
    return &OriginChecker{
        allowedOrigins: allowedOrigins,
        allowAll:       len(allowedOrigins) == 0,
    }
}

func (oc *OriginChecker) IsAllowed(origin string) bool {
    if oc.allowAll {
        return true
    }

    if origin == "" {
        return false
    }

    // 解析 Origin URL
    originURL, err := url.Parse(origin)
    if err != nil {
        return false
    }

    // 检查是否在允许列表中
    for _, allowed := range oc.allowedOrigins {
        if oc.matchOrigin(originURL, allowed) {
            return true
        }
    }

    return false
}

func (oc *OriginChecker) matchOrigin(originURL *url.URL, pattern string) bool {
    // 支持通配符匹配
    if pattern == "*" {
        return true
    }

    // 精确匹配
    if originURL.String() == pattern {
        return true
    }

    // 域名匹配(支持子域名)
    if strings.HasPrefix(pattern, "*.") {
        domain := strings.TrimPrefix(pattern, "*.")
        return strings.HasSuffix(originURL.Host, domain)
    }

    return false
}

func main() {
    // 创建同源检查器
    checker := NewOriginChecker([]string{
        "https://example.com",
        "https://app.example.com",
        "*.trusted-domain.com",
    })

    // 测试不同的 Origin
    testOrigins := []string{
        "https://example.com",
        "https://app.example.com",
        "https://sub.trusted-domain.com",
        "https://malicious.com",
        "http://example.com", // 协议不匹配
        "",
    }

    fmt.Println("同源策略检查结果:")
    for _, origin := range testOrigins {
        allowed := checker.IsAllowed(origin)
        status := "拒绝"
        if allowed {
            status = "允许"
        }
        fmt.Printf("Origin: %-30s -> %s\n", origin, status)
    }
}

输入验证 #

WebSocket 消息需要进行严格的输入验证:

package main

import (
    "encoding/json"
    "fmt"
    "strings"
    "unicode/utf8"
)

// 消息验证器
type MessageValidator struct {
    maxTextLength   int
    maxBinaryLength int
    allowedTypes    map[string]bool
}

func NewMessageValidator() *MessageValidator {
    return &MessageValidator{
        maxTextLength:   1024 * 1024, // 1MB
        maxBinaryLength: 10 * 1024 * 1024, // 10MB
        allowedTypes: map[string]bool{
            "message": true,
            "ping":    true,
            "join":    true,
            "leave":   true,
        },
    }
}

// 验证文本消息
func (mv *MessageValidator) ValidateTextMessage(data []byte) error {
    // 检查长度
    if len(data) > mv.maxTextLength {
        return fmt.Errorf("文本消息过长: %d > %d", len(data), mv.maxTextLength)
    }

    // 检查 UTF-8 编码
    if !utf8.Valid(data) {
        return fmt.Errorf("无效的 UTF-8 编码")
    }

    // 检查是否包含控制字符
    text := string(data)
    for _, r := range text {
        if r < 32 && r != '\n' && r != '\r' && r != '\t' {
            return fmt.Errorf("包含无效的控制字符: %U", r)
        }
    }

    return nil
}

// 验证二进制消息
func (mv *MessageValidator) ValidateBinaryMessage(data []byte) error {
    // 检查长度
    if len(data) > mv.maxBinaryLength {
        return fmt.Errorf("二进制消息过长: %d > %d", len(data), mv.maxBinaryLength)
    }

    return nil
}

// 验证 JSON 消息
func (mv *MessageValidator) ValidateJSONMessage(data []byte) error {
    // 先验证文本消息
    if err := mv.ValidateTextMessage(data); err != nil {
        return err
    }

    // 解析 JSON
    var msg map[string]interface{}
    if err := json.Unmarshal(data, &msg); err != nil {
        return fmt.Errorf("无效的 JSON 格式: %v", err)
    }

    // 检查消息类型
    msgType, ok := msg["type"].(string)
    if !ok {
        return fmt.Errorf("缺少消息类型字段")
    }

    if !mv.allowedTypes[msgType] {
        return fmt.Errorf("不允许的消息类型: %s", msgType)
    }

    // 检查必需字段
    switch msgType {
    case "message":
        if _, ok := msg["content"]; !ok {
            return fmt.Errorf("消息类型 'message' 缺少 'content' 字段")
        }
    case "join", "leave":
        if _, ok := msg["room"]; !ok {
            return fmt.Errorf("消息类型 '%s' 缺少 'room' 字段", msgType)
        }
    }

    return nil
}

// 清理和过滤消息内容
func (mv *MessageValidator) SanitizeMessage(content string) string {
    // 移除潜在的恶意内容
    content = strings.ReplaceAll(content, "<script>", "")
    content = strings.ReplaceAll(content, "</script>", "")
    content = strings.ReplaceAll(content, "javascript:", "")
    content = strings.ReplaceAll(content, "data:", "")

    // 限制长度
    if len(content) > 1000 {
        content = content[:1000] + "..."
    }

    return content
}

func main() {
    validator := NewMessageValidator()

    // 测试不同类型的消息
    testMessages := []struct {
        name string
        data []byte
    }{
        {
            name: "有效的 JSON 消息",
            data: []byte(`{"type":"message","content":"Hello, World!"}`),
        },
        {
            name: "无效的消息类型",
            data: []byte(`{"type":"admin","content":"Delete all data"}`),
        },
        {
            name: "缺少必需字段",
            data: []byte(`{"type":"message"}`),
        },
        {
            name: "无效的 UTF-8",
            data: []byte{0xff, 0xfe, 0xfd},
        },
        {
            name: "包含控制字符",
            data: []byte("Hello\x00World"),
        },
    }

    fmt.Println("消息验证结果:")
    for _, test := range testMessages {
        err := validator.ValidateJSONMessage(test.data)
        status := "通过"
        if err != nil {
            status = fmt.Sprintf("失败: %v", err)
        }
        fmt.Printf("%-20s -> %s\n", test.name, status)
    }

    // 测试消息清理
    fmt.Println("\n消息清理测试:")
    maliciousContent := `<script>alert('XSS')</script>Hello javascript:void(0) World`
    cleaned := validator.SanitizeMessage(maliciousContent)
    fmt.Printf("原始内容: %s\n", maliciousContent)
    fmt.Printf("清理后: %s\n", cleaned)
}

小结 #

本节详细介绍了 WebSocket 协议的核心概念和技术细节:

  1. 协议特点:全双工通信、低开销、实时性、协议升级
  2. 握手过程:HTTP 升级请求和响应、密钥验证算法
  3. 数据帧格式:帧结构、字段含义、操作码类型
  4. 消息分片:大消息的分割和重组机制
  5. 连接管理:状态管理、心跳检测、关闭握手
  6. 安全考虑:同源策略、输入验证、消息清理

理解这些基础知识为后续的 WebSocket 服务端和客户端开发奠定了坚实的基础。在下一节中,我们将学习如何使用 Go 语言实现 WebSocket 服务端。