1.9.2 泛型函数与类型 #
在掌握了泛型的基础概念后,本节将深入探讨如何定义和使用泛型函数、泛型类型以及泛型方法。我们将学习类型推断机制,并通过实际示例来理解泛型在不同场景下的应用。
泛型函数的定义与使用 #
基本语法 #
泛型函数的定义语法如下:
func FunctionName[TypeParam TypeConstraint](parameters) ReturnType {
// 函数体
}
让我们从简单的例子开始:
package main
import "fmt"
// 基础泛型函数:交换两个值
func Swap[T any](a, b T) (T, T) {
return b, a
}
// 泛型函数:获取切片第一个元素
func First[T any](slice []T) (T, bool) {
if len(slice) == 0 {
var zero T
return zero, false
}
return slice[0], true
}
// 泛型函数:获取切片最后一个元素
func Last[T any](slice []T) (T, bool) {
if len(slice) == 0 {
var zero T
return zero, false
}
return slice[len(slice)-1], true
}
func main() {
// 交换整数
x, y := Swap(10, 20)
fmt.Printf("Swapped integers: %d, %d\n", x, y)
// 交换字符串
s1, s2 := Swap("hello", "world")
fmt.Printf("Swapped strings: %s, %s\n", s1, s2)
// 获取切片元素
numbers := []int{1, 2, 3, 4, 5}
if first, ok := First(numbers); ok {
fmt.Printf("First number: %d\n", first)
}
if last, ok := Last(numbers); ok {
fmt.Printf("Last number: %d\n", last)
}
// 空切片处理
var empty []string
if _, ok := First(empty); !ok {
fmt.Println("Empty slice has no first element")
}
}
多类型参数函数 #
泛型函数可以有多个类型参数:
package main
import "fmt"
// 多类型参数的泛型函数
func Transform[T, U any](input T, transformer func(T) U) U {
return transformer(input)
}
// 键值对结构
type Pair[K, V any] struct {
Key K
Value V
}
// 创建键值对的泛型函数
func MakePair[K, V any](key K, value V) Pair[K, V] {
return Pair[K, V]{Key: key, Value: value}
}
// 映射函数:将一个类型的切片转换为另一个类型的切片
func MapSlice[T, U any](slice []T, mapper func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = mapper(v)
}
return result
}
func main() {
// 使用 Transform 函数
result1 := Transform(42, func(n int) string {
return fmt.Sprintf("Number: %d", n)
})
fmt.Println(result1) // 输出: Number: 42
result2 := Transform("hello", func(s string) int {
return len(s)
})
fmt.Println(result2) // 输出: 5
// 创建不同类型的键值对
intStringPair := MakePair(1, "one")
fmt.Printf("Int-String pair: %+v\n", intStringPair)
stringBoolPair := MakePair("active", true)
fmt.Printf("String-Bool pair: %+v\n", stringBoolPair)
// 使用 MapSlice 进行类型转换
numbers := []int{1, 2, 3, 4, 5}
strings := MapSlice(numbers, func(n int) string {
return fmt.Sprintf("item_%d", n)
})
fmt.Println("Mapped strings:", strings)
lengths := MapSlice([]string{"hello", "world", "go"}, func(s string) int {
return len(s)
})
fmt.Println("String lengths:", lengths)
}
带约束的泛型函数 #
使用类型约束可以限制类型参数的范围,并允许在函数中使用特定的操作:
package main
import (
"fmt"
"golang.org/x/exp/constraints"
)
// 数值类型的加法函数
func Add[T constraints.Ordered](a, b T) T {
return a + b
}
// 查找最小值
func Min[T constraints.Ordered](values ...T) T {
if len(values) == 0 {
var zero T
return zero
}
min := values[0]
for _, v := range values[1:] {
if v < min {
min = v
}
}
return min
}
// 查找最大值
func Max[T constraints.Ordered](values ...T) T {
if len(values) == 0 {
var zero T
return zero
}
max := values[0]
for _, v := range values[1:] {
if v > max {
max = v
}
}
return max
}
// 数值范围检查
func InRange[T constraints.Ordered](value, min, max T) bool {
return value >= min && value <= max
}
// 绝对值函数(仅适用于有符号数值类型)
func Abs[T constraints.Signed](value T) T {
if value < 0 {
return -value
}
return value
}
func main() {
// 整数运算
fmt.Println("Add integers:", Add(10, 20))
fmt.Println("Min integers:", Min(5, 2, 8, 1, 9))
fmt.Println("Max integers:", Max(5, 2, 8, 1, 9))
// 浮点数运算
fmt.Println("Add floats:", Add(3.14, 2.86))
fmt.Println("Min floats:", Min(3.14, 2.71, 1.41, 1.73))
// 字符串比较
fmt.Println("Min strings:", Min("zebra", "apple", "banana"))
fmt.Println("Max strings:", Max("zebra", "apple", "banana"))
// 范围检查
fmt.Println("10 in range [5, 15]:", InRange(10, 5, 15))
fmt.Println("20 in range [5, 15]:", InRange(20, 5, 15))
// 绝对值
fmt.Println("Abs(-42):", Abs(-42))
fmt.Println("Abs(3.14):", Abs(-3.14))
}
泛型类型的声明 #
基本泛型类型 #
泛型类型允许我们创建可以处理多种数据类型的结构:
package main
import "fmt"
// 泛型容器类型
type Container[T any] struct {
Value T
}
// 泛型容器的方法
func (c *Container[T]) Set(value T) {
c.Value = value
}
func (c *Container[T]) Get() T {
return c.Value
}
func (c *Container[T]) IsZero() bool {
var zero T
return fmt.Sprintf("%v", c.Value) == fmt.Sprintf("%v", zero)
}
// 泛型可选类型(类似于其他语言的 Optional)
type Optional[T any] struct {
value T
present bool
}
// 创建有值的 Optional
func Some[T any](value T) Optional[T] {
return Optional[T]{value: value, present: true}
}
// 创建空的 Optional
func None[T any]() Optional[T] {
return Optional[T]{present: false}
}
// Optional 的方法
func (o Optional[T]) IsPresent() bool {
return o.present
}
func (o Optional[T]) Get() (T, bool) {
return o.value, o.present
}
func (o Optional[T]) OrElse(defaultValue T) T {
if o.present {
return o.value
}
return defaultValue
}
func main() {
// 使用泛型容器
intContainer := &Container[int]{Value: 42}
fmt.Println("Int container:", intContainer.Get())
stringContainer := &Container[string]{}
stringContainer.Set("Hello, Generics!")
fmt.Println("String container:", stringContainer.Get())
fmt.Println("Is zero:", stringContainer.IsZero())
// 使用泛型可选类型
someInt := Some(100)
if value, ok := someInt.Get(); ok {
fmt.Println("Some int value:", value)
}
noneInt := None[int]()
fmt.Println("None int value:", noneInt.OrElse(-1))
someString := Some("Hello")
fmt.Println("Some string value:", someString.OrElse("Default"))
noneString := None[string]()
fmt.Println("None string value:", noneString.OrElse("Default"))
}
复杂泛型数据结构 #
让我们实现一些更复杂的泛型数据结构:
package main
import (
"fmt"
"strings"
)
// 泛型动态数组
type DynamicArray[T any] struct {
items []T
capacity int
}
// 创建新的动态数组
func NewDynamicArray[T any](initialCapacity int) *DynamicArray[T] {
return &DynamicArray[T]{
items: make([]T, 0, initialCapacity),
capacity: initialCapacity,
}
}
// 添加元素
func (da *DynamicArray[T]) Add(item T) {
da.items = append(da.items, item)
}
// 获取元素
func (da *DynamicArray[T]) Get(index int) (T, bool) {
if index < 0 || index >= len(da.items) {
var zero T
return zero, false
}
return da.items[index], true
}
// 获取长度
func (da *DynamicArray[T]) Size() int {
return len(da.items)
}
// 转换为切片
func (da *DynamicArray[T]) ToSlice() []T {
result := make([]T, len(da.items))
copy(result, da.items)
return result
}
// 泛型映射表
type Map[K comparable, V any] struct {
data map[K]V
}
// 创建新的映射表
func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
data: make(map[K]V),
}
}
// 设置键值对
func (m *Map[K, V]) Set(key K, value V) {
m.data[key] = value
}
// 获取值
func (m *Map[K, V]) Get(key K) (V, bool) {
value, exists := m.data[key]
return value, exists
}
// 删除键值对
func (m *Map[K, V]) Delete(key K) {
delete(m.data, key)
}
// 获取所有键
func (m *Map[K, V]) Keys() []K {
keys := make([]K, 0, len(m.data))
for k := range m.data {
keys = append(keys, k)
}
return keys
}
// 获取所有值
func (m *Map[K, V]) Values() []V {
values := make([]V, 0, len(m.data))
for _, v := range m.data {
values = append(values, v)
}
return values
}
// 大小
func (m *Map[K, V]) Size() int {
return len(m.data)
}
func main() {
// 使用泛型动态数组
intArray := NewDynamicArray[int](5)
intArray.Add(1)
intArray.Add(2)
intArray.Add(3)
fmt.Printf("Int array size: %d\n", intArray.Size())
if value, ok := intArray.Get(1); ok {
fmt.Printf("Element at index 1: %d\n", value)
}
fmt.Println("Int array contents:", intArray.ToSlice())
// 字符串数组
stringArray := NewDynamicArray[string](3)
stringArray.Add("hello")
stringArray.Add("world")
stringArray.Add("generics")
fmt.Println("String array contents:", stringArray.ToSlice())
// 使用泛型映射表
userMap := NewMap[string, int]()
userMap.Set("alice", 25)
userMap.Set("bob", 30)
userMap.Set("charlie", 35)
if age, exists := userMap.Get("alice"); exists {
fmt.Printf("Alice's age: %d\n", age)
}
fmt.Println("All users:", userMap.Keys())
fmt.Println("All ages:", userMap.Values())
fmt.Printf("Map size: %d\n", userMap.Size())
// 复杂类型的映射
type Person struct {
Name string
Age int
}
personMap := NewMap[int, Person]()
personMap.Set(1, Person{Name: "Alice", Age: 25})
personMap.Set(2, Person{Name: "Bob", Age: 30})
if person, exists := personMap.Get(1); exists {
fmt.Printf("Person 1: %+v\n", person)
}
}
泛型方法的实现 #
结构体的泛型方法 #
泛型类型可以有自己的方法,这些方法可以使用类型参数:
package main
import (
"fmt"
"strings"
)
// 泛型结果类型(类似于 Rust 的 Result)
type Result[T, E any] struct {
value T
err E
isOk bool
}
// 创建成功结果
func Ok[T, E any](value T) Result[T, E] {
return Result[T, E]{value: value, isOk: true}
}
// 创建错误结果
func Err[T, E any](err E) Result[T, E] {
return Result[T, E]{err: err, isOk: false}
}
// Result 的方法
func (r Result[T, E]) IsOk() bool {
return r.isOk
}
func (r Result[T, E]) IsErr() bool {
return !r.isOk
}
func (r Result[T, E]) Unwrap() T {
if !r.isOk {
panic("called Unwrap on an Err value")
}
return r.value
}
func (r Result[T, E]) UnwrapOr(defaultValue T) T {
if r.isOk {
return r.value
}
return defaultValue
}
func (r Result[T, E]) UnwrapErr() E {
if r.isOk {
panic("called UnwrapErr on an Ok value")
}
return r.err
}
// Map 方法:转换成功值的类型
func (r Result[T, E]) Map(f func(T) interface{}) Result[interface{}, E] {
if r.isOk {
return Ok[interface{}, E](f(r.value))
}
return Err[interface{}, E](r.err)
}
// 泛型链表节点
type ListNode[T any] struct {
Value T
Next *ListNode[T]
}
// 泛型链表
type LinkedList[T any] struct {
head *ListNode[T]
tail *ListNode[T]
size int
}
// 创建新链表
func NewLinkedList[T any]() *LinkedList[T] {
return &LinkedList[T]{}
}
// 添加到头部
func (ll *LinkedList[T]) AddFirst(value T) {
newNode := &ListNode[T]{Value: value, Next: ll.head}
ll.head = newNode
if ll.tail == nil {
ll.tail = newNode
}
ll.size++
}
// 添加到尾部
func (ll *LinkedList[T]) AddLast(value T) {
newNode := &ListNode[T]{Value: value}
if ll.tail == nil {
ll.head = newNode
ll.tail = newNode
} else {
ll.tail.Next = newNode
ll.tail = newNode
}
ll.size++
}
// 删除第一个元素
func (ll *LinkedList[T]) RemoveFirst() (T, bool) {
if ll.head == nil {
var zero T
return zero, false
}
value := ll.head.Value
ll.head = ll.head.Next
if ll.head == nil {
ll.tail = nil
}
ll.size--
return value, true
}
// 获取大小
func (ll *LinkedList[T]) Size() int {
return ll.size
}
// 转换为切片
func (ll *LinkedList[T]) ToSlice() []T {
result := make([]T, 0, ll.size)
current := ll.head
for current != nil {
result = append(result, current.Value)
current = current.Next
}
return result
}
// 查找元素
func (ll *LinkedList[T]) Find(predicate func(T) bool) (T, bool) {
current := ll.head
for current != nil {
if predicate(current.Value) {
return current.Value, true
}
current = current.Next
}
var zero T
return zero, false
}
// 过滤元素
func (ll *LinkedList[T]) Filter(predicate func(T) bool) *LinkedList[T] {
newList := NewLinkedList[T]()
current := ll.head
for current != nil {
if predicate(current.Value) {
newList.AddLast(current.Value)
}
current = current.Next
}
return newList
}
func main() {
// 使用 Result 类型
successResult := Ok[string, error]("Success!")
fmt.Println("Is success ok?", successResult.IsOk())
fmt.Println("Success value:", successResult.UnwrapOr("Default"))
errorResult := Err[string, error](fmt.Errorf("something went wrong"))
fmt.Println("Is error ok?", errorResult.IsOk())
fmt.Println("Error value:", errorResult.UnwrapOr("Default"))
// Map 操作
mappedResult := successResult.Map(func(s string) interface{} {
return strings.ToUpper(s)
})
if mappedResult.IsOk() {
fmt.Println("Mapped result:", mappedResult.Unwrap())
}
// 使用泛型链表
intList := NewLinkedList[int]()
intList.AddLast(1)
intList.AddLast(2)
intList.AddLast(3)
intList.AddFirst(0)
fmt.Println("Int list:", intList.ToSlice())
fmt.Println("List size:", intList.Size())
// 查找元素
if value, found := intList.Find(func(n int) bool { return n > 2 }); found {
fmt.Println("Found value > 2:", value)
}
// 过滤元素
evenList := intList.Filter(func(n int) bool { return n%2 == 0 })
fmt.Println("Even numbers:", evenList.ToSlice())
// 字符串链表
stringList := NewLinkedList[string]()
stringList.AddLast("apple")
stringList.AddLast("banana")
stringList.AddLast("cherry")
fmt.Println("String list:", stringList.ToSlice())
// 过滤长度大于5的字符串
longStrings := stringList.Filter(func(s string) bool { return len(s) > 5 })
fmt.Println("Long strings:", longStrings.ToSlice())
}
类型推断机制 #
Go 编译器具有强大的类型推断能力,可以在很多情况下自动推断类型参数:
package main
import "fmt"
// 泛型函数用于演示类型推断
func Process[T any](value T, processor func(T) T) T {
return processor(value)
}
func Combine[T any](a, b T, combiner func(T, T) T) T {
return combiner(a, b)
}
func CreateSlice[T any](values ...T) []T {
return values
}
func main() {
// 类型推断:编译器可以从参数推断出 T 是 int
result1 := Process(42, func(n int) int {
return n * 2
})
fmt.Println("Processed int:", result1)
// 类型推断:编译器可以从参数推断出 T 是 string
result2 := Process("hello", func(s string) string {
return s + " world"
})
fmt.Println("Processed string:", result2)
// 类型推断:从多个参数推断类型
sum := Combine(10, 20, func(a, b int) int {
return a + b
})
fmt.Println("Combined int:", sum)
concat := Combine("hello", "world", func(a, b string) string {
return a + " " + b
})
fmt.Println("Combined string:", concat)
// 类型推断:从可变参数推断类型
intSlice := CreateSlice(1, 2, 3, 4, 5)
fmt.Println("Int slice:", intSlice)
stringSlice := CreateSlice("a", "b", "c")
fmt.Println("String slice:", stringSlice)
// 显式指定类型参数(当推断不明确时)
emptyIntSlice := CreateSlice[int]()
fmt.Println("Empty int slice:", emptyIntSlice)
// 混合类型需要显式指定
interfaceSlice := CreateSlice[interface{}](1, "hello", 3.14, true)
fmt.Println("Interface slice:", interfaceSlice)
}
类型推断的限制 #
有些情况下,编译器无法推断类型,需要显式指定:
package main
import "fmt"
// 返回类型无法从参数推断的情况
func Convert[From, To any](value From, converter func(From) To) To {
return converter(value)
}
// 创建零值的函数
func Zero[T any]() T {
var zero T
return zero
}
// 类型转换函数
func Cast[T any](value interface{}) (T, bool) {
if v, ok := value.(T); ok {
return v, true
}
var zero T
return zero, false
}
func main() {
// 需要显式指定目标类型
stringResult := Convert[int, string](42, func(n int) string {
return fmt.Sprintf("Number: %d", n)
})
fmt.Println("Converted to string:", stringResult)
// 无法从上下文推断,需要显式指定
zeroInt := Zero[int]()
zeroString := Zero[string]()
fmt.Printf("Zero int: %d, Zero string: '%s'\n", zeroInt, zeroString)
// 类型转换需要显式指定目标类型
var value interface{} = 42
if intValue, ok := Cast[int](value); ok {
fmt.Println("Cast to int:", intValue)
}
if stringValue, ok := Cast[string](value); ok {
fmt.Println("Cast to string:", stringValue)
} else {
fmt.Println("Cannot cast to string")
}
}
实际应用示例 #
让我们通过一个完整的示例来展示泛型函数和类型的实际应用:
package main
import (
"fmt"
"sort"
"strings"
)
// 泛型集合类型
type Set[T comparable] struct {
items map[T]struct{}
}
// 创建新集合
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
items: make(map[T]struct{}),
}
}
// 添加元素
func (s *Set[T]) Add(item T) {
s.items[item] = struct{}{}
}
// 删除元素
func (s *Set[T]) Remove(item T) {
delete(s.items, item)
}
// 检查元素是否存在
func (s *Set[T]) Contains(item T) bool {
_, exists := s.items[item]
return exists
}
// 获取大小
func (s *Set[T]) Size() int {
return len(s.items)
}
// 转换为切片
func (s *Set[T]) ToSlice() []T {
result := make([]T, 0, len(s.items))
for item := range s.items {
result = append(result, item)
}
return result
}
// 集合并集
func (s *Set[T]) Union(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
result.Add(item)
}
for item := range other.items {
result.Add(item)
}
return result
}
// 集合交集
func (s *Set[T]) Intersection(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
if other.Contains(item) {
result.Add(item)
}
}
return result
}
// 泛型缓存类型
type Cache[K comparable, V any] struct {
data map[K]V
maxSize int
keyOrder []K
}
// 创建新缓存
func NewCache[K comparable, V any](maxSize int) *Cache[K, V] {
return &Cache[K, V]{
data: make(map[K]V),
maxSize: maxSize,
keyOrder: make([]K, 0, maxSize),
}
}
// 设置缓存项
func (c *Cache[K, V]) Set(key K, value V) {
if _, exists := c.data[key]; !exists {
// 新键,检查是否需要淘汰
if len(c.keyOrder) >= c.maxSize {
// 淘汰最旧的键
oldestKey := c.keyOrder[0]
delete(c.data, oldestKey)
c.keyOrder = c.keyOrder[1:]
}
c.keyOrder = append(c.keyOrder, key)
}
c.data[key] = value
}
// 获取缓存项
func (c *Cache[K, V]) Get(key K) (V, bool) {
value, exists := c.data[key]
return value, exists
}
// 删除缓存项
func (c *Cache[K, V]) Delete(key K) {
if _, exists := c.data[key]; exists {
delete(c.data, key)
// 从顺序列表中移除
for i, k := range c.keyOrder {
if k == key {
c.keyOrder = append(c.keyOrder[:i], c.keyOrder[i+1:]...)
break
}
}
}
}
// 获取缓存大小
func (c *Cache[K, V]) Size() int {
return len(c.data)
}
func main() {
// 使用泛型集合
intSet := NewSet[int]()
intSet.Add(1)
intSet.Add(2)
intSet.Add(3)
intSet.Add(2) // 重复元素
fmt.Println("Int set:", intSet.ToSlice())
fmt.Println("Contains 2:", intSet.Contains(2))
fmt.Println("Set size:", intSet.Size())
// 另一个集合
otherIntSet := NewSet[int]()
otherIntSet.Add(2)
otherIntSet.Add(3)
otherIntSet.Add(4)
// 集合运算
union := intSet.Union(otherIntSet)
intersection := intSet.Intersection(otherIntSet)
unionSlice := union.ToSlice()
sort.Ints(unionSlice)
fmt.Println("Union:", unionSlice)
intersectionSlice := intersection.ToSlice()
sort.Ints(intersectionSlice)
fmt.Println("Intersection:", intersectionSlice)
// 字符串集合
stringSet := NewSet[string]()
words := []string{"hello", "world", "go", "generics", "hello"}
for _, word := range words {
stringSet.Add(word)
}
uniqueWords := stringSet.ToSlice()
sort.Strings(uniqueWords)
fmt.Println("Unique words:", uniqueWords)
// 使用泛型缓存
cache := NewCache[string, int](3)
cache.Set("one", 1)
cache.Set("two", 2)
cache.Set("three", 3)
fmt.Printf("Cache size: %d\n", cache.Size())
if value, exists := cache.Get("two"); exists {
fmt.Printf("Cached value for 'two': %d\n", value)
}
// 添加第四个元素,应该淘汰第一个
cache.Set("four", 4)
fmt.Printf("Cache size after adding fourth: %d\n", cache.Size())
if _, exists := cache.Get("one"); !exists {
fmt.Println("'one' was evicted from cache")
}
// 复杂类型的缓存
type User struct {
ID int
Name string
}
userCache := NewCache[int, User](2)
userCache.Set(1, User{ID: 1, Name: "Alice"})
userCache.Set(2, User{ID: 2, Name: "Bob"})
if user, exists := userCache.Get(1); exists {
fmt.Printf("Cached user: %+v\n", user)
}
}
通过本节的学习,您应该已经掌握了:
- 泛型函数的定义和使用:包括单类型参数和多类型参数函数
- 泛型类型的声明:创建可复用的泛型数据结构
- 泛型方法的实现:为泛型类型添加方法
- 类型推断机制:理解编译器如何自动推断类型参数
- 实际应用:通过完整示例了解泛型在实际项目中的应用
在下一节中,我们将深入学习泛型约束与接口,了解如何通过约束来限制和扩展泛型的功能。