2.8.2 高并发服务器设计

2.8.2 高并发服务器设计 #

高并发服务器是现代互联网应用的核心组件,需要能够同时处理大量客户端连接和请求。本节将深入探讨如何使用 Go 语言设计和实现一个高性能的并发服务器,涵盖连接管理、负载均衡、性能优化等关键技术。

服务器架构设计 #

整体架构 #

一个高并发服务器通常采用以下架构:

┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│   客户端    │───▶│  负载均衡   │───▶│  服务器集群  │
└─────────────┘    └─────────────┘    └─────────────┘
                          │                   │
                          ▼                   ▼
                   ┌─────────────┐    ┌─────────────┐
                   │  连接池管理  │    │  请求处理器  │
                   └─────────────┘    └─────────────┘
                          │                   │
                          ▼                   ▼
                   ┌─────────────┐    ┌─────────────┐
                   │  监控统计   │    │  数据存储   │
                   └─────────────┘    └─────────────┘

核心组件 #

  1. 连接管理器:管理客户端连接的生命周期
  2. 请求路由器:根据请求类型分发到相应处理器
  3. 工作池:管理工作 Goroutine,控制并发数量
  4. 负载均衡器:在多个后端服务间分配请求
  5. 监控系统:收集性能指标和健康状态

基础服务器框架 #

首先定义服务器的基础结构:

package server

import (
    "context"
    "net"
    "sync"
    "sync/atomic"
    "time"
)

// Server 高并发服务器
type Server struct {
    addr            string
    listener        net.Listener
    connManager     *ConnectionManager
    router          *Router
    workerPool      *WorkerPool
    middleware      []Middleware
    shutdownTimeout time.Duration

    // 统计信息
    stats           *ServerStats

    // 控制
    ctx             context.Context
    cancel          context.CancelFunc
    wg              sync.WaitGroup
    running         int32
}

// ServerStats 服务器统计信息
type ServerStats struct {
    StartTime       time.Time
    ActiveConns     int64
    TotalConns      int64
    RequestsTotal   int64
    RequestsSuccess int64
    RequestsFailed  int64
    BytesRead       int64
    BytesWritten    int64
}

// Config 服务器配置
type Config struct {
    Addr            string
    MaxConnections  int
    WorkerPoolSize  int
    ReadTimeout     time.Duration
    WriteTimeout    time.Duration
    IdleTimeout     time.Duration
    ShutdownTimeout time.Duration
}

// NewServer 创建新服务器
func NewServer(config *Config) *Server {
    ctx, cancel := context.WithCancel(context.Background())

    return &Server{
        addr:            config.Addr,
        connManager:     NewConnectionManager(config.MaxConnections),
        router:          NewRouter(),
        workerPool:      NewWorkerPool(config.WorkerPoolSize),
        shutdownTimeout: config.ShutdownTimeout,
        stats:           &ServerStats{StartTime: time.Now()},
        ctx:             ctx,
        cancel:          cancel,
    }
}

连接管理器实现 #

连接管理器负责管理所有客户端连接:

// Connection 连接封装
type Connection struct {
    conn        net.Conn
    id          uint64
    server      *Server
    lastActive  time.Time
    closed      int32
    ctx         context.Context
    cancel      context.CancelFunc
}

// ConnectionManager 连接管理器
type ConnectionManager struct {
    mu          sync.RWMutex
    connections map[uint64]*Connection
    maxConns    int
    nextID      uint64
    closed      bool
}

// NewConnectionManager 创建连接管理器
func NewConnectionManager(maxConns int) *ConnectionManager {
    return &ConnectionManager{
        connections: make(map[uint64]*Connection),
        maxConns:    maxConns,
    }
}

