1.3.4 函数高级特性 #
Go 语言的函数具有许多高级特性,包括递归、defer 语句、panic 和 recover 机制等。这些特性使得 Go 语言的函数更加强大和灵活。本节将深入探讨这些高级特性的使用方法和应用场景。
递归函数 #
1. 递归的基本概念 #
递归是函数调用自身的编程技术,适用于解决可以分解为相似子问题的问题:
package main
import "fmt"
// 计算阶乘
func factorial(n int) int {
// 基础情况
if n <= 1 {
return 1
}
// 递归情况
return n * factorial(n-1)
}
// 计算斐波那契数列
func fibonacci(n int) int {
if n <= 1 {
return n
}
return fibonacci(n-1) + fibonacci(n-2)
}
// 计算最大公约数(欧几里得算法)
func gcd(a, b int) int {
if b == 0 {
return a
}
return gcd(b, a%b)
}
// 数字求和
func sumDigits(n int) int {
if n < 10 {
return n
}
return n%10 + sumDigits(n/10)
}
func main() {
// 测试阶乘
fmt.Println("=== 阶乘计算 ===")
for i := 0; i <= 10; i++ {
fmt.Printf("%d! = %d\n", i, factorial(i))
}
// 测试斐波那契数列
fmt.Println("\n=== 斐波那契数列 ===")
for i := 0; i <= 10; i++ {
fmt.Printf("fib(%d) = %d\n", i, fibonacci(i))
}
// 测试最大公约数
fmt.Println("\n=== 最大公约数 ===")
pairs := [][2]int{{48, 18}, {100, 25}, {17, 13}}
for _, pair := range pairs {
fmt.Printf("gcd(%d, %d) = %d\n", pair[0], pair[1], gcd(pair[0], pair[1]))
}
// 测试数字求和
fmt.Println("\n=== 数字各位求和 ===")
numbers := []int{123, 456, 789, 1234}
for _, num := range numbers {
fmt.Printf("sumDigits(%d) = %d\n", num, sumDigits(num))
}
}
2. 复杂递归应用 #
package main
import "fmt"
// 二叉树节点
type TreeNode struct {
Val int
Left *TreeNode
Right *TreeNode
}
// 树的深度优先遍历
func inorderTraversal(root *TreeNode) []int {
if root == nil {
return []int{}
}
result := []int{}
result = append(result, inorderTraversal(root.Left)...)
result = append(result, root.Val)
result = append(result, inorderTraversal(root.Right)...)
return result
}
// 计算树的最大深度
func maxDepth(root *TreeNode) int {
if root == nil {
return 0
}
leftDepth := maxDepth(root.Left)
rightDepth := maxDepth(root.Right)
if leftDepth > rightDepth {
return leftDepth + 1
}
return rightDepth + 1
}
// 汉诺塔问题
func hanoi(n int, from, to, aux string) {
if n == 1 {
fmt.Printf("移动盘子从 %s 到 %s\n", from, to)
return
}
hanoi(n-1, from, aux, to)
fmt.Printf("移动盘子从 %s 到 %s\n", from, to)
hanoi(n-1, aux, to, from)
}
// 快速排序
func quickSort(arr []int, low, high int) {
if low < high {
pi := partition(arr, low, high)
quickSort(arr, low, pi-1)
quickSort(arr, pi+1, high)
}
}
func partition(arr []int, low, high int) int {
pivot := arr[high]
i := low - 1
for j := low; j < high; j++ {
if arr[j] < pivot {
i++
arr[i], arr[j] = arr[j], arr[i]
}
}
arr[i+1], arr[high] = arr[high], arr[i+1]
return i + 1
}
func main() {
// 创建二叉树
root := &TreeNode{
Val: 1,
Left: &TreeNode{
Val: 2,
Left: &TreeNode{Val: 4},
Right: &TreeNode{Val: 5},
},
Right: &TreeNode{
Val: 3,
Right: &TreeNode{Val: 6},
},
}
// 测试树遍历
fmt.Println("=== 二叉树中序遍历 ===")
result := inorderTraversal(root)
fmt.Printf("遍历结果: %v\n", result)
// 测试树深度
fmt.Printf("树的最大深度: %d\n", maxDepth(root))
// 汉诺塔问题
fmt.Println("\n=== 汉诺塔问题 (3个盘子) ===")
hanoi(3, "A", "C", "B")
// 快速排序
fmt.Println("\n=== 快速排序 ===")
arr := []int{64, 34, 25, 12, 22, 11, 90}
fmt.Printf("排序前: %v\n", arr)
quickSort(arr, 0, len(arr)-1)
fmt.Printf("排序后: %v\n", arr)
}
defer 语句 #
1. defer 的基本使用 #
defer
语句用于延迟函数调用,直到包含它的函数返回时才执行:
package main
import (
"fmt"
"os"
"time"
)
// 基本 defer 使用
func basicDefer() {
fmt.Println("函数开始")
defer fmt.Println("defer 1")
defer fmt.Println("defer 2")
defer fmt.Println("defer 3")
fmt.Println("函数中间")
fmt.Println("函数结束")
}
// defer 与资源管理
func fileOperation() error {
file, err := os.Create("test.txt")
if err != nil {
return err
}
defer file.Close() // 确保文件被关闭
defer fmt.Println("文件操作完成")
// 写入文件
_, err = file.WriteString("Hello, World!")
if err != nil {
return err
}
fmt.Println("文件写入成功")
return nil
}
// defer 与函数参数求值
func deferWithArgs() {
x := 10
defer fmt.Printf("defer 中的 x: %d\n", x) // x 的值在 defer 时确定
x = 20
fmt.Printf("函数中的 x: %d\n", x)
}
// defer 与返回值
func deferWithReturn() (result int) {
defer func() {
result *= 2 // 修改命名返回值
fmt.Printf("defer 中修改返回值: %d\n", result)
}()
result = 10
fmt.Printf("函数中的返回值: %d\n", result)
return result
}
// 计时器示例
func timeTracker(name string) func() {
start := time.Now()
fmt.Printf("%s 开始执行\n", name)
return func() {
fmt.Printf("%s 执行完成,耗时: %v\n", name, time.Since(start))
}
}
func expensiveOperation() {
defer timeTracker("昂贵操作")()
// 模拟耗时操作
time.Sleep(100 * time.Millisecond)
fmt.Println("执行复杂计算...")
}
func main() {
// 基本 defer 使用
fmt.Println("=== 基本 defer 使用 ===")
basicDefer()
// 文件操作
fmt.Println("\n=== defer 与资源管理 ===")
if err := fileOperation(); err != nil {
fmt.Printf("文件操作错误: %v\n", err)
}
// defer 参数求值
fmt.Println("\n=== defer 参数求值 ===")
deferWithArgs()
// defer 与返回值
fmt.Println("\n=== defer 与返回值 ===")
result := deferWithReturn()
fmt.Printf("最终返回值: %d\n", result)
// 计时器示例
fmt.Println("\n=== 计时器示例 ===")
expensiveOperation()
}
2. defer 的高级应用 #
package main
import (
"fmt"
"sync"
"time"
)
// 互斥锁管理
type SafeCounter struct {
mu sync.Mutex
count int
}
func (c *SafeCounter) Increment() {
c.mu.Lock()
defer c.mu.Unlock() // 确保锁被释放
c.count++
fmt.Printf("计数器增加到: %d\n", c.count)
}
func (c *SafeCounter) Get() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.count
}
// 数据库连接模拟
type Database struct {
connected bool
}
func (db *Database) Connect() error {
fmt.Println("连接数据库...")
db.connected = true
return nil
}
func (db *Database) Disconnect() {
if db.connected {
fmt.Println("断开数据库连接")
db.connected = false
}
}
func (db *Database) Query(sql string) error {
if !db.connected {
return fmt.Errorf("数据库未连接")
}
fmt.Printf("执行查询: %s\n", sql)
return nil
}
func databaseOperation() error {
db := &Database{}
if err := db.Connect(); err != nil {
return err
}
defer db.Disconnect() // 确保连接被关闭
// 执行多个查询
queries := []string{
"SELECT * FROM users",
"SELECT * FROM orders",
"SELECT * FROM products",
}
for _, query := range queries {
if err := db.Query(query); err != nil {
return err
}
}
return nil
}
// 栈结构实现
type Stack struct {
items []int
}
func (s *Stack) Push(item int) {
s.items = append(s.items, item)
}
func (s *Stack) Pop() (int, bool) {
if len(s.items) == 0 {
return 0, false
}
index := len(s.items) - 1
item := s.items[index]
s.items = s.items[:index]
return item, true
}
func (s *Stack) Size() int {
return len(s.items)
}
// 使用 defer 实现自动清理
func stackOperation() {
stack := &Stack{}
// 使用 defer 确保栈被清理
defer func() {
fmt.Printf("清理栈,最终大小: %d\n", stack.Size())
for stack.Size() > 0 {
if item, ok := stack.Pop(); ok {
fmt.Printf("清理项目: %d\n", item)
}
}
}()
// 向栈中添加元素
for i := 1; i <= 5; i++ {
stack.Push(i)
fmt.Printf("推入: %d, 栈大小: %d\n", i, stack.Size())
}
// 弹出一些元素
for i := 0; i < 2; i++ {
if item, ok := stack.Pop(); ok {
fmt.Printf("弹出: %d, 栈大小: %d\n", item, stack.Size())
}
}
}
func main() {
// 互斥锁管理
fmt.Println("=== 互斥锁管理 ===")
counter := &SafeCounter{}
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
counter.Increment()
}()
}
wg.Wait()
fmt.Printf("最终计数: %d\n", counter.Get())
// 数据库操作
fmt.Println("\n=== 数据库操作 ===")
if err := databaseOperation(); err != nil {
fmt.Printf("数据库操作错误: %v\n", err)
}
// 栈操作
fmt.Println("\n=== 栈操作 ===")
stackOperation()
}
panic 和 recover #
1. panic 和 recover 基础 #
panic
用于引发运行时恐慌,recover
用于捕获和处理恐慌:
package main
import "fmt"
// 基本 panic 和 recover
func basicPanicRecover() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("捕获到 panic: %v\n", r)
}
}()
fmt.Println("正常执行")
panic("发生了错误!")
fmt.Println("这行不会执行")
}
// 数组越界保护
func safeArrayAccess(arr []int, index int) (value int, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("数组访问错误: %v", r)
}
}()
value = arr[index] // 可能引发 panic
return value, nil
}
// 除零保护
func safeDivide(a, b float64) (result float64, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("除法错误: %v", r)
}
}()
if b == 0 {
panic("除数不能为零")
}
result = a / b
return result, nil
}
// 类型断言保护
func safeTypeAssertion(value interface{}) (str string, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("类型断言错误: %v", r)
}
}()
str = value.(string) // 可能引发 panic
return str, nil
}
func main() {
// 基本 panic 和 recover
fmt.Println("=== 基本 panic 和 recover ===")
basicPanicRecover()
fmt.Println("程序继续执行")
// 数组越界保护
fmt.Println("\n=== 数组越界保护 ===")
arr := []int{1, 2, 3, 4, 5}
if value, err := safeArrayAccess(arr, 2); err != nil {
fmt.Printf("错误: %v\n", err)
} else {
fmt.Printf("arr[2] = %d\n", value)
}
if value, err := safeArrayAccess(arr, 10); err != nil {
fmt.Printf("错误: %v\n", err)
} else {
fmt.Printf("arr[10] = %d\n", value)
}
// 除零保护
fmt.Println("\n=== 除零保护 ===")
if result, err := safeDivide(10, 2); err != nil {
fmt.Printf("错误: %v\n", err)
} else {
fmt.Printf("10 / 2 = %.2f\n", result)
}
if result, err := safeDivide(10, 0); err != nil {
fmt.Printf("错误: %v\n", err)
} else {
fmt.Printf("10 / 0 = %.2f\n", result)
}
// 类型断言保护
fmt.Println("\n=== 类型断言保护 ===")
values := []interface{}{"hello", 42, true, "world"}
for i, value := range values {
if str, err := safeTypeAssertion(value); err != nil {
fmt.Printf("值 %d 错误: %v\n", i, err)
} else {
fmt.Printf("值 %d 是字符串: %s\n", i, str)
}
}
}
2. 高级错误处理模式 #
package main
import (
"fmt"
"log"
"runtime"
)
// 错误类型定义
type CustomError struct {
Code int
Message string
Cause error
}
func (e *CustomError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("错误 %d: %s (原因: %v)", e.Code, e.Message, e.Cause)
}
return fmt.Sprintf("错误 %d: %s", e.Code, e.Message)
}
// 带错误恢复的函数
func robustFunction(input int) (result int, err error) {
defer func() {
if r := recover(); r != nil {
// 获取调用栈信息
buf := make([]byte, 1024)
n := runtime.Stack(buf, false)
err = &CustomError{
Code: 500,
Message: fmt.Sprintf("函数执行失败: %v", r),
Cause: fmt.Errorf("调用栈: %s", buf[:n]),
}
}
}()
if input < 0 {
panic("输入不能为负数")
}
if input == 0 {
panic("输入不能为零")
}
result = 100 / input
return result, nil
}
// 链式错误处理
func processData(data []int) error {
defer func() {
if r := recover(); r != nil {
log.Printf("数据处理发生 panic: %v", r)
}
}()
for i, value := range data {
if err := validateData(value, i); err != nil {
return fmt.Errorf("数据验证失败在位置 %d: %w", i, err)
}
if err := transformData(value); err != nil {
return fmt.Errorf("数据转换失败在位置 %d: %w", i, err)
}
}
return nil
}
func validateData(value, index int) error {
if value < 0 {
return &CustomError{
Code: 400,
Message: fmt.Sprintf("负数值不允许: %d", value),
}
}
if value > 1000 {
return &CustomError{
Code: 400,
Message: fmt.Sprintf("值太大: %d", value),
}
}
return nil
}
func transformData(value int) error {
defer func() {
if r := recover(); r != nil {
panic(fmt.Sprintf("数据转换失败: %v", r))
}
}()
if value == 13 { // 模拟特殊情况
panic("不吉利的数字")
}
// 模拟转换过程
_ = value * 2
return nil
}
// 资源管理与错误处理
type Resource struct {
name string
acquired bool
}
func (r *Resource) Acquire() error {
if r.acquired {
return fmt.Errorf("资源 %s 已被获取", r.name)
}
fmt.Printf("获取资源: %s\n", r.name)
r.acquired = true
return nil
}
func (r *Resource) Release() {
if r.acquired {
fmt.Printf("释放资源: %s\n", r.name)
r.acquired = false
}
}
func useResource(resourceName string) (err error) {
resource := &Resource{name: resourceName}
defer func() {
resource.Release()
if r := recover(); r != nil {
err = fmt.Errorf("使用资源 %s 时发生 panic: %v", resourceName, r)
}
}()
if err := resource.Acquire(); err != nil {
return err
}
// 模拟使用资源
fmt.Printf("使用资源: %s\n", resourceName)
if resourceName == "危险资源" {
panic("资源使用失败")
}
return nil
}
func main() {
// 带错误恢复的函数测试
fmt.Println("=== 带错误恢复的函数 ===")
testInputs := []int{10, 5, 0, -1}
for _, input := range testInputs {
if result, err := robustFunction(input); err != nil {
fmt.Printf("输入 %d 错误: %v\n", input, err)
} else {
fmt.Printf("输入 %d 结果: %d\n", input, result)
}
}
// 链式错误处理测试
fmt.Println("\n=== 链式错误处理 ===")
testData := [][]int{
{1, 2, 3, 4, 5},
{10, 20, -5, 30},
{100, 200, 13, 400},
{1, 2000, 3},
}
for i, data := range testData {
fmt.Printf("处理数据集 %d: %v\n", i+1, data)
if err := processData(data); err != nil {
fmt.Printf(" 错误: %v\n", err)
} else {
fmt.Printf(" 处理成功\n")
}
}
// 资源管理测试
fmt.Println("\n=== 资源管理与错误处理 ===")
resources := []string{"正常资源", "危险资源", "另一个资源"}
for _, resourceName := range resources {
if err := useResource(resourceName); err != nil {
fmt.Printf("错误: %v\n", err)
} else {
fmt.Printf("成功使用资源: %s\n", resourceName)
}
}
}
函数式编程特性 #
1. 高阶函数 #
package main
import (
"fmt"
"sort"
"strings"
)
// 函数类型定义
type Predicate[T any] func(T) bool
type Mapper[T, R any] func(T) R
type Reducer[T, R any] func(R, T) R
// 泛型过滤函数
func Filter[T any](slice []T, predicate Predicate[T]) []T {
result := make([]T, 0)
for _, item := range slice {
if predicate(item) {
result = append(result, item)
}
}
return result
}
// 泛型映射函数
func Map[T, R any](slice []T, mapper Mapper[T, R]) []R {
result := make([]R, len(slice))
for i, item := range slice {
result[i] = mapper(item)
}
return result
}
// 泛型归约函数
func Reduce[T, R any](slice []T, initial R, reducer Reducer[T, R]) R {
result := initial
for _, item := range slice {
result = reducer(result, item)
}
return result
}
// 函数组合
func Compose[T, U, V any](f func(U) V, g func(T) U) func(T) V {
return func(x T) V {
return f(g(x))
}
}
// 柯里化示例
func Add(a int) func(int) int {
return func(b int) int {
return a + b
}
}
func Multiply(a int) func(int) int {
return func(b int) int {
return a * b
}
}
// 部分应用
func PartialApply[T, U, V any](f func(T, U) V, arg T) func(U) V {
return func(u U) V {
return f(arg, u)
}
}
func main() {
// 测试数据
numbers := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
words := []string{"hello", "world", "go", "programming", "functional"}
// 过滤示例
fmt.Println("=== 过滤示例 ===")
evenNumbers := Filter(numbers, func(n int) bool { return n%2 == 0 })
fmt.Printf("偶数: %v\n", evenNumbers)
longWords := Filter(words, func(s string) bool { return len(s) > 5 })
fmt.Printf("长单词: %v\n", longWords)
// 映射示例
fmt.Println("\n=== 映射示例 ===")
squares := Map(numbers, func(n int) int { return n * n })
fmt.Printf("平方: %v\n", squares)
upperWords := Map(words, func(s string) string { return strings.ToUpper(s) })
fmt.Printf("大写单词: %v\n", upperWords)
wordLengths := Map(words, func(s string) int { return len(s) })
fmt.Printf("单词长度: %v\n", wordLengths)
// 归约示例
fmt.Println("\n=== 归约示例 ===")
sum := Reduce(numbers, 0, func(acc, n int) int { return acc + n })
fmt.Printf("数字总和: %d\n", sum)
product := Reduce(numbers, 1, func(acc, n int) int { return acc * n })
fmt.Printf("数字乘积: %d\n", product)
concatenated := Reduce(words, "", func(acc, s string) string {
if acc == "" {
return s
}
return acc + " " + s
})
fmt.Printf("连接字符串: %s\n", concatenated)
// 函数组合示例
fmt.Println("\n=== 函数组合示例 ===")
double := func(x int) int { return x * 2 }
addOne := func(x int) int { return x + 1 }
doubleAndAddOne := Compose(addOne, double)
fmt.Printf("5 * 2 + 1 = %d\n", doubleAndAddOne(5))
// 柯里化示例
fmt.Println("\n=== 柯里化示例 ===")
add5 := Add(5)
multiply3 := Multiply(3)
fmt.Printf("5 + 10 = %d\n", add5(10))
fmt.Printf("3 * 7 = %d\n", multiply3(7))
// 部分应用示例
fmt.Println("\n=== 部分应用示例 ===")
power := func(base, exp int) int {
result := 1
for i := 0; i < exp; i++ {
result *= base
}
return result
}
square := PartialApply(power, 2)
cube := PartialApply(func(exp, base int) int { return power(base, exp) }, 3)
fmt.Printf("2^8 = %d\n", square(8))
fmt.Printf("5^3 = %d\n", cube(5))
// 复杂的函数式编程示例
fmt.Println("\n=== 复杂示例:处理学生数据 ===")
type Student struct {
Name string
Age int
Grade float64
}
students := []Student{
{"Alice", 20, 85.5},
{"Bob", 19, 92.0},
{"Charlie", 21, 78.5},
{"Diana", 20, 96.5},
{"Eve", 18, 88.0},
}
// 链式操作:过滤 -> 映射 -> 排序
excellentStudents := Filter(students, func(s Student) bool {
return s.Grade >= 90
})
excellentNames := Map(excellentStudents, func(s Student) string {
return s.Name
})
sort.Strings(excellentNames)
fmt.Printf("优秀学生 (成绩 >= 90): %v\n", excellentNames)
// 计算平均成绩
totalGrade := Reduce(students, 0.0, func(acc float64, s Student) float64 {
return acc + s.Grade
})
averageGrade := totalGrade / float64(len(students))
fmt.Printf("平均成绩: %.2f\n", averageGrade)
}
小结 #
本节详细介绍了 Go 语言函数的高级特性,主要内容包括:
递归函数 #
- 基本递归:阶乘、斐波那契、最大公约数
- 复杂递归:树遍历、汉诺塔、快速排序
- 递归优化:尾递归、记忆化
defer 语句 #
- 基本用法:延迟执行、LIFO 顺序
- 资源管理:文件关闭、锁释放、连接断开
- 参数求值:defer 时确定参数值
- 返回值修改:通过命名返回值
panic 和 recover #
- 错误恢复:捕获运行时恐慌
- 资源保护:数组越界、除零、类型断言
- 错误传播:自定义错误类型、错误链
- 调用栈信息:运行时调试信息
函数式编程 #
- 高阶函数:函数作为参数和返回值
- 泛型支持:类型安全的函数式操作
- 函数组合:组合简单函数构建复杂逻辑
- 柯里化和部分应用:函数参数的灵活处理
最佳实践 #
- 合理使用递归,注意栈溢出
- 用 defer 确保资源清理
- 谨慎使用 panic,优先返回错误
- 利用函数式编程提高代码可读性
掌握这些高级特性能够让你编写更加健壮、优雅和高效的 Go 代码。
练习题:
- 实现一个递归的目录遍历函数,统计文件数量和总大小
- 编写一个带有完整错误处理的文件操作库
- 创建一个函数式编程工具库,实现常用的高阶函数
- 实现一个带有 panic 恢复机制的 Web 服务器中间件
- 编写一个递归下降解析器,解析简单的数学表达式