1.9.4 泛型实战应用

1.9.4 泛型实战应用 #

在掌握了泛型的基础概念、函数类型和约束机制后,本节将通过实际项目示例来展示泛型在真实开发场景中的应用。我们将构建一些实用的泛型数据结构和算法,并探讨性能优化技巧和最佳实践。

泛型数据结构设计 #

1. 泛型优先队列 #

优先队列是一个常用的数据结构,让我们用泛型来实现一个类型安全且高效的版本:

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
)

// 优先队列接口
type PriorityQueue[T any] interface {
    Push(item T)
    Pop() (T, bool)
    Peek() (T, bool)
    Size() int
    IsEmpty() bool
}

// 基于堆的优先队列实现
type HeapPriorityQueue[T constraints.Ordered] struct {
    items []T
    less  func(T, T) bool
}

// 创建最小堆优先队列
func NewMinHeap[T constraints.Ordered]() *HeapPriorityQueue[T] {
    return &HeapPriorityQueue[T]{
        items: make([]T, 0),
        less:  func(a, b T) bool { return a < b },
    }
}

// 创建最大堆优先队列
func NewMaxHeap[T constraints.Ordered]() *HeapPriorityQueue[T] {
    return &HeapPriorityQueue[T]{
        items: make([]T, 0),
        less:  func(a, b T) bool { return a > b },
    }
}

// 创建自定义比较函数的优先队列
func NewCustomHeap[T any](less func(T, T) bool) *HeapPriorityQueue[T] {
    return &HeapPriorityQueue[T]{
        items: make([]T, 0),
        less:  less,
    }
}

// 添加元素
func (pq *HeapPriorityQueue[T]) Push(item T) {
    pq.items = append(pq.items, item)
    pq.heapifyUp(len(pq.items) - 1)
}

// 移除并返回优先级最高的元素
func (pq *HeapPriorityQueue[T]) Pop() (T, bool) {
    if len(pq.items) == 0 {
        var zero T
        return zero, false
    }

    root := pq.items[0]
    lastIndex := len(pq.items) - 1
    pq.items[0] = pq.items[lastIndex]
    pq.items = pq.items[:lastIndex]

    if len(pq.items) > 0 {
        pq.heapifyDown(0)
    }

    return root, true
}

// 查看优先级最高的元素但不移除
func (pq *HeapPriorityQueue[T]) Peek() (T, bool) {
    if len(pq.items) == 0 {
        var zero T
        return zero, false
    }
    return pq.items[0], true
}

// 获取大小
func (pq *HeapPriorityQueue[T]) Size() int {
    return len(pq.items)
}

// 检查是否为空
func (pq *HeapPriorityQueue[T]) IsEmpty() bool {
    return len(pq.items) == 0
}

// 向上调整堆
func (pq *HeapPriorityQueue[T]) heapifyUp(index int) {
    for index > 0 {
        parentIndex := (index - 1) / 2
        if !pq.less(pq.items[index], pq.items[parentIndex]) {
            break
        }
        pq.items[index], pq.items[parentIndex] = pq.items[parentIndex], pq.items[index]
        index = parentIndex
    }
}

// 向下调整堆
func (pq *HeapPriorityQueue[T]) heapifyDown(index int) {
    for {
        leftChild := 2*index + 1
        rightChild := 2*index + 2
        smallest := index

        if leftChild < len(pq.items) && pq.less(pq.items[leftChild], pq.items[smallest]) {
            smallest = leftChild
        }

        if rightChild < len(pq.items) && pq.less(pq.items[rightChild], pq.items[smallest]) {
            smallest = rightChild
        }

        if smallest == index {
            break
        }

        pq.items[index], pq.items[smallest] = pq.items[smallest], pq.items[index]
        index = smallest
    }
}

// 任务结构体用于演示自定义比较
type Task struct {
    ID       int
    Priority int
    Name     string
}

func (t Task) String() string {
    return fmt.Sprintf("Task{ID: %d, Priority: %d, Name: %s}", t.ID, t.Priority, t.Name)
}