// AddConnection 添加连接
func (cm *ConnectionManager) AddConnection(conn net.Conn, server *Server) (*Connection, error) {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    if cm.closed {
        return nil, ErrServerClosed
    }

    if len(cm.connections) >= cm.maxConns {
        return nil, ErrTooManyConnections
    }

    id := atomic.AddUint64(&cm.nextID, 1)
    ctx, cancel := context.WithCancel(server.ctx)

    connection := &Connection{
        conn:       conn,
        id:         id,
        server:     server,
        lastActive: time.Now(),
        ctx:        ctx,
        cancel:     cancel,
    }

    cm.connections[id] = connection
    atomic.AddInt64(&server.stats.ActiveConns, 1)
    atomic.AddInt64(&server.stats.TotalConns, 1)

    return connection, nil
}

// RemoveConnection 移除连接
func (cm *ConnectionManager) RemoveConnection(id uint64) {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    if conn, exists := cm.connections[id]; exists {
        delete(cm.connections, id)
        conn.Close()
        atomic.AddInt64(&conn.server.stats.ActiveConns, -1)
    }
}

// GetConnection 获取连接
func (cm *ConnectionManager) GetConnection(id uint64) (*Connection, bool) {
    cm.mu.RLock()
    defer cm.mu.RUnlock()

    conn, exists := cm.connections[id]
    return conn, exists
}

// CloseAll 关闭所有连接
func (cm *ConnectionManager) CloseAll() {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    cm.closed = true
    for _, conn := range cm.connections {
        conn.Close()
    }
    cm.connections = make(map[uint64]*Connection)
}

// CleanupIdleConnections 清理空闲连接
func (cm *ConnectionManager) CleanupIdleConnections(idleTimeout time.Duration) {
    cm.mu.Lock()
    defer cm.mu.Unlock()

    now := time.Now()
    for id, conn := range cm.connections {
        if now.Sub(conn.lastActive) > idleTimeout {
            delete(cm.connections, id)
            conn.Close()
            atomic.AddInt64(&conn.server.stats.ActiveConns, -1)
        }
    }
}

// Connection 方法

// Close 关闭连接
func (c *Connection) Close() error {
    if atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
        c.cancel()
        return c.conn.Close()
    }
    return nil
}

// IsClosed 检查连接是否已关闭
func (c *Connection) IsClosed() bool {
    return atomic.LoadInt32(&c.closed) == 1
}

// UpdateActivity 更新活动时间
func (c *Connection) UpdateActivity() {
    c.lastActive = time.Now()
}

// Read 读取数据
func (c *Connection) Read(b []byte) (int, error) {
    if c.IsClosed() {
        return 0, ErrConnectionClosed
    }

    n, err := c.conn.Read(b)
    if n > 0 {
        atomic.AddInt64(&c.server.stats.BytesRead, int64(n))
        c.UpdateActivity()
    }
    return n, err
}

// Write 写入数据
func (c *Connection) Write(b []byte) (int, error) {
    if c.IsClosed() {
        return 0, ErrConnectionClosed
    }

    n, err := c.conn.Write(b)
    if n > 0 {
        atomic.AddInt64(&c.server.stats.BytesWritten, int64(n))
        c.UpdateActivity()
    }
    return n, err
}

工作池实现 #

工作池管理处理请求的 Goroutine:

// WorkerPool 工作池
type WorkerPool struct {
    workers    int
    jobQueue   chan Job
    workerPool chan chan Job
    quit       chan bool
    wg         sync.WaitGroup
}

// Job 工作任务
type Job struct {
    ID       uint64
    Handler  func() error
    Callback func(error)
    Timeout  time.Duration
}

// NewWorkerPool 创建工作池
func NewWorkerPool(workers int) *WorkerPool {
    return &WorkerPool{
        workers:    workers,
        jobQueue:   make(chan Job, workers*2),
        workerPool: make(chan chan Job, workers),
        quit:       make(chan bool),
    }
}

// Start 启动工作池
func (wp *WorkerPool) Start() {
    for i := 0; i < wp.workers; i++ {
        worker := NewWorker(i, wp.workerPool, wp.quit)
        worker.Start()
        wp.wg.Add(1)
    }

    go wp.dispatch()
}

