1.3.4 函数高级特性

1.3.4 函数高级特性 #

Go 语言的函数具有许多高级特性,包括递归、defer 语句、panic 和 recover 机制等。这些特性使得 Go 语言的函数更加强大和灵活。本节将深入探讨这些高级特性的使用方法和应用场景。

递归函数 #

1. 递归的基本概念 #

递归是函数调用自身的编程技术,适用于解决可以分解为相似子问题的问题:

package main

import "fmt"

// 计算阶乘
func factorial(n int) int {
    // 基础情况
    if n <= 1 {
        return 1
    }
    // 递归情况
    return n * factorial(n-1)
}

// 计算斐波那契数列
func fibonacci(n int) int {
    if n <= 1 {
        return n
    }
    return fibonacci(n-1) + fibonacci(n-2)
}

// 计算最大公约数(欧几里得算法)
func gcd(a, b int) int {
    if b == 0 {
        return a
    }
    return gcd(b, a%b)
}

// 数字求和
func sumDigits(n int) int {
    if n < 10 {
        return n
    }
    return n%10 + sumDigits(n/10)
}

func main() {
    // 测试阶乘
    fmt.Println("=== 阶乘计算 ===")
    for i := 0; i <= 10; i++ {
        fmt.Printf("%d! = %d\n", i, factorial(i))
    }

    // 测试斐波那契数列
    fmt.Println("\n=== 斐波那契数列 ===")
    for i := 0; i <= 10; i++ {
        fmt.Printf("fib(%d) = %d\n", i, fibonacci(i))
    }

    // 测试最大公约数
    fmt.Println("\n=== 最大公约数 ===")
    pairs := [][2]int{{48, 18}, {100, 25}, {17, 13}}
    for _, pair := range pairs {
        fmt.Printf("gcd(%d, %d) = %d\n", pair[0], pair[1], gcd(pair[0], pair[1]))
    }

    // 测试数字求和
    fmt.Println("\n=== 数字各位求和 ===")
    numbers := []int{123, 456, 789, 1234}
    for _, num := range numbers {
        fmt.Printf("sumDigits(%d) = %d\n", num, sumDigits(num))
    }
}

2. 复杂递归应用 #

package main

import "fmt"

// 二叉树节点
type TreeNode struct {
    Val   int
    Left  *TreeNode
    Right *TreeNode
}

// 树的深度优先遍历
func inorderTraversal(root *TreeNode) []int {
    if root == nil {
        return []int{}
    }

    result := []int{}
    result = append(result, inorderTraversal(root.Left)...)
    result = append(result, root.Val)
    result = append(result, inorderTraversal(root.Right)...)

    return result
}

// 计算树的最大深度
func maxDepth(root *TreeNode) int {
    if root == nil {
        return 0
    }

    leftDepth := maxDepth(root.Left)
    rightDepth := maxDepth(root.Right)

    if leftDepth > rightDepth {
        return leftDepth + 1
    }
    return rightDepth + 1
}

// 汉诺塔问题
func hanoi(n int, from, to, aux string) {
    if n == 1 {
        fmt.Printf("移动盘子从 %s 到 %s\n", from, to)
        return
    }

    hanoi(n-1, from, aux, to)
    fmt.Printf("移动盘子从 %s 到 %s\n", from, to)
    hanoi(n-1, aux, to, from)
}

// 快速排序
func quickSort(arr []int, low, high int) {
    if low < high {
        pi := partition(arr, low, high)
        quickSort(arr, low, pi-1)
        quickSort(arr, pi+1, high)
    }
}

func partition(arr []int, low, high int) int {
    pivot := arr[high]
    i := low - 1

    for j := low; j < high; j++ {
        if arr[j] < pivot {
            i++
            arr[i], arr[j] = arr[j], arr[i]
        }
    }

    arr[i+1], arr[high] = arr[high], arr[i+1]
    return i + 1
}

