1.6.2 自定义错误类型

1.6.2 自定义错误类型 #

虽然 Go 语言的内置错误类型已经能够满足大部分需求,但在复杂的应用程序中,我们往往需要创建自定义错误类型来提供更丰富的错误信息、更好的错误分类和更精确的错误处理。本节将深入探讨如何设计和实现自定义错误类型。

自定义错误类型的基础 #

实现 error 接口 #

任何实现了 Error() string 方法的类型都可以作为错误使用:

package main

import "fmt"

// 简单的自定义错误类型
type ValidationError struct {
    Field   string
    Value   interface{}
    Message string
}

func (e ValidationError) Error() string {
    return fmt.Sprintf("字段 '%s' 的值 '%v' 验证失败: %s", e.Field, e.Value, e.Message)
}

func validateAge(age int) error {
    if age < 0 {
        return ValidationError{
            Field:   "age",
            Value:   age,
            Message: "年龄不能为负数",
        }
    }
    if age > 150 {
        return ValidationError{
            Field:   "age",
            Value:   age,
            Message: "年龄不能超过150岁",
        }
    }
    return nil
}

func main() {
    ages := []int{25, -5, 200}

    for _, age := range ages {
        err := validateAge(age)
        if err != nil {
            fmt.Printf("验证失败: %v\n", err)

            // 类型断言获取详细信息
            if validationErr, ok := err.(ValidationError); ok {
                fmt.Printf("  字段: %s\n", validationErr.Field)
                fmt.Printf("  值: %v\n", validationErr.Value)
                fmt.Printf("  消息: %s\n", validationErr.Message)
            }
        } else {
            fmt.Printf("年龄 %d 验证通过\n", age)
        }
        fmt.Println()
    }
}

错误类型的分类 #

根据不同的使用场景,我们可以设计不同类型的自定义错误:

package main

import (
    "fmt"
    "time"
)

// 1. 业务逻辑错误
type BusinessError struct {
    Code    string
    Message string
    Details map[string]interface{}
}

func (e BusinessError) Error() string {
    return fmt.Sprintf("[%s] %s", e.Code, e.Message)
}

// 2. 系统错误
type SystemError struct {
    Component string
    Operation string
    Cause     error
    Timestamp time.Time
}

func (e SystemError) Error() string {
    return fmt.Sprintf("系统错误 [%s:%s] %v (发生时间: %s)",
        e.Component, e.Operation, e.Cause, e.Timestamp.Format("2006-01-02 15:04:05"))
}

func (e SystemError) Unwrap() error {
    return e.Cause
}

// 3. 网络错误
type NetworkError struct {
    Host      string
    Port      int
    Operation string
    Timeout   time.Duration
    Retries   int
}

func (e NetworkError) Error() string {
    return fmt.Sprintf("网络错误: %s操作失败 %s:%d (超时: %v, 重试: %d次)",
        e.Operation, e.Host, e.Port, e.Timeout, e.Retries)
}

func (e NetworkError) IsTimeout() bool {
    return e.Timeout > 0
}

func (e NetworkError) IsRetryable() bool {
    return e.Retries < 3
}

// 模拟不同类型的错误
func processOrder(orderID string) error {
    // 模拟业务逻辑错误
    if orderID == "INVALID" {
        return BusinessError{
            Code:    "ORDER_NOT_FOUND",
            Message: "订单不存在",
            Details: map[string]interface{}{
                "orderID": orderID,
                "action":  "process",
            },
        }
    }

    // 模拟系统错误
    if orderID == "SYSTEM_ERROR" {
        return SystemError{
            Component: "database",
            Operation: "query",
            Cause:     fmt.Errorf("连接池耗尽"),
            Timestamp: time.Now(),
        }
    }

    // 模拟网络错误
    if orderID == "NETWORK_ERROR" {
        return NetworkError{
            Host:      "api.payment.com",
            Port:      443,
            Operation: "payment_process",
            Timeout:   time.Second * 30,
            Retries:   2,
        }
    }

    return nil
}

