4.2.3 HTTP 服务器开发

4.2.3 HTTP 服务器开发 #

HTTP(HyperText Transfer Protocol)是现代 Web 应用的基础协议。Go 语言的 net/http 包提供了强大而灵活的 HTTP 服务器实现,使得构建高性能的 Web 服务变得简单高效。本节将深入介绍如何使用 Go 语言开发 HTTP 服务器。

HTTP 协议基础 #

HTTP 请求响应模型 #

HTTP 是基于请求-响应模型的协议:

客户端                    服务器
   |                        |
   |------ HTTP 请求 ------>|
   |                        |
   |<----- HTTP 响应 -------|
   |                        |

HTTP 消息结构 #

HTTP 请求结构:

GET /api/users HTTP/1.1
Host: example.com
User-Agent: Go-http-client/1.1
Accept: application/json

[请求体]

HTTP 响应结构:

HTTP/1.1 200 OK
Content-Type: application/json
Content-Length: 123

[响应体]

基础 HTTP 服务器 #

最简单的 HTTP 服务器 #

package main

import (
    "fmt"
    "net/http"
)

func main() {
    // 注册处理函数
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "Hello, World!")
    })

    // 启动服务器
    fmt.Println("服务器启动在 :8080")
    err := http.ListenAndServe(":8080", nil)
    if err != nil {
        fmt.Printf("服务器启动失败: %v\n", err)
    }
}

处理不同的 HTTP 方法 #

func methodHandler(w http.ResponseWriter, r *http.Request) {
    switch r.Method {
    case http.MethodGet:
        fmt.Fprintf(w, "GET 请求")
    case http.MethodPost:
        fmt.Fprintf(w, "POST 请求")
    case http.MethodPut:
        fmt.Fprintf(w, "PUT 请求")
    case http.MethodDelete:
        fmt.Fprintf(w, "DELETE 请求")
    default:
        http.Error(w, "不支持的方法", http.StatusMethodNotAllowed)
    }
}

func setupBasicServer() {
    http.HandleFunc("/api", methodHandler)

    // 静态文件服务
    http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/"))))

    fmt.Println("服务器启动在 :8080")
    http.ListenAndServe(":8080", nil)
}

读取请求数据 #

import (
    "encoding/json"
    "io"
    "net/http"
    "net/url"
)

type User struct {
    Name  string `json:"name"`
    Email string `json:"email"`
    Age   int    `json:"age"`
}

func handleUserData(w http.ResponseWriter, r *http.Request) {
    switch r.Method {
    case http.MethodGet:
        // 读取查询参数
        queryParams := r.URL.Query()
        name := queryParams.Get("name")
        age := queryParams.Get("age")

        fmt.Fprintf(w, "查询参数 - 姓名: %s, 年龄: %s", name, age)

    case http.MethodPost:
        // 读取表单数据
        if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
            err := r.ParseForm()
            if err != nil {
                http.Error(w, "解析表单失败", http.StatusBadRequest)
                return
            }

            name := r.FormValue("name")
            email := r.FormValue("email")

            fmt.Fprintf(w, "表单数据 - 姓名: %s, 邮箱: %s", name, email)

        } else if r.Header.Get("Content-Type") == "application/json" {
            // 读取 JSON 数据
            body, err := io.ReadAll(r.Body)
            if err != nil {
                http.Error(w, "读取请求体失败", http.StatusBadRequest)
                return
            }
            defer r.Body.Close()

            var user User
            err = json.Unmarshal(body, &user)
            if err != nil {
                http.Error(w, "解析 JSON 失败", http.StatusBadRequest)
                return
            }

            // 返回 JSON 响应
            w.Header().Set("Content-Type", "application/json")
            response := map[string]interface{}{
                "message": "用户创建成功",
                "user":    user,
            }
            json.NewEncoder(w).Encode(response)
        }
    }
}

高级 HTTP 服务器 #

自定义 HTTP 服务器 #

import (
    "context"
    "fmt"
    "net/http"
    "time"
)

type HTTPServer struct {
    server *http.Server
    mux    *http.ServeMux
}

func NewHTTPServer(addr string) *HTTPServer {
    mux := http.NewServeMux()

    server := &http.Server{
        Addr:         addr,
        Handler:      mux,
        ReadTimeout:  15 * time.Second,
        WriteTimeout: 15 * time.Second,
        IdleTimeout:  60 * time.Second,
    }

    return &HTTPServer{
        server: server,
        mux:    mux,
    }
}