// dispatch 分发任务
func (wp *WorkerPool) dispatch() {
    for {
        select {
        case job := <-wp.jobQueue:
            // 获取可用的工作协程
            jobChannel := <-wp.workerPool
            // 分发任务
            jobChannel <- job
        case <-wp.quit:
            return
        }
    }
}

// Submit 提交任务
func (wp *WorkerPool) Submit(job Job) error {
    select {
    case wp.jobQueue <- job:
        return nil
    default:
        return ErrWorkerPoolFull
    }
}

// Stop 停止工作池
func (wp *WorkerPool) Stop() {
    close(wp.quit)
    wp.wg.Wait()
}

// Worker 工作协程
type Worker struct {
    id         int
    jobChannel chan Job
    workerPool chan chan Job
    quit       chan bool
}

// NewWorker 创建工作协程
func NewWorker(id int, workerPool chan chan Job, quit chan bool) *Worker {
    return &Worker{
        id:         id,
        jobChannel: make(chan Job),
        workerPool: workerPool,
        quit:       quit,
    }
}

// Start 启动工作协程
func (w *Worker) Start() {
    go func() {
        defer func() {
            if r := recover(); r != nil {
                log.Printf("Worker %d panic: %v", w.id, r)
            }
        }()

        for {
            // 将工作协程注册到池中
            w.workerPool <- w.jobChannel

            select {
            case job := <-w.jobChannel:
                // 执行任务
                w.executeJob(job)
            case <-w.quit:
                return
            }
        }
    }()
}

// executeJob 执行任务
func (w *Worker) executeJob(job Job) {
    var err error

    if job.Timeout > 0 {
        // 带超时的任务执行
        done := make(chan error, 1)
        go func() {
            done <- job.Handler()
        }()

        select {
        case err = <-done:
        case <-time.After(job.Timeout):
            err = ErrJobTimeout
        }
    } else {
        // 普通任务执行
        err = job.Handler()
    }

    // 执行回调
    if job.Callback != nil {
        job.Callback(err)
    }
}

请求路由器实现 #

路由器负责将请求分发到相应的处理器:

// Router 请求路由器
type Router struct {
    mu       sync.RWMutex
    routes   map[string]Handler
    notFound Handler
}

// Handler 请求处理器
type Handler func(*Context) error

// Context 请求上下文
type Context struct {
    Conn     *Connection
    Request  *Request
    Response *Response
    Data     map[string]interface{}
    mu       sync.RWMutex
}

// Request 请求结构
type Request struct {
    Method  string
    Path    string
    Headers map[string]string
    Body    []byte
    Params  map[string]string
}

// Response 响应结构
type Response struct {
    StatusCode int
    Headers    map[string]string
    Body       []byte
}

// NewRouter 创建路由器
func NewRouter() *Router {
    return &Router{
        routes: make(map[string]Handler),
        notFound: func(ctx *Context) error {
            ctx.Response.StatusCode = 404
            ctx.Response.Body = []byte("Not Found")
            return nil
        },
    }
}

// Handle 注册路由处理器
func (r *Router) Handle(pattern string, handler Handler) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.routes[pattern] = handler
}

// HandleNotFound 设置404处理器
func (r *Router) HandleNotFound(handler Handler) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.notFound = handler
}

// Route 路由请求
func (r *Router) Route(ctx *Context) error {
    r.mu.RLock()
    handler, exists := r.routes[ctx.Request.Path]
    if !exists {
        handler = r.notFound
    }
    r.mu.RUnlock()

    return handler(ctx)
}

// Context 方法

// Set 设置上下文数据
func (ctx *Context) Set(key string, value interface{}) {
    ctx.mu.Lock()
    defer ctx.mu.Unlock()
    if ctx.Data == nil {
        ctx.Data = make(map[string]interface{})
    }
    ctx.Data[key] = value
}

