4.7.2 CGO 高级特性与陷阱

4.7.2 CGO 高级特性与陷阱 #

在掌握了 CGO 的基础知识后,本节将深入探讨 CGO 的高级特性,包括复杂的内存管理、性能优化技巧,以及在实际开发中需要避免的常见陷阱。这些知识对于编写稳定、高效的 CGO 代码至关重要。

高级内存管理 #

内存所有权模型 #

在 CGO 中,内存管理是最复杂也是最容易出错的部分。理解内存所有权模型是关键:

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// 内存所有权示例

// 1. C 分配,C 释放
char* create_c_string(const char* input) {
    char* result = malloc(strlen(input) + 1);
    strcpy(result, input);
    return result;
}

void free_c_string(char* str) {
    free(str);
}

// 2. C 分配,Go 管理
typedef struct {
    char* data;
    int size;
    int capacity;
} managed_buffer;

managed_buffer* create_managed_buffer(int capacity) {
    managed_buffer* buf = malloc(sizeof(managed_buffer));
    buf->data = malloc(capacity);
    buf->size = 0;
    buf->capacity = capacity;
    return buf;
}

void free_managed_buffer(managed_buffer* buf) {
    if (buf) {
        free(buf->data);
        free(buf);
    }
}

// 3. 共享内存区域
static char shared_memory[1024];

char* get_shared_memory() {
    return shared_memory;
}

// 4. 引用计数
typedef struct ref_counted {
    int ref_count;
    char* data;
    struct ref_counted* next;
} ref_counted;

static ref_counted* ref_list = NULL;

ref_counted* create_ref_counted(const char* data) {
    ref_counted* obj = malloc(sizeof(ref_counted));
    obj->ref_count = 1;
    obj->data = malloc(strlen(data) + 1);
    strcpy(obj->data, data);
    obj->next = ref_list;
    ref_list = obj;
    return obj;
}

void retain_ref_counted(ref_counted* obj) {
    if (obj) {
        obj->ref_count++;
    }
}

void release_ref_counted(ref_counted* obj) {
    if (obj && --obj->ref_count == 0) {
        // 从链表中移除
        if (ref_list == obj) {
            ref_list = obj->next;
        } else {
            ref_counted* current = ref_list;
            while (current && current->next != obj) {
                current = current->next;
            }
            if (current) {
                current->next = obj->next;
            }
        }
        free(obj->data);
        free(obj);
    }
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "sync"
    "unsafe"
)

// ManagedBuffer Go 包装器,演示内存所有权管理
type ManagedBuffer struct {
    cBuffer *C.managed_buffer
    mutex   sync.RWMutex
}

func NewManagedBuffer(capacity int) *ManagedBuffer {
    cBuf := C.create_managed_buffer(C.int(capacity))
    if cBuf == nil {
        return nil
    }

    mb := &ManagedBuffer{cBuffer: cBuf}
    runtime.SetFinalizer(mb, (*ManagedBuffer).finalize)
    return mb
}

func (mb *ManagedBuffer) Write(data []byte) error {
    mb.mutex.Lock()
    defer mb.mutex.Unlock()

    if mb.cBuffer == nil {
        return fmt.Errorf("buffer is closed")
    }

    if len(data) > int(mb.cBuffer.capacity) {
        return fmt.Errorf("data too large")
    }

    // 将 Go 数据复制到 C 缓冲区
    C.memcpy(unsafe.Pointer(mb.cBuffer.data),
             unsafe.Pointer(&data[0]),
             C.size_t(len(data)))
    mb.cBuffer.size = C.int(len(data))

    return nil
}

func (mb *ManagedBuffer) Read() []byte {
    mb.mutex.RLock()
    defer mb.mutex.RUnlock()

    if mb.cBuffer == nil || mb.cBuffer.size == 0 {
        return nil
    }

    // 从 C 缓冲区复制数据到 Go
    return C.GoBytes(unsafe.Pointer(mb.cBuffer.data), mb.cBuffer.size)
}

func (mb *ManagedBuffer) Close() {
    mb.mutex.Lock()
    defer mb.mutex.Unlock()

    if mb.cBuffer != nil {
        C.free_managed_buffer(mb.cBuffer)
        mb.cBuffer = nil
        runtime.SetFinalizer(mb, nil)
    }
}

func (mb *ManagedBuffer) finalize() {
    mb.Close()
}

// RefCounted 引用计数对象包装器
type RefCounted struct {
    cObj *C.ref_counted
}