func main() {
    // 整数最小堆
    minHeap := NewMinHeap[int]()
    numbers := []int{5, 2, 8, 1, 9, 3}

    fmt.Println("Adding numbers to min heap:", numbers)
    for _, num := range numbers {
        minHeap.Push(num)
    }

    fmt.Println("Popping from min heap:")
    for !minHeap.IsEmpty() {
        if value, ok := minHeap.Pop(); ok {
            fmt.Printf("%d ", value)
        }
    }
    fmt.Println()

    // 字符串最大堆
    maxHeap := NewMaxHeap[string]()
    words := []string{"apple", "banana", "cherry", "date", "elderberry"}

    fmt.Println("Adding words to max heap:", words)
    for _, word := range words {
        maxHeap.Push(word)
    }

    fmt.Println("Popping from max heap:")
    for !maxHeap.IsEmpty() {
        if value, ok := maxHeap.Pop(); ok {
            fmt.Printf("%s ", value)
        }
    }
    fmt.Println()

    // 自定义比较的任务优先队列(优先级高的先执行)
    taskQueue := NewCustomHeap(func(a, b Task) bool {
        return a.Priority > b.Priority // 优先级高的在前
    })

    tasks := []Task{
        {ID: 1, Priority: 3, Name: "Low priority task"},
        {ID: 2, Priority: 1, Name: "High priority task"},
        {ID: 3, Priority: 2, Name: "Medium priority task"},
        {ID: 4, Priority: 1, Name: "Another high priority task"},
    }

    fmt.Println("Adding tasks to priority queue:")
    for _, task := range tasks {
        fmt.Println(" ", task)
        taskQueue.Push(task)
    }

    fmt.Println("Processing tasks by priority:")
    for !taskQueue.IsEmpty() {
        if task, ok := taskQueue.Pop(); ok {
            fmt.Println(" ", task)
        }
    }
}

2. 泛型 LRU 缓存 #

LRU(Least Recently Used)缓存是另一个常见的数据结构,让我们实现一个泛型版本:

package main

import (
    "fmt"
)

// LRU缓存节点
type LRUNode[K comparable, V any] struct {
    key   K
    value V
    prev  *LRUNode[K, V]
    next  *LRUNode[K, V]
}

// LRU缓存
type LRUCache[K comparable, V any] struct {
    capacity int
    cache    map[K]*LRUNode[K, V]
    head     *LRUNode[K, V]
    tail     *LRUNode[K, V]
}

// 创建新的LRU缓存
func NewLRUCache[K comparable, V any](capacity int) *LRUCache[K, V] {
    if capacity <= 0 {
        panic("capacity must be positive")
    }

    lru := &LRUCache[K, V]{
        capacity: capacity,
        cache:    make(map[K]*LRUNode[K, V]),
    }

    // 创建哨兵节点
    lru.head = &LRUNode[K, V]{}
    lru.tail = &LRUNode[K, V]{}
    lru.head.next = lru.tail
    lru.tail.prev = lru.head

    return lru
}

// 获取值
func (lru *LRUCache[K, V]) Get(key K) (V, bool) {
    if node, exists := lru.cache[key]; exists {
        // 移动到头部
        lru.moveToHead(node)
        return node.value, true
    }

    var zero V
    return zero, false
}

// 设置值
func (lru *LRUCache[K, V]) Put(key K, value V) {
    if node, exists := lru.cache[key]; exists {
        // 更新现有节点
        node.value = value
        lru.moveToHead(node)
    } else {
        // 创建新节点
        newNode := &LRUNode[K, V]{
            key:   key,
            value: value,
        }

        lru.cache[key] = newNode
        lru.addToHead(newNode)

        // 检查容量
        if len(lru.cache) > lru.capacity {
            tail := lru.removeTail()
            delete(lru.cache, tail.key)
        }
    }
}

// 删除键
func (lru *LRUCache[K, V]) Delete(key K) bool {
    if node, exists := lru.cache[key]; exists {
        lru.removeNode(node)
        delete(lru.cache, key)
        return true
    }
    return false
}

// 获取当前大小
func (lru *LRUCache[K, V]) Size() int {
    return len(lru.cache)
}

// 获取容量
func (lru *LRUCache[K, V]) Capacity() int {
    return lru.capacity
}

// 清空缓存
func (lru *LRUCache[K, V]) Clear() {
    lru.cache = make(map[K]*LRUNode[K, V])
    lru.head.next = lru.tail
    lru.tail.prev = lru.head
}

// 获取所有键(按使用顺序)
func (lru *LRUCache[K, V]) Keys() []K {
    keys := make([]K, 0, len(lru.cache))
    current := lru.head.next
    for current != lru.tail {
        keys = append(keys, current.key)
        current = current.next
    }
    return keys
}

// 内部方法:添加节点到头部
func (lru *LRUCache[K, V]) addToHead(node *LRUNode[K, V]) {
    node.prev = lru.head
    node.next = lru.head.next
    lru.head.next.prev = node
    lru.head.next = node
}

// 内部方法:移除节点
func (lru *LRUCache[K, V]) removeNode(node *LRUNode[K, V]) {
    node.prev.next = node.next
    node.next.prev = node.prev
}

// 内部方法:移动节点到头部
func (lru *LRUCache[K, V]) moveToHead(node *LRUNode[K, V]) {
    lru.removeNode(node)
    lru.addToHead(node)
}