// Get 获取上下文数据
func (ctx *Context) Get(key string) (interface{}, bool) {
    ctx.mu.RLock()
    defer ctx.mu.RUnlock()
    if ctx.Data == nil {
        return nil, false
    }
    value, exists := ctx.Data[key]
    return value, exists
}

// WriteResponse 写入响应
func (ctx *Context) WriteResponse() error {
    // 构建响应数据
    response := fmt.Sprintf("HTTP/1.1 %d OK\r\n", ctx.Response.StatusCode)

    // 添加响应头
    for key, value := range ctx.Response.Headers {
        response += fmt.Sprintf("%s: %s\r\n", key, value)
    }

    response += fmt.Sprintf("Content-Length: %d\r\n", len(ctx.Response.Body))
    response += "\r\n"

    // 写入响应头
    if _, err := ctx.Conn.Write([]byte(response)); err != nil {
        return err
    }

    // 写入响应体
    if len(ctx.Response.Body) > 0 {
        if _, err := ctx.Conn.Write(ctx.Response.Body); err != nil {
            return err
        }
    }

    return nil
}

中间件系统 #

中间件系统提供请求预处理和后处理功能:

// Middleware 中间件接口
type Middleware func(Handler) Handler

// Use 添加中间件
func (s *Server) Use(middleware Middleware) {
    s.middleware = append(s.middleware, middleware)
}

// applyMiddleware 应用中间件
func (s *Server) applyMiddleware(handler Handler) Handler {
    // 从后往前应用中间件
    for i := len(s.middleware) - 1; i >= 0; i-- {
        handler = s.middleware[i](handler)
    }
    return handler
}

// 常用中间件实现

// LoggingMiddleware 日志中间件
func LoggingMiddleware(next Handler) Handler {
    return func(ctx *Context) error {
        start := time.Now()

        err := next(ctx)

        duration := time.Since(start)
        log.Printf("[%s] %s %s - %d - %v",
            ctx.Conn.conn.RemoteAddr(),
            ctx.Request.Method,
            ctx.Request.Path,
            ctx.Response.StatusCode,
            duration,
        )

        return err
    }
}

// RecoveryMiddleware 恢复中间件
func RecoveryMiddleware(next Handler) Handler {
    return func(ctx *Context) error {
        defer func() {
            if r := recover(); r != nil {
                log.Printf("Panic recovered: %v", r)
                ctx.Response.StatusCode = 500
                ctx.Response.Body = []byte("Internal Server Error")
            }
        }()

        return next(ctx)
    }
}

// RateLimitMiddleware 限流中间件
func RateLimitMiddleware(rate int, burst int) Middleware {
    limiter := rate.NewLimiter(rate.Limit(rate), burst)

    return func(next Handler) Handler {
        return func(ctx *Context) error {
            if !limiter.Allow() {
                ctx.Response.StatusCode = 429
                ctx.Response.Body = []byte("Too Many Requests")
                return nil
            }
            return next(ctx)
        }
    }
}

// CORSMiddleware CORS中间件
func CORSMiddleware(origins []string) Middleware {
    return func(next Handler) Handler {
        return func(ctx *Context) error {
            origin := ctx.Request.Headers["Origin"]

            // 检查是否允许的源
            allowed := false
            for _, allowedOrigin := range origins {
                if origin == allowedOrigin || allowedOrigin == "*" {
                    allowed = true
                    break
                }
            }

            if allowed {
                if ctx.Response.Headers == nil {
                    ctx.Response.Headers = make(map[string]string)
                }
                ctx.Response.Headers["Access-Control-Allow-Origin"] = origin
                ctx.Response.Headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
                ctx.Response.Headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
            }

            return next(ctx)
        }
    }
}

负载均衡器实现 #

负载均衡器在多个后端服务间分配请求:

// LoadBalancer 负载均衡器
type LoadBalancer struct {
    mu        sync.RWMutex
    backends  []*Backend
    algorithm BalanceAlgorithm
    health    *HealthChecker
}

// Backend 后端服务
type Backend struct {
    ID       string
    Address  string
    Weight   int
    Active   bool
    Requests int64
    Failures int64
    LastSeen time.Time
}

// BalanceAlgorithm 负载均衡算法
type BalanceAlgorithm interface {
    Select(backends []*Backend) *Backend
}

// RoundRobinAlgorithm 轮询算法
type RoundRobinAlgorithm struct {
    current int64
}

func (rr *RoundRobinAlgorithm) Select(backends []*Backend) *Backend {
    if len(backends) == 0 {
        return nil
    }

    activeBackends := make([]*Backend, 0)
    for _, backend := range backends {
        if backend.Active {
            activeBackends = append(activeBackends, backend)
        }
    }

    if len(activeBackends) == 0 {
        return nil
    }

    index := atomic.AddInt64(&rr.current, 1) % int64(len(activeBackends))
    return activeBackends[index]
}

// WeightedRoundRobinAlgorithm 加权轮询算法
type WeightedRoundRobinAlgorithm struct {
    mu      sync.Mutex
    weights map[string]int
}

func (wrr *WeightedRoundRobinAlgorithm) Select(backends []*Backend) *Backend {
    wrr.mu.Lock()
    defer wrr.mu.Unlock()

    if wrr.weights == nil {
        wrr.weights = make(map[string]int)
    }

    var selected *Backend
    maxWeight := 0

    for _, backend := range backends {
        if !backend.Active {
            continue
        }

        wrr.weights[backend.ID] += backend.Weight
        if wrr.weights[backend.ID] > maxWeight {
            maxWeight = wrr.weights[backend.ID]
            selected = backend
        }
    }

    if selected != nil {
        wrr.weights[selected.ID] -= wrr.getTotalWeight(backends)
    }

    return selected
}

func (wrr *WeightedRoundRobinAlgorithm) getTotalWeight(backends []*Backend) int {
    total := 0
    for _, backend := range backends {
        if backend.Active {
            total += backend.Weight
        }
    }
    return total
}

// LeastConnectionsAlgorithm 最少连接算法
type LeastConnectionsAlgorithm struct{}

func (lc *LeastConnectionsAlgorithm) Select(backends []*Backend) *Backend {
    var selected *Backend
    minRequests := int64(-1)

    for _, backend := range backends {
        if !backend.Active {
            continue
        }

        requests := atomic.LoadInt64(&backend.Requests)
        if minRequests == -1 || requests < minRequests {
            minRequests = requests
            selected = backend
        }
    }

    return selected
}

// HealthChecker 健康检查器
type HealthChecker struct {
    interval time.Duration
    timeout  time.Duration
    path     string
    quit     chan bool
}

// NewHealthChecker 创建健康检查器
func NewHealthChecker(interval, timeout time.Duration, path string) *HealthChecker {
    return &HealthChecker{
        interval: interval,
        timeout:  timeout,
        path:     path,
        quit:     make(chan bool),
    }
}

// Start 启动健康检查
func (hc *HealthChecker) Start(lb *LoadBalancer) {
    ticker := time.NewTicker(hc.interval)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            hc.checkBackends(lb)
        case <-hc.quit:
            return
        }
    }
}

// checkBackends 检查后端服务健康状态
func (hc *HealthChecker) checkBackends(lb *LoadBalancer) {
    lb.mu.RLock()
    backends := make([]*Backend, len(lb.backends))
    copy(backends, lb.backends)
    lb.mu.RUnlock()

    var wg sync.WaitGroup
    for _, backend := range backends {
        wg.Add(1)
        go func(b *Backend) {
            defer wg.Done()
            hc.checkBackend(b)
        }(backend)
    }
    wg.Wait()
}