func NewRefCounted(data string) *RefCounted {
    cData := C.CString(data)
    defer C.free(unsafe.Pointer(cData))

    cObj := C.create_ref_counted(cData)
    if cObj == nil {
        return nil
    }

    rc := &RefCounted{cObj: cObj}
    runtime.SetFinalizer(rc, (*RefCounted).finalize)
    return rc
}

func (rc *RefCounted) Retain() *RefCounted {
    if rc.cObj != nil {
        C.retain_ref_counted(rc.cObj)
        newRC := &RefCounted{cObj: rc.cObj}
        runtime.SetFinalizer(newRC, (*RefCounted).finalize)
        return newRC
    }
    return nil
}

func (rc *RefCounted) Data() string {
    if rc.cObj != nil {
        return C.GoString(rc.cObj.data)
    }
    return ""
}

func (rc *RefCounted) Release() {
    if rc.cObj != nil {
        C.release_ref_counted(rc.cObj)
        rc.cObj = nil
        runtime.SetFinalizer(rc, nil)
    }
}

func (rc *RefCounted) finalize() {
    rc.Release()
}

func demonstrateMemoryOwnership() {
    fmt.Println("=== Memory Ownership Demonstration ===")

    // 1. C 分配,C 释放
    cStr := C.create_c_string(C.CString("Hello, C!"))
    goStr := C.GoString(cStr)
    fmt.Printf("C string: %s\n", goStr)
    C.free_c_string(cStr)

    // 2. 管理缓冲区
    buf := NewManagedBuffer(1024)
    defer buf.Close()

    data := []byte("Hello, managed buffer!")
    buf.Write(data)
    readData := buf.Read()
    fmt.Printf("Buffer data: %s\n", string(readData))

    // 3. 引用计数
    rc1 := NewRefCounted("Reference counted data")
    fmt.Printf("RC1 data: %s\n", rc1.Data())

    rc2 := rc1.Retain()
    fmt.Printf("RC2 data: %s\n", rc2.Data())

    rc1.Release()
    fmt.Printf("RC2 data after RC1 release: %s\n", rc2.Data())
    rc2.Release()
}

内存池和对象池 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define POOL_SIZE 100
#define OBJECT_SIZE 256

typedef struct pool_object {
    char data[OBJECT_SIZE];
    struct pool_object* next;
    int in_use;
} pool_object;

typedef struct {
    pool_object objects[POOL_SIZE];
    pool_object* free_list;
    int allocated_count;
    int total_allocations;
    int total_deallocations;
} memory_pool;

static memory_pool global_pool = {0};
static int pool_initialized = 0;

void init_memory_pool() {
    if (pool_initialized) return;

    // 初始化空闲链表
    for (int i = 0; i < POOL_SIZE - 1; i++) {
        global_pool.objects[i].next = &global_pool.objects[i + 1];
        global_pool.objects[i].in_use = 0;
    }
    global_pool.objects[POOL_SIZE - 1].next = NULL;
    global_pool.objects[POOL_SIZE - 1].in_use = 0;

    global_pool.free_list = &global_pool.objects[0];
    global_pool.allocated_count = 0;
    global_pool.total_allocations = 0;
    global_pool.total_deallocations = 0;

    pool_initialized = 1;
}

pool_object* pool_alloc() {
    if (!pool_initialized) init_memory_pool();

    if (global_pool.free_list == NULL) {
        return NULL; // 池已满
    }

    pool_object* obj = global_pool.free_list;
    global_pool.free_list = obj->next;
    obj->next = NULL;
    obj->in_use = 1;

    global_pool.allocated_count++;
    global_pool.total_allocations++;

    memset(obj->data, 0, OBJECT_SIZE);
    return obj;
}

void pool_free(pool_object* obj) {
    if (!obj || !obj->in_use) return;

    obj->in_use = 0;
    obj->next = global_pool.free_list;
    global_pool.free_list = obj;

    global_pool.allocated_count--;
    global_pool.total_deallocations++;
}

void get_pool_stats(int* allocated, int* total_allocs, int* total_deallocs) {
    *allocated = global_pool.allocated_count;
    *total_allocs = global_pool.total_allocations;
    *total_deallocs = global_pool.total_deallocations;
}

void set_object_data(pool_object* obj, const char* data) {
    if (obj && obj->in_use) {
        strncpy(obj->data, data, OBJECT_SIZE - 1);
        obj->data[OBJECT_SIZE - 1] = '\0';
    }
}

char* get_object_data(pool_object* obj) {
    if (obj && obj->in_use) {
        return obj->data;
    }
    return NULL;
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "sync"
    "time"
    "unsafe"
)

