1.9.2 泛型函数与类型

1.9.2 泛型函数与类型 #

在掌握了泛型的基础概念后,本节将深入探讨如何定义和使用泛型函数、泛型类型以及泛型方法。我们将学习类型推断机制,并通过实际示例来理解泛型在不同场景下的应用。

泛型函数的定义与使用 #

基本语法 #

泛型函数的定义语法如下:

func FunctionName[TypeParam TypeConstraint](parameters) ReturnType {
    // 函数体
}

让我们从简单的例子开始:

package main

import "fmt"

// 基础泛型函数:交换两个值
func Swap[T any](a, b T) (T, T) {
    return b, a
}

// 泛型函数:获取切片第一个元素
func First[T any](slice []T) (T, bool) {
    if len(slice) == 0 {
        var zero T
        return zero, false
    }
    return slice[0], true
}

// 泛型函数:获取切片最后一个元素
func Last[T any](slice []T) (T, bool) {
    if len(slice) == 0 {
        var zero T
        return zero, false
    }
    return slice[len(slice)-1], true
}

func main() {
    // 交换整数
    x, y := Swap(10, 20)
    fmt.Printf("Swapped integers: %d, %d\n", x, y)

    // 交换字符串
    s1, s2 := Swap("hello", "world")
    fmt.Printf("Swapped strings: %s, %s\n", s1, s2)

    // 获取切片元素
    numbers := []int{1, 2, 3, 4, 5}
    if first, ok := First(numbers); ok {
        fmt.Printf("First number: %d\n", first)
    }

    if last, ok := Last(numbers); ok {
        fmt.Printf("Last number: %d\n", last)
    }

    // 空切片处理
    var empty []string
    if _, ok := First(empty); !ok {
        fmt.Println("Empty slice has no first element")
    }
}

多类型参数函数 #

泛型函数可以有多个类型参数:

package main

import "fmt"

// 多类型参数的泛型函数
func Transform[T, U any](input T, transformer func(T) U) U {
    return transformer(input)
}

// 键值对结构
type Pair[K, V any] struct {
    Key   K
    Value V
}

// 创建键值对的泛型函数
func MakePair[K, V any](key K, value V) Pair[K, V] {
    return Pair[K, V]{Key: key, Value: value}
}

// 映射函数:将一个类型的切片转换为另一个类型的切片
func MapSlice[T, U any](slice []T, mapper func(T) U) []U {
    result := make([]U, len(slice))
    for i, v := range slice {
        result[i] = mapper(v)
    }
    return result
}

func main() {
    // 使用 Transform 函数
    result1 := Transform(42, func(n int) string {
        return fmt.Sprintf("Number: %d", n)
    })
    fmt.Println(result1) // 输出: Number: 42

    result2 := Transform("hello", func(s string) int {
        return len(s)
    })
    fmt.Println(result2) // 输出: 5

    // 创建不同类型的键值对
    intStringPair := MakePair(1, "one")
    fmt.Printf("Int-String pair: %+v\n", intStringPair)

    stringBoolPair := MakePair("active", true)
    fmt.Printf("String-Bool pair: %+v\n", stringBoolPair)

    // 使用 MapSlice 进行类型转换
    numbers := []int{1, 2, 3, 4, 5}
    strings := MapSlice(numbers, func(n int) string {
        return fmt.Sprintf("item_%d", n)
    })
    fmt.Println("Mapped strings:", strings)

    lengths := MapSlice([]string{"hello", "world", "go"}, func(s string) int {
        return len(s)
    })
    fmt.Println("String lengths:", lengths)
}

带约束的泛型函数 #

使用类型约束可以限制类型参数的范围,并允许在函数中使用特定的操作:

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
)

// 数值类型的加法函数
func Add[T constraints.Ordered](a, b T) T {
    return a + b
}

// 查找最小值
func Min[T constraints.Ordered](values ...T) T {
    if len(values) == 0 {
        var zero T
        return zero
    }

    min := values[0]
    for _, v := range values[1:] {
        if v < min {
            min = v
        }
    }
    return min
}

// 查找最大值
func Max[T constraints.Ordered](values ...T) T {
    if len(values) == 0 {
        var zero T
        return zero
    }

    max := values[0]
    for _, v := range values[1:] {
        if v > max {
            max = v
        }
    }
    return max
}