// checkBackend 检查单个后端服务
func (hc *HealthChecker) checkBackend(backend *Backend) {
    client := &http.Client{Timeout: hc.timeout}

    url := fmt.Sprintf("http://%s%s", backend.Address, hc.path)
    resp, err := client.Get(url)

    if err != nil || resp.StatusCode >= 400 {
        backend.Active = false
        atomic.AddInt64(&backend.Failures, 1)
    } else {
        backend.Active = true
        backend.LastSeen = time.Now()
    }

    if resp != nil {
        resp.Body.Close()
    }
}

// Stop 停止健康检查
func (hc *HealthChecker) Stop() {
    close(hc.quit)
}

服务器主循环实现 #

// Start 启动服务器
func (s *Server) Start() error {
    if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
        return ErrServerAlreadyRunning
    }

    // 监听端口
    listener, err := net.Listen("tcp", s.addr)
    if err != nil {
        atomic.StoreInt32(&s.running, 0)
        return err
    }
    s.listener = listener

    // 启动工作池
    s.workerPool.Start()

    // 启动连接清理协程
    s.wg.Add(1)
    go s.cleanupRoutine()

    // 启动统计协程
    s.wg.Add(1)
    go s.statsRoutine()

    log.Printf("Server started on %s", s.addr)

    // 主循环
    s.wg.Add(1)
    go s.acceptLoop()

    return nil
}

// acceptLoop 接受连接循环
func (s *Server) acceptLoop() {
    defer s.wg.Done()

    for {
        conn, err := s.listener.Accept()
        if err != nil {
            select {
            case <-s.ctx.Done():
                return
            default:
                log.Printf("Accept error: %v", err)
                continue
            }
        }

        // 添加连接到管理器
        connection, err := s.connManager.AddConnection(conn, s)
        if err != nil {
            conn.Close()
            continue
        }

        // 处理连接
        s.wg.Add(1)
        go s.handleConnection(connection)
    }
}

// handleConnection 处理连接
func (s *Server) handleConnection(conn *Connection) {
    defer s.wg.Done()
    defer s.connManager.RemoveConnection(conn.id)

    for {
        select {
        case <-conn.ctx.Done():
            return
        default:
            // 读取请求
            request, err := s.readRequest(conn)
            if err != nil {
                if err != io.EOF {
                    log.Printf("Read request error: %v", err)
                }
                return
            }

            // 创建上下文
            ctx := &Context{
                Conn:     conn,
                Request:  request,
                Response: &Response{StatusCode: 200, Headers: make(map[string]string)},
            }

            // 提交到工作池处理
            job := Job{
                ID: conn.id,
                Handler: func() error {
                    return s.processRequest(ctx)
                },
                Callback: func(err error) {
                    if err != nil {
                        atomic.AddInt64(&s.stats.RequestsFailed, 1)
                        log.Printf("Process request error: %v", err)
                    } else {
                        atomic.AddInt64(&s.stats.RequestsSuccess, 1)
                    }
                },
                Timeout: 30 * time.Second,
            }

            if err := s.workerPool.Submit(job); err != nil {
                log.Printf("Submit job error: %v", err)
                return
            }

            atomic.AddInt64(&s.stats.RequestsTotal, 1)
        }
    }
}

// processRequest 处理请求
func (s *Server) processRequest(ctx *Context) error {
    // 应用中间件
    handler := s.applyMiddleware(s.router.Route)

    // 执行处理器
    if err := handler(ctx); err != nil {
        return err
    }

    // 写入响应
    return ctx.WriteResponse()
}