// PoolObject Go 包装器
type PoolObject struct {
    cObj *C.pool_object
    pool *ObjectPool
}

// ObjectPool 对象池
type ObjectPool struct {
    mutex sync.Mutex
}

var globalPool = &ObjectPool{}

func (pool *ObjectPool) Get() *PoolObject {
    pool.mutex.Lock()
    defer pool.mutex.Unlock()

    cObj := C.pool_alloc()
    if cObj == nil {
        return nil
    }

    obj := &PoolObject{cObj: cObj, pool: pool}
    runtime.SetFinalizer(obj, (*PoolObject).finalize)
    return obj
}

func (obj *PoolObject) SetData(data string) {
    if obj.cObj != nil {
        cData := C.CString(data)
        defer C.free(unsafe.Pointer(cData))
        C.set_object_data(obj.cObj, cData)
    }
}

func (obj *PoolObject) GetData() string {
    if obj.cObj != nil {
        cData := C.get_object_data(obj.cObj)
        if cData != nil {
            return C.GoString(cData)
        }
    }
    return ""
}

func (obj *PoolObject) Release() {
    if obj.cObj != nil {
        obj.pool.mutex.Lock()
        C.pool_free(obj.cObj)
        obj.pool.mutex.Unlock()

        obj.cObj = nil
        runtime.SetFinalizer(obj, nil)
    }
}

func (obj *PoolObject) finalize() {
    obj.Release()
}

func (pool *ObjectPool) GetStats() (allocated, totalAllocs, totalDeallocs int) {
    pool.mutex.Lock()
    defer pool.mutex.Unlock()

    var cAllocated, cTotalAllocs, cTotalDeallocs C.int
    C.get_pool_stats(&cAllocated, &cTotalAllocs, &cTotalDeallocs)

    return int(cAllocated), int(cTotalAllocs), int(cTotalDeallocs)
}

func demonstrateObjectPool() {
    fmt.Println("=== Object Pool Demonstration ===")

    // 分配一些对象
    objects := make([]*PoolObject, 10)
    for i := 0; i < 10; i++ {
        obj := globalPool.Get()
        if obj != nil {
            obj.SetData(fmt.Sprintf("Object %d", i))
            objects[i] = obj
        }
    }

    allocated, totalAllocs, totalDeallocs := globalPool.GetStats()
    fmt.Printf("After allocation - Allocated: %d, Total allocs: %d, Total deallocs: %d\n",
               allocated, totalAllocs, totalDeallocs)

    // 释放一些对象
    for i := 0; i < 5; i++ {
        if objects[i] != nil {
            fmt.Printf("Releasing object with data: %s\n", objects[i].GetData())
            objects[i].Release()
            objects[i] = nil
        }
    }

    allocated, totalAllocs, totalDeallocs = globalPool.GetStats()
    fmt.Printf("After partial release - Allocated: %d, Total allocs: %d, Total deallocs: %d\n",
               allocated, totalAllocs, totalDeallocs)

    // 重新分配
    for i := 0; i < 5; i++ {
        obj := globalPool.Get()
        if obj != nil {
            obj.SetData(fmt.Sprintf("Reused Object %d", i))
            objects[i] = obj
        }
    }

    allocated, totalAllocs, totalDeallocs = globalPool.GetStats()
    fmt.Printf("After reallocation - Allocated: %d, Total allocs: %d, Total deallocs: %d\n",
               allocated, totalAllocs, totalDeallocs)

    // 清理剩余对象
    for i := 0; i < 10; i++ {
        if objects[i] != nil {
            objects[i].Release()
        }
    }

    allocated, totalAllocs, totalDeallocs = globalPool.GetStats()
    fmt.Printf("After cleanup - Allocated: %d, Total allocs: %d, Total deallocs: %d\n",
               allocated, totalAllocs, totalDeallocs)
}

高级类型转换 #

复杂数据结构转换 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// 复杂嵌套结构体
typedef struct node {
    int id;
    char* name;
    struct node** children;
    int child_count;
    struct node* parent;
} node_t;

// 创建节点
node_t* create_node(int id, const char* name) {
    node_t* node = malloc(sizeof(node_t));
    node->id = id;
    node->name = malloc(strlen(name) + 1);
    strcpy(node->name, name);
    node->children = NULL;
    node->child_count = 0;
    node->parent = NULL;
    return node;
}

