4.3.1 TCP 并发服务器 #
高并发 TCP 服务器是现代网络应用的核心组件。本节将深入探讨如何设计和实现能够处理大量并发连接的 TCP 服务器,包括连接管理、负载均衡、性能优化等关键技术。
并发模型设计 #
Goroutine-per-Connection 模型 #
这是 Go 语言中最常见的并发模型,为每个连接创建一个 goroutine:
package main
import (
"bufio"
"fmt"
"io"
"net"
"sync"
"time"
)
type ConcurrentServer struct {
listener net.Listener
connections map[net.Conn]bool
mutex sync.RWMutex
wg sync.WaitGroup
shutdown chan struct{}
}
func NewConcurrentServer(address string) (*ConcurrentServer, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
return &ConcurrentServer{
listener: listener,
connections: make(map[net.Conn]bool),
shutdown: make(chan struct{}),
}, nil
}
func (s *ConcurrentServer) Start() error {
fmt.Printf("TCP 并发服务器启动在 %s\n", s.listener.Addr())
for {
select {
case <-s.shutdown:
return nil
default:
// 设置接受连接的超时
if tcpListener, ok := s.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := s.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // 超时,继续循环
}
return err
}
s.addConnection(conn)
s.wg.Add(1)
go s.handleConnection(conn)
}
}
}
func (s *ConcurrentServer) handleConnection(conn net.Conn) {
defer func() {
s.removeConnection(conn)
conn.Close()
s.wg.Done()
}()
fmt.Printf("新连接: %s\n", conn.RemoteAddr())
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
for {
select {
case <-s.shutdown:
return
default:
// 设置读取超时
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
message, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Printf("客户端断开连接: %s\n", conn.RemoteAddr())
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
fmt.Printf("连接超时: %s\n", conn.RemoteAddr())
} else {
fmt.Printf("读取错误: %v\n", err)
}
return
}
// 处理消息
response := s.processMessage(message, conn)
// 发送响应
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
_, err = writer.WriteString(response)
if err != nil {
fmt.Printf("写入错误: %v\n", err)
return
}
writer.Flush()
}
}
}
func (s *ConcurrentServer) processMessage(message string, conn net.Conn) string {
// 简单的回显处理
return fmt.Sprintf("Echo: %s", message)
}
func (s *ConcurrentServer) addConnection(conn net.Conn) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.connections[conn] = true
}
func (s *ConcurrentServer) removeConnection(conn net.Conn) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.connections, conn)
}
func (s *ConcurrentServer) GetConnectionCount() int {
s.mutex.RLock()
defer s.mutex.RUnlock()
return len(s.connections)
}
func (s *ConcurrentServer) Shutdown() error {
close(s.shutdown)
// 关闭监听器
err := s.listener.Close()
// 关闭所有连接
s.mutex.Lock()
for conn := range s.connections {
conn.Close()
}
s.mutex.Unlock()
// 等待所有 goroutine 结束
s.wg.Wait()
return err
}
Worker Pool 模型 #
对于需要限制 goroutine 数量的场景,可以使用 worker pool 模型:
type WorkerPoolServer struct {
listener net.Listener
workerCount int
jobQueue chan net.Conn
workers []*Worker
shutdown chan struct{}
wg sync.WaitGroup
}
type Worker struct {
id int
jobQueue chan net.Conn
quit chan struct{}
}
func NewWorkerPoolServer(address string, workerCount int) (*WorkerPoolServer, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
return &WorkerPoolServer{
listener: listener,
workerCount: workerCount,
jobQueue: make(chan net.Conn, workerCount*2),
shutdown: make(chan struct{}),
}, nil
}
func (s *WorkerPoolServer) Start() error {
// 启动 worker
s.workers = make([]*Worker, s.workerCount)
for i := 0; i < s.workerCount; i++ {
worker := &Worker{
id: i,
jobQueue: s.jobQueue,
quit: make(chan struct{}),
}
s.workers[i] = worker
s.wg.Add(1)
go worker.start(&s.wg)
}
fmt.Printf("Worker Pool 服务器启动,%d 个 worker\n", s.workerCount)
// 接受连接
for {
select {
case <-s.shutdown:
return nil
default:
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.shutdown:
return nil
default:
fmt.Printf("接受连接失败: %v\n", err)
continue
}
}
// 将连接分发给 worker
select {
case s.jobQueue <- conn:
case <-s.shutdown:
conn.Close()
return nil
default:
// 队列满,拒绝连接
fmt.Printf("服务器繁忙,拒绝连接: %s\n", conn.RemoteAddr())
conn.Close()
}
}
}
}
func (w *Worker) start(wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case conn := <-w.jobQueue:
w.handleConnection(conn)
case <-w.quit:
return
}
}
}
func (w *Worker) handleConnection(conn net.Conn) {
defer conn.Close()
fmt.Printf("Worker %d 处理连接: %s\n", w.id, conn.RemoteAddr())
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
for {
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
message, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
fmt.Printf("Worker %d 读取错误: %v\n", w.id, err)
}
break
}
// 处理消息
response := fmt.Sprintf("Worker %d Echo: %s", w.id, message)
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
writer.WriteString(response)
writer.Flush()
}
fmt.Printf("Worker %d 完成处理: %s\n", w.id, conn.RemoteAddr())
}
func (s *WorkerPoolServer) Shutdown() error {
close(s.shutdown)
// 停止所有 worker
for _, worker := range s.workers {
close(worker.quit)
}
// 关闭监听器
err := s.listener.Close()
// 等待所有 worker 结束
s.wg.Wait()
return err
}
连接管理 #
连接池管理 #
import (
"context"
"sync"
"time"
)
type ConnectionManager struct {
connections map[string]*ManagedConnection
mutex sync.RWMutex
maxIdle time.Duration
cleanupTicker *time.Ticker
ctx context.Context
cancel context.CancelFunc
}
type ManagedConnection struct {
Conn net.Conn
LastActive time.Time
BytesSent int64
BytesRecv int64
State ConnectionState
mutex sync.RWMutex
}
type ConnectionState int
const (
StateActive ConnectionState = iota
StateIdle
StateClosing
)
func NewConnectionManager(maxIdle time.Duration) *ConnectionManager {
ctx, cancel := context.WithCancel(context.Background())
cm := &ConnectionManager{
connections: make(map[string]*ManagedConnection),
maxIdle: maxIdle,
ctx: ctx,
cancel: cancel,
}
// 启动清理协程
cm.cleanupTicker = time.NewTicker(maxIdle / 2)
go cm.cleanupLoop()
return cm
}
func (cm *ConnectionManager) AddConnection(conn net.Conn) *ManagedConnection {
cm.mutex.Lock()
defer cm.mutex.Unlock()
managed := &ManagedConnection{
Conn: conn,
LastActive: time.Now(),
State: StateActive,
}
cm.connections[conn.RemoteAddr().String()] = managed
return managed
}
func (cm *ConnectionManager) RemoveConnection(conn net.Conn) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
addr := conn.RemoteAddr().String()
if managed, exists := cm.connections[addr]; exists {
managed.mutex.Lock()
managed.State = StateClosing
managed.mutex.Unlock()
delete(cm.connections, addr)
}
}
func (cm *ConnectionManager) UpdateActivity(conn net.Conn, bytesSent, bytesRecv int64) {
cm.mutex.RLock()
managed, exists := cm.connections[conn.RemoteAddr().String()]
cm.mutex.RUnlock()
if exists {
managed.mutex.Lock()
managed.LastActive = time.Now()
managed.BytesSent += bytesSent
managed.BytesRecv += bytesRecv
managed.mutex.Unlock()
}
}
func (cm *ConnectionManager) cleanupLoop() {
for {
select {
case <-cm.ctx.Done():
return
case <-cm.cleanupTicker.C:
cm.cleanupIdleConnections()
}
}
}
func (cm *ConnectionManager) cleanupIdleConnections() {
cm.mutex.Lock()
defer cm.mutex.Unlock()
now := time.Now()
var toClose []string
for addr, managed := range cm.connections {
managed.mutex.RLock()
if now.Sub(managed.LastActive) > cm.maxIdle && managed.State == StateActive {
managed.State = StateIdle
toClose = append(toClose, addr)
}
managed.mutex.RUnlock()
}
for _, addr := range toClose {
if managed, exists := cm.connections[addr]; exists {
fmt.Printf("关闭空闲连接: %s\n", addr)
managed.Conn.Close()
delete(cm.connections, addr)
}
}
}
func (cm *ConnectionManager) GetStats() map[string]interface{} {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
stats := map[string]interface{}{
"total_connections": len(cm.connections),
"active_connections": 0,
"idle_connections": 0,
"total_bytes_sent": int64(0),
"total_bytes_recv": int64(0),
}
for _, managed := range cm.connections {
managed.mutex.RLock()
if managed.State == StateActive {
stats["active_connections"] = stats["active_connections"].(int) + 1
} else if managed.State == StateIdle {
stats["idle_connections"] = stats["idle_connections"].(int) + 1
}
stats["total_bytes_sent"] = stats["total_bytes_sent"].(int64) + managed.BytesSent
stats["total_bytes_recv"] = stats["total_bytes_recv"].(int64) + managed.BytesRecv
managed.mutex.RUnlock()
}
return stats
}
func (cm *ConnectionManager) Shutdown() {
cm.cancel()
cm.cleanupTicker.Stop()
cm.mutex.Lock()
defer cm.mutex.Unlock()
for _, managed := range cm.connections {
managed.Conn.Close()
}
cm.connections = make(map[string]*ManagedConnection)
}
负载均衡 #
type LoadBalancer struct {
servers []*ServerInfo
current int
mutex sync.RWMutex
algorithm LoadBalanceAlgorithm
}
type ServerInfo struct {
Address string
Weight int
Connections int
Active bool
mutex sync.RWMutex
}
type LoadBalanceAlgorithm int
const (
RoundRobin LoadBalanceAlgorithm = iota
WeightedRoundRobin
LeastConnections
)
func NewLoadBalancer(algorithm LoadBalanceAlgorithm) *LoadBalancer {
return &LoadBalancer{
servers: make([]*ServerInfo, 0),
algorithm: algorithm,
}
}
func (lb *LoadBalancer) AddServer(address string, weight int) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
server := &ServerInfo{
Address: address,
Weight: weight,
Active: true,
}
lb.servers = append(lb.servers, server)
}
func (lb *LoadBalancer) GetNextServer() *ServerInfo {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
if len(lb.servers) == 0 {
return nil
}
switch lb.algorithm {
case RoundRobin:
return lb.roundRobin()
case WeightedRoundRobin:
return lb.weightedRoundRobin()
case LeastConnections:
return lb.leastConnections()
default:
return lb.roundRobin()
}
}
func (lb *LoadBalancer) roundRobin() *ServerInfo {
activeServers := lb.getActiveServers()
if len(activeServers) == 0 {
return nil
}
server := activeServers[lb.current%len(activeServers)]
lb.current++
return server
}
func (lb *LoadBalancer) weightedRoundRobin() *ServerInfo {
activeServers := lb.getActiveServers()
if len(activeServers) == 0 {
return nil
}
totalWeight := 0
for _, server := range activeServers {
totalWeight += server.Weight
}
if totalWeight == 0 {
return lb.roundRobin()
}
target := lb.current % totalWeight
lb.current++
currentWeight := 0
for _, server := range activeServers {
currentWeight += server.Weight
if target < currentWeight {
return server
}
}
return activeServers[0]
}
func (lb *LoadBalancer) leastConnections() *ServerInfo {
activeServers := lb.getActiveServers()
if len(activeServers) == 0 {
return nil
}
var selected *ServerInfo
minConnections := int(^uint(0) >> 1) // 最大整数
for _, server := range activeServers {
server.mutex.RLock()
connections := server.Connections
server.mutex.RUnlock()
if connections < minConnections {
minConnections = connections
selected = server
}
}
return selected
}
func (lb *LoadBalancer) getActiveServers() []*ServerInfo {
var active []*ServerInfo
for _, server := range lb.servers {
if server.Active {
active = append(active, server)
}
}
return active
}
func (lb *LoadBalancer) UpdateServerConnections(address string, delta int) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
for _, server := range lb.servers {
if server.Address == address {
server.mutex.Lock()
server.Connections += delta
if server.Connections < 0 {
server.Connections = 0
}
server.mutex.Unlock()
break
}
}
}
func (lb *LoadBalancer) SetServerActive(address string, active bool) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
for _, server := range lb.servers {
if server.Address == address {
server.Active = active
break
}
}
}
性能优化 #
零拷贝优化 #
import (
"os"
"syscall"
)
type ZeroCopyServer struct {
listener net.Listener
}
func (s *ZeroCopyServer) sendFile(conn net.Conn, filename string) error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
// 获取文件信息
stat, err := file.Stat()
if err != nil {
return err
}
// 使用 sendfile 系统调用
if tcpConn, ok := conn.(*net.TCPConn); ok {
rawConn, err := tcpConn.SyscallConn()
if err != nil {
return err
}
var sendErr error
err = rawConn.Control(func(fd uintptr) {
_, sendErr = syscall.Sendfile(int(fd), int(file.Fd()), nil, int(stat.Size()))
})
if err != nil {
return err
}
return sendErr
}
return fmt.Errorf("不支持零拷贝")
}
内存池优化 #
type MemoryPool struct {
small sync.Pool // 小缓冲区 (1KB)
medium sync.Pool // 中等缓冲区 (4KB)
large sync.Pool // 大缓冲区 (16KB)
}
func NewMemoryPool() *MemoryPool {
return &MemoryPool{
small: sync.Pool{
New: func() interface{} {
return make([]byte, 1024)
},
},
medium: sync.Pool{
New: func() interface{} {
return make([]byte, 4096)
},
},
large: sync.Pool{
New: func() interface{} {
return make([]byte, 16384)
},
},
}
}
func (mp *MemoryPool) GetBuffer(size int) []byte {
switch {
case size <= 1024:
return mp.small.Get().([]byte)[:size]
case size <= 4096:
return mp.medium.Get().([]byte)[:size]
case size <= 16384:
return mp.large.Get().([]byte)[:size]
default:
return make([]byte, size)
}
}
func (mp *MemoryPool) PutBuffer(buf []byte) {
capacity := cap(buf)
switch {
case capacity == 1024:
mp.small.Put(buf[:1024])
case capacity == 4096:
mp.medium.Put(buf[:4096])
case capacity == 16384:
mp.large.Put(buf[:16384])
}
// 其他大小的缓冲区直接丢弃,让 GC 回收
}
批量处理优化 #
type BatchProcessor struct {
conn net.Conn
buffer [][]byte
batchSize int
flushTimer *time.Timer
mutex sync.Mutex
}
func NewBatchProcessor(conn net.Conn, batchSize int) *BatchProcessor {
return &BatchProcessor{
conn: conn,
buffer: make([][]byte, 0, batchSize),
batchSize: batchSize,
}
}
func (bp *BatchProcessor) Write(data []byte) error {
bp.mutex.Lock()
defer bp.mutex.Unlock()
bp.buffer = append(bp.buffer, data)
if len(bp.buffer) >= bp.batchSize {
return bp.flush()
}
// 设置定时刷新
if bp.flushTimer != nil {
bp.flushTimer.Stop()
}
bp.flushTimer = time.AfterFunc(10*time.Millisecond, func() {
bp.mutex.Lock()
bp.flush()
bp.mutex.Unlock()
})
return nil
}
func (bp *BatchProcessor) flush() error {
if len(bp.buffer) == 0 {
return nil
}
// 计算总大小
totalSize := 0
for _, data := range bp.buffer {
totalSize += len(data)
}
// 合并数据
combined := make([]byte, 0, totalSize)
for _, data := range bp.buffer {
combined = append(combined, data...)
}
// 发送数据
_, err := bp.conn.Write(combined)
// 清空缓冲区
bp.buffer = bp.buffer[:0]
if bp.flushTimer != nil {
bp.flushTimer.Stop()
bp.flushTimer = nil
}
return err
}
func (bp *BatchProcessor) Close() error {
bp.mutex.Lock()
defer bp.mutex.Unlock()
if bp.flushTimer != nil {
bp.flushTimer.Stop()
}
return bp.flush()
}
高可用性设计 #
健康检查 #
type HealthChecker struct {
servers map[string]*ServerHealth
interval time.Duration
timeout time.Duration
mutex sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
type ServerHealth struct {
Address string
Healthy bool
LastCheck time.Time
FailCount int
SuccessCount int
}
func NewHealthChecker(interval, timeout time.Duration) *HealthChecker {
ctx, cancel := context.WithCancel(context.Background())
hc := &HealthChecker{
servers: make(map[string]*ServerHealth),
interval: interval,
timeout: timeout,
ctx: ctx,
cancel: cancel,
}
go hc.checkLoop()
return hc
}
func (hc *HealthChecker) AddServer(address string) {
hc.mutex.Lock()
defer hc.mutex.Unlock()
hc.servers[address] = &ServerHealth{
Address: address,
Healthy: true,
}
}
func (hc *HealthChecker) checkLoop() {
ticker := time.NewTicker(hc.interval)
defer ticker.Stop()
for {
select {
case <-hc.ctx.Done():
return
case <-ticker.C:
hc.checkAllServers()
}
}
}
func (hc *HealthChecker) checkAllServers() {
hc.mutex.RLock()
servers := make([]*ServerHealth, 0, len(hc.servers))
for _, server := range hc.servers {
servers = append(servers, server)
}
hc.mutex.RUnlock()
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(s *ServerHealth) {
defer wg.Done()
hc.checkServer(s)
}(server)
}
wg.Wait()
}
func (hc *HealthChecker) checkServer(server *ServerHealth) {
ctx, cancel := context.WithTimeout(hc.ctx, hc.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", server.Address)
hc.mutex.Lock()
defer hc.mutex.Unlock()
server.LastCheck = time.Now()
if err != nil {
server.FailCount++
if server.FailCount >= 3 && server.Healthy {
server.Healthy = false
fmt.Printf("服务器 %s 标记为不健康\n", server.Address)
}
} else {
conn.Close()
server.SuccessCount++
if !server.Healthy && server.SuccessCount >= 2 {
server.Healthy = true
server.FailCount = 0
fmt.Printf("服务器 %s 恢复健康\n", server.Address)
}
}
}
func (hc *HealthChecker) IsHealthy(address string) bool {
hc.mutex.RLock()
defer hc.mutex.RUnlock()
if server, exists := hc.servers[address]; exists {
return server.Healthy
}
return false
}
func (hc *HealthChecker) Shutdown() {
hc.cancel()
}
故障转移 #
type FailoverManager struct {
primary string
secondary []string
current string
checker *HealthChecker
mutex sync.RWMutex
}
func NewFailoverManager(primary string, secondary []string, checker *HealthChecker) *FailoverManager {
return &FailoverManager{
primary: primary,
secondary: secondary,
current: primary,
checker: checker,
}
}
func (fm *FailoverManager) GetCurrentServer() string {
fm.mutex.RLock()
defer fm.mutex.RUnlock()
// 检查当前服务器是否健康
if fm.checker.IsHealthy(fm.current) {
return fm.current
}
// 当前服务器不健康,尝试故障转移
fm.mutex.RUnlock()
fm.mutex.Lock()
defer func() {
fm.mutex.Unlock()
fm.mutex.RLock()
}()
// 双重检查
if fm.checker.IsHealthy(fm.current) {
return fm.current
}
// 尝试主服务器
if fm.current != fm.primary && fm.checker.IsHealthy(fm.primary) {
fmt.Printf("故障转移到主服务器: %s\n", fm.primary)
fm.current = fm.primary
return fm.current
}
// 尝试备用服务器
for _, server := range fm.secondary {
if fm.checker.IsHealthy(server) {
fmt.Printf("故障转移到备用服务器: %s\n", server)
fm.current = server
return fm.current
}
}
// 所有服务器都不健康,返回主服务器
fmt.Printf("所有服务器都不健康,使用主服务器: %s\n", fm.primary)
fm.current = fm.primary
return fm.current
}
完整示例:高性能聊天服务器 #
package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"sync"
"time"
)
type ChatServer struct {
listener net.Listener
clients map[string]*Client
rooms map[string]*Room
connManager *ConnectionManager
memPool *MemoryPool
mutex sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
type Client struct {
ID string
Conn net.Conn
Username string
Room string
Writer *bufio.Writer
LastSeen time.Time
}
type Room struct {
Name string
Clients map[string]*Client
mutex sync.RWMutex
}
type Message struct {
Type string `json:"type"`
From string `json:"from"`
To string `json:"to,omitempty"`
Room string `json:"room,omitempty"`
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
}
func NewChatServer(address string) (*ChatServer, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &ChatServer{
listener: listener,
clients: make(map[string]*Client),
rooms: make(map[string]*Room),
connManager: NewConnectionManager(5 * time.Minute),
memPool: NewMemoryPool(),
ctx: ctx,
cancel: cancel,
}, nil
}
func (s *ChatServer) Start() error {
fmt.Printf("聊天服务器启动在 %s\n", s.listener.Addr())
for {
select {
case <-s.ctx.Done():
return nil
default:
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return nil
default:
fmt.Printf("接受连接失败: %v\n", err)
continue
}
}
go s.handleClient(conn)
}
}
}
func (s *ChatServer) handleClient(conn net.Conn) {
defer conn.Close()
// 添加到连接管理器
managed := s.connManager.AddConnection(conn)
defer s.connManager.RemoveConnection(conn)
client := &Client{
ID: conn.RemoteAddr().String(),
Conn: conn,
Writer: bufio.NewWriter(conn),
LastSeen: time.Now(),
}
reader := bufio.NewReader(conn)
for {
select {
case <-s.ctx.Done():
return
default:
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
line, err := reader.ReadString('\n')
if err != nil {
s.removeClient(client)
return
}
// 更新活动时间
client.LastSeen = time.Now()
s.connManager.UpdateActivity(conn, 0, int64(len(line)))
var msg Message
if err := json.Unmarshal([]byte(line), &msg); err != nil {
s.sendError(client, "无效的消息格式")
continue
}
s.handleMessage(client, &msg)
}
}
}
func (s *ChatServer) handleMessage(client *Client, msg *Message) {
msg.From = client.Username
msg.Timestamp = time.Now()
switch msg.Type {
case "join":
s.handleJoin(client, msg)
case "leave":
s.handleLeave(client, msg)
case "chat":
s.handleChat(client, msg)
case "private":
s.handlePrivateMessage(client, msg)
default:
s.sendError(client, "未知的消息类型")
}
}
func (s *ChatServer) handleJoin(client *Client, msg *Message) {
s.mutex.Lock()
defer s.mutex.Unlock()
client.Username = msg.Content
client.Room = msg.Room
// 添加到客户端列表
s.clients[client.ID] = client
// 添加到房间
room, exists := s.rooms[msg.Room]
if !exists {
room = &Room{
Name: msg.Room,
Clients: make(map[string]*Client),
}
s.rooms[msg.Room] = room
}
room.mutex.Lock()
room.Clients[client.ID] = client
room.mutex.Unlock()
// 通知房间内其他用户
notification := &Message{
Type: "notification",
Content: fmt.Sprintf("%s 加入了房间", client.Username),
Room: msg.Room,
Timestamp: time.Now(),
}
s.broadcastToRoom(msg.Room, notification, client.ID)
}
func (s *ChatServer) handleChat(client *Client, msg *Message) {
if client.Room == "" {
s.sendError(client, "请先加入房间")
return
}
msg.Room = client.Room
s.broadcastToRoom(client.Room, msg, "")
}
func (s *ChatServer) broadcastToRoom(roomName string, msg *Message, excludeID string) {
s.mutex.RLock()
room, exists := s.rooms[roomName]
s.mutex.RUnlock()
if !exists {
return
}
room.mutex.RLock()
defer room.mutex.RUnlock()
data, _ := json.Marshal(msg)
data = append(data, '\n')
for clientID, client := range room.Clients {
if clientID != excludeID {
s.sendToClient(client, data)
}
}
}
func (s *ChatServer) sendToClient(client *Client, data []byte) {
client.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
client.Writer.Write(data)
client.Writer.Flush()
s.connManager.UpdateActivity(client.Conn, int64(len(data)), 0)
}
func (s *ChatServer) sendError(client *Client, message string) {
errorMsg := &Message{
Type: "error",
Content: message,
Timestamp: time.Now(),
}
data, _ := json.Marshal(errorMsg)
data = append(data, '\n')
s.sendToClient(client, data)
}
func (s *ChatServer) removeClient(client *Client) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.clients, client.ID)
if client.Room != "" {
if room, exists := s.rooms[client.Room]; exists {
room.mutex.Lock()
delete(room.Clients, client.ID)
room.mutex.Unlock()
// 通知房间内其他用户
if client.Username != "" {
notification := &Message{
Type: "notification",
Content: fmt.Sprintf("%s 离开了房间", client.Username),
Room: client.Room,
Timestamp: time.Now(),
}
s.broadcastToRoom(client.Room, notification, client.ID)
}
}
}
}
func (s *ChatServer) GetStats() map[string]interface{} {
s.mutex.RLock()
defer s.mutex.RUnlock()
stats := map[string]interface{}{
"total_clients": len(s.clients),
"total_rooms": len(s.rooms),
"rooms": make(map[string]int),
}
for roomName, room := range s.rooms {
room.mutex.RLock()
stats["rooms"].(map[string]int)[roomName] = len(room.Clients)
room.mutex.RUnlock()
}
// 合并连接管理器统计
connStats := s.connManager.GetStats()
for k, v := range connStats {
stats[k] = v
}
return stats
}
func (s *ChatServer) Shutdown() error {
s.cancel()
s.connManager.Shutdown()
return s.listener.Close()
}
func main() {
server, err := NewChatServer(":8080")
if err != nil {
fmt.Printf("创建服务器失败: %v\n", err)
return
}
// 启动统计协程
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
stats := server.GetStats()
fmt.Printf("服务器统计: %+v\n", stats)
}
}()
fmt.Println("高性能聊天服务器启动...")
if err := server.Start(); err != nil {
fmt.Printf("服务器运行失败: %v\n", err)
}
}
小结 #
本节详细介绍了 TCP 并发服务器的设计与实现,包括:
- 并发模型 - Goroutine-per-Connection 和 Worker Pool 模型
- 连接管理 - 连接池管理和生命周期控制
- 负载均衡 - 多种负载均衡算法的实现
- 性能优化 - 零拷贝、内存池、批量处理等优化技术
- 高可用性 - 健康检查和故障转移机制
- 完整示例 - 高性能聊天服务器的实现
掌握这些技术后,你就能够构建高性能、高可用的 TCP 并发服务器。在下一节中,我们将学习 UDP 广播与组播技术。