// 数值范围检查
func InRange[T constraints.Ordered](value, min, max T) bool {
    return value >= min && value <= max
}

// 绝对值函数(仅适用于有符号数值类型)
func Abs[T constraints.Signed](value T) T {
    if value < 0 {
        return -value
    }
    return value
}

func main() {
    // 整数运算
    fmt.Println("Add integers:", Add(10, 20))
    fmt.Println("Min integers:", Min(5, 2, 8, 1, 9))
    fmt.Println("Max integers:", Max(5, 2, 8, 1, 9))

    // 浮点数运算
    fmt.Println("Add floats:", Add(3.14, 2.86))
    fmt.Println("Min floats:", Min(3.14, 2.71, 1.41, 1.73))

    // 字符串比较
    fmt.Println("Min strings:", Min("zebra", "apple", "banana"))
    fmt.Println("Max strings:", Max("zebra", "apple", "banana"))

    // 范围检查
    fmt.Println("10 in range [5, 15]:", InRange(10, 5, 15))
    fmt.Println("20 in range [5, 15]:", InRange(20, 5, 15))

    // 绝对值
    fmt.Println("Abs(-42):", Abs(-42))
    fmt.Println("Abs(3.14):", Abs(-3.14))
}

泛型类型的声明 #

基本泛型类型 #

泛型类型允许我们创建可以处理多种数据类型的结构:

package main

import "fmt"

// 泛型容器类型
type Container[T any] struct {
    Value T
}

// 泛型容器的方法
func (c *Container[T]) Set(value T) {
    c.Value = value
}

func (c *Container[T]) Get() T {
    return c.Value
}

func (c *Container[T]) IsZero() bool {
    var zero T
    return fmt.Sprintf("%v", c.Value) == fmt.Sprintf("%v", zero)
}

// 泛型可选类型(类似于其他语言的 Optional)
type Optional[T any] struct {
    value   T
    present bool
}

// 创建有值的 Optional
func Some[T any](value T) Optional[T] {
    return Optional[T]{value: value, present: true}
}

// 创建空的 Optional
func None[T any]() Optional[T] {
    return Optional[T]{present: false}
}

// Optional 的方法
func (o Optional[T]) IsPresent() bool {
    return o.present
}

func (o Optional[T]) Get() (T, bool) {
    return o.value, o.present
}

func (o Optional[T]) OrElse(defaultValue T) T {
    if o.present {
        return o.value
    }
    return defaultValue
}

func main() {
    // 使用泛型容器
    intContainer := &Container[int]{Value: 42}
    fmt.Println("Int container:", intContainer.Get())

    stringContainer := &Container[string]{}
    stringContainer.Set("Hello, Generics!")
    fmt.Println("String container:", stringContainer.Get())
    fmt.Println("Is zero:", stringContainer.IsZero())

    // 使用泛型可选类型
    someInt := Some(100)
    if value, ok := someInt.Get(); ok {
        fmt.Println("Some int value:", value)
    }

    noneInt := None[int]()
    fmt.Println("None int value:", noneInt.OrElse(-1))

    someString := Some("Hello")
    fmt.Println("Some string value:", someString.OrElse("Default"))

    noneString := None[string]()
    fmt.Println("None string value:", noneString.OrElse("Default"))
}

复杂泛型数据结构 #

让我们实现一些更复杂的泛型数据结构:

package main

import (
    "fmt"
    "strings"
)

// 泛型动态数组
type DynamicArray[T any] struct {
    items    []T
    capacity int
}

// 创建新的动态数组
func NewDynamicArray[T any](initialCapacity int) *DynamicArray[T] {
    return &DynamicArray[T]{
        items:    make([]T, 0, initialCapacity),
        capacity: initialCapacity,
    }
}

// 添加元素
func (da *DynamicArray[T]) Add(item T) {
    da.items = append(da.items, item)
}

// 获取元素
func (da *DynamicArray[T]) Get(index int) (T, bool) {
    if index < 0 || index >= len(da.items) {
        var zero T
        return zero, false
    }
    return da.items[index], true
}

// 获取长度
func (da *DynamicArray[T]) Size() int {
    return len(da.items)
}