// 内部方法:移除尾部节点
func (lru *LRUCache[K, V]) removeTail() *LRUNode[K, V] {
    lastNode := lru.tail.prev
    lru.removeNode(lastNode)
    return lastNode
}

// 用户信息结构体
type UserInfo struct {
    ID    int
    Name  string
    Email string
}

func (u UserInfo) String() string {
    return fmt.Sprintf("User{ID: %d, Name: %s, Email: %s}", u.ID, u.Name, u.Email)
}

func main() {
    // 字符串到整数的缓存
    intCache := NewLRUCache[string, int](3)

    fmt.Println("=== String to Int Cache ===")
    intCache.Put("one", 1)
    intCache.Put("two", 2)
    intCache.Put("three", 3)

    fmt.Println("Keys after adding three items:", intCache.Keys())

    // 访问 "one",使其成为最近使用的
    if value, ok := intCache.Get("one"); ok {
        fmt.Printf("Got 'one': %d\n", value)
    }

    fmt.Println("Keys after accessing 'one':", intCache.Keys())

    // 添加第四个元素,应该淘汰 "two"
    intCache.Put("four", 4)
    fmt.Println("Keys after adding 'four':", intCache.Keys())

    // 尝试获取被淘汰的 "two"
    if _, ok := intCache.Get("two"); !ok {
        fmt.Println("'two' was evicted from cache")
    }

    // 用户信息缓存
    fmt.Println("\n=== User Info Cache ===")
    userCache := NewLRUCache[int, UserInfo](2)

    users := []UserInfo{
        {ID: 1, Name: "Alice", Email: "[email protected]"},
        {ID: 2, Name: "Bob", Email: "[email protected]"},
        {ID: 3, Name: "Charlie", Email: "[email protected]"},
    }

    // 添加用户到缓存
    for _, user := range users {
        userCache.Put(user.ID, user)
        fmt.Printf("Added user: %s\n", user)
        fmt.Printf("Current cache size: %d/%d\n", userCache.Size(), userCache.Capacity())
        fmt.Println("Cached user IDs:", userCache.Keys())
        fmt.Println()
    }

    // 访问用户1,使其成为最近使用的
    if user, ok := userCache.Get(1); ok {
        fmt.Printf("Retrieved user 1: %s\n", user)
        fmt.Println("Keys after accessing user 1:", userCache.Keys())
    }

    // 删除用户
    if userCache.Delete(3) {
        fmt.Println("Deleted user 3")
        fmt.Println("Keys after deletion:", userCache.Keys())
    }

    // 复杂类型的缓存
    fmt.Println("\n=== Complex Type Cache ===")
    type CacheKey struct {
        UserID int
        Action string
    }

    type CacheValue struct {
        Result    string
        Timestamp int64
    }

    complexCache := NewLRUCache[CacheKey, CacheValue](3)

    complexCache.Put(
        CacheKey{UserID: 1, Action: "login"},
        CacheValue{Result: "success", Timestamp: 1640995200},
    )

    complexCache.Put(
        CacheKey{UserID: 2, Action: "logout"},
        CacheValue{Result: "success", Timestamp: 1640995300},
    )

    key := CacheKey{UserID: 1, Action: "login"}
    if value, ok := complexCache.Get(key); ok {
        fmt.Printf("Complex cache result: %+v\n", value)
    }

    fmt.Printf("Complex cache size: %d\n", complexCache.Size())
}

泛型算法实现 #

1. 通用排序算法 #

让我们实现一些通用的排序算法:

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
    "math/rand"
    "time"
)

// 排序接口
type Sortable[T any] interface {
    Len() int
    Less(i, j int) bool
    Swap(i, j int)
}

// 切片包装器,实现 Sortable 接口
type SliceWrapper[T constraints.Ordered] []T

func (s SliceWrapper[T]) Len() int           { return len(s) }
func (s SliceWrapper[T]) Less(i, j int) bool { return s[i] < s[j] }
func (s SliceWrapper[T]) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }

// 自定义比较的切片包装器
type CustomSliceWrapper[T any] struct {
    data []T
    less func(T, T) bool
}

func NewCustomSliceWrapper[T any](data []T, less func(T, T) bool) *CustomSliceWrapper[T] {
    return &CustomSliceWrapper[T]{data: data, less: less}
}

func (s *CustomSliceWrapper[T]) Len() int { return len(s.data) }
func (s *CustomSliceWrapper[T]) Less(i, j int) bool {
    return s.less(s.data[i], s.data[j])
}
func (s *CustomSliceWrapper[T]) Swap(i, j int) {
    s.data[i], s.data[j] = s.data[j], s.data[i]
}

// 冒泡排序
func BubbleSort[T Sortable[T]](data T) {
    n := data.Len()
    for i := 0; i < n-1; i++ {
        swapped := false
        for j := 0; j < n-i-1; j++ {
            if !data.Less(j, j+1) {
                data.Swap(j, j+1)
                swapped = true
            }
        }
        if !swapped {
            break
        }
    }
}