// 添加子节点
void add_child(node_t* parent, node_t* child) {
    parent->children = realloc(parent->children,
                              (parent->child_count + 1) * sizeof(node_t*));
    parent->children[parent->child_count] = child;
    parent->child_count++;
    child->parent = parent;
}

// 释放节点树
void free_node_tree(node_t* node) {
    if (!node) return;

    for (int i = 0; i < node->child_count; i++) {
        free_node_tree(node->children[i]);
    }

    free(node->children);
    free(node->name);
    free(node);
}

// 打印节点树
void print_node_tree(node_t* node, int depth) {
    if (!node) return;

    for (int i = 0; i < depth; i++) {
        printf("  ");
    }
    printf("Node %d: %s\n", node->id, node->name);

    for (int i = 0; i < node->child_count; i++) {
        print_node_tree(node->children[i], depth + 1);
    }
}

// 联合体示例
typedef union {
    int i;
    float f;
    char c[4];
} value_union;

typedef struct {
    int type; // 0=int, 1=float, 2=string
    value_union value;
} variant_t;

variant_t create_int_variant(int value) {
    variant_t v;
    v.type = 0;
    v.value.i = value;
    return v;
}

variant_t create_float_variant(float value) {
    variant_t v;
    v.type = 1;
    v.value.f = value;
    return v;
}

variant_t create_string_variant(const char* value) {
    variant_t v;
    v.type = 2;
    strncpy(v.value.c, value, 4);
    return v;
}

void print_variant(variant_t* v) {
    switch (v->type) {
        case 0:
            printf("Int: %d\n", v->value.i);
            break;
        case 1:
            printf("Float: %f\n", v->value.f);
            break;
        case 2:
            printf("String: %.4s\n", v->value.c);
            break;
    }
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "unsafe"
)

// Node Go 包装器
type Node struct {
    cNode    *C.node_t
    children []*Node
    parent   *Node
}

func NewNode(id int, name string) *Node {
    cName := C.CString(name)
    defer C.free(unsafe.Pointer(cName))

    cNode := C.create_node(C.int(id), cName)
    if cNode == nil {
        return nil
    }

    node := &Node{cNode: cNode}
    runtime.SetFinalizer(node, (*Node).finalize)
    return node
}

func (n *Node) AddChild(child *Node) {
    if n.cNode == nil || child.cNode == nil {
        return
    }

    C.add_child(n.cNode, child.cNode)
    n.children = append(n.children, child)
    child.parent = n
}

func (n *Node) ID() int {
    if n.cNode != nil {
        return int(n.cNode.id)
    }
    return 0
}

func (n *Node) Name() string {
    if n.cNode != nil && n.cNode.name != nil {
        return C.GoString(n.cNode.name)
    }
    return ""
}

func (n *Node) Children() []*Node {
    return n.children
}

func (n *Node) Parent() *Node {
    return n.parent
}

func (n *Node) Print() {
    if n.cNode != nil {
        C.print_node_tree(n.cNode, 0)
    }
}

func (n *Node) Close() {
    if n.cNode != nil {
        C.free_node_tree(n.cNode)
        n.cNode = nil
        n.children = nil
        n.parent = nil
        runtime.SetFinalizer(n, nil)
    }
}

func (n *Node) finalize() {
    n.Close()
}

// Variant 变体类型包装器
type Variant struct {
    cVariant C.variant_t
}

func NewIntVariant(value int) *Variant {
    return &Variant{cVariant: C.create_int_variant(C.int(value))}
}

func NewFloatVariant(value float32) *Variant {
    return &Variant{cVariant: C.create_float_variant(C.float(value))}
}

func NewStringVariant(value string) *Variant {
    cValue := C.CString(value)
    defer C.free(unsafe.Pointer(cValue))
    return &Variant{cVariant: C.create_string_variant(cValue)}
}

func (v *Variant) Type() int {
    return int(v.cVariant._type)
}

func (v *Variant) IntValue() int {
    if v.cVariant._type == 0 {
        return int(v.cVariant.value.i)
    }
    return 0
}

func (v *Variant) FloatValue() float32 {
    if v.cVariant._type == 1 {
        return float32(v.cVariant.value.f)
    }
    return 0
}

func (v *Variant) StringValue() string {
    if v.cVariant._type == 2 {
        // 注意:这里需要小心处理,因为 C 联合体中的字符数组可能不以 null 结尾
        bytes := C.GoBytes(unsafe.Pointer(&v.cVariant.value.c[0]), 4)
        // 找到第一个 null 字节
        for i, b := range bytes {
            if b == 0 {
                return string(bytes[:i])
            }
        }
        return string(bytes)
    }
    return ""
}