func main() {
    orders := []string{"ORDER123", "INVALID", "SYSTEM_ERROR", "NETWORK_ERROR"}

    for _, orderID := range orders {
        fmt.Printf("处理订单: %s\n", orderID)
        err := processOrder(orderID)

        if err != nil {
            fmt.Printf("错误: %v\n", err)

            // 根据错误类型进行不同的处理
            switch e := err.(type) {
            case BusinessError:
                fmt.Printf("  业务错误代码: %s\n", e.Code)
                fmt.Printf("  详细信息: %v\n", e.Details)

            case SystemError:
                fmt.Printf("  系统组件: %s\n", e.Component)
                fmt.Printf("  操作: %s\n", e.Operation)
                fmt.Printf("  发生时间: %s\n", e.Timestamp.Format("15:04:05"))

            case NetworkError:
                fmt.Printf("  目标主机: %s:%d\n", e.Host, e.Port)
                fmt.Printf("  是否可重试: %t\n", e.IsRetryable())
                fmt.Printf("  是否超时: %t\n", e.IsTimeout())
            }
        } else {
            fmt.Println("  处理成功")
        }
        fmt.Println()
    }
}

高级自定义错误模式 #

错误链和错误包装 #

package main

import (
    "errors"
    "fmt"
)

// 可包装的错误类型
type DatabaseError struct {
    Query     string
    Table     string
    Operation string
    Cause     error
}

func (e DatabaseError) Error() string {
    return fmt.Sprintf("数据库错误 [%s:%s] 查询: %s", e.Table, e.Operation, e.Query)
}

func (e DatabaseError) Unwrap() error {
    return e.Cause
}

// 服务层错误
type ServiceError struct {
    Service string
    Method  string
    Cause   error
}

func (e ServiceError) Error() string {
    return fmt.Sprintf("服务错误 [%s.%s]", e.Service, e.Method)
}

func (e ServiceError) Unwrap() error {
    return e.Cause
}

// 模拟数据库操作
func queryUser(userID int) error {
    // 模拟底层数据库错误
    baseErr := errors.New("连接超时")

    return DatabaseError{
        Query:     fmt.Sprintf("SELECT * FROM users WHERE id = %d", userID),
        Table:     "users",
        Operation: "SELECT",
        Cause:     baseErr,
    }
}

// 模拟服务层操作
func getUserProfile(userID int) error {
    err := queryUser(userID)
    if err != nil {
        return ServiceError{
            Service: "UserService",
            Method:  "GetProfile",
            Cause:   err,
        }
    }
    return nil
}

func main() {
    err := getUserProfile(123)
    if err != nil {
        fmt.Printf("顶层错误: %v\n", err)

        // 检查是否包含特定类型的错误
        var dbErr DatabaseError
        if errors.As(err, &dbErr) {
            fmt.Printf("数据库错误详情:\n")
            fmt.Printf("  表: %s\n", dbErr.Table)
            fmt.Printf("  操作: %s\n", dbErr.Operation)
            fmt.Printf("  查询: %s\n", dbErr.Query)
        }

        var serviceErr ServiceError
        if errors.As(err, &serviceErr) {
            fmt.Printf("服务错误详情:\n")
            fmt.Printf("  服务: %s\n", serviceErr.Service)
            fmt.Printf("  方法: %s\n", serviceErr.Method)
        }

        // 检查根本原因
        if errors.Is(err, errors.New("连接超时")) {
            fmt.Println("根本原因是连接超时")
        }
    }
}

错误状态和恢复 #

package main

import (
    "fmt"
    "time"
)

// 可恢复的错误接口
type RecoverableError interface {
    error
    IsRecoverable() bool
    GetRetryDelay() time.Duration
}

// 临时错误类型
type TemporaryError struct {
    Message    string
    RetryDelay time.Duration
    Attempts   int
    MaxRetries int
}