func main() {
    // 创建二叉树
    root := &TreeNode{
        Val: 1,
        Left: &TreeNode{
            Val:   2,
            Left:  &TreeNode{Val: 4},
            Right: &TreeNode{Val: 5},
        },
        Right: &TreeNode{
            Val:   3,
            Right: &TreeNode{Val: 6},
        },
    }

    // 测试树遍历
    fmt.Println("=== 二叉树中序遍历 ===")
    result := inorderTraversal(root)
    fmt.Printf("遍历结果: %v\n", result)

    // 测试树深度
    fmt.Printf("树的最大深度: %d\n", maxDepth(root))

    // 汉诺塔问题
    fmt.Println("\n=== 汉诺塔问题 (3个盘子) ===")
    hanoi(3, "A", "C", "B")

    // 快速排序
    fmt.Println("\n=== 快速排序 ===")
    arr := []int{64, 34, 25, 12, 22, 11, 90}
    fmt.Printf("排序前: %v\n", arr)
    quickSort(arr, 0, len(arr)-1)
    fmt.Printf("排序后: %v\n", arr)
}

defer 语句 #

1. defer 的基本使用 #

defer 语句用于延迟函数调用,直到包含它的函数返回时才执行:

package main

import (
    "fmt"
    "os"
    "time"
)

// 基本 defer 使用
func basicDefer() {
    fmt.Println("函数开始")

    defer fmt.Println("defer 1")
    defer fmt.Println("defer 2")
    defer fmt.Println("defer 3")

    fmt.Println("函数中间")
    fmt.Println("函数结束")
}

// defer 与资源管理
func fileOperation() error {
    file, err := os.Create("test.txt")
    if err != nil {
        return err
    }
    defer file.Close() // 确保文件被关闭

    defer fmt.Println("文件操作完成")

    // 写入文件
    _, err = file.WriteString("Hello, World!")
    if err != nil {
        return err
    }

    fmt.Println("文件写入成功")
    return nil
}

// defer 与函数参数求值
func deferWithArgs() {
    x := 10
    defer fmt.Printf("defer 中的 x: %d\n", x) // x 的值在 defer 时确定

    x = 20
    fmt.Printf("函数中的 x: %d\n", x)
}

// defer 与返回值
func deferWithReturn() (result int) {
    defer func() {
        result *= 2 // 修改命名返回值
        fmt.Printf("defer 中修改返回值: %d\n", result)
    }()

    result = 10
    fmt.Printf("函数中的返回值: %d\n", result)
    return result
}

// 计时器示例
func timeTracker(name string) func() {
    start := time.Now()
    fmt.Printf("%s 开始执行\n", name)

    return func() {
        fmt.Printf("%s 执行完成,耗时: %v\n", name, time.Since(start))
    }
}

func expensiveOperation() {
    defer timeTracker("昂贵操作")()

    // 模拟耗时操作
    time.Sleep(100 * time.Millisecond)
    fmt.Println("执行复杂计算...")
}

func main() {
    // 基本 defer 使用
    fmt.Println("=== 基本 defer 使用 ===")
    basicDefer()

    // 文件操作
    fmt.Println("\n=== defer 与资源管理 ===")
    if err := fileOperation(); err != nil {
        fmt.Printf("文件操作错误: %v\n", err)
    }

    // defer 参数求值
    fmt.Println("\n=== defer 参数求值 ===")
    deferWithArgs()

    // defer 与返回值
    fmt.Println("\n=== defer 与返回值 ===")
    result := deferWithReturn()
    fmt.Printf("最终返回值: %d\n", result)

    // 计时器示例
    fmt.Println("\n=== 计时器示例 ===")
    expensiveOperation()
}

2. defer 的高级应用 #

package main

import (
    "fmt"
    "sync"
    "time"
)

// 互斥锁管理
type SafeCounter struct {
    mu    sync.Mutex
    count int
}

func (c *SafeCounter) Increment() {
    c.mu.Lock()
    defer c.mu.Unlock() // 确保锁被释放

    c.count++
    fmt.Printf("计数器增加到: %d\n", c.count)
}

func (c *SafeCounter) Get() int {
    c.mu.Lock()
    defer c.mu.Unlock()

    return c.count
}

// 数据库连接模拟
type Database struct {
    connected bool
}

func (db *Database) Connect() error {
    fmt.Println("连接数据库...")
    db.connected = true
    return nil
}

func (db *Database) Disconnect() {
    if db.connected {
        fmt.Println("断开数据库连接")
        db.connected = false
    }
}