func (v *Variant) Print() {
    C.print_variant(&v.cVariant)
}

func demonstrateComplexTypes() {
    fmt.Println("=== Complex Type Conversion Demonstration ===")

    // 创建节点树
    root := NewNode(1, "Root")
    defer root.Close()

    child1 := NewNode(2, "Child 1")
    child2 := NewNode(3, "Child 2")
    grandchild := NewNode(4, "Grandchild")

    root.AddChild(child1)
    root.AddChild(child2)
    child1.AddChild(grandchild)

    fmt.Println("Node tree:")
    root.Print()

    // 变体类型示例
    fmt.Println("\nVariant types:")

    intVar := NewIntVariant(42)
    floatVar := NewFloatVariant(3.14)
    stringVar := NewStringVariant("Hi!")

    fmt.Printf("Int variant (type %d): %d\n", intVar.Type(), intVar.IntValue())
    fmt.Printf("Float variant (type %d): %f\n", floatVar.Type(), floatVar.FloatValue())
    fmt.Printf("String variant (type %d): %s\n", stringVar.Type(), stringVar.StringValue())

    fmt.Println("\nC print variants:")
    intVar.Print()
    floatVar.Print()
    stringVar.Print()
}

性能优化技巧 #

批量操作优化 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

// 单个操作
int process_single(int value) {
    return value * value + value + 1;
}

// 批量操作
void process_batch(int* input, int* output, int count) {
    for (int i = 0; i < count; i++) {
        output[i] = input[i] * input[i] + input[i] + 1;
    }
}

// SIMD 优化示例(简化版)
void process_batch_optimized(int* input, int* output, int count) {
    // 这里可以使用 SIMD 指令进行优化
    // 为了简化,我们只是展示批量处理的概念
    int i;
    for (i = 0; i < count - 3; i += 4) {
        // 模拟 4 个元素的并行处理
        output[i] = input[i] * input[i] + input[i] + 1;
        output[i+1] = input[i+1] * input[i+1] + input[i+1] + 1;
        output[i+2] = input[i+2] * input[i+2] + input[i+2] + 1;
        output[i+3] = input[i+3] * input[i+3] + input[i+3] + 1;
    }
    // 处理剩余元素
    for (; i < count; i++) {
        output[i] = input[i] * input[i] + input[i] + 1;
    }
}

// 字符串批量处理
void process_strings_batch(char** inputs, char** outputs, int count) {
    for (int i = 0; i < count; i++) {
        int len = strlen(inputs[i]);
        outputs[i] = malloc(len + 1);

        // 转换为大写
        for (int j = 0; j < len; j++) {
            if (inputs[i][j] >= 'a' && inputs[i][j] <= 'z') {
                outputs[i][j] = inputs[i][j] - 32;
            } else {
                outputs[i][j] = inputs[i][j];
            }
        }
        outputs[i][len] = '\0';
    }
}

void free_string_array(char** strings, int count) {
    for (int i = 0; i < count; i++) {
        free(strings[i]);
    }
}

// 内存预分配优化
typedef struct {
    int* buffer;
    int capacity;
    int size;
} int_vector;

int_vector* create_int_vector(int initial_capacity) {
    int_vector* vec = malloc(sizeof(int_vector));
    vec->buffer = malloc(initial_capacity * sizeof(int));
    vec->capacity = initial_capacity;
    vec->size = 0;
    return vec;
}

void int_vector_push_batch(int_vector* vec, int* values, int count) {
    // 确保容量足够
    if (vec->size + count > vec->capacity) {
        int new_capacity = (vec->size + count) * 2;
        vec->buffer = realloc(vec->buffer, new_capacity * sizeof(int));
        vec->capacity = new_capacity;
    }

    // 批量复制
    memcpy(vec->buffer + vec->size, values, count * sizeof(int));
    vec->size += count;
}

void free_int_vector(int_vector* vec) {
    if (vec) {
        free(vec->buffer);
        free(vec);
    }
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "time"
    "unsafe"
)