func (e TemporaryError) Error() string {
    return fmt.Sprintf("临时错误: %s (尝试 %d/%d)", e.Message, e.Attempts, e.MaxRetries)
}

func (e TemporaryError) IsRecoverable() bool {
    return e.Attempts < e.MaxRetries
}

func (e TemporaryError) GetRetryDelay() time.Duration {
    return e.RetryDelay
}

// 永久错误类型
type PermanentError struct {
    Message string
    Code    int
}

func (e PermanentError) Error() string {
    return fmt.Sprintf("永久错误 [%d]: %s", e.Code, e.Message)
}

func (e PermanentError) IsRecoverable() bool {
    return false
}

func (e PermanentError) GetRetryDelay() time.Duration {
    return 0
}

// 模拟可能失败的操作
func unreliableOperation(attempt int) error {
    switch attempt {
    case 1, 2:
        return TemporaryError{
            Message:    "服务暂时不可用",
            RetryDelay: time.Second * 2,
            Attempts:   attempt,
            MaxRetries: 3,
        }
    case 3:
        return nil // 成功
    default:
        return PermanentError{
            Message: "服务已停用",
            Code:    503,
        }
    }
}

// 带重试机制的操作执行器
func executeWithRetry(operation func(int) error) error {
    attempt := 1

    for {
        err := operation(attempt)
        if err == nil {
            fmt.Printf("操作在第 %d 次尝试后成功\n", attempt)
            return nil
        }

        fmt.Printf("第 %d 次尝试失败: %v\n", attempt, err)

        // 检查是否为可恢复错误
        if recoverableErr, ok := err.(RecoverableError); ok {
            if recoverableErr.IsRecoverable() {
                delay := recoverableErr.GetRetryDelay()
                fmt.Printf("等待 %v 后重试...\n", delay)
                time.Sleep(delay)
                attempt++
                continue
            }
        }

        // 不可恢复的错误
        return fmt.Errorf("操作最终失败: %w", err)
    }
}

func main() {
    fmt.Println("测试可恢复错误:")
    err := executeWithRetry(unreliableOperation)
    if err != nil {
        fmt.Printf("最终错误: %v\n", err)
    }

    fmt.Println("\n测试永久错误:")
    err = executeWithRetry(func(attempt int) error {
        return PermanentError{
            Message: "权限不足",
            Code:    403,
        }
    })
    if err != nil {
        fmt.Printf("最终错误: %v\n", err)
    }
}

实际应用:HTTP API 错误处理 #

让我们通过一个 HTTP API 的例子来展示自定义错误类型的实际应用:

package main

import (
    "encoding/json"
    "fmt"
    "net/http"
    "strconv"
    "time"
)

// API 错误响应结构
type APIError struct {
    Code      int                    `json:"code"`
    Message   string                 `json:"message"`
    Details   map[string]interface{} `json:"details,omitempty"`
    Timestamp time.Time              `json:"timestamp"`
    RequestID string                 `json:"request_id,omitempty"`
}

func (e APIError) Error() string {
    return fmt.Sprintf("API错误 [%d]: %s", e.Code, e.Message)
}

// HTTP 状态码映射
func (e APIError) HTTPStatus() int {
    switch e.Code {
    case 1001, 1002, 1003: // 验证错误
        return http.StatusBadRequest
    case 2001: // 未找到
        return http.StatusNotFound
    case 2002: // 未授权
        return http.StatusUnauthorized
    case 3001, 3002: // 服务器错误
        return http.StatusInternalServerError
    default:
        return http.StatusInternalServerError
    }
}

// 预定义的错误类型
var (
    ErrInvalidInput = APIError{
        Code:    1001,
        Message: "输入参数无效",
    }

    ErrMissingField = APIError{
        Code:    1002,
        Message: "缺少必需字段",
    }

    ErrValidationFailed = APIError{
        Code:    1003,
        Message: "数据验证失败",
    }

    ErrUserNotFound = APIError{
        Code:    2001,
        Message: "用户不存在",
    }

    ErrUnauthorized = APIError{
        Code:    2002,
        Message: "未授权访问",
    }

    ErrDatabaseConnection = APIError{
        Code:    3001,
        Message: "数据库连接失败",
    }

    ErrInternalServer = APIError{
        Code:    3002,
        Message: "内部服务器错误",
    }
)