func (db *Database) Query(sql string) error {
    if !db.connected {
        return fmt.Errorf("数据库未连接")
    }
    fmt.Printf("执行查询: %s\n", sql)
    return nil
}

func databaseOperation() error {
    db := &Database{}

    if err := db.Connect(); err != nil {
        return err
    }
    defer db.Disconnect() // 确保连接被关闭

    // 执行多个查询
    queries := []string{
        "SELECT * FROM users",
        "SELECT * FROM orders",
        "SELECT * FROM products",
    }

    for _, query := range queries {
        if err := db.Query(query); err != nil {
            return err
        }
    }

    return nil
}

// 栈结构实现
type Stack struct {
    items []int
}

func (s *Stack) Push(item int) {
    s.items = append(s.items, item)
}

func (s *Stack) Pop() (int, bool) {
    if len(s.items) == 0 {
        return 0, false
    }

    index := len(s.items) - 1
    item := s.items[index]
    s.items = s.items[:index]
    return item, true
}

func (s *Stack) Size() int {
    return len(s.items)
}

// 使用 defer 实现自动清理
func stackOperation() {
    stack := &Stack{}

    // 使用 defer 确保栈被清理
    defer func() {
        fmt.Printf("清理栈,最终大小: %d\n", stack.Size())
        for stack.Size() > 0 {
            if item, ok := stack.Pop(); ok {
                fmt.Printf("清理项目: %d\n", item)
            }
        }
    }()

    // 向栈中添加元素
    for i := 1; i <= 5; i++ {
        stack.Push(i)
        fmt.Printf("推入: %d, 栈大小: %d\n", i, stack.Size())
    }

    // 弹出一些元素
    for i := 0; i < 2; i++ {
        if item, ok := stack.Pop(); ok {
            fmt.Printf("弹出: %d, 栈大小: %d\n", item, stack.Size())
        }
    }
}

func main() {
    // 互斥锁管理
    fmt.Println("=== 互斥锁管理 ===")
    counter := &SafeCounter{}

    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
    }
    wg.Wait()

    fmt.Printf("最终计数: %d\n", counter.Get())

    // 数据库操作
    fmt.Println("\n=== 数据库操作 ===")
    if err := databaseOperation(); err != nil {
        fmt.Printf("数据库操作错误: %v\n", err)
    }

    // 栈操作
    fmt.Println("\n=== 栈操作 ===")
    stackOperation()
}

panic 和 recover #

1. panic 和 recover 基础 #

panic 用于引发运行时恐慌,recover 用于捕获和处理恐慌:

package main

import "fmt"

// 基本 panic 和 recover
func basicPanicRecover() {
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("捕获到 panic: %v\n", r)
        }
    }()

    fmt.Println("正常执行")
    panic("发生了错误!")
    fmt.Println("这行不会执行")
}

// 数组越界保护
func safeArrayAccess(arr []int, index int) (value int, err error) {
    defer func() {
        if r := recover(); r != nil {
            err = fmt.Errorf("数组访问错误: %v", r)
        }
    }()

    value = arr[index] // 可能引发 panic
    return value, nil
}

// 除零保护
func safeDivide(a, b float64) (result float64, err error) {
    defer func() {
        if r := recover(); r != nil {
            err = fmt.Errorf("除法错误: %v", r)
        }
    }()

    if b == 0 {
        panic("除数不能为零")
    }

    result = a / b
    return result, nil
}

// 类型断言保护
func safeTypeAssertion(value interface{}) (str string, err error) {
    defer func() {
        if r := recover(); r != nil {
            err = fmt.Errorf("类型断言错误: %v", r)
        }
    }()

    str = value.(string) // 可能引发 panic
    return str, nil
}