func (s *HTTPServer) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
    s.mux.HandleFunc(pattern, handler)
}

func (s *HTTPServer) Handle(pattern string, handler http.Handler) {
    s.mux.Handle(pattern, handler)
}

func (s *HTTPServer) Start() error {
    fmt.Printf("HTTP 服务器启动在 %s\n", s.server.Addr)
    return s.server.ListenAndServe()
}

func (s *HTTPServer) Shutdown(ctx context.Context) error {
    fmt.Println("正在关闭 HTTP 服务器...")
    return s.server.Shutdown(ctx)
}

func demonstrateCustomServer() {
    server := NewHTTPServer(":8080")

    // 注册路由
    server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "自定义 HTTP 服务器")
    })

    server.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
        w.Header().Set("Content-Type", "application/json")
        fmt.Fprintf(w, `{"status": "healthy", "timestamp": "%s"}`, time.Now().Format(time.RFC3339))
    })

    // 启动服务器
    err := server.Start()
    if err != nil && err != http.ErrServerClosed {
        fmt.Printf("服务器启动失败: %v\n", err)
    }
}

中间件实现 #

type Middleware func(http.Handler) http.Handler

// 日志中间件
func LoggingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        start := time.Now()

        // 包装 ResponseWriter 以捕获状态码
        wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}

        next.ServeHTTP(wrapped, r)

        duration := time.Since(start)
        fmt.Printf("[%s] %s %s %d %v\n",
            time.Now().Format("2006-01-02 15:04:05"),
            r.Method,
            r.URL.Path,
            wrapped.statusCode,
            duration)
    })
}

type responseWriter struct {
    http.ResponseWriter
    statusCode int
}

func (rw *responseWriter) WriteHeader(code int) {
    rw.statusCode = code
    rw.ResponseWriter.WriteHeader(code)
}

// 认证中间件
func AuthMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        token := r.Header.Get("Authorization")
        if token == "" {
            http.Error(w, "缺少认证令牌", http.StatusUnauthorized)
            return
        }

        // 简单的令牌验证
        if token != "Bearer valid-token" {
            http.Error(w, "无效的认证令牌", http.StatusUnauthorized)
            return
        }

        next.ServeHTTP(w, r)
    })
}

// CORS 中间件
func CORSMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")

        if r.Method == "OPTIONS" {
            w.WriteHeader(http.StatusOK)
            return
        }

        next.ServeHTTP(w, r)
    })
}

// 限流中间件
func RateLimitMiddleware(requestsPerMinute int) Middleware {
    clients := make(map[string][]time.Time)
    var mutex sync.RWMutex

    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            clientIP := r.RemoteAddr
            now := time.Now()

            mutex.Lock()
            defer mutex.Unlock()

            // 清理过期的请求记录
            if requests, exists := clients[clientIP]; exists {
                var validRequests []time.Time
                for _, reqTime := range requests {
                    if now.Sub(reqTime) < time.Minute {
                        validRequests = append(validRequests, reqTime)
                    }
                }
                clients[clientIP] = validRequests
            }

            // 检查请求频率
            if len(clients[clientIP]) >= requestsPerMinute {
                http.Error(w, "请求过于频繁", http.StatusTooManyRequests)
                return
            }

            // 记录当前请求
            clients[clientIP] = append(clients[clientIP], now)

            next.ServeHTTP(w, r)
        })
    }
}

// 中间件链
func ChainMiddleware(middlewares ...Middleware) Middleware {
    return func(next http.Handler) http.Handler {
        for i := len(middlewares) - 1; i >= 0; i-- {
            next = middlewares[i](next)
        }
        return next
    }
}

func demonstrateMiddleware() {
    mux := http.NewServeMux()

    // 受保护的路由
    protectedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "这是受保护的资源")
    })

    // 应用中间件链
    chain := ChainMiddleware(
        LoggingMiddleware,
        CORSMiddleware,
        RateLimitMiddleware(10), // 每分钟最多 10 个请求
        AuthMiddleware,
    )

    mux.Handle("/protected", chain(protectedHandler))

    // 公开路由
    mux.HandleFunc("/public", func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "这是公开资源")
    })

    server := &http.Server{
        Addr:    ":8080",
        Handler: LoggingMiddleware(mux),
    }

    fmt.Println("服务器启动在 :8080")
    server.ListenAndServe()
}