// 快速排序
func QuickSort[T Sortable[T]](data T) {
    quickSortRecursive(data, 0, data.Len()-1)
}

func quickSortRecursive[T Sortable[T]](data T, low, high int) {
    if low < high {
        pi := partition(data, low, high)
        quickSortRecursive(data, low, pi-1)
        quickSortRecursive(data, pi+1, high)
    }
}

func partition[T Sortable[T]](data T, low, high int) int {
    i := low - 1
    for j := low; j < high; j++ {
        if data.Less(j, high) {
            i++
            data.Swap(i, j)
        }
    }
    data.Swap(i+1, high)
    return i + 1
}

// 归并排序
func MergeSort[T Sortable[T]](data T) {
    if data.Len() <= 1 {
        return
    }
    mergeSortRecursive(data, 0, data.Len()-1)
}

func mergeSortRecursive[T Sortable[T]](data T, left, right int) {
    if left < right {
        mid := left + (right-left)/2
        mergeSortRecursive(data, left, mid)
        mergeSortRecursive(data, mid+1, right)
        merge(data, left, mid, right)
    }
}

func merge[T Sortable[T]](data T, left, mid, right int) {
    // 创建临时数组来存储索引
    leftIndices := make([]int, mid-left+1)
    rightIndices := make([]int, right-mid)

    for i := range leftIndices {
        leftIndices[i] = left + i
    }
    for i := range rightIndices {
        rightIndices[i] = mid + 1 + i
    }

    i, j, k := 0, 0, left

    // 合并过程
    for i < len(leftIndices) && j < len(rightIndices) {
        if data.Less(leftIndices[i], rightIndices[j]) {
            // 将左侧元素移动到正确位置
            if leftIndices[i] != k {
                data.Swap(leftIndices[i], k)
                // 更新索引数组
                for idx := range leftIndices {
                    if leftIndices[idx] == k {
                        leftIndices[idx] = leftIndices[i]
                        break
                    }
                }
                leftIndices[i] = k
            }
            i++
        } else {
            // 将右侧元素移动到正确位置
            if rightIndices[j] != k {
                data.Swap(rightIndices[j], k)
                // 更新索引数组
                for idx := range rightIndices {
                    if rightIndices[idx] == k {
                        rightIndices[idx] = rightIndices[j]
                        break
                    }
                }
                for idx := range leftIndices {
                    if leftIndices[idx] == k {
                        leftIndices[idx] = rightIndices[j]
                        break
                    }
                }
                rightIndices[j] = k
            }
            j++
        }
        k++
    }

    // 处理剩余元素
    for i < len(leftIndices) {
        if leftIndices[i] != k {
            data.Swap(leftIndices[i], k)
        }
        i++
        k++
    }

    for j < len(rightIndices) {
        if rightIndices[j] != k {
            data.Swap(rightIndices[j], k)
        }
        j++
        k++
    }
}

// 简化的归并排序(使用切片)
func SimpleMergeSort[T constraints.Ordered](slice []T) []T {
    if len(slice) <= 1 {
        return slice
    }

    mid := len(slice) / 2
    left := SimpleMergeSort(slice[:mid])
    right := SimpleMergeSort(slice[mid:])

    return simpleMerge(left, right)
}

func simpleMerge[T constraints.Ordered](left, right []T) []T {
    result := make([]T, 0, len(left)+len(right))
    i, j := 0, 0

    for i < len(left) && j < len(right) {
        if left[i] <= right[j] {
            result = append(result, left[i])
            i++
        } else {
            result = append(result, right[j])
            j++
        }
    }

    result = append(result, left[i:]...)
    result = append(result, right[j:]...)

    return result
}

// 性能测试函数
func benchmarkSort[T Sortable[T]](name string, sortFunc func(T), data T) {
    start := time.Now()
    sortFunc(data)
    duration := time.Since(start)
    fmt.Printf("%s took: %v\n", name, duration)
}

// 学生结构体用于演示自定义排序
type Student struct {
    ID    int
    Name  string
    Grade float64
}

func (s Student) String() string {
    return fmt.Sprintf("Student{ID: %d, Name: %s, Grade: %.2f}", s.ID, s.Name, s.Grade)
}