// 性能测试函数
func benchmarkSingleVsBatch() {
    fmt.Println("=== Performance: Single vs Batch Operations ===")

    const count = 1000000
    input := make([]int, count)
    for i := 0; i < count; i++ {
        input[i] = i
    }

    // 单个操作测试
    start := time.Now()
    output1 := make([]int, count)
    for i := 0; i < count; i++ {
        output1[i] = int(C.process_single(C.int(input[i])))
    }
    singleDuration := time.Since(start)

    // 批量操作测试
    start = time.Now()
    output2 := make([]int, count)
    C.process_batch((*C.int)(unsafe.Pointer(&input[0])),
                   (*C.int)(unsafe.Pointer(&output2[0])),
                   C.int(count))
    batchDuration := time.Since(start)

    // 优化批量操作测试
    start = time.Now()
    output3 := make([]int, count)
    C.process_batch_optimized((*C.int)(unsafe.Pointer(&input[0])),
                             (*C.int)(unsafe.Pointer(&output3[0])),
                             C.int(count))
    optimizedDuration := time.Since(start)

    fmt.Printf("Single operations: %v\n", singleDuration)
    fmt.Printf("Batch operations: %v (%.2fx faster)\n",
               batchDuration, float64(singleDuration)/float64(batchDuration))
    fmt.Printf("Optimized batch: %v (%.2fx faster than single)\n",
               optimizedDuration, float64(singleDuration)/float64(optimizedDuration))

    // 验证结果一致性
    consistent := true
    for i := 0; i < 100; i++ { // 只检查前100个元素
        if output1[i] != output2[i] || output2[i] != output3[i] {
            consistent = false
            break
        }
    }
    fmt.Printf("Results consistent: %v\n", consistent)
}

// 字符串批量处理示例
func demonstrateStringBatch() {
    fmt.Println("=== String Batch Processing ===")

    inputs := []string{"hello", "world", "golang", "cgo", "optimization"}

    // 准备 C 字符串数组
    cInputs := make([]*C.char, len(inputs))
    for i, s := range inputs {
        cInputs[i] = C.CString(s)
    }
    defer func() {
        for _, cStr := range cInputs {
            C.free(unsafe.Pointer(cStr))
        }
    }()

    // 准备输出数组
    cOutputs := make([]*C.char, len(inputs))

    // 批量处理
    C.process_strings_batch((**C.char)(unsafe.Pointer(&cInputs[0])),
                           (**C.char)(unsafe.Pointer(&cOutputs[0])),
                           C.int(len(inputs)))

    // 转换回 Go 字符串并清理
    outputs := make([]string, len(inputs))
    for i := 0; i < len(inputs); i++ {
        outputs[i] = C.GoString(cOutputs[i])
    }

    // 清理 C 分配的内存
    C.free_string_array((**C.char)(unsafe.Pointer(&cOutputs[0])), C.int(len(inputs)))

    fmt.Println("String processing results:")
    for i, output := range outputs {
        fmt.Printf("  %s -> %s\n", inputs[i], output)
    }
}

// IntVector Go 包装器
type IntVector struct {
    cVec *C.int_vector
}

func NewIntVector(capacity int) *IntVector {
    cVec := C.create_int_vector(C.int(capacity))
    if cVec == nil {
        return nil
    }

    iv := &IntVector{cVec: cVec}
    runtime.SetFinalizer(iv, (*IntVector).finalize)
    return iv
}

func (iv *IntVector) PushBatch(values []int) {
    if iv.cVec != nil && len(values) > 0 {
        C.int_vector_push_batch(iv.cVec,
                               (*C.int)(unsafe.Pointer(&values[0])),
                               C.int(len(values)))
    }
}

func (iv *IntVector) Size() int {
    if iv.cVec != nil {
        return int(iv.cVec.size)
    }
    return 0
}

func (iv *IntVector) Capacity() int {
    if iv.cVec != nil {
        return int(iv.cVec.capacity)
    }
    return 0
}

func (iv *IntVector) ToSlice() []int {
    if iv.cVec == nil || iv.cVec.size == 0 {
        return nil
    }

    size := int(iv.cVec.size)
    // 将 C 数组转换为 Go 切片
    cArray := (*[1 << 30]C.int)(unsafe.Pointer(iv.cVec.buffer))[:size:size]

    result := make([]int, size)
    for i, v := range cArray {
        result[i] = int(v)
    }
    return result
}

func (iv *IntVector) Close() {
    if iv.cVec != nil {
        C.free_int_vector(iv.cVec)
        iv.cVec = nil
        runtime.SetFinalizer(iv, nil)
    }
}

func (iv *IntVector) finalize() {
    iv.Close()
}