// 转换为切片
func (da *DynamicArray[T]) ToSlice() []T {
    result := make([]T, len(da.items))
    copy(result, da.items)
    return result
}

// 泛型映射表
type Map[K comparable, V any] struct {
    data map[K]V
}

// 创建新的映射表
func NewMap[K comparable, V any]() *Map[K, V] {
    return &Map[K, V]{
        data: make(map[K]V),
    }
}

// 设置键值对
func (m *Map[K, V]) Set(key K, value V) {
    m.data[key] = value
}

// 获取值
func (m *Map[K, V]) Get(key K) (V, bool) {
    value, exists := m.data[key]
    return value, exists
}

// 删除键值对
func (m *Map[K, V]) Delete(key K) {
    delete(m.data, key)
}

// 获取所有键
func (m *Map[K, V]) Keys() []K {
    keys := make([]K, 0, len(m.data))
    for k := range m.data {
        keys = append(keys, k)
    }
    return keys
}

// 获取所有值
func (m *Map[K, V]) Values() []V {
    values := make([]V, 0, len(m.data))
    for _, v := range m.data {
        values = append(values, v)
    }
    return values
}

// 大小
func (m *Map[K, V]) Size() int {
    return len(m.data)
}

func main() {
    // 使用泛型动态数组
    intArray := NewDynamicArray[int](5)
    intArray.Add(1)
    intArray.Add(2)
    intArray.Add(3)

    fmt.Printf("Int array size: %d\n", intArray.Size())
    if value, ok := intArray.Get(1); ok {
        fmt.Printf("Element at index 1: %d\n", value)
    }
    fmt.Println("Int array contents:", intArray.ToSlice())

    // 字符串数组
    stringArray := NewDynamicArray[string](3)
    stringArray.Add("hello")
    stringArray.Add("world")
    stringArray.Add("generics")

    fmt.Println("String array contents:", stringArray.ToSlice())

    // 使用泛型映射表
    userMap := NewMap[string, int]()
    userMap.Set("alice", 25)
    userMap.Set("bob", 30)
    userMap.Set("charlie", 35)

    if age, exists := userMap.Get("alice"); exists {
        fmt.Printf("Alice's age: %d\n", age)
    }

    fmt.Println("All users:", userMap.Keys())
    fmt.Println("All ages:", userMap.Values())
    fmt.Printf("Map size: %d\n", userMap.Size())

    // 复杂类型的映射
    type Person struct {
        Name string
        Age  int
    }

    personMap := NewMap[int, Person]()
    personMap.Set(1, Person{Name: "Alice", Age: 25})
    personMap.Set(2, Person{Name: "Bob", Age: 30})

    if person, exists := personMap.Get(1); exists {
        fmt.Printf("Person 1: %+v\n", person)
    }
}

泛型方法的实现 #

结构体的泛型方法 #

泛型类型可以有自己的方法,这些方法可以使用类型参数:

package main

import (
    "fmt"
    "strings"
)

// 泛型结果类型(类似于 Rust 的 Result)
type Result[T, E any] struct {
    value T
    err   E
    isOk  bool
}

// 创建成功结果
func Ok[T, E any](value T) Result[T, E] {
    return Result[T, E]{value: value, isOk: true}
}

// 创建错误结果
func Err[T, E any](err E) Result[T, E] {
    return Result[T, E]{err: err, isOk: false}
}

// Result 的方法
func (r Result[T, E]) IsOk() bool {
    return r.isOk
}

func (r Result[T, E]) IsErr() bool {
    return !r.isOk
}

func (r Result[T, E]) Unwrap() T {
    if !r.isOk {
        panic("called Unwrap on an Err value")
    }
    return r.value
}

func (r Result[T, E]) UnwrapOr(defaultValue T) T {
    if r.isOk {
        return r.value
    }
    return defaultValue
}

func (r Result[T, E]) UnwrapErr() E {
    if r.isOk {
        panic("called UnwrapErr on an Ok value")
    }
    return r.err
}

// Map 方法:转换成功值的类型
func (r Result[T, E]) Map(f func(T) interface{}) Result[interface{}, E] {
    if r.isOk {
        return Ok[interface{}, E](f(r.value))
    }
    return Err[interface{}, E](r.err)
}