func main() {
    // 整数排序测试
    fmt.Println("=== Integer Sorting ===")
    intData := []int{64, 34, 25, 12, 22, 11, 90, 5, 77, 30}
    fmt.Println("Original:", intData)

    // 冒泡排序
    bubbleData := make([]int, len(intData))
    copy(bubbleData, intData)
    wrapper := SliceWrapper[int](bubbleData)
    BubbleSort(&wrapper)
    fmt.Println("Bubble Sort:", bubbleData)

    // 快速排序
    quickData := make([]int, len(intData))
    copy(quickData, intData)
    wrapper = SliceWrapper[int](quickData)
    QuickSort(&wrapper)
    fmt.Println("Quick Sort:", quickData)

    // 简化归并排序
    mergeData := make([]int, len(intData))
    copy(mergeData, intData)
    sorted := SimpleMergeSort(mergeData)
    fmt.Println("Merge Sort:", sorted)

    // 字符串排序
    fmt.Println("\n=== String Sorting ===")
    stringData := []string{"banana", "apple", "cherry", "date", "elderberry"}
    fmt.Println("Original:", stringData)

    stringWrapper := SliceWrapper[string](stringData)
    QuickSort(&stringWrapper)
    fmt.Println("Quick Sort:", stringData)

    // 自定义排序:学生按成绩排序
    fmt.Println("\n=== Custom Sorting (Students by Grade) ===")
    students := []Student{
        {ID: 1, Name: "Alice", Grade: 85.5},
        {ID: 2, Name: "Bob", Grade: 92.0},
        {ID: 3, Name: "Charlie", Grade: 78.5},
        {ID: 4, Name: "Diana", Grade: 96.5},
        {ID: 5, Name: "Eve", Grade: 88.0},
    }

    fmt.Println("Original students:")
    for _, student := range students {
        fmt.Println(" ", student)
    }

    // 按成绩升序排序
    studentWrapper := NewCustomSliceWrapper(students, func(a, b Student) bool {
        return a.Grade < b.Grade
    })
    QuickSort(studentWrapper)

    fmt.Println("Sorted by grade (ascending):")
    for _, student := range students {
        fmt.Println(" ", student)
    }

    // 按姓名字母顺序排序
    studentWrapper = NewCustomSliceWrapper(students, func(a, b Student) bool {
        return a.Name < b.Name
    })
    QuickSort(studentWrapper)

    fmt.Println("Sorted by name (alphabetical):")
    for _, student := range students {
        fmt.Println(" ", student)
    }

    // 性能测试
    fmt.Println("\n=== Performance Test ===")
    rand.Seed(time.Now().UnixNano())

    // 生成大量随机数据
    size := 10000
    testData := make([]int, size)
    for i := 0; i < size; i++ {
        testData[i] = rand.Intn(10000)
    }

    // 测试不同排序算法的性能
    bubbleTestData := make([]int, len(testData))
    copy(bubbleTestData, testData)
    bubbleWrapper := SliceWrapper[int](bubbleTestData)
    benchmarkSort("Bubble Sort", func(data SliceWrapper[int]) { BubbleSort(&data) }, bubbleWrapper)

    quickTestData := make([]int, len(testData))
    copy(quickTestData, testData)
    quickWrapper := SliceWrapper[int](quickTestData)
    benchmarkSort("Quick Sort", func(data SliceWrapper[int]) { QuickSort(&data) }, quickWrapper)

    mergeTestData := make([]int, len(testData))
    copy(mergeTestData, testData)
    benchmarkSort("Merge Sort", func(data []int) { SimpleMergeSort(data) }, mergeTestData)
}

2. 泛型搜索算法 #

实现一些常用的搜索算法:

package main

import (
    "fmt"
    "golang.org/x/exp/constraints"
)

// 线性搜索
func LinearSearch[T comparable](slice []T, target T) int {
    for i, v := range slice {
        if v == target {
            return i
        }
    }
    return -1
}

// 带谓词的线性搜索
func LinearSearchWith[T any](slice []T, predicate func(T) bool) int {
    for i, v := range slice {
        if predicate(v) {
            return i
        }
    }
    return -1
}

// 二分搜索(要求切片已排序)
func BinarySearch[T constraints.Ordered](slice []T, target T) int {
    left, right := 0, len(slice)-1

    for left <= right {
        mid := left + (right-left)/2

        if slice[mid] == target {
            return mid
        } else if slice[mid] < target {
            left = mid + 1
        } else {
            right = mid - 1
        }
    }

    return -1
}

// 带自定义比较函数的二分搜索
func BinarySearchWith[T any](slice []T, target T, compare func(T, T) int) int {
    left, right := 0, len(slice)-1

    for left <= right {
        mid := left + (right-left)/2
        cmp := compare(slice[mid], target)

        if cmp == 0 {
            return mid
        } else if cmp < 0 {
            left = mid + 1
        } else {
            right = mid - 1
        }
    }

    return -1
}

// 查找第一个匹配的元素
func FindFirst[T any](slice []T, predicate func(T) bool) (T, int, bool) {
    for i, v := range slice {
        if predicate(v) {
            return v, i, true
        }
    }
    var zero T
    return zero, -1, false
}