func demonstrateIntVector() {
    fmt.Println("=== Int Vector Demonstration ===")

    vec := NewIntVector(10)
    defer vec.Close()

    fmt.Printf("Initial - Size: %d, Capacity: %d\n", vec.Size(), vec.Capacity())

    // 批量添加数据
    batch1 := []int{1, 2, 3, 4, 5}
    vec.PushBatch(batch1)
    fmt.Printf("After first batch - Size: %d, Capacity: %d\n", vec.Size(), vec.Capacity())

    batch2 := []int{6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
    vec.PushBatch(batch2)
    fmt.Printf("After second batch - Size: %d, Capacity: %d\n", vec.Size(), vec.Capacity())

    data := vec.ToSlice()
    fmt.Printf("Vector data: %v\n", data)
}

常见陷阱和最佳实践 #

内存泄漏陷阱 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// 陷阱1:忘记释放 C.CString 分配的内存
char* process_string_bad(char* input) {
    // 这个函数返回的字符串需要调用者释放
    char* result = malloc(strlen(input) + 10);
    sprintf(result, "Processed: %s", input);
    return result;
}

// 陷阱2:双重释放
static char* global_buffer = NULL;

void set_global_buffer(char* data) {
    if (global_buffer) {
        free(global_buffer);
    }
    global_buffer = malloc(strlen(data) + 1);
    strcpy(global_buffer, data);
}

char* get_global_buffer() {
    return global_buffer;
}

void clear_global_buffer() {
    if (global_buffer) {
        free(global_buffer);
        global_buffer = NULL;
    }
}

// 陷阱3:栈内存返回
char* get_stack_string_bad() {
    char buffer[100];
    strcpy(buffer, "This is dangerous!");
    return buffer; // 返回栈内存地址!
}

// 正确的做法
char* get_heap_string_good() {
    char* buffer = malloc(100);
    strcpy(buffer, "This is safe!");
    return buffer;
}

// 陷阱4:缓冲区溢出
void unsafe_copy(char* dest, char* src) {
    strcpy(dest, src); // 没有检查目标缓冲区大小
}

void safe_copy(char* dest, char* src, int dest_size) {
    strncpy(dest, src, dest_size - 1);
    dest[dest_size - 1] = '\0';
}

// 陷阱5:竞态条件
static int global_counter = 0;

int increment_counter_unsafe() {
    return ++global_counter;
}

// 需要在 Go 端使用互斥锁保护
int get_counter() {
    return global_counter;
}

void reset_counter() {
    global_counter = 0;
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "sync"
    "unsafe"
)

// 演示内存泄漏陷阱
func demonstrateMemoryLeaks() {
    fmt.Println("=== Memory Leak Traps ===")

    // 陷阱1:忘记释放 C.CString
    fmt.Println("1. C.CString memory leak trap:")

    // 错误的做法(会导致内存泄漏)
    badExample := func() {
        input := "test string"
        cInput := C.CString(input)
        // 忘记调用 C.free(unsafe.Pointer(cInput))

        result := C.process_string_bad(cInput)
        goResult := C.GoString(result)
        fmt.Printf("  Result: %s\n", goResult)

        // 也忘记释放 result
        // 这里有两个内存泄漏!
        _ = goResult
    }

    // 正确的做法
    goodExample := func() {
        input := "test string"
        cInput := C.CString(input)
        defer C.free(unsafe.Pointer(cInput)) // 正确释放

        result := C.process_string_bad(cInput)
        defer C.free(unsafe.Pointer(result)) // 正确释放

        goResult := C.GoString(result)
        fmt.Printf("  Result (correct): %s\n", goResult)
    }

    fmt.Println("  Bad example (leaks memory):")
    badExample()

    fmt.Println("  Good example (no leaks):")
    goodExample()

    // 陷阱2:全局缓冲区管理
    fmt.Println("\n2. Global buffer management:")

    data1 := C.CString("First data")
    defer C.free(unsafe.Pointer(data1))
    C.set_global_buffer(data1)

    buffer := C.get_global_buffer()
    fmt.Printf("  Global buffer: %s\n", C.GoString(buffer))

    data2 := C.CString("Second data")
    defer C.free(unsafe.Pointer(data2))
    C.set_global_buffer(data2) // 这会正确释放之前的缓冲区

    buffer = C.get_global_buffer()
    fmt.Printf("  Global buffer: %s\n", C.GoString(buffer))

    C.clear_global_buffer() // 清理
}

// 演示缓冲区溢出陷阱
func demonstrateBufferOverflow() {
    fmt.Println("=== Buffer Overflow Traps ===")

    // 安全的字符串复制
    dest := make([]byte, 20)
    src := C.CString("This is a long string that might overflow")
    defer C.free(unsafe.Pointer(src))

    // 使用安全的复制函数
    C.safe_copy((*C.char)(unsafe.Pointer(&dest[0])), src, C.int(len(dest)))

    // 确保字符串以 null 结尾
    result := C.GoString((*C.char)(unsafe.Pointer(&dest[0])))
    fmt.Printf("Safely copied string: %s\n", result)
}

// 演示竞态条件陷阱
func demonstrateRaceCondition() {
    fmt.Println("=== Race Condition Traps ===")

    var mutex sync.Mutex
    var wg sync.WaitGroup

    C.reset_counter()

    // 启动多个 goroutine 并发访问 C 全局变量
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()

            for j := 0; j < 100; j++ {
                // 使用互斥锁保护对 C 全局变量的访问
                mutex.Lock()
                count := C.increment_counter_unsafe()
                mutex.Unlock()

                if j == 99 { // 只打印最后一次
                    fmt.Printf("  Goroutine %d final count: %d\n", id, int(count))
                }
            }
        }(i)
    }

    wg.Wait()

    finalCount := C.get_counter()
    fmt.Printf("Final counter value: %d (should be 1000)\n", int(finalCount))
}