// 用户数据结构
type User struct {
    ID    int    `json:"id"`
    Name  string `json:"name"`
    Email string `json:"email"`
    Age   int    `json:"age"`
}

// 模拟用户数据库
var users = map[int]User{
    1: {ID: 1, Name: "张三", Email: "[email protected]", Age: 25},
    2: {ID: 2, Name: "李四", Email: "[email protected]", Age: 30},
}

// 用户服务
type UserService struct{}

func (s *UserService) GetUser(id int) (*User, error) {
    if id <= 0 {
        err := ErrInvalidInput
        err.Details = map[string]interface{}{
            "field": "id",
            "value": id,
            "rule":  "must be positive",
        }
        err.Timestamp = time.Now()
        return nil, err
    }

    user, exists := users[id]
    if !exists {
        err := ErrUserNotFound
        err.Details = map[string]interface{}{
            "user_id": id,
        }
        err.Timestamp = time.Now()
        return nil, err
    }

    return &user, nil
}

func (s *UserService) CreateUser(user User) (*User, error) {
    // 验证用户数据
    if user.Name == "" {
        err := ErrMissingField
        err.Details = map[string]interface{}{
            "field": "name",
        }
        err.Timestamp = time.Now()
        return nil, err
    }

    if user.Age < 0 || user.Age > 150 {
        err := ErrValidationFailed
        err.Details = map[string]interface{}{
            "field": "age",
            "value": user.Age,
            "rule":  "must be between 0 and 150",
        }
        err.Timestamp = time.Now()
        return nil, err
    }

    // 生成新ID
    newID := len(users) + 1
    user.ID = newID
    users[newID] = user

    return &user, nil
}

// HTTP 处理器
type UserHandler struct {
    service *UserService
}

func NewUserHandler() *UserHandler {
    return &UserHandler{
        service: &UserService{},
    }
}

func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
    // 解析用户ID
    idStr := r.URL.Query().Get("id")
    if idStr == "" {
        h.writeError(w, r, ErrMissingField, "id")
        return
    }

    id, err := strconv.Atoi(idStr)
    if err != nil {
        apiErr := ErrInvalidInput
        apiErr.Details = map[string]interface{}{
            "field": "id",
            "value": idStr,
            "error": err.Error(),
        }
        apiErr.Timestamp = time.Now()
        h.writeError(w, r, apiErr, "")
        return
    }

    // 获取用户
    user, err := h.service.GetUser(id)
    if err != nil {
        if apiErr, ok := err.(APIError); ok {
            h.writeError(w, r, apiErr, "")
        } else {
            h.writeError(w, r, ErrInternalServer, "")
        }
        return
    }

    // 返回成功响应
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(map[string]interface{}{
        "success": true,
        "data":    user,
    })
}

func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
    var user User
    err := json.NewDecoder(r.Body).Decode(&user)
    if err != nil {
        apiErr := ErrInvalidInput
        apiErr.Details = map[string]interface{}{
            "error": "invalid JSON format",
        }
        apiErr.Timestamp = time.Now()
        h.writeError(w, r, apiErr, "")
        return
    }

    // 创建用户
    createdUser, err := h.service.CreateUser(user)
    if err != nil {
        if apiErr, ok := err.(APIError); ok {
            h.writeError(w, r, apiErr, "")
        } else {
            h.writeError(w, r, ErrInternalServer, "")
        }
        return
    }

    // 返回成功响应
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(http.StatusCreated)
    json.NewEncoder(w).Encode(map[string]interface{}{
        "success": true,
        "data":    createdUser,
    })
}