func main() {
    // 基本 panic 和 recover
    fmt.Println("=== 基本 panic 和 recover ===")
    basicPanicRecover()
    fmt.Println("程序继续执行")

    // 数组越界保护
    fmt.Println("\n=== 数组越界保护 ===")
    arr := []int{1, 2, 3, 4, 5}

    if value, err := safeArrayAccess(arr, 2); err != nil {
        fmt.Printf("错误: %v\n", err)
    } else {
        fmt.Printf("arr[2] = %d\n", value)
    }

    if value, err := safeArrayAccess(arr, 10); err != nil {
        fmt.Printf("错误: %v\n", err)
    } else {
        fmt.Printf("arr[10] = %d\n", value)
    }

    // 除零保护
    fmt.Println("\n=== 除零保护 ===")
    if result, err := safeDivide(10, 2); err != nil {
        fmt.Printf("错误: %v\n", err)
    } else {
        fmt.Printf("10 / 2 = %.2f\n", result)
    }

    if result, err := safeDivide(10, 0); err != nil {
        fmt.Printf("错误: %v\n", err)
    } else {
        fmt.Printf("10 / 0 = %.2f\n", result)
    }

    // 类型断言保护
    fmt.Println("\n=== 类型断言保护 ===")
    values := []interface{}{"hello", 42, true, "world"}

    for i, value := range values {
        if str, err := safeTypeAssertion(value); err != nil {
            fmt.Printf("值 %d 错误: %v\n", i, err)
        } else {
            fmt.Printf("值 %d 是字符串: %s\n", i, str)
        }
    }
}

2. 高级错误处理模式 #

package main

import (
    "fmt"
    "log"
    "runtime"
)

// 错误类型定义
type CustomError struct {
    Code    int
    Message string
    Cause   error
}

func (e *CustomError) Error() string {
    if e.Cause != nil {
        return fmt.Sprintf("错误 %d: %s (原因: %v)", e.Code, e.Message, e.Cause)
    }
    return fmt.Sprintf("错误 %d: %s", e.Code, e.Message)
}

// 带错误恢复的函数
func robustFunction(input int) (result int, err error) {
    defer func() {
        if r := recover(); r != nil {
            // 获取调用栈信息
            buf := make([]byte, 1024)
            n := runtime.Stack(buf, false)

            err = &CustomError{
                Code:    500,
                Message: fmt.Sprintf("函数执行失败: %v", r),
                Cause:   fmt.Errorf("调用栈: %s", buf[:n]),
            }
        }
    }()

    if input < 0 {
        panic("输入不能为负数")
    }

    if input == 0 {
        panic("输入不能为零")
    }

    result = 100 / input
    return result, nil
}

// 链式错误处理
func processData(data []int) error {
    defer func() {
        if r := recover(); r != nil {
            log.Printf("数据处理发生 panic: %v", r)
        }
    }()

    for i, value := range data {
        if err := validateData(value, i); err != nil {
            return fmt.Errorf("数据验证失败在位置 %d: %w", i, err)
        }

        if err := transformData(value); err != nil {
            return fmt.Errorf("数据转换失败在位置 %d: %w", i, err)
        }
    }

    return nil
}

func validateData(value, index int) error {
    if value < 0 {
        return &CustomError{
            Code:    400,
            Message: fmt.Sprintf("负数值不允许: %d", value),
        }
    }

    if value > 1000 {
        return &CustomError{
            Code:    400,
            Message: fmt.Sprintf("值太大: %d", value),
        }
    }

    return nil
}

func transformData(value int) error {
    defer func() {
        if r := recover(); r != nil {
            panic(fmt.Sprintf("数据转换失败: %v", r))
        }
    }()

    if value == 13 { // 模拟特殊情况
        panic("不吉利的数字")
    }

    // 模拟转换过程
    _ = value * 2
    return nil
}

// 资源管理与错误处理
type Resource struct {
    name     string
    acquired bool
}

func (r *Resource) Acquire() error {
    if r.acquired {
        return fmt.Errorf("资源 %s 已被获取", r.name)
    }

    fmt.Printf("获取资源: %s\n", r.name)
    r.acquired = true
    return nil
}

func (r *Resource) Release() {
    if r.acquired {
        fmt.Printf("释放资源: %s\n", r.name)
        r.acquired = false
    }
}

func useResource(resourceName string) (err error) {
    resource := &Resource{name: resourceName}

    defer func() {
        resource.Release()
        if r := recover(); r != nil {
            err = fmt.Errorf("使用资源 %s 时发生 panic: %v", resourceName, r)
        }
    }()

    if err := resource.Acquire(); err != nil {
        return err
    }

    // 模拟使用资源
    fmt.Printf("使用资源: %s\n", resourceName)

    if resourceName == "危险资源" {
        panic("资源使用失败")
    }

    return nil
}