// 查找最后一个匹配的元素
func FindLast[T any](slice []T, predicate func(T) bool) (T, int, bool) {
    for i := len(slice) - 1; i >= 0; i-- {
        if predicate(slice[i]) {
            return slice[i], i, true
        }
    }
    var zero T
    return zero, -1, false
}

// 查找所有匹配的元素
func FindAll[T any](slice []T, predicate func(T) bool) []T {
    var result []T
    for _, v := range slice {
        if predicate(v) {
            result = append(result, v)
        }
    }
    return result
}

// 查找所有匹配元素的索引
func FindAllIndices[T any](slice []T, predicate func(T) bool) []int {
    var indices []int
    for i, v := range slice {
        if predicate(v) {
            indices = append(indices, i)
        }
    }
    return indices
}

// 检查是否存在匹配的元素
func Any[T any](slice []T, predicate func(T) bool) bool {
    for _, v := range slice {
        if predicate(v) {
            return true
        }
    }
    return false
}

// 检查是否所有元素都匹配
func All[T any](slice []T, predicate func(T) bool) bool {
    for _, v := range slice {
        if !predicate(v) {
            return false
        }
    }
    return true
}

// 计算匹配元素的数量
func Count[T any](slice []T, predicate func(T) bool) int {
    count := 0
    for _, v := range slice {
        if predicate(v) {
            count++
        }
    }
    return count
}

// 产品结构体用于演示
type Product struct {
    ID       int
    Name     string
    Price    float64
    Category string
    InStock  bool
}

func (p Product) String() string {
    return fmt.Sprintf("Product{ID: %d, Name: %s, Price: %.2f, Category: %s, InStock: %t}",
        p.ID, p.Name, p.Price, p.Category, p.InStock)
}

func main() {
    // 基本搜索测试
    fmt.Println("=== Basic Search Tests ===")
    numbers := []int{1, 3, 5, 7, 9, 11, 13, 15, 17, 19}
    fmt.Println("Numbers:", numbers)

    // 线性搜索
    target := 7
    if index := LinearSearch(numbers, target); index != -1 {
        fmt.Printf("Linear search: found %d at index %d\n", target, index)
    } else {
        fmt.Printf("Linear search: %d not found\n", target)
    }

    // 二分搜索
    if index := BinarySearch(numbers, target); index != -1 {
        fmt.Printf("Binary search: found %d at index %d\n", target, index)
    } else {
        fmt.Printf("Binary search: %d not found\n", target)
    }

    // 带谓词的搜索
    if index := LinearSearchWith(numbers, func(n int) bool { return n > 10 }); index != -1 {
        fmt.Printf("First number > 10: %d at index %d\n", numbers[index], index)
    }

    // 字符串搜索
    fmt.Println("\n=== String Search ===")
    words := []string{"apple", "banana", "cherry", "date", "elderberry", "fig", "grape"}
    fmt.Println("Words:", words)

    wordTarget := "cherry"
    if index := BinarySearch(words, wordTarget); index != -1 {
        fmt.Printf("Found '%s' at index %d\n", wordTarget, index)
    }

    // 自定义比较的二分搜索
    if index := BinarySearchWith(words, "CHERRY", func(a, b string) int {
        // 不区分大小写的比较
        aLower := strings.ToLower(a)
        bLower := strings.ToLower(b)
        if aLower < bLower {
            return -1
        } else if aLower > bLower {
            return 1
        }
        return 0
    }); index != -1 {
        fmt.Printf("Case-insensitive search: found 'CHERRY' as '%s' at index %d\n", words[index], index)
    }

    // 复杂对象搜索
    fmt.Println("\n=== Product Search ===")
    products := []Product{
        {ID: 1, Name: "Laptop", Price: 999.99, Category: "Electronics", InStock: true},
        {ID: 2, Name: "Mouse", Price: 29.99, Category: "Electronics", InStock: true},
        {ID: 3, Name: "Keyboard", Price: 79.99, Category: "Electronics", InStock: false},
        {ID: 4, Name: "Book", Price: 19.99, Category: "Books", InStock: true},
        {ID: 5, Name: "Pen", Price: 2.99, Category: "Stationery", InStock: true},
        {ID: 6, Name: "Notebook", Price: 5.99, Category: "Stationery", InStock: false},
    }

    fmt.Println("Products:")
    for _, product := range products {
        fmt.Println(" ", product)
    }

    // 查找第一个电子产品
    if product, index, found := FindFirst(products, func(p Product) bool {
        return p.Category == "Electronics"
    }); found {
        fmt.Printf("\nFirst electronics product: %s at index %d\n", product.Name, index)
    }

    // 查找最后一个有库存的产品
    if product, index, found := FindLast(products, func(p Product) bool {
        return p.InStock
    }); found {
        fmt.Printf("Last in-stock product: %s at index %d\n", product.Name, index)
    }

    // 查找所有电子产品
    electronics := FindAll(products, func(p Product) bool {
        return p.Category == "Electronics"
    })
    fmt.Printf("\nAll electronics products (%d found):\n", len(electronics))
    for _, product := range electronics {
        fmt.Println(" ", product)
    }

    // 查找价格超过50的产品索引
    expensiveIndices := FindAllIndices(products, func(p Product) bool {
        return p.Price > 50
    })
    fmt.Printf("\nIndices of products > $50: %v\n", expensiveIndices)

    // 检查是否有缺货产品
    hasOutOfStock := Any(products, func(p Product) bool {
        return !p.InStock
    })
    fmt.Printf("Has out-of-stock products: %t\n", hasOutOfStock)

    // 检查是否所有产品都有库存
    allInStock := All(products, func(p Product) bool {
        return p.InStock
    })
    fmt.Printf("All products in stock: %t\n", allInStock)

    // 统计电子产品数量
    electronicsCount := Count(products, func(p Product) bool {
        return p.Category == "Electronics"
    })
    fmt.Printf("Number of electronics products: %d\n", electronicsCount)

    // 统计有库存的产品数量
    inStockCount := Count(products, func(p Product) bool {
        return p.InStock
    })
    fmt.Printf("Number of in-stock products: %d\n", inStockCount)

    // 性能比较
    fmt.Println("\n=== Performance Comparison ===")
    largeNumbers := make([]int, 100000)
    for i := 0; i < len(largeNumbers); i++ {
        largeNumbers[i] = i * 2 // 偶数序列
    }

    searchTarget := 99998

    // 线性搜索性能
    start := time.Now()
    linearResult := LinearSearch(largeNumbers, searchTarget)
    linearDuration := time.Since(start)

    // 二分搜索性能
    start = time.Now()
    binaryResult := BinarySearch(largeNumbers, searchTarget)
    binaryDuration := time.Since(start)

    fmt.Printf("Searching for %d in %d elements:\n", searchTarget, len(largeNumbers))
    fmt.Printf("Linear search: index %d, took %v\n", linearResult, linearDuration)
    fmt.Printf("Binary search: index %d, took %v\n", binaryResult, binaryDuration)
    fmt.Printf("Binary search is %.2fx faster\n", float64(linearDuration)/float64(binaryDuration))
}