func (h *UserHandler) writeError(w http.ResponseWriter, r *http.Request, apiErr APIError, field string) {
    if field != "" {
        if apiErr.Details == nil {
            apiErr.Details = make(map[string]interface{})
        }
        apiErr.Details["field"] = field
    }

    if apiErr.Timestamp.IsZero() {
        apiErr.Timestamp = time.Now()
    }

    // 添加请求ID(实际应用中可能从中间件获取)
    apiErr.RequestID = fmt.Sprintf("req_%d", time.Now().UnixNano())

    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(apiErr.HTTPStatus())

    response := map[string]interface{}{
        "success": false,
        "error":   apiErr,
    }

    json.NewEncoder(w).Encode(response)
}

// 演示函数
func demonstrateAPI() {
    handler := NewUserHandler()

    // 模拟不同的请求场景
    scenarios := []struct {
        name        string
        method      string
        path        string
        body        string
        description string
    }{
        {
            name:        "获取存在的用户",
            method:      "GET",
            path:        "/user?id=1",
            description: "正常获取用户信息",
        },
        {
            name:        "获取不存在的用户",
            method:      "GET",
            path:        "/user?id=999",
            description: "用户不存在错误",
        },
        {
            name:        "无效的用户ID",
            method:      "GET",
            path:        "/user?id=abc",
            description: "参数验证错误",
        },
        {
            name:        "缺少用户ID",
            method:      "GET",
            path:        "/user",
            description: "缺少必需参数",
        },
        {
            name:        "创建有效用户",
            method:      "POST",
            path:        "/user",
            body:        `{"name":"王五","email":"[email protected]","age":28}`,
            description: "成功创建用户",
        },
        {
            name:        "创建无效用户",
            method:      "POST",
            path:        "/user",
            body:        `{"name":"","email":"invalid","age":-5}`,
            description: "数据验证失败",
        },
    }

    fmt.Println("API 错误处理演示:")
    fmt.Println(strings.Repeat("=", 50))

    for _, scenario := range scenarios {
        fmt.Printf("\n场景: %s\n", scenario.name)
        fmt.Printf("描述: %s\n", scenario.description)
        fmt.Printf("请求: %s %s\n", scenario.method, scenario.path)
        if scenario.body != "" {
            fmt.Printf("请求体: %s\n", scenario.body)
        }

        // 这里只是演示错误类型的创建和处理逻辑
        // 实际的HTTP请求处理需要启动HTTP服务器
        fmt.Printf("响应: [模拟] 根据场景返回相应的成功或错误响应\n")
    }
}

func main() {
    demonstrateAPI()
}

错误类型设计的最佳实践 #

1. 错误类型的层次结构 #

package main

import (
    "fmt"
    "time"
)

// 基础错误接口
type BaseError interface {
    error
    Code() string
    Timestamp() time.Time
}

// 基础错误实现
type baseError struct {
    code      string
    message   string
    timestamp time.Time
}

func (e baseError) Error() string {
    return e.message
}

func (e baseError) Code() string {
    return e.code
}

func (e baseError) Timestamp() time.Time {
    return e.timestamp
}

// 业务错误
type BusinessError struct {
    baseError
    Domain string
    Action string
}

func NewBusinessError(domain, action, code, message string) BusinessError {
    return BusinessError{
        baseError: baseError{
            code:      code,
            message:   message,
            timestamp: time.Now(),
        },
        Domain: domain,
        Action: action,
    }
}

// 技术错误
type TechnicalError struct {
    baseError
    Component string
    Cause     error
}

func NewTechnicalError(component, code, message string, cause error) TechnicalError {
    return TechnicalError{
        baseError: baseError{
            code:      code,
            message:   message,
            timestamp: time.Now(),
        },
        Component: component,
        Cause:     cause,
    }
}

func (e TechnicalError) Unwrap() error {
    return e.Cause
}

