4.2.1 TCP 编程基础 #
TCP(Transmission Control Protocol)是一种面向连接的、可靠的传输层协议。它提供了数据的有序传输、错误检测和重传机制,是构建可靠网络应用的基础。Go 语言通过 net
包提供了完善的 TCP 编程支持,本节将详细介绍如何在 Go 中进行 TCP 编程。
TCP 协议特性 #
面向连接 #
TCP 是面向连接的协议,在数据传输前需要建立连接,传输完成后需要关闭连接:
客户端 服务器
| |
|-------- SYN --------->| (建立连接请求)
|<------- SYN+ACK ------| (确认并请求连接)
|-------- ACK --------->| (确认连接建立)
| |
|<====== 数据传输 ======>|
| |
|-------- FIN --------->| (关闭连接请求)
|<------- ACK ----------| (确认关闭)
|<------- FIN ----------| (服务器关闭请求)
|-------- ACK --------->| (确认关闭)
可靠传输 #
TCP 提供可靠的数据传输保证:
- 数据完整性 - 通过校验和检测数据错误
- 数据有序性 - 通过序列号保证数据按序到达
- 流量控制 - 通过滑动窗口控制发送速率
- 拥塞控制 - 根据网络状况调整传输速率
TCP 客户端编程 #
基础 TCP 客户端 #
package main
import (
"fmt"
"net"
"time"
)
func basicTCPClient() {
// 连接到服务器
conn, err := net.Dial("tcp", "localhost:8080")
if err != nil {
fmt.Printf("连接失败: %v\n", err)
return
}
defer conn.Close()
fmt.Printf("连接成功: %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr())
// 发送数据
message := "Hello, TCP Server!"
_, err = conn.Write([]byte(message))
if err != nil {
fmt.Printf("发送数据失败: %v\n", err)
return
}
// 接收响应
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
fmt.Printf("接收数据失败: %v\n", err)
return
}
fmt.Printf("收到响应: %s\n", string(buffer[:n]))
}
带超时的 TCP 客户端 #
func timeoutTCPClient() {
// 设置连接超时
conn, err := net.DialTimeout("tcp", "localhost:8080", 5*time.Second)
if err != nil {
fmt.Printf("连接超时: %v\n", err)
return
}
defer conn.Close()
// 设置读写超时
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
// 发送数据
_, err = conn.Write([]byte("Hello with timeout"))
if err != nil {
fmt.Printf("发送失败: %v\n", err)
return
}
// 接收数据
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
fmt.Println("读取超时")
} else {
fmt.Printf("读取失败: %v\n", err)
}
return
}
fmt.Printf("收到数据: %s\n", string(buffer[:n]))
}
高级 TCP 客户端 #
import (
"bufio"
"fmt"
"io"
"net"
"time"
)
type TCPClient struct {
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
}
func NewTCPClient(address string) (*TCPClient, error) {
conn, err := net.DialTimeout("tcp", address, 10*time.Second)
if err != nil {
return nil, err
}
return &TCPClient{
conn: conn,
reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn),
}, nil
}
func (c *TCPClient) SendMessage(message string) error {
// 发送消息长度(4字节)+ 消息内容
length := len(message)
// 写入消息长度
err := c.writer.WriteByte(byte(length >> 24))
if err != nil {
return err
}
err = c.writer.WriteByte(byte(length >> 16))
if err != nil {
return err
}
err = c.writer.WriteByte(byte(length >> 8))
if err != nil {
return err
}
err = c.writer.WriteByte(byte(length))
if err != nil {
return err
}
// 写入消息内容
_, err = c.writer.WriteString(message)
if err != nil {
return err
}
// 刷新缓冲区
return c.writer.Flush()
}
func (c *TCPClient) ReceiveMessage() (string, error) {
// 读取消息长度(4字节)
lengthBytes := make([]byte, 4)
_, err := io.ReadFull(c.reader, lengthBytes)
if err != nil {
return "", err
}
length := int(lengthBytes[0])<<24 | int(lengthBytes[1])<<16 |
int(lengthBytes[2])<<8 | int(lengthBytes[3])
// 读取消息内容
messageBytes := make([]byte, length)
_, err = io.ReadFull(c.reader, messageBytes)
if err != nil {
return "", err
}
return string(messageBytes), nil
}
func (c *TCPClient) Close() error {
return c.conn.Close()
}
func demonstrateAdvancedClient() {
client, err := NewTCPClient("localhost:8080")
if err != nil {
fmt.Printf("创建客户端失败: %v\n", err)
return
}
defer client.Close()
// 发送消息
err = client.SendMessage("Hello, Advanced TCP!")
if err != nil {
fmt.Printf("发送消息失败: %v\n", err)
return
}
// 接收响应
response, err := client.ReceiveMessage()
if err != nil {
fmt.Printf("接收消息失败: %v\n", err)
return
}
fmt.Printf("收到响应: %s\n", response)
}
TCP 服务器编程 #
基础 TCP 服务器 #
func basicTCPServer() {
// 监听端口
listener, err := net.Listen("tcp", ":8080")
if err != nil {
fmt.Printf("监听失败: %v\n", err)
return
}
defer listener.Close()
fmt.Println("TCP 服务器启动,监听端口 8080")
for {
// 接受连接
conn, err := listener.Accept()
if err != nil {
fmt.Printf("接受连接失败: %v\n", err)
continue
}
// 处理连接
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
defer conn.Close()
fmt.Printf("新连接: %s\n", conn.RemoteAddr())
// 读取数据
buffer := make([]byte, 1024)
for {
n, err := conn.Read(buffer)
if err != nil {
if err != io.EOF {
fmt.Printf("读取错误: %v\n", err)
}
break
}
message := string(buffer[:n])
fmt.Printf("收到消息: %s\n", message)
// 回显消息
response := fmt.Sprintf("Echo: %s", message)
_, err = conn.Write([]byte(response))
if err != nil {
fmt.Printf("发送响应失败: %v\n", err)
break
}
}
fmt.Printf("连接关闭: %s\n", conn.RemoteAddr())
}
并发 TCP 服务器 #
import (
"context"
"fmt"
"net"
"sync"
"time"
)
type TCPServer struct {
address string
listener net.Listener
clients map[net.Conn]bool
mutex sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
func NewTCPServer(address string) *TCPServer {
ctx, cancel := context.WithCancel(context.Background())
return &TCPServer{
address: address,
clients: make(map[net.Conn]bool),
ctx: ctx,
cancel: cancel,
}
}
func (s *TCPServer) Start() error {
listener, err := net.Listen("tcp", s.address)
if err != nil {
return err
}
s.listener = listener
fmt.Printf("TCP 服务器启动: %s\n", s.address)
go s.acceptConnections()
return nil
}
func (s *TCPServer) acceptConnections() {
for {
select {
case <-s.ctx.Done():
return
default:
// 设置接受连接的超时
if tcpListener, ok := s.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := s.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // 超时,继续循环
}
fmt.Printf("接受连接失败: %v\n", err)
continue
}
s.addClient(conn)
go s.handleClient(conn)
}
}
}
func (s *TCPServer) handleClient(conn net.Conn) {
defer func() {
s.removeClient(conn)
conn.Close()
}()
fmt.Printf("客户端连接: %s\n", conn.RemoteAddr())
buffer := make([]byte, 1024)
for {
select {
case <-s.ctx.Done():
return
default:
// 设置读取超时
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // 读取超时,继续等待
}
if err != io.EOF {
fmt.Printf("读取错误: %v\n", err)
}
return
}
message := string(buffer[:n])
fmt.Printf("收到消息 [%s]: %s\n", conn.RemoteAddr(), message)
// 广播消息给所有客户端
s.broadcast(message, conn)
}
}
}
func (s *TCPServer) addClient(conn net.Conn) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.clients[conn] = true
}
func (s *TCPServer) removeClient(conn net.Conn) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.clients, conn)
}
func (s *TCPServer) broadcast(message string, sender net.Conn) {
s.mutex.RLock()
defer s.mutex.RUnlock()
broadcastMsg := fmt.Sprintf("广播 [%s]: %s", sender.RemoteAddr(), message)
for client := range s.clients {
if client != sender {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, err := client.Write([]byte(broadcastMsg))
if err != nil {
fmt.Printf("广播失败 [%s]: %v\n", client.RemoteAddr(), err)
}
}
}
}
func (s *TCPServer) Stop() error {
s.cancel()
if s.listener != nil {
return s.listener.Close()
}
return nil
}
func (s *TCPServer) GetClientCount() int {
s.mutex.RLock()
defer s.mutex.RUnlock()
return len(s.clients)
}
func demonstrateConcurrentServer() {
server := NewTCPServer(":8080")
err := server.Start()
if err != nil {
fmt.Printf("启动服务器失败: %v\n", err)
return
}
// 运行服务器
fmt.Println("服务器运行中,按 Ctrl+C 停止")
// 定期打印客户端数量
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Printf("当前客户端数量: %d\n", server.GetClientCount())
}
}
}
TCP 连接管理 #
连接状态监控 #
type ConnectionMonitor struct {
connections map[string]*ConnectionInfo
mutex sync.RWMutex
}
type ConnectionInfo struct {
Conn net.Conn
ConnectedAt time.Time
LastActive time.Time
BytesSent int64
BytesRecv int64
}
func NewConnectionMonitor() *ConnectionMonitor {
return &ConnectionMonitor{
connections: make(map[string]*ConnectionInfo),
}
}
func (cm *ConnectionMonitor) AddConnection(conn net.Conn) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
key := conn.RemoteAddr().String()
cm.connections[key] = &ConnectionInfo{
Conn: conn,
ConnectedAt: time.Now(),
LastActive: time.Now(),
}
}
func (cm *ConnectionMonitor) UpdateActivity(conn net.Conn, bytesSent, bytesRecv int64) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
key := conn.RemoteAddr().String()
if info, exists := cm.connections[key]; exists {
info.LastActive = time.Now()
info.BytesSent += bytesSent
info.BytesRecv += bytesRecv
}
}
func (cm *ConnectionMonitor) RemoveConnection(conn net.Conn) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
key := conn.RemoteAddr().String()
delete(cm.connections, key)
}
func (cm *ConnectionMonitor) GetStats() map[string]ConnectionInfo {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
stats := make(map[string]ConnectionInfo)
for key, info := range cm.connections {
stats[key] = *info
}
return stats
}
func (cm *ConnectionMonitor) CleanupIdleConnections(timeout time.Duration) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
now := time.Now()
for key, info := range cm.connections {
if now.Sub(info.LastActive) > timeout {
fmt.Printf("关闭空闲连接: %s\n", key)
info.Conn.Close()
delete(cm.connections, key)
}
}
}
连接池实现 #
type TCPConnectionPool struct {
address string
maxSize int
connections chan net.Conn
factory func() (net.Conn, error)
mutex sync.Mutex
closed bool
}
func NewTCPConnectionPool(address string, maxSize int) *TCPConnectionPool {
pool := &TCPConnectionPool{
address: address,
maxSize: maxSize,
connections: make(chan net.Conn, maxSize),
}
pool.factory = func() (net.Conn, error) {
return net.DialTimeout("tcp", address, 10*time.Second)
}
return pool
}
func (p *TCPConnectionPool) Get() (net.Conn, error) {
p.mutex.Lock()
if p.closed {
p.mutex.Unlock()
return nil, fmt.Errorf("连接池已关闭")
}
p.mutex.Unlock()
select {
case conn := <-p.connections:
// 检查连接是否仍然有效
if p.isConnValid(conn) {
return conn, nil
}
// 连接无效,创建新连接
return p.factory()
default:
// 池中没有可用连接,创建新连接
return p.factory()
}
}
func (p *TCPConnectionPool) Put(conn net.Conn) {
if conn == nil {
return
}
p.mutex.Lock()
if p.closed {
p.mutex.Unlock()
conn.Close()
return
}
p.mutex.Unlock()
select {
case p.connections <- conn:
// 成功放回池中
default:
// 池已满,关闭连接
conn.Close()
}
}
func (p *TCPConnectionPool) isConnValid(conn net.Conn) bool {
// 设置很短的超时来测试连接
conn.SetReadDeadline(time.Now().Add(1 * time.Millisecond))
defer conn.SetReadDeadline(time.Time{})
buffer := make([]byte, 1)
_, err := conn.Read(buffer)
// 如果是超时错误,说明连接正常但没有数据
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return true
}
// 其他错误说明连接有问题
return false
}
func (p *TCPConnectionPool) Close() {
p.mutex.Lock()
defer p.mutex.Unlock()
if p.closed {
return
}
p.closed = true
close(p.connections)
// 关闭池中所有连接
for conn := range p.connections {
conn.Close()
}
}
func demonstrateConnectionPool() {
pool := NewTCPConnectionPool("localhost:8080", 10)
defer pool.Close()
// 使用连接池
for i := 0; i < 5; i++ {
go func(id int) {
conn, err := pool.Get()
if err != nil {
fmt.Printf("获取连接失败 [%d]: %v\n", id, err)
return
}
// 使用连接
message := fmt.Sprintf("Hello from client %d", id)
conn.Write([]byte(message))
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
fmt.Printf("读取失败 [%d]: %v\n", id, err)
conn.Close()
return
}
fmt.Printf("客户端 %d 收到: %s\n", id, string(buffer[:n]))
// 归还连接到池中
pool.Put(conn)
}(i)
}
time.Sleep(5 * time.Second)
}
错误处理和重连机制 #
自动重连客户端 #
type ReconnectClient struct {
address string
conn net.Conn
maxRetries int
retryInterval time.Duration
connected bool
mutex sync.RWMutex
stopChan chan struct{}
reconnectChan chan struct{}
}
func NewReconnectClient(address string) *ReconnectClient {
return &ReconnectClient{
address: address,
maxRetries: 5,
retryInterval: 2 * time.Second,
stopChan: make(chan struct{}),
reconnectChan: make(chan struct{}, 1),
}
}
func (rc *ReconnectClient) Connect() error {
conn, err := net.DialTimeout("tcp", rc.address, 10*time.Second)
if err != nil {
return err
}
rc.mutex.Lock()
rc.conn = conn
rc.connected = true
rc.mutex.Unlock()
fmt.Printf("连接成功: %s\n", rc.address)
// 启动重连监控
go rc.monitorConnection()
return nil
}
func (rc *ReconnectClient) monitorConnection() {
for {
select {
case <-rc.stopChan:
return
case <-rc.reconnectChan:
rc.attemptReconnect()
}
}
}
func (rc *ReconnectClient) attemptReconnect() {
rc.mutex.Lock()
rc.connected = false
if rc.conn != nil {
rc.conn.Close()
rc.conn = nil
}
rc.mutex.Unlock()
fmt.Println("开始重连...")
for attempt := 1; attempt <= rc.maxRetries; attempt++ {
select {
case <-rc.stopChan:
return
default:
fmt.Printf("重连尝试 %d/%d\n", attempt, rc.maxRetries)
conn, err := net.DialTimeout("tcp", rc.address, 10*time.Second)
if err != nil {
fmt.Printf("重连失败: %v\n", err)
if attempt < rc.maxRetries {
time.Sleep(rc.retryInterval)
}
continue
}
rc.mutex.Lock()
rc.conn = conn
rc.connected = true
rc.mutex.Unlock()
fmt.Println("重连成功")
return
}
}
fmt.Println("重连失败,已达到最大重试次数")
}
func (rc *ReconnectClient) Send(data []byte) error {
rc.mutex.RLock()
conn := rc.conn
connected := rc.connected
rc.mutex.RUnlock()
if !connected || conn == nil {
return fmt.Errorf("连接未建立")
}
_, err := conn.Write(data)
if err != nil {
// 触发重连
select {
case rc.reconnectChan <- struct{}{}:
default:
}
return err
}
return nil
}
func (rc *ReconnectClient) Receive() ([]byte, error) {
rc.mutex.RLock()
conn := rc.conn
connected := rc.connected
rc.mutex.RUnlock()
if !connected || conn == nil {
return nil, fmt.Errorf("连接未建立")
}
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
// 触发重连
select {
case rc.reconnectChan <- struct{}{}:
default:
}
return nil, err
}
return buffer[:n], nil
}
func (rc *ReconnectClient) Close() {
close(rc.stopChan)
rc.mutex.Lock()
defer rc.mutex.Unlock()
if rc.conn != nil {
rc.conn.Close()
rc.conn = nil
}
rc.connected = false
}
func demonstrateReconnectClient() {
client := NewReconnectClient("localhost:8080")
defer client.Close()
err := client.Connect()
if err != nil {
fmt.Printf("初始连接失败: %v\n", err)
return
}
// 发送消息
for i := 0; i < 10; i++ {
message := fmt.Sprintf("Message %d", i)
err := client.Send([]byte(message))
if err != nil {
fmt.Printf("发送失败: %v\n", err)
} else {
fmt.Printf("发送成功: %s\n", message)
}
time.Sleep(2 * time.Second)
}
}
性能优化 #
零拷贝优化 #
import (
"net"
"os"
"syscall"
)
func sendFileZeroCopy(conn net.Conn, filename string) error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
// 获取文件信息
stat, err := file.Stat()
if err != nil {
return err
}
// 使用 sendfile 系统调用进行零拷贝传输
if tcpConn, ok := conn.(*net.TCPConn); ok {
rawConn, err := tcpConn.SyscallConn()
if err != nil {
return err
}
var sendErr error
err = rawConn.Control(func(fd uintptr) {
_, sendErr = syscall.Sendfile(int(fd), int(file.Fd()), nil, int(stat.Size()))
})
if err != nil {
return err
}
return sendErr
}
return fmt.Errorf("不支持零拷贝传输")
}
批量处理优化 #
type BatchProcessor struct {
conn net.Conn
buffer []byte
batchSize int
flushTimer *time.Timer
}
func NewBatchProcessor(conn net.Conn, batchSize int) *BatchProcessor {
return &BatchProcessor{
conn: conn,
buffer: make([]byte, 0, batchSize),
batchSize: batchSize,
}
}
func (bp *BatchProcessor) Write(data []byte) error {
bp.buffer = append(bp.buffer, data...)
if len(bp.buffer) >= bp.batchSize {
return bp.flush()
}
// 设置定时刷新
if bp.flushTimer != nil {
bp.flushTimer.Stop()
}
bp.flushTimer = time.AfterFunc(100*time.Millisecond, func() {
bp.flush()
})
return nil
}
func (bp *BatchProcessor) flush() error {
if len(bp.buffer) == 0 {
return nil
}
_, err := bp.conn.Write(bp.buffer)
bp.buffer = bp.buffer[:0] // 重置缓冲区
if bp.flushTimer != nil {
bp.flushTimer.Stop()
bp.flushTimer = nil
}
return err
}
func (bp *BatchProcessor) Close() error {
if bp.flushTimer != nil {
bp.flushTimer.Stop()
}
return bp.flush()
}
小结 #
本节详细介绍了 Go 语言中的 TCP 编程基础,包括:
- TCP 协议特性 - 面向连接、可靠传输的特点
- TCP 客户端 - 基础客户端、超时控制、高级客户端实现
- TCP 服务器 - 基础服务器、并发服务器、连接管理
- 连接管理 - 连接监控、连接池、状态管理
- 错误处理 - 重连机制、错误恢复策略
- 性能优化 - 零拷贝、批量处理等优化技术
掌握这些 TCP 编程技术后,你就能够构建可靠、高性能的网络应用程序。在下一节中,我们将学习 UDP 编程的相关知识。