func main() {
    // 带错误恢复的函数测试
    fmt.Println("=== 带错误恢复的函数 ===")
    testInputs := []int{10, 5, 0, -1}

    for _, input := range testInputs {
        if result, err := robustFunction(input); err != nil {
            fmt.Printf("输入 %d 错误: %v\n", input, err)
        } else {
            fmt.Printf("输入 %d 结果: %d\n", input, result)
        }
    }

    // 链式错误处理测试
    fmt.Println("\n=== 链式错误处理 ===")
    testData := [][]int{
        {1, 2, 3, 4, 5},
        {10, 20, -5, 30},
        {100, 200, 13, 400},
        {1, 2000, 3},
    }

    for i, data := range testData {
        fmt.Printf("处理数据集 %d: %v\n", i+1, data)
        if err := processData(data); err != nil {
            fmt.Printf("  错误: %v\n", err)
        } else {
            fmt.Printf("  处理成功\n")
        }
    }

    // 资源管理测试
    fmt.Println("\n=== 资源管理与错误处理 ===")
    resources := []string{"正常资源", "危险资源", "另一个资源"}

    for _, resourceName := range resources {
        if err := useResource(resourceName); err != nil {
            fmt.Printf("错误: %v\n", err)
        } else {
            fmt.Printf("成功使用资源: %s\n", resourceName)
        }
    }
}

函数式编程特性 #

1. 高阶函数 #

package main

import (
    "fmt"
    "sort"
    "strings"
)

// 函数类型定义
type Predicate[T any] func(T) bool
type Mapper[T, R any] func(T) R
type Reducer[T, R any] func(R, T) R

// 泛型过滤函数
func Filter[T any](slice []T, predicate Predicate[T]) []T {
    result := make([]T, 0)
    for _, item := range slice {
        if predicate(item) {
            result = append(result, item)
        }
    }
    return result
}

// 泛型映射函数
func Map[T, R any](slice []T, mapper Mapper[T, R]) []R {
    result := make([]R, len(slice))
    for i, item := range slice {
        result[i] = mapper(item)
    }
    return result
}

// 泛型归约函数
func Reduce[T, R any](slice []T, initial R, reducer Reducer[T, R]) R {
    result := initial
    for _, item := range slice {
        result = reducer(result, item)
    }
    return result
}

// 函数组合
func Compose[T, U, V any](f func(U) V, g func(T) U) func(T) V {
    return func(x T) V {
        return f(g(x))
    }
}

// 柯里化示例
func Add(a int) func(int) int {
    return func(b int) int {
        return a + b
    }
}

func Multiply(a int) func(int) int {
    return func(b int) int {
        return a * b
    }
}

// 部分应用
func PartialApply[T, U, V any](f func(T, U) V, arg T) func(U) V {
    return func(u U) V {
        return f(arg, u)
    }
}