RESTful API 实现 #

用户管理 API #

import (
    "encoding/json"
    "fmt"
    "net/http"
    "strconv"
    "strings"
    "sync"
)

type User struct {
    ID    int    `json:"id"`
    Name  string `json:"name"`
    Email string `json:"email"`
    Age   int    `json:"age"`
}

type UserService struct {
    users  map[int]*User
    nextID int
    mutex  sync.RWMutex
}

func NewUserService() *UserService {
    return &UserService{
        users:  make(map[int]*User),
        nextID: 1,
    }
}

func (us *UserService) CreateUser(user *User) *User {
    us.mutex.Lock()
    defer us.mutex.Unlock()

    user.ID = us.nextID
    us.nextID++
    us.users[user.ID] = user
    return user
}

func (us *UserService) GetUser(id int) (*User, bool) {
    us.mutex.RLock()
    defer us.mutex.RUnlock()

    user, exists := us.users[id]
    return user, exists
}

func (us *UserService) GetAllUsers() []*User {
    us.mutex.RLock()
    defer us.mutex.RUnlock()

    users := make([]*User, 0, len(us.users))
    for _, user := range us.users {
        users = append(users, user)
    }
    return users
}

func (us *UserService) UpdateUser(id int, updatedUser *User) (*User, bool) {
    us.mutex.Lock()
    defer us.mutex.Unlock()

    if _, exists := us.users[id]; !exists {
        return nil, false
    }

    updatedUser.ID = id
    us.users[id] = updatedUser
    return updatedUser, true
}

func (us *UserService) DeleteUser(id int) bool {
    us.mutex.Lock()
    defer us.mutex.Unlock()

    if _, exists := us.users[id]; !exists {
        return false
    }

    delete(us.users, id)
    return true
}

type UserHandler struct {
    service *UserService
}

func NewUserHandler(service *UserService) *UserHandler {
    return &UserHandler{service: service}
}

func (uh *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    // 解析路径
    path := strings.TrimPrefix(r.URL.Path, "/api/users")

    switch {
    case path == "" || path == "/":
        uh.handleUsers(w, r)
    case strings.HasPrefix(path, "/"):
        uh.handleUserByID(w, r, strings.TrimPrefix(path, "/"))
    default:
        http.NotFound(w, r)
    }
}

func (uh *UserHandler) handleUsers(w http.ResponseWriter, r *http.Request) {
    switch r.Method {
    case http.MethodGet:
        uh.getAllUsers(w, r)
    case http.MethodPost:
        uh.createUser(w, r)
    default:
        http.Error(w, "方法不允许", http.StatusMethodNotAllowed)
    }
}

func (uh *UserHandler) handleUserByID(w http.ResponseWriter, r *http.Request, idStr string) {
    id, err := strconv.Atoi(idStr)
    if err != nil {
        http.Error(w, "无效的用户 ID", http.StatusBadRequest)
        return
    }

    switch r.Method {
    case http.MethodGet:
        uh.getUserByID(w, r, id)
    case http.MethodPut:
        uh.updateUser(w, r, id)
    case http.MethodDelete:
        uh.deleteUser(w, r, id)
    default:
        http.Error(w, "方法不允许", http.StatusMethodNotAllowed)
    }
}

func (uh *UserHandler) getAllUsers(w http.ResponseWriter, r *http.Request) {
    users := uh.service.GetAllUsers()

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(map[string]interface{}{
        "users": users,
        "count": len(users),
    })
}

func (uh *UserHandler) createUser(w http.ResponseWriter, r *http.Request) {
    var user User
    err := json.NewDecoder(r.Body).Decode(&user)
    if err != nil {
        http.Error(w, "无效的 JSON 数据", http.StatusBadRequest)
        return
    }

    // 简单验证
    if user.Name == "" || user.Email == "" {
        http.Error(w, "姓名和邮箱不能为空", http.StatusBadRequest)
        return
    }

    createdUser := uh.service.CreateUser(&user)

    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(http.StatusCreated)
    json.NewEncoder(w).Encode(createdUser)
}