// 泛型链表节点
type ListNode[T any] struct {
    Value T
    Next  *ListNode[T]
}

// 泛型链表
type LinkedList[T any] struct {
    head *ListNode[T]
    tail *ListNode[T]
    size int
}

// 创建新链表
func NewLinkedList[T any]() *LinkedList[T] {
    return &LinkedList[T]{}
}

// 添加到头部
func (ll *LinkedList[T]) AddFirst(value T) {
    newNode := &ListNode[T]{Value: value, Next: ll.head}
    ll.head = newNode
    if ll.tail == nil {
        ll.tail = newNode
    }
    ll.size++
}

// 添加到尾部
func (ll *LinkedList[T]) AddLast(value T) {
    newNode := &ListNode[T]{Value: value}
    if ll.tail == nil {
        ll.head = newNode
        ll.tail = newNode
    } else {
        ll.tail.Next = newNode
        ll.tail = newNode
    }
    ll.size++
}

// 删除第一个元素
func (ll *LinkedList[T]) RemoveFirst() (T, bool) {
    if ll.head == nil {
        var zero T
        return zero, false
    }

    value := ll.head.Value
    ll.head = ll.head.Next
    if ll.head == nil {
        ll.tail = nil
    }
    ll.size--
    return value, true
}

// 获取大小
func (ll *LinkedList[T]) Size() int {
    return ll.size
}

// 转换为切片
func (ll *LinkedList[T]) ToSlice() []T {
    result := make([]T, 0, ll.size)
    current := ll.head
    for current != nil {
        result = append(result, current.Value)
        current = current.Next
    }
    return result
}

// 查找元素
func (ll *LinkedList[T]) Find(predicate func(T) bool) (T, bool) {
    current := ll.head
    for current != nil {
        if predicate(current.Value) {
            return current.Value, true
        }
        current = current.Next
    }
    var zero T
    return zero, false
}

// 过滤元素
func (ll *LinkedList[T]) Filter(predicate func(T) bool) *LinkedList[T] {
    newList := NewLinkedList[T]()
    current := ll.head
    for current != nil {
        if predicate(current.Value) {
            newList.AddLast(current.Value)
        }
        current = current.Next
    }
    return newList
}

func main() {
    // 使用 Result 类型
    successResult := Ok[string, error]("Success!")
    fmt.Println("Is success ok?", successResult.IsOk())
    fmt.Println("Success value:", successResult.UnwrapOr("Default"))

    errorResult := Err[string, error](fmt.Errorf("something went wrong"))
    fmt.Println("Is error ok?", errorResult.IsOk())
    fmt.Println("Error value:", errorResult.UnwrapOr("Default"))

    // Map 操作
    mappedResult := successResult.Map(func(s string) interface{} {
        return strings.ToUpper(s)
    })
    if mappedResult.IsOk() {
        fmt.Println("Mapped result:", mappedResult.Unwrap())
    }

    // 使用泛型链表
    intList := NewLinkedList[int]()
    intList.AddLast(1)
    intList.AddLast(2)
    intList.AddLast(3)
    intList.AddFirst(0)

    fmt.Println("Int list:", intList.ToSlice())
    fmt.Println("List size:", intList.Size())

    // 查找元素
    if value, found := intList.Find(func(n int) bool { return n > 2 }); found {
        fmt.Println("Found value > 2:", value)
    }

    // 过滤元素
    evenList := intList.Filter(func(n int) bool { return n%2 == 0 })
    fmt.Println("Even numbers:", evenList.ToSlice())

    // 字符串链表
    stringList := NewLinkedList[string]()
    stringList.AddLast("apple")
    stringList.AddLast("banana")
    stringList.AddLast("cherry")

    fmt.Println("String list:", stringList.ToSlice())

    // 过滤长度大于5的字符串
    longStrings := stringList.Filter(func(s string) bool { return len(s) > 5 })
    fmt.Println("Long strings:", longStrings.ToSlice())
}

类型推断机制 #

Go 编译器具有强大的类型推断能力,可以在很多情况下自动推断类型参数:

package main

import "fmt"

// 泛型函数用于演示类型推断
func Process[T any](value T, processor func(T) T) T {
    return processor(value)
}

func Combine[T any](a, b T, combiner func(T, T) T) T {
    return combiner(a, b)
}