// readRequest 读取请求
func (s *Server) readRequest(conn *Connection) (*Request, error) {
    // 简化的HTTP请求解析
    buffer := make([]byte, 4096)
    n, err := conn.Read(buffer)
    if err != nil {
        return nil, err
    }

    // 解析请求行和头部
    lines := strings.Split(string(buffer[:n]), "\r\n")
    if len(lines) < 1 {
        return nil, ErrInvalidRequest
    }

    // 解析请求行
    parts := strings.Split(lines[0], " ")
    if len(parts) < 3 {
        return nil, ErrInvalidRequest
    }

    request := &Request{
        Method:  parts[0],
        Path:    parts[1],
        Headers: make(map[string]string),
        Params:  make(map[string]string),
    }

    // 解析头部
    for i := 1; i < len(lines); i++ {
        if lines[i] == "" {
            break
        }

        headerParts := strings.SplitN(lines[i], ": ", 2)
        if len(headerParts) == 2 {
            request.Headers[headerParts[0]] = headerParts[1]
        }
    }

    return request, nil
}

// cleanupRoutine 清理协程
func (s *Server) cleanupRoutine() {
    defer s.wg.Done()

    ticker := time.NewTicker(30 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            s.connManager.CleanupIdleConnections(5 * time.Minute)
        case <-s.ctx.Done():
            return
        }
    }
}

// statsRoutine 统计协程
func (s *Server) statsRoutine() {
    defer s.wg.Done()

    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ticker.C:
            s.printStats()
        case <-s.ctx.Done():
            return
        }
    }
}

// printStats 打印统计信息
func (s *Server) printStats() {
    stats := s.GetStats()
    log.Printf("Stats - Active: %d, Total: %d, Requests: %d (Success: %d, Failed: %d)",
        stats.ActiveConns,
        stats.TotalConns,
        stats.RequestsTotal,
        stats.RequestsSuccess,
        stats.RequestsFailed,
    )
}

// GetStats 获取统计信息
func (s *Server) GetStats() ServerStats {
    return ServerStats{
        StartTime:       s.stats.StartTime,
        ActiveConns:     atomic.LoadInt64(&s.stats.ActiveConns),
        TotalConns:      atomic.LoadInt64(&s.stats.TotalConns),
        RequestsTotal:   atomic.LoadInt64(&s.stats.RequestsTotal),
        RequestsSuccess: atomic.LoadInt64(&s.stats.RequestsSuccess),
        RequestsFailed:  atomic.LoadInt64(&s.stats.RequestsFailed),
        BytesRead:       atomic.LoadInt64(&s.stats.BytesRead),
        BytesWritten:    atomic.LoadInt64(&s.stats.BytesWritten),
    }
}

// Shutdown 优雅关闭服务器
func (s *Server) Shutdown() error {
    if !atomic.CompareAndSwapInt32(&s.running, 1, 0) {
        return ErrServerNotRunning
    }

    log.Println("Shutting down server...")

    // 停止接受新连接
    if s.listener != nil {
        s.listener.Close()
    }

    // 取消上下文
    s.cancel()

    // 等待所有协程完成或超时
    done := make(chan struct{})
    go func() {
        s.wg.Wait()
        close(done)
    }()

    select {
    case <-done:
        log.Println("Server shutdown completed")
    case <-time.After(s.shutdownTimeout):
        log.Println("Server shutdown timeout")
    }

    // 关闭所有连接
    s.connManager.CloseAll()

    // 停止工作池
    s.workerPool.Stop()

    return nil
}

使用示例 #

下面是一个完整的使用示例:

package main

import (
    "encoding/json"
    "log"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    // 创建服务器配置
    config := &Config{
        Addr:            ":8080",
        MaxConnections:  1000,
        WorkerPoolSize:  100,
        ReadTimeout:     30 * time.Second,
        WriteTimeout:    30 * time.Second,
        IdleTimeout:     5 * time.Minute,
        ShutdownTimeout: 30 * time.Second,
    }

    // 创建服务器
    server := NewServer(config)

    // 添加中间件
    server.Use(LoggingMiddleware)
    server.Use(RecoveryMiddleware)
    server.Use(RateLimitMiddleware(100, 10))
    server.Use(CORSMiddleware([]string{"*"}))

    // 注册路由
    server.router.Handle("/", homeHandler)
    server.router.Handle("/api/status", statusHandler)
    server.router.Handle("/api/stats", statsHandler(server))

    // 启动服务器
    if err := server.Start(); err != nil {
        log.Fatal(err)
    }

    // 等待信号
    sigChan := make(chan os.Signal, 1)
    signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
    <-sigChan

    // 优雅关闭
    if err := server.Shutdown(); err != nil {
        log.Printf("Shutdown error: %v", err)
    }
}