func (uh *UserHandler) getUserByID(w http.ResponseWriter, r *http.Request, id int) {
    user, exists := uh.service.GetUser(id)
    if !exists {
        http.Error(w, "用户不存在", http.StatusNotFound)
        return
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(user)
}

func (uh *UserHandler) updateUser(w http.ResponseWriter, r *http.Request, id int) {
    var user User
    err := json.NewDecoder(r.Body).Decode(&user)
    if err != nil {
        http.Error(w, "无效的 JSON 数据", http.StatusBadRequest)
        return
    }

    updatedUser, exists := uh.service.UpdateUser(id, &user)
    if !exists {
        http.Error(w, "用户不存在", http.StatusNotFound)
        return
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(updatedUser)
}

func (uh *UserHandler) deleteUser(w http.ResponseWriter, r *http.Request, id int) {
    success := uh.service.DeleteUser(id)
    if !success {
        http.Error(w, "用户不存在", http.StatusNotFound)
        return
    }

    w.WriteHeader(http.StatusNoContent)
}

func demonstrateRESTAPI() {
    userService := NewUserService()
    userHandler := NewUserHandler(userService)

    // 添加一些测试数据
    userService.CreateUser(&User{Name: "张三", Email: "[email protected]", Age: 25})
    userService.CreateUser(&User{Name: "李四", Email: "[email protected]", Age: 30})

    mux := http.NewServeMux()
    mux.Handle("/api/users", userHandler)
    mux.Handle("/api/users/", userHandler)

    // 添加健康检查端点
    mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
        w.Header().Set("Content-Type", "application/json")
        json.NewEncoder(w).Encode(map[string]string{
            "status": "healthy",
            "time":   time.Now().Format(time.RFC3339),
        })
    })

    server := &http.Server{
        Addr:    ":8080",
        Handler: LoggingMiddleware(mux),
    }

    fmt.Println("RESTful API 服务器启动在 :8080")
    server.ListenAndServe()
}

文件上传和下载 #

文件上传处理 #

import (
    "crypto/md5"
    "fmt"
    "io"
    "mime/multipart"
    "net/http"
    "os"
    "path/filepath"
    "time"
)

func handleFileUpload(w http.ResponseWriter, r *http.Request) {
    if r.Method != http.MethodPost {
        http.Error(w, "只支持 POST 方法", http.StatusMethodNotAllowed)
        return
    }

    // 限制上传文件大小为 10MB
    r.ParseMultipartForm(10 << 20)

    file, handler, err := r.FormFile("file")
    if err != nil {
        http.Error(w, "获取上传文件失败", http.StatusBadRequest)
        return
    }
    defer file.Close()

    fmt.Printf("上传文件: %s, 大小: %d 字节\n", handler.Filename, handler.Size)

    // 验证文件类型
    if !isValidFileType(handler.Header.Get("Content-Type")) {
        http.Error(w, "不支持的文件类型", http.StatusBadRequest)
        return
    }

    // 生成唯一文件名
    filename := generateUniqueFilename(handler.Filename)
    filepath := filepath.Join("uploads", filename)

    // 确保上传目录存在
    os.MkdirAll("uploads", 0755)

    // 创建目标文件
    dst, err := os.Create(filepath)
    if err != nil {
        http.Error(w, "创建文件失败", http.StatusInternalServerError)
        return
    }
    defer dst.Close()

    // 复制文件内容
    _, err = io.Copy(dst, file)
    if err != nil {
        http.Error(w, "保存文件失败", http.StatusInternalServerError)
        return
    }

    // 计算文件哈希
    hash, err := calculateFileHash(filepath)
    if err != nil {
        fmt.Printf("计算文件哈希失败: %v\n", err)
    }

    // 返回上传结果
    w.Header().Set("Content-Type", "application/json")
    response := map[string]interface{}{
        "message":  "文件上传成功",
        "filename": filename,
        "size":     handler.Size,
        "hash":     hash,
    }
    json.NewEncoder(w).Encode(response)
}

func isValidFileType(contentType string) bool {
    allowedTypes := []string{
        "image/jpeg",
        "image/png",
        "image/gif",
        "text/plain",
        "application/pdf",
    }

    for _, allowed := range allowedTypes {
        if contentType == allowed {
            return true
        }
    }
    return false
}