func CreateSlice[T any](values ...T) []T {
    return values
}

func main() {
    // 类型推断:编译器可以从参数推断出 T 是 int
    result1 := Process(42, func(n int) int {
        return n * 2
    })
    fmt.Println("Processed int:", result1)

    // 类型推断:编译器可以从参数推断出 T 是 string
    result2 := Process("hello", func(s string) string {
        return s + " world"
    })
    fmt.Println("Processed string:", result2)

    // 类型推断:从多个参数推断类型
    sum := Combine(10, 20, func(a, b int) int {
        return a + b
    })
    fmt.Println("Combined int:", sum)

    concat := Combine("hello", "world", func(a, b string) string {
        return a + " " + b
    })
    fmt.Println("Combined string:", concat)

    // 类型推断:从可变参数推断类型
    intSlice := CreateSlice(1, 2, 3, 4, 5)
    fmt.Println("Int slice:", intSlice)

    stringSlice := CreateSlice("a", "b", "c")
    fmt.Println("String slice:", stringSlice)

    // 显式指定类型参数(当推断不明确时)
    emptyIntSlice := CreateSlice[int]()
    fmt.Println("Empty int slice:", emptyIntSlice)

    // 混合类型需要显式指定
    interfaceSlice := CreateSlice[interface{}](1, "hello", 3.14, true)
    fmt.Println("Interface slice:", interfaceSlice)
}

类型推断的限制 #

有些情况下,编译器无法推断类型,需要显式指定:

package main

import "fmt"

// 返回类型无法从参数推断的情况
func Convert[From, To any](value From, converter func(From) To) To {
    return converter(value)
}

// 创建零值的函数
func Zero[T any]() T {
    var zero T
    return zero
}

// 类型转换函数
func Cast[T any](value interface{}) (T, bool) {
    if v, ok := value.(T); ok {
        return v, true
    }
    var zero T
    return zero, false
}

func main() {
    // 需要显式指定目标类型
    stringResult := Convert[int, string](42, func(n int) string {
        return fmt.Sprintf("Number: %d", n)
    })
    fmt.Println("Converted to string:", stringResult)

    // 无法从上下文推断,需要显式指定
    zeroInt := Zero[int]()
    zeroString := Zero[string]()
    fmt.Printf("Zero int: %d, Zero string: '%s'\n", zeroInt, zeroString)

    // 类型转换需要显式指定目标类型
    var value interface{} = 42
    if intValue, ok := Cast[int](value); ok {
        fmt.Println("Cast to int:", intValue)
    }

    if stringValue, ok := Cast[string](value); ok {
        fmt.Println("Cast to string:", stringValue)
    } else {
        fmt.Println("Cannot cast to string")
    }
}

实际应用示例 #

让我们通过一个完整的示例来展示泛型函数和类型的实际应用:

package main

import (
    "fmt"
    "sort"
    "strings"
)

// 泛型集合类型
type Set[T comparable] struct {
    items map[T]struct{}
}

// 创建新集合
func NewSet[T comparable]() *Set[T] {
    return &Set[T]{
        items: make(map[T]struct{}),
    }
}

// 添加元素
func (s *Set[T]) Add(item T) {
    s.items[item] = struct{}{}
}

// 删除元素
func (s *Set[T]) Remove(item T) {
    delete(s.items, item)
}

// 检查元素是否存在
func (s *Set[T]) Contains(item T) bool {
    _, exists := s.items[item]
    return exists
}

// 获取大小
func (s *Set[T]) Size() int {
    return len(s.items)
}

// 转换为切片
func (s *Set[T]) ToSlice() []T {
    result := make([]T, 0, len(s.items))
    for item := range s.items {
        result = append(result, item)
    }
    return result
}

// 集合并集
func (s *Set[T]) Union(other *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range s.items {
        result.Add(item)
    }
    for item := range other.items {
        result.Add(item)
    }
    return result
}

// 集合交集
func (s *Set[T]) Intersection(other *Set[T]) *Set[T] {
    result := NewSet[T]()
    for item := range s.items {
        if other.Contains(item) {
            result.Add(item)
        }
    }
    return result
}

// 泛型缓存类型
type Cache[K comparable, V any] struct {
    data     map[K]V
    maxSize  int
    keyOrder []K
}