func main() {
    // 业务错误示例
    bizErr := NewBusinessError("user", "create", "USER_EXISTS", "用户已存在")
    fmt.Printf("业务错误: %v\n", bizErr)
    fmt.Printf("  领域: %s\n", bizErr.Domain)
    fmt.Printf("  操作: %s\n", bizErr.Action)
    fmt.Printf("  代码: %s\n", bizErr.Code())

    // 技术错误示例
    techErr := NewTechnicalError("database", "DB_CONN_FAILED", "数据库连接失败",
        fmt.Errorf("连接超时"))
    fmt.Printf("\n技术错误: %v\n", techErr)
    fmt.Printf("  组件: %s\n", techErr.Component)
    fmt.Printf("  代码: %s\n", techErr.Code())
    fmt.Printf("  原因: %v\n", techErr.Cause)
}

2. 错误上下文和追踪 #

package main

import (
    "context"
    "fmt"
    "runtime"
    "time"
)

// 带上下文的错误
type ContextualError struct {
    Message   string
    Code      string
    Timestamp time.Time
    Context   map[string]interface{}
    Stack     []StackFrame
}

type StackFrame struct {
    Function string
    File     string
    Line     int
}

func (e ContextualError) Error() string {
    return fmt.Sprintf("[%s] %s", e.Code, e.Message)
}

func (e ContextualError) WithContext(key string, value interface{}) ContextualError {
    if e.Context == nil {
        e.Context = make(map[string]interface{})
    }
    e.Context[key] = value
    return e
}

func NewContextualError(code, message string) ContextualError {
    // 获取调用栈
    var stack []StackFrame
    for i := 1; i < 10; i++ {
        pc, file, line, ok := runtime.Caller(i)
        if !ok {
            break
        }

        fn := runtime.FuncForPC(pc)
        stack = append(stack, StackFrame{
            Function: fn.Name(),
            File:     file,
            Line:     line,
        })
    }

    return ContextualError{
        Message:   message,
        Code:      code,
        Timestamp: time.Now(),
        Context:   make(map[string]interface{}),
        Stack:     stack,
    }
}

// 错误处理中间件
func withErrorContext(ctx context.Context, operation string) context.Context {
    return context.WithValue(ctx, "operation", operation)
}

func processWithContext(ctx context.Context, userID int) error {
    if userID <= 0 {
        err := NewContextualError("INVALID_USER_ID", "用户ID无效")
        err = err.WithContext("user_id", userID)
        err = err.WithContext("operation", ctx.Value("operation"))
        return err
    }

    return nil
}

func main() {
    ctx := withErrorContext(context.Background(), "user_validation")

    err := processWithContext(ctx, -1)
    if err != nil {
        if contextErr, ok := err.(ContextualError); ok {
            fmt.Printf("错误: %v\n", contextErr)
            fmt.Printf("代码: %s\n", contextErr.Code)
            fmt.Printf("时间: %s\n", contextErr.Timestamp.Format("2006-01-02 15:04:05"))
            fmt.Printf("上下文: %v\n", contextErr.Context)

            fmt.Println("调用栈:")
            for i, frame := range contextErr.Stack[:3] { // 只显示前3层
                fmt.Printf("  %d. %s (%s:%d)\n", i+1, frame.Function, frame.File, frame.Line)
            }
        }
    }
}

小结 #

本节详细介绍了 Go 语言中自定义错误类型的设计和实现,包括:

  • 自定义错误类型的基础概念和实现方法
  • 不同场景下的错误类型分类和设计
  • 错误包装和错误链的高级用法
  • 可恢复错误和错误状态管理
  • HTTP API 中的错误处理实践
  • 错误类型设计的最佳实践

通过合理设计自定义错误类型,我们可以:

  • 提供更丰富的错误信息
  • 实现更精确的错误分类和处理
  • 改善错误的可调试性和可维护性
  • 构建更健壮的应用程序

掌握这些自定义错误类型的设计技巧,将有助于编写更加专业和可维护的 Go 代码。