func generateUniqueFilename(originalName string) string {
    ext := filepath.Ext(originalName)
    name := strings.TrimSuffix(originalName, ext)
    timestamp := time.Now().Unix()
    return fmt.Sprintf("%s_%d%s", name, timestamp, ext)
}

func calculateFileHash(filepath string) (string, error) {
    file, err := os.Open(filepath)
    if err != nil {
        return "", err
    }
    defer file.Close()

    hash := md5.New()
    _, err = io.Copy(hash, file)
    if err != nil {
        return "", err
    }

    return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

文件下载处理 #

func handleFileDownload(w http.ResponseWriter, r *http.Request) {
    if r.Method != http.MethodGet {
        http.Error(w, "只支持 GET 方法", http.StatusMethodNotAllowed)
        return
    }

    filename := r.URL.Query().Get("file")
    if filename == "" {
        http.Error(w, "缺少文件名参数", http.StatusBadRequest)
        return
    }

    // 安全检查:防止路径遍历攻击
    if strings.Contains(filename, "..") || strings.Contains(filename, "/") {
        http.Error(w, "无效的文件名", http.StatusBadRequest)
        return
    }

    filepath := filepath.Join("uploads", filename)

    // 检查文件是否存在
    fileInfo, err := os.Stat(filepath)
    if os.IsNotExist(err) {
        http.Error(w, "文件不存在", http.StatusNotFound)
        return
    }

    // 打开文件
    file, err := os.Open(filepath)
    if err != nil {
        http.Error(w, "打开文件失败", http.StatusInternalServerError)
        return
    }
    defer file.Close()

    // 设置响应头
    w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
    w.Header().Set("Content-Type", "application/octet-stream")
    w.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))

    // 发送文件内容
    _, err = io.Copy(w, file)
    if err != nil {
        fmt.Printf("发送文件失败: %v\n", err)
    }
}

// 支持断点续传的文件下载
func handleRangeDownload(w http.ResponseWriter, r *http.Request) {
    filename := r.URL.Query().Get("file")
    if filename == "" {
        http.Error(w, "缺少文件名参数", http.StatusBadRequest)
        return
    }

    filepath := filepath.Join("uploads", filename)
    file, err := os.Open(filepath)
    if err != nil {
        http.Error(w, "文件不存在", http.StatusNotFound)
        return
    }
    defer file.Close()

    fileInfo, err := file.Stat()
    if err != nil {
        http.Error(w, "获取文件信息失败", http.StatusInternalServerError)
        return
    }

    fileSize := fileInfo.Size()

    // 处理 Range 请求
    rangeHeader := r.Header.Get("Range")
    if rangeHeader != "" {
        // 解析 Range 头
        ranges, err := parseRange(rangeHeader, fileSize)
        if err != nil {
            http.Error(w, "无效的 Range 请求", http.StatusRequestedRangeNotSatisfiable)
            return
        }

        if len(ranges) == 1 {
            // 单个范围
            start, end := ranges[0].start, ranges[0].end

            w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize))
            w.Header().Set("Content-Length", fmt.Sprintf("%d", end-start+1))
            w.Header().Set("Accept-Ranges", "bytes")
            w.WriteHeader(http.StatusPartialContent)

            file.Seek(start, 0)
            io.CopyN(w, file, end-start+1)
            return
        }
    }

    // 普通下载
    w.Header().Set("Content-Length", fmt.Sprintf("%d", fileSize))
    w.Header().Set("Accept-Ranges", "bytes")
    io.Copy(w, file)
}

type httpRange struct {
    start, end int64
}

func parseRange(rangeHeader string, size int64) ([]httpRange, error) {
    if !strings.HasPrefix(rangeHeader, "bytes=") {
        return nil, fmt.Errorf("无效的 Range 头")
    }

    rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
    ranges := strings.Split(rangeSpec, ",")

    var result []httpRange
    for _, r := range ranges {
        r = strings.TrimSpace(r)
        if strings.Contains(r, "-") {
            parts := strings.Split(r, "-")
            if len(parts) != 2 {
                continue
            }

            var start, end int64
            var err error

            if parts[0] != "" {
                start, err = strconv.ParseInt(parts[0], 10, 64)
                if err != nil {
                    continue
                }
            }

            if parts[1] != "" {
                end, err = strconv.ParseInt(parts[1], 10, 64)
                if err != nil {
                    continue
                }
            } else {
                end = size - 1
            }

            if start > end || start >= size {
                continue
            }

            if end >= size {
                end = size - 1
            }

            result = append(result, httpRange{start: start, end: end})
        }
    }

    return result, nil
}