// 创建新缓存
func NewCache[K comparable, V any](maxSize int) *Cache[K, V] {
    return &Cache[K, V]{
        data:     make(map[K]V),
        maxSize:  maxSize,
        keyOrder: make([]K, 0, maxSize),
    }
}

// 设置缓存项
func (c *Cache[K, V]) Set(key K, value V) {
    if _, exists := c.data[key]; !exists {
        // 新键,检查是否需要淘汰
        if len(c.keyOrder) >= c.maxSize {
            // 淘汰最旧的键
            oldestKey := c.keyOrder[0]
            delete(c.data, oldestKey)
            c.keyOrder = c.keyOrder[1:]
        }
        c.keyOrder = append(c.keyOrder, key)
    }
    c.data[key] = value
}

// 获取缓存项
func (c *Cache[K, V]) Get(key K) (V, bool) {
    value, exists := c.data[key]
    return value, exists
}

// 删除缓存项
func (c *Cache[K, V]) Delete(key K) {
    if _, exists := c.data[key]; exists {
        delete(c.data, key)
        // 从顺序列表中移除
        for i, k := range c.keyOrder {
            if k == key {
                c.keyOrder = append(c.keyOrder[:i], c.keyOrder[i+1:]...)
                break
            }
        }
    }
}

// 获取缓存大小
func (c *Cache[K, V]) Size() int {
    return len(c.data)
}

func main() {
    // 使用泛型集合
    intSet := NewSet[int]()
    intSet.Add(1)
    intSet.Add(2)
    intSet.Add(3)
    intSet.Add(2) // 重复元素

    fmt.Println("Int set:", intSet.ToSlice())
    fmt.Println("Contains 2:", intSet.Contains(2))
    fmt.Println("Set size:", intSet.Size())

    // 另一个集合
    otherIntSet := NewSet[int]()
    otherIntSet.Add(2)
    otherIntSet.Add(3)
    otherIntSet.Add(4)

    // 集合运算
    union := intSet.Union(otherIntSet)
    intersection := intSet.Intersection(otherIntSet)

    unionSlice := union.ToSlice()
    sort.Ints(unionSlice)
    fmt.Println("Union:", unionSlice)

    intersectionSlice := intersection.ToSlice()
    sort.Ints(intersectionSlice)
    fmt.Println("Intersection:", intersectionSlice)

    // 字符串集合
    stringSet := NewSet[string]()
    words := []string{"hello", "world", "go", "generics", "hello"}
    for _, word := range words {
        stringSet.Add(word)
    }

    uniqueWords := stringSet.ToSlice()
    sort.Strings(uniqueWords)
    fmt.Println("Unique words:", uniqueWords)

    // 使用泛型缓存
    cache := NewCache[string, int](3)
    cache.Set("one", 1)
    cache.Set("two", 2)
    cache.Set("three", 3)

    fmt.Printf("Cache size: %d\n", cache.Size())

    if value, exists := cache.Get("two"); exists {
        fmt.Printf("Cached value for 'two': %d\n", value)
    }

    // 添加第四个元素,应该淘汰第一个
    cache.Set("four", 4)
    fmt.Printf("Cache size after adding fourth: %d\n", cache.Size())

    if _, exists := cache.Get("one"); !exists {
        fmt.Println("'one' was evicted from cache")
    }

    // 复杂类型的缓存
    type User struct {
        ID   int
        Name string
    }

    userCache := NewCache[int, User](2)
    userCache.Set(1, User{ID: 1, Name: "Alice"})
    userCache.Set(2, User{ID: 2, Name: "Bob"})

    if user, exists := userCache.Get(1); exists {
        fmt.Printf("Cached user: %+v\n", user)
    }
}

通过本节的学习,您应该已经掌握了:

  1. 泛型函数的定义和使用:包括单类型参数和多类型参数函数
  2. 泛型类型的声明:创建可复用的泛型数据结构
  3. 泛型方法的实现:为泛型类型添加方法
  4. 类型推断机制:理解编译器如何自动推断类型参数
  5. 实际应用:通过完整示例了解泛型在实际项目中的应用

在下一节中,我们将深入学习泛型约束与接口,了解如何通过约束来限制和扩展泛型的功能。