// 处理器示例

func homeHandler(ctx *Context) error {
    ctx.Response.Headers["Content-Type"] = "text/html"
    ctx.Response.Body = []byte("<h1>Welcome to High Concurrency Server</h1>")
    return nil
}

func statusHandler(ctx *Context) error {
    status := map[string]interface{}{
        "status":    "ok",
        "timestamp": time.Now().Unix(),
    }

    data, _ := json.Marshal(status)
    ctx.Response.Headers["Content-Type"] = "application/json"
    ctx.Response.Body = data
    return nil
}

func statsHandler(server *Server) Handler {
    return func(ctx *Context) error {
        stats := server.GetStats()
        data, _ := json.Marshal(stats)
        ctx.Response.Headers["Content-Type"] = "application/json"
        ctx.Response.Body = data
        return nil
    }
}

性能优化技巧 #

1. 内存池优化 #

// 使用内存池减少GC压力
var (
    requestPool = sync.Pool{
        New: func() interface{} {
            return &Request{
                Headers: make(map[string]string),
                Params:  make(map[string]string),
            }
        },
    }

    responsePool = sync.Pool{
        New: func() interface{} {
            return &Response{
                Headers: make(map[string]string),
            }
        },
    }
)

func getRequest() *Request {
    return requestPool.Get().(*Request)
}

func putRequest(req *Request) {
    // 清理数据
    req.Method = ""
    req.Path = ""
    req.Body = nil
    for k := range req.Headers {
        delete(req.Headers, k)
    }
    for k := range req.Params {
        delete(req.Params, k)
    }
    requestPool.Put(req)
}

2. 零拷贝优化 #

// 使用零拷贝技术提升性能
func (conn *Connection) SendFile(filename string) error {
    file, err := os.Open(filename)
    if err != nil {
        return err
    }
    defer file.Close()

    // 使用sendfile系统调用(Linux)
    if tcpConn, ok := conn.conn.(*net.TCPConn); ok {
        if f, ok := file.(*os.File); ok {
            _, err = tcpConn.ReadFrom(f)
            return err
        }
    }

    // 回退到普通拷贝
    _, err = io.Copy(conn.conn, file)
    return err
}

3. 连接复用 #

// HTTP/1.1 Keep-Alive支持
func (s *Server) handleHTTPConnection(conn *Connection) {
    defer conn.Close()

    for {
        // 设置读取超时
        conn.conn.SetReadDeadline(time.Now().Add(s.readTimeout))

        request, err := s.readRequest(conn)
        if err != nil {
            return
        }

        // 处理请求
        ctx := &Context{Conn: conn, Request: request}
        s.processRequest(ctx)

        // 检查是否保持连接
        if !s.shouldKeepAlive(request) {
            return
        }
    }
}

func (s *Server) shouldKeepAlive(req *Request) bool {
    connection := req.Headers["Connection"]
    return strings.ToLower(connection) == "keep-alive"
}

小结 #

本节我们构建了一个功能完整的高并发服务器,涵盖了以下关键技术:

  1. 连接管理:高效管理大量并发连接
  2. 工作池:控制并发处理数量,避免资源耗尽
  3. 请求路由:灵活的请求分发机制
  4. 中间件系统:可扩展的请求处理管道
  5. 负载均衡:多种负载均衡算法实现
  6. 性能优化:内存池、零拷贝等优化技术

这个高并发服务器展示了 Go 语言在构建高性能网络服务方面的优势,为实际项目开发提供了坚实的基础。