func main() {
    // 测试数据
    numbers := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    words := []string{"hello", "world", "go", "programming", "functional"}

    // 过滤示例
    fmt.Println("=== 过滤示例 ===")
    evenNumbers := Filter(numbers, func(n int) bool { return n%2 == 0 })
    fmt.Printf("偶数: %v\n", evenNumbers)

    longWords := Filter(words, func(s string) bool { return len(s) > 5 })
    fmt.Printf("长单词: %v\n", longWords)

    // 映射示例
    fmt.Println("\n=== 映射示例 ===")
    squares := Map(numbers, func(n int) int { return n * n })
    fmt.Printf("平方: %v\n", squares)

    upperWords := Map(words, func(s string) string { return strings.ToUpper(s) })
    fmt.Printf("大写单词: %v\n", upperWords)

    wordLengths := Map(words, func(s string) int { return len(s) })
    fmt.Printf("单词长度: %v\n", wordLengths)

    // 归约示例
    fmt.Println("\n=== 归约示例 ===")
    sum := Reduce(numbers, 0, func(acc, n int) int { return acc + n })
    fmt.Printf("数字总和: %d\n", sum)

    product := Reduce(numbers, 1, func(acc, n int) int { return acc * n })
    fmt.Printf("数字乘积: %d\n", product)

    concatenated := Reduce(words, "", func(acc, s string) string {
        if acc == "" {
            return s
        }
        return acc + " " + s
    })
    fmt.Printf("连接字符串: %s\n", concatenated)

    // 函数组合示例
    fmt.Println("\n=== 函数组合示例 ===")
    double := func(x int) int { return x * 2 }
    addOne := func(x int) int { return x + 1 }

    doubleAndAddOne := Compose(addOne, double)
    fmt.Printf("5 * 2 + 1 = %d\n", doubleAndAddOne(5))

    // 柯里化示例
    fmt.Println("\n=== 柯里化示例 ===")
    add5 := Add(5)
    multiply3 := Multiply(3)

    fmt.Printf("5 + 10 = %d\n", add5(10))
    fmt.Printf("3 * 7 = %d\n", multiply3(7))

    // 部分应用示例
    fmt.Println("\n=== 部分应用示例 ===")
    power := func(base, exp int) int {
        result := 1
        for i := 0; i < exp; i++ {
            result *= base
        }
        return result
    }

    square := PartialApply(power, 2)
    cube := PartialApply(func(exp, base int) int { return power(base, exp) }, 3)

    fmt.Printf("2^8 = %d\n", square(8))
    fmt.Printf("5^3 = %d\n", cube(5))

    // 复杂的函数式编程示例
    fmt.Println("\n=== 复杂示例:处理学生数据 ===")

    type Student struct {
        Name  string
        Age   int
        Grade float64
    }

    students := []Student{
        {"Alice", 20, 85.5},
        {"Bob", 19, 92.0},
        {"Charlie", 21, 78.5},
        {"Diana", 20, 96.5},
        {"Eve", 18, 88.0},
    }

    // 链式操作:过滤 -> 映射 -> 排序
    excellentStudents := Filter(students, func(s Student) bool {
        return s.Grade >= 90
    })

    excellentNames := Map(excellentStudents, func(s Student) string {
        return s.Name
    })

    sort.Strings(excellentNames)

    fmt.Printf("优秀学生 (成绩 >= 90): %v\n", excellentNames)

    // 计算平均成绩
    totalGrade := Reduce(students, 0.0, func(acc float64, s Student) float64 {
        return acc + s.Grade
    })
    averageGrade := totalGrade / float64(len(students))

    fmt.Printf("平均成绩: %.2f\n", averageGrade)
}

小结 #

本节详细介绍了 Go 语言函数的高级特性,主要内容包括:

递归函数 #

  • 基本递归:阶乘、斐波那契、最大公约数
  • 复杂递归:树遍历、汉诺塔、快速排序
  • 递归优化:尾递归、记忆化

defer 语句 #

  • 基本用法:延迟执行、LIFO 顺序
  • 资源管理:文件关闭、锁释放、连接断开
  • 参数求值:defer 时确定参数值
  • 返回值修改:通过命名返回值

panic 和 recover #

  • 错误恢复:捕获运行时恐慌
  • 资源保护:数组越界、除零、类型断言
  • 错误传播:自定义错误类型、错误链
  • 调用栈信息:运行时调试信息

函数式编程 #

  • 高阶函数:函数作为参数和返回值
  • 泛型支持:类型安全的函数式操作
  • 函数组合:组合简单函数构建复杂逻辑
  • 柯里化和部分应用:函数参数的灵活处理

最佳实践 #

  • 合理使用递归,注意栈溢出
  • 用 defer 确保资源清理
  • 谨慎使用 panic,优先返回错误
  • 利用函数式编程提高代码可读性

掌握这些高级特性能够让你编写更加健壮、优雅和高效的 Go 代码。


练习题:

  1. 实现一个递归的目录遍历函数,统计文件数量和总大小
  2. 编写一个带有完整错误处理的文件操作库
  3. 创建一个函数式编程工具库,实现常用的高阶函数
  4. 实现一个带有 panic 恢复机制的 Web 服务器中间件
  5. 编写一个递归下降解析器,解析简单的数学表达式