性能优化技巧 #

1. 避免不必要的类型转换 #

package main

import (
    "fmt"
    "time"
)

// 低效的实现:使用 interface{}
func ProcessInterfaceSlice(items []interface{}) int {
    sum := 0
    for _, item := range items {
        if num, ok := item.(int); ok {
            sum += num
        }
    }
    return sum
}

// 高效的实现:使用泛型
func ProcessGenericSlice[T ~int](items []T) T {
    var sum T
    for _, item := range items {
        sum += item
    }
    return sum
}

// 性能测试
func benchmarkProcessing() {
    size := 1000000

    // 准备 interface{} 切片
    interfaceSlice := make([]interface{}, size)
    for i := 0; i < size; i++ {
        interfaceSlice[i] = i
    }

    // 准备泛型切片
    genericSlice := make([]int, size)
    for i := 0; i < size; i++ {
        genericSlice[i] = i
    }

    // 测试 interface{} 版本
    start := time.Now()
    result1 := ProcessInterfaceSlice(interfaceSlice)
    duration1 := time.Since(start)

    // 测试泛型版本
    start = time.Now()
    result2 := ProcessGenericSlice(genericSlice)
    duration2 := time.Since(start)

    fmt.Printf("Interface version: result=%d, time=%v\n", result1, duration1)
    fmt.Printf("Generic version: result=%d, time=%v\n", result2, duration2)
    fmt.Printf("Generic version is %.2fx faster\n", float64(duration1)/float64(duration2))
}

func main() {
    benchmarkProcessing()
}

2. 内存分配优化 #

package main

import (
    "fmt"
    "runtime"
    "time"
)

// 低效的实现:频繁分配内存
func IneffientMap[T, U any](slice []T, mapper func(T) U) []U {
    var result []U
    for _, v := range slice {
        result = append(result, mapper(v))
    }
    return result
}

// 高效的实现:预分配内存
func EfficientMap[T, U any](slice []T, mapper func(T) U) []U {
    result := make([]U, len(slice))
    for i, v := range slice {
        result[i] = mapper(v)
    }
    return result
}