// 安全的字符串处理包装器
type SafeStringProcessor struct {
    mutex sync.Mutex
}

func (ssp *SafeStringProcessor) ProcessString(input string) (string, error) {
    ssp.mutex.Lock()
    defer ssp.mutex.Unlock()

    if len(input) == 0 {
        return "", fmt.Errorf("empty input")
    }

    cInput := C.CString(input)
    defer C.free(unsafe.Pointer(cInput))

    result := C.process_string_bad(cInput)
    if result == nil {
        return "", fmt.Errorf("C function returned NULL")
    }
    defer C.free(unsafe.Pointer(result))

    return C.GoString(result), nil
}

// 资源管理最佳实践
type ResourceManager struct {
    resources []unsafe.Pointer
    mutex     sync.Mutex
}

func NewResourceManager() *ResourceManager {
    rm := &ResourceManager{
        resources: make([]unsafe.Pointer, 0),
    }
    runtime.SetFinalizer(rm, (*ResourceManager).finalize)
    return rm
}

func (rm *ResourceManager) AllocateString(s string) *C.char {
    rm.mutex.Lock()
    defer rm.mutex.Unlock()

    cStr := C.CString(s)
    rm.resources = append(rm.resources, unsafe.Pointer(cStr))
    return cStr
}

func (rm *ResourceManager) Close() {
    rm.mutex.Lock()
    defer rm.mutex.Unlock()

    for _, ptr := range rm.resources {
        C.free(ptr)
    }
    rm.resources = rm.resources[:0]
    runtime.SetFinalizer(rm, nil)
}

func (rm *ResourceManager) finalize() {
    rm.Close()
}

func demonstrateBestPractices() {
    fmt.Println("=== Best Practices ===")

    // 使用安全的字符串处理器
    processor := &SafeStringProcessor{}
    result, err := processor.ProcessString("Hello, CGO!")
    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("Processed: %s\n", result)
    }

    // 使用资源管理器
    rm := NewResourceManager()
    defer rm.Close()

    str1 := rm.AllocateString("Resource 1")
    str2 := rm.AllocateString("Resource 2")

    fmt.Printf("Managed string 1: %s\n", C.GoString(str1))
    fmt.Printf("Managed string 2: %s\n", C.GoString(str2))

    // 资源会在 rm.Close() 时自动释放
}

func main() {
    demonstrateMemoryOwnership()
    demonstrateObjectPool()
    demonstrateComplexTypes()
    benchmarkSingleVsBatch()
    demonstrateStringBatch()
    demonstrateIntVector()
    demonstrateMemoryLeaks()
    demonstrateBufferOverflow()
    demonstrateRaceCondition()
    demonstrateBestPractices()

    // 强制垃圾回收以触发终结器
    runtime.GC()
    runtime.GC()
}

小结 #

本节深入探讨了 CGO 的高级特性和常见陷阱,包括:

  1. 高级内存管理:内存所有权模型、对象池、引用计数等
  2. 复杂类型转换:嵌套结构体、联合体、变体类型的处理
  3. 性能优化:批量操作、内存预分配、SIMD 优化等技巧
  4. 常见陷阱:内存泄漏、缓冲区溢出、竞态条件等问题的避免
  5. 最佳实践:安全的资源管理、错误处理、并发控制等

掌握这些高级特性和避免常见陷阱对于编写稳定、高效的 CGO 代码至关重要。在下一节中,我们将学习如何设计灵活的插件架构。