性能优化 #

连接池和 Keep-Alive #

func optimizedHTTPServer() {
    server := &http.Server{
        Addr:         ":8080",
        ReadTimeout:  15 * time.Second,
        WriteTimeout: 15 * time.Second,
        IdleTimeout:  60 * time.Second,

        // 启用 HTTP/2
        TLSConfig: &tls.Config{
            NextProtos: []string{"h2", "http/1.1"},
        },
    }

    // 配置传输层
    transport := &http.Transport{
        MaxIdleConns:        100,
        MaxIdleConnsPerHost: 10,
        IdleConnTimeout:     90 * time.Second,
        DisableCompression:  false,
    }

    client := &http.Client{
        Transport: transport,
        Timeout:   30 * time.Second,
    }

    _ = client // 使用优化的客户端

    server.ListenAndServe()
}

响应压缩 #

import (
    "compress/gzip"
    "strings"
)

func GzipMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 检查客户端是否支持 gzip
        if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
            next.ServeHTTP(w, r)
            return
        }

        // 创建 gzip writer
        w.Header().Set("Content-Encoding", "gzip")
        gz := gzip.NewWriter(w)
        defer gz.Close()

        // 包装 ResponseWriter
        gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w}
        next.ServeHTTP(gzw, r)
    })
}

type gzipResponseWriter struct {
    io.Writer
    http.ResponseWriter
}

func (w *gzipResponseWriter) Write(b []byte) (int, error) {
    return w.Writer.Write(b)
}

缓存控制 #

func CacheMiddleware(maxAge int) Middleware {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 设置缓存头
            w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", maxAge))
            w.Header().Set("Expires", time.Now().Add(time.Duration(maxAge)*time.Second).Format(http.TimeFormat))

            // 检查 If-Modified-Since
            if modifiedSince := r.Header.Get("If-Modified-Since"); modifiedSince != "" {
                if t, err := time.Parse(http.TimeFormat, modifiedSince); err == nil {
                    // 这里应该检查资源的实际修改时间
                    // 为了演示,假设资源在1小时前修改
                    lastModified := time.Now().Add(-1 * time.Hour)
                    if !lastModified.After(t) {
                        w.WriteHeader(http.StatusNotModified)
                        return
                    }
                }
            }

            w.Header().Set("Last-Modified", time.Now().Format(http.TimeFormat))
            next.ServeHTTP(w, r)
        })
    }
}

优雅关闭 #

服务器优雅关闭 #

import (
    "context"
    "os"
    "os/signal"
    "syscall"
)

func gracefulShutdown() {
    server := &http.Server{
        Addr: ":8080",
        Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 模拟长时间处理
            time.Sleep(5 * time.Second)
            fmt.Fprintf(w, "请求处理完成")
        }),
    }

    // 启动服务器
    go func() {
        fmt.Println("服务器启动在 :8080")
        if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
            fmt.Printf("服务器启动失败: %v\n", err)
        }
    }()

    // 等待中断信号
    quit := make(chan os.Signal, 1)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit

    fmt.Println("正在关闭服务器...")

    // 创建超时上下文
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()

    // 优雅关闭服务器
    if err := server.Shutdown(ctx); err != nil {
        fmt.Printf("服务器关闭失败: %v\n", err)
    } else {
        fmt.Println("服务器已优雅关闭")
    }
}

小结 #

本节详细介绍了 Go 语言中的 HTTP 服务器开发,包括:

  1. HTTP 基础 - HTTP 协议和请求响应模型
  2. 基础服务器 - 简单的 HTTP 服务器实现
  3. 高级服务器 - 自定义服务器配置和中间件系统
  4. RESTful API - 完整的 REST API 实现
  5. 文件处理 - 文件上传下载和断点续传
  6. 性能优化 - 连接池、压缩、缓存等优化技术
  7. 优雅关闭 - 服务器的优雅关闭机制

掌握这些 HTTP 服务器开发技术后,你就能够构建高性能、可扩展的 Web 服务。在下一节中,我们将学习 WebSocket 服务器的开发。