// 内存使用测试
func benchmarkMemoryAllocation() {
    size := 100000
    input := make([]int, size)
    for i := 0; i < size; i++ {
        input[i] = i
    }

    mapper := func(n int) int { return n * 2 }

    // 测试低效版本
    var m1, m2 runtime.MemStats
    runtime.GC()
    runtime.ReadMemStats(&m1)

    start := time.Now()
    result1 := IneffientMap(input, mapper)
    duration1 := time.Since(start)

    runtime.ReadMemStats(&m2)
    allocations1 := m2.TotalAlloc - m1.TotalAlloc

    // 测试高效版本
    runtime.GC()
    runtime.ReadMemStats(&m1)

    start = time.Now()
    result2 := EfficientMap(input, mapper)
    duration2 := time.Since(start)

    runtime.ReadMemStats(&m2)
    allocations2 := m2.TotalAlloc - m1.TotalAlloc

    fmt.Printf("Inefficient version: time=%v, allocations=%d bytes\n", duration1, allocations1)
    fmt.Printf("Efficient version: time=%v, allocations=%d bytes\n", duration2, allocations2)
    fmt.Printf("Results equal: %t\n", len(result1) == len(result2))
    fmt.Printf("Memory reduction: %.2fx\n", float64(allocations1)/float64(allocations2))
}

func main() {
    benchmarkMemoryAllocation()
}

最佳实践与注意事项 #

1. 合理使用泛型 #

package main

import "fmt"

// 不好的例子:过度使用泛型
func BadExample[T any](value T) T {
    fmt.Println(value)
    return value
}

// 好的例子:在需要类型安全和复用时使用泛型
func GoodExample[T comparable](slice []T, target T) bool {
    for _, v := range slice {
        if v == target {
            return true
        }
    }
    return false
}

// 不好的例子:简单场景使用泛型
func BadPrint[T any](value T) {
    fmt.Println(value)
}

// 好的例子:简单场景使用接口
func GoodPrint(value interface{}) {
    fmt.Println(value)
}

func main() {
    // 演示合理使用
    numbers := []int{1, 2, 3, 4, 5}
    fmt.Println("Contains 3:", GoodExample(numbers, 3))

    words := []string{"hello", "world", "go"}
    fmt.Println("Contains 'go':", GoodExample(words, "go"))
}

2. 错误处理模式 #

package main

import (
    "fmt"
    "errors"
)

// 泛型结果类型
type Result[T any] struct {
    value T
    err   error
}

// 创建成功结果
func Ok[T any](value T) Result[T] {
    return Result[T]{value: value}
}

// 创建错误结果
func Err[T any](err error) Result[T] {
    return Result[T]{err: err}
}

// 结果方法
func (r Result[T]) IsOk() bool {
    return r.err == nil
}

func (r Result[T]) IsErr() bool {
    return r.err != nil
}

func (r Result[T]) Unwrap() (T, error) {
    return r.value, r.err
}

func (r Result[T]) UnwrapOr(defaultValue T) T {
    if r.err != nil {
        return defaultValue
    }
    return r.value
}

// 链式操作
func (r Result[T]) Map(f func(T) interface{}) Result[interface{}] {
    if r.err != nil {
        return Err[interface{}](r.err)
    }
    return Ok[interface{}](f(r.value))
}

// 使用示例
func Divide(a, b float64) Result[float64] {
    if b == 0 {
        return Err[float64](errors.New("division by zero"))
    }
    return Ok(a / b)
}

func Sqrt(x float64) Result[float64] {
    if x < 0 {
        return Err[float64](errors.New("negative number"))
    }
    return Ok(x * x) // 简化的平方根
}

func main() {
    // 成功案例
    result := Divide(10, 2)
    if result.IsOk() {
        value, _ := result.Unwrap()
        fmt.Printf("10 / 2 = %.2f\n", value)
    }

    // 错误案例
    result = Divide(10, 0)
    if result.IsErr() {
        _, err := result.Unwrap()
        fmt.Printf("Error: %v\n", err)
    }

    // 使用默认值
    value := Divide(10, 0).UnwrapOr(-1)
    fmt.Printf("With default: %.2f\n", value)

    // 链式操作
    chainResult := Divide(16, 4).Map(func(x float64) interface{} {
        return fmt.Sprintf("Result: %.2f", x)
    })

    if chainResult.IsOk() {
        value, _ := chainResult.Unwrap()
        fmt.Println(value)
    }
}

通过本节的学习,您应该已经掌握了:

  1. 泛型数据结构设计:如何构建实用的泛型数据结构
  2. 泛型算法实现:掌握通用算法的泛型实现方法
  3. 性能优化技巧:了解如何优化泛型代码的性能
  4. 最佳实践:知道何时使用泛型以及如何正确使用

泛型是 Go 语言的强大特性,但需要在合适的场景下使用。通过合理的设计和实现,泛型可以显著提高代码的复用性、类型安全性和性能。在实际项目中,建议根据具体需求选择是否使用泛型,避免过度设计。