3.8.1 WebSocket 协议详解 #
WebSocket 是一种基于 TCP 的应用层通信协议,它在 2011 年被 IETF 标准化为 RFC 6455。WebSocket 协议使得客户端和服务器之间可以建立持久的全双工通信连接,极大地提升了实时应用的性能和用户体验。
WebSocket 协议概述 #
协议特点 #
1. 全双工通信 WebSocket 连接建立后,客户端和服务端都可以主动发送数据,无需等待对方的请求。
2. 低开销 相比传统的 HTTP 轮询,WebSocket 只需要一次握手,后续通信无需重复的 HTTP 头部信息。
3. 实时性 消息可以即时传递,无需轮询等待,大大降低了通信延迟。
4. 协议升级 WebSocket 通过 HTTP 升级机制建立连接,兼容现有的网络基础设施。
与 HTTP 的比较 #
特性 | HTTP | WebSocket |
---|---|---|
通信模式 | 请求-响应 | 全双工 |
连接持久性 | 短连接 | 长连接 |
服务端推送 | 不支持 | 原生支持 |
协议开销 | 每次请求都有头部 | 握手后开销很小 |
实时性 | 需要轮询 | 即时通信 |
WebSocket 握手过程 #
客户端握手请求 #
WebSocket 连接始于一个标准的 HTTP 请求,包含特殊的升级头部:
GET /chat HTTP/1.1
Host: example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Sec-WebSocket-Protocol: chat, superchat
Origin: http://example.com
关键头部说明:
Upgrade: websocket
:请求协议升级到 WebSocketConnection: Upgrade
:表示这是一个升级请求Sec-WebSocket-Key
:客户端生成的随机字符串,用于验证Sec-WebSocket-Version
:WebSocket 协议版本,当前为 13Sec-WebSocket-Protocol
:可选的子协议列表Origin
:请求的来源,用于安全验证
服务端握手响应 #
服务端验证请求后,返回升级确认响应:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
关键头部说明:
101 Switching Protocols
:表示协议切换成功Sec-WebSocket-Accept
:服务端根据客户端的 Key 计算得出的确认值Sec-WebSocket-Protocol
:服务端选择的子协议
握手验证算法 #
服务端需要验证 Sec-WebSocket-Accept
的计算:
package main
import (
"crypto/sha1"
"encoding/base64"
"fmt"
)
const websocketMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
func generateAcceptKey(clientKey string) string {
// 将客户端密钥与魔术字符串连接
combined := clientKey + websocketMagicString
// 计算 SHA-1 哈希
hash := sha1.Sum([]byte(combined))
// Base64 编码
return base64.StdEncoding.EncodeToString(hash[:])
}
func main() {
clientKey := "dGhlIHNhbXBsZSBub25jZQ=="
acceptKey := generateAcceptKey(clientKey)
fmt.Printf("Sec-WebSocket-Accept: %s\n", acceptKey)
// 输出: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
}
WebSocket 数据帧格式 #
帧结构 #
WebSocket 数据以帧(Frame)的形式传输,每个帧包含以下结构:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
字段说明 #
1. FIN (1 bit)
- 表示这是否为消息的最后一个片段
- 1 表示最后一个片段,0 表示还有后续片段
2. RSV1, RSV2, RSV3 (各 1 bit)
- 保留位,必须为 0,除非扩展定义了非零值的含义
3. Opcode (4 bits)
- 定义帧的类型:
- 0x0:继续帧
- 0x1:文本帧
- 0x2:二进制帧
- 0x8:连接关闭
- 0x9:Ping
- 0xA:Pong
4. MASK (1 bit)
- 表示载荷数据是否被掩码
- 客户端发送的帧必须设置为 1
- 服务端发送的帧必须设置为 0
5. Payload Length (7 bits, 7+16 bits, 或 7+64 bits)
- 载荷数据的长度
- 如果值为 0-125,则为实际长度
- 如果值为 126,则后续 16 位为实际长度
- 如果值为 127,则后续 64 位为实际长度
帧类型详解 #
package main
import (
"fmt"
)
// WebSocket 操作码常量
const (
OpcodeContinuation = 0x0
OpcodeText = 0x1
OpcodeBinary = 0x2
OpcodeClose = 0x8
OpcodePing = 0x9
OpcodePong = 0xA
)
// 帧类型描述
func getOpcodeDescription(opcode byte) string {
switch opcode {
case OpcodeContinuation:
return "继续帧 - 分片消息的后续部分"
case OpcodeText:
return "文本帧 - UTF-8 编码的文本数据"
case OpcodeBinary:
return "二进制帧 - 任意二进制数据"
case OpcodeClose:
return "关闭帧 - 连接关闭请求"
case OpcodePing:
return "Ping 帧 - 心跳检测"
case OpcodePong:
return "Pong 帧 - 心跳响应"
default:
return "未知帧类型"
}
}
func main() {
opcodes := []byte{0x0, 0x1, 0x2, 0x8, 0x9, 0xA}
for _, opcode := range opcodes {
fmt.Printf("Opcode 0x%X: %s\n", opcode, getOpcodeDescription(opcode))
}
}
消息分片机制 #
分片原理 #
当消息过大时,WebSocket 支持将消息分割成多个帧传输:
package main
import (
"fmt"
"strings"
)
// 模拟消息分片
type Frame struct {
FIN bool // 是否为最后一帧
Opcode byte // 操作码
Data string // 数据
}
func fragmentMessage(message string, chunkSize int) []Frame {
var frames []Frame
for i := 0; i < len(message); i += chunkSize {
end := i + chunkSize
if end > len(message) {
end = len(message)
}
chunk := message[i:end]
frame := Frame{
Data: chunk,
}
if i == 0 {
// 第一帧使用文本操作码
frame.Opcode = OpcodeText
} else {
// 后续帧使用继续操作码
frame.Opcode = OpcodeContinuation
}
// 最后一帧设置 FIN 标志
if end == len(message) {
frame.FIN = true
}
frames = append(frames, frame)
}
return frames
}
func reassembleMessage(frames []Frame) string {
var parts []string
for _, frame := range frames {
parts = append(parts, frame.Data)
}
return strings.Join(parts, "")
}
func main() {
message := "这是一个很长的消息,需要分片传输以避免单个帧过大"
chunkSize := 10
// 分片
frames := fragmentMessage(message, chunkSize)
fmt.Printf("原始消息: %s\n", message)
fmt.Printf("分片数量: %d\n\n", len(frames))
for i, frame := range frames {
fmt.Printf("帧 %d:\n", i+1)
fmt.Printf(" FIN: %t\n", frame.FIN)
fmt.Printf(" Opcode: 0x%X (%s)\n", frame.Opcode, getOpcodeDescription(frame.Opcode))
fmt.Printf(" Data: %s\n\n", frame.Data)
}
// 重组
reassembled := reassembleMessage(frames)
fmt.Printf("重组消息: %s\n", reassembled)
fmt.Printf("消息完整性: %t\n", message == reassembled)
}
连接管理 #
连接状态 #
WebSocket 连接具有以下状态:
package main
import (
"fmt"
"time"
)
// WebSocket 连接状态
type ConnectionState int
const (
StateConnecting ConnectionState = iota
StateOpen
StateClosing
StateClosed
)
func (s ConnectionState) String() string {
switch s {
case StateConnecting:
return "CONNECTING"
case StateOpen:
return "OPEN"
case StateClosing:
return "CLOSING"
case StateClosed:
return "CLOSED"
default:
return "UNKNOWN"
}
}
// WebSocket 连接模拟
type WebSocketConnection struct {
state ConnectionState
createdAt time.Time
}
func NewWebSocketConnection() *WebSocketConnection {
return &WebSocketConnection{
state: StateConnecting,
createdAt: time.Now(),
}
}
func (conn *WebSocketConnection) GetState() ConnectionState {
return conn.state
}
func (conn *WebSocketConnection) Open() {
if conn.state == StateConnecting {
conn.state = StateOpen
fmt.Printf("连接已建立,状态: %s\n", conn.state)
}
}
func (conn *WebSocketConnection) Close() {
if conn.state == StateOpen {
conn.state = StateClosing
fmt.Printf("开始关闭连接,状态: %s\n", conn.state)
// 模拟关闭过程
time.Sleep(100 * time.Millisecond)
conn.state = StateClosed
fmt.Printf("连接已关闭,状态: %s\n", conn.state)
}
}
func (conn *WebSocketConnection) IsOpen() bool {
return conn.state == StateOpen
}
func main() {
conn := NewWebSocketConnection()
fmt.Printf("初始状态: %s\n", conn.GetState())
// 建立连接
conn.Open()
// 检查连接状态
if conn.IsOpen() {
fmt.Println("连接可用,可以发送消息")
}
// 关闭连接
conn.Close()
fmt.Printf("连接存活时间: %v\n", time.Since(conn.createdAt))
}
心跳机制 #
WebSocket 提供了 Ping/Pong 帧来实现心跳检测:
package main
import (
"fmt"
"time"
)
// 心跳管理器
type HeartbeatManager struct {
interval time.Duration
timeout time.Duration
lastPong time.Time
ticker *time.Ticker
stopChan chan struct{}
isActive bool
}
func NewHeartbeatManager(interval, timeout time.Duration) *HeartbeatManager {
return &HeartbeatManager{
interval: interval,
timeout: timeout,
lastPong: time.Now(),
stopChan: make(chan struct{}),
}
}
func (hm *HeartbeatManager) Start() {
if hm.isActive {
return
}
hm.isActive = true
hm.ticker = time.NewTicker(hm.interval)
go func() {
for {
select {
case <-hm.ticker.C:
// 检查是否超时
if time.Since(hm.lastPong) > hm.timeout {
fmt.Println("心跳超时,连接可能已断开")
hm.Stop()
return
}
// 发送 Ping
fmt.Printf("发送 Ping 帧 - %s\n", time.Now().Format("15:04:05"))
// 模拟发送 Ping 帧的逻辑
go hm.simulatePong()
case <-hm.stopChan:
return
}
}
}()
fmt.Printf("心跳管理器已启动,间隔: %v,超时: %v\n", hm.interval, hm.timeout)
}
func (hm *HeartbeatManager) Stop() {
if !hm.isActive {
return
}
hm.isActive = false
if hm.ticker != nil {
hm.ticker.Stop()
}
close(hm.stopChan)
fmt.Println("心跳管理器已停止")
}
func (hm *HeartbeatManager) OnPong() {
hm.lastPong = time.Now()
fmt.Printf("收到 Pong 帧 - %s\n", hm.lastPong.Format("15:04:05"))
}
// 模拟接收 Pong 响应
func (hm *HeartbeatManager) simulatePong() {
// 模拟网络延迟
time.Sleep(50 * time.Millisecond)
hm.OnPong()
}
func main() {
// 创建心跳管理器:每 3 秒发送一次 Ping,超时时间 10 秒
heartbeat := NewHeartbeatManager(3*time.Second, 10*time.Second)
// 启动心跳
heartbeat.Start()
// 运行 15 秒
time.Sleep(15 * time.Second)
// 停止心跳
heartbeat.Stop()
}
关闭握手 #
正常关闭流程 #
WebSocket 连接的正常关闭需要经过握手过程:
package main
import (
"fmt"
"time"
)
// 关闭状态码
const (
CloseNormalClosure = 1000
CloseGoingAway = 1001
CloseProtocolError = 1002
CloseUnsupportedData = 1003
CloseNoStatusReceived = 1005
CloseAbnormalClosure = 1006
CloseInvalidFramePayloadData = 1007
ClosePolicyViolation = 1008
CloseMessageTooBig = 1009
CloseMandatoryExtension = 1010
CloseInternalServerErr = 1011
CloseServiceRestart = 1012
CloseTryAgainLater = 1013
CloseBadGateway = 1014
CloseTLSHandshake = 1015
)
// 关闭帧
type CloseFrame struct {
Code uint16
Reason string
}
func getCloseCodeDescription(code uint16) string {
switch code {
case CloseNormalClosure:
return "正常关闭"
case CloseGoingAway:
return "端点离开"
case CloseProtocolError:
return "协议错误"
case CloseUnsupportedData:
return "不支持的数据"
case CloseNoStatusReceived:
return "未收到状态码"
case CloseAbnormalClosure:
return "异常关闭"
case CloseInvalidFramePayloadData:
return "无效的帧载荷数据"
case ClosePolicyViolation:
return "策略违规"
case CloseMessageTooBig:
return "消息过大"
case CloseMandatoryExtension:
return "强制扩展"
case CloseInternalServerErr:
return "内部服务器错误"
case CloseServiceRestart:
return "服务重启"
case CloseTryAgainLater:
return "稍后重试"
case CloseBadGateway:
return "网关错误"
case CloseTLSHandshake:
return "TLS 握手失败"
default:
return "未知错误"
}
}
// WebSocket 关闭管理器
type CloseManager struct {
closeInitiated bool
closeReceived bool
closeFrame *CloseFrame
}
func NewCloseManager() *CloseManager {
return &CloseManager{}
}
// 发起关闭
func (cm *CloseManager) InitiateClose(code uint16, reason string) {
if cm.closeInitiated {
return
}
cm.closeInitiated = true
cm.closeFrame = &CloseFrame{
Code: code,
Reason: reason,
}
fmt.Printf("发起关闭握手:\n")
fmt.Printf(" 状态码: %d (%s)\n", code, getCloseCodeDescription(code))
fmt.Printf(" 原因: %s\n", reason)
// 模拟发送关闭帧
go cm.simulateCloseResponse()
}
// 接收关闭帧
func (cm *CloseManager) ReceiveClose(code uint16, reason string) {
cm.closeReceived = true
fmt.Printf("收到关闭帧:\n")
fmt.Printf(" 状态码: %d (%s)\n", code, getCloseCodeDescription(code))
fmt.Printf(" 原因: %s\n", reason)
if !cm.closeInitiated {
// 如果我们没有发起关闭,需要回复关闭帧
fmt.Println("回复关闭帧确认")
cm.closeInitiated = true
}
fmt.Println("关闭握手完成,连接将被关闭")
}
// 模拟接收关闭响应
func (cm *CloseManager) simulateCloseResponse() {
// 模拟网络延迟
time.Sleep(100 * time.Millisecond)
// 模拟对方回复关闭帧
cm.ReceiveClose(cm.closeFrame.Code, "确认关闭")
}
func main() {
closeManager := NewCloseManager()
// 模拟正常关闭
fmt.Println("=== 正常关闭流程 ===")
closeManager.InitiateClose(CloseNormalClosure, "用户主动关闭")
time.Sleep(200 * time.Millisecond)
// 模拟异常关闭
fmt.Println("\n=== 异常关闭流程 ===")
closeManager2 := NewCloseManager()
closeManager2.InitiateClose(CloseInternalServerErr, "服务器内部错误")
time.Sleep(200 * time.Millisecond)
}
安全考虑 #
同源策略 #
WebSocket 连接受到同源策略的限制:
package main
import (
"fmt"
"net/url"
"strings"
)
// 同源检查器
type OriginChecker struct {
allowedOrigins []string
allowAll bool
}
func NewOriginChecker(allowedOrigins []string) *OriginChecker {
return &OriginChecker{
allowedOrigins: allowedOrigins,
allowAll: len(allowedOrigins) == 0,
}
}
func (oc *OriginChecker) IsAllowed(origin string) bool {
if oc.allowAll {
return true
}
if origin == "" {
return false
}
// 解析 Origin URL
originURL, err := url.Parse(origin)
if err != nil {
return false
}
// 检查是否在允许列表中
for _, allowed := range oc.allowedOrigins {
if oc.matchOrigin(originURL, allowed) {
return true
}
}
return false
}
func (oc *OriginChecker) matchOrigin(originURL *url.URL, pattern string) bool {
// 支持通配符匹配
if pattern == "*" {
return true
}
// 精确匹配
if originURL.String() == pattern {
return true
}
// 域名匹配(支持子域名)
if strings.HasPrefix(pattern, "*.") {
domain := strings.TrimPrefix(pattern, "*.")
return strings.HasSuffix(originURL.Host, domain)
}
return false
}
func main() {
// 创建同源检查器
checker := NewOriginChecker([]string{
"https://example.com",
"https://app.example.com",
"*.trusted-domain.com",
})
// 测试不同的 Origin
testOrigins := []string{
"https://example.com",
"https://app.example.com",
"https://sub.trusted-domain.com",
"https://malicious.com",
"http://example.com", // 协议不匹配
"",
}
fmt.Println("同源策略检查结果:")
for _, origin := range testOrigins {
allowed := checker.IsAllowed(origin)
status := "拒绝"
if allowed {
status = "允许"
}
fmt.Printf("Origin: %-30s -> %s\n", origin, status)
}
}
输入验证 #
WebSocket 消息需要进行严格的输入验证:
package main
import (
"encoding/json"
"fmt"
"strings"
"unicode/utf8"
)
// 消息验证器
type MessageValidator struct {
maxTextLength int
maxBinaryLength int
allowedTypes map[string]bool
}
func NewMessageValidator() *MessageValidator {
return &MessageValidator{
maxTextLength: 1024 * 1024, // 1MB
maxBinaryLength: 10 * 1024 * 1024, // 10MB
allowedTypes: map[string]bool{
"message": true,
"ping": true,
"join": true,
"leave": true,
},
}
}
// 验证文本消息
func (mv *MessageValidator) ValidateTextMessage(data []byte) error {
// 检查长度
if len(data) > mv.maxTextLength {
return fmt.Errorf("文本消息过长: %d > %d", len(data), mv.maxTextLength)
}
// 检查 UTF-8 编码
if !utf8.Valid(data) {
return fmt.Errorf("无效的 UTF-8 编码")
}
// 检查是否包含控制字符
text := string(data)
for _, r := range text {
if r < 32 && r != '\n' && r != '\r' && r != '\t' {
return fmt.Errorf("包含无效的控制字符: %U", r)
}
}
return nil
}
// 验证二进制消息
func (mv *MessageValidator) ValidateBinaryMessage(data []byte) error {
// 检查长度
if len(data) > mv.maxBinaryLength {
return fmt.Errorf("二进制消息过长: %d > %d", len(data), mv.maxBinaryLength)
}
return nil
}
// 验证 JSON 消息
func (mv *MessageValidator) ValidateJSONMessage(data []byte) error {
// 先验证文本消息
if err := mv.ValidateTextMessage(data); err != nil {
return err
}
// 解析 JSON
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err != nil {
return fmt.Errorf("无效的 JSON 格式: %v", err)
}
// 检查消息类型
msgType, ok := msg["type"].(string)
if !ok {
return fmt.Errorf("缺少消息类型字段")
}
if !mv.allowedTypes[msgType] {
return fmt.Errorf("不允许的消息类型: %s", msgType)
}
// 检查必需字段
switch msgType {
case "message":
if _, ok := msg["content"]; !ok {
return fmt.Errorf("消息类型 'message' 缺少 'content' 字段")
}
case "join", "leave":
if _, ok := msg["room"]; !ok {
return fmt.Errorf("消息类型 '%s' 缺少 'room' 字段", msgType)
}
}
return nil
}
// 清理和过滤消息内容
func (mv *MessageValidator) SanitizeMessage(content string) string {
// 移除潜在的恶意内容
content = strings.ReplaceAll(content, "<script>", "")
content = strings.ReplaceAll(content, "</script>", "")
content = strings.ReplaceAll(content, "javascript:", "")
content = strings.ReplaceAll(content, "data:", "")
// 限制长度
if len(content) > 1000 {
content = content[:1000] + "..."
}
return content
}
func main() {
validator := NewMessageValidator()
// 测试不同类型的消息
testMessages := []struct {
name string
data []byte
}{
{
name: "有效的 JSON 消息",
data: []byte(`{"type":"message","content":"Hello, World!"}`),
},
{
name: "无效的消息类型",
data: []byte(`{"type":"admin","content":"Delete all data"}`),
},
{
name: "缺少必需字段",
data: []byte(`{"type":"message"}`),
},
{
name: "无效的 UTF-8",
data: []byte{0xff, 0xfe, 0xfd},
},
{
name: "包含控制字符",
data: []byte("Hello\x00World"),
},
}
fmt.Println("消息验证结果:")
for _, test := range testMessages {
err := validator.ValidateJSONMessage(test.data)
status := "通过"
if err != nil {
status = fmt.Sprintf("失败: %v", err)
}
fmt.Printf("%-20s -> %s\n", test.name, status)
}
// 测试消息清理
fmt.Println("\n消息清理测试:")
maliciousContent := `<script>alert('XSS')</script>Hello javascript:void(0) World`
cleaned := validator.SanitizeMessage(maliciousContent)
fmt.Printf("原始内容: %s\n", maliciousContent)
fmt.Printf("清理后: %s\n", cleaned)
}
小结 #
本节详细介绍了 WebSocket 协议的核心概念和技术细节:
- 协议特点:全双工通信、低开销、实时性、协议升级
- 握手过程:HTTP 升级请求和响应、密钥验证算法
- 数据帧格式:帧结构、字段含义、操作码类型
- 消息分片:大消息的分割和重组机制
- 连接管理:状态管理、心跳检测、关闭握手
- 安全考虑:同源策略、输入验证、消息清理
理解这些基础知识为后续的 WebSocket 服务端和客户端开发奠定了坚实的基础。在下一节中,我们将学习如何使用 Go 语言实现 WebSocket 服务端。