4.7.1 CGO 基础与使用

4.7.1 CGO 基础与使用 #

CGO 是 Go 语言的一个重要特性,它提供了 Go 程序与 C 代码互操作的能力。通过 CGO,我们可以在 Go 程序中调用 C 函数、使用 C 库,也可以让 C 代码调用 Go 函数。本节将详细介绍 CGO 的基础知识和使用方法。

CGO 概述 #

什么是 CGO #

CGO 是 Go 语言提供的一个工具,它允许 Go 包调用 C 代码。CGO 不是一个单独的编译器,而是 Go 工具链的一部分,它会在编译时处理 Go 代码中的 C 代码片段。

CGO 的优势和劣势 #

优势:

  • 可以重用现有的 C 库
  • 访问系统底层 API
  • 在性能关键的场景中使用优化的 C 代码
  • 与其他语言编写的库进行集成

劣势:

  • 增加了编译复杂性
  • 破坏了 Go 的跨平台特性
  • 可能引入内存安全问题
  • 调试变得更加困难
  • 性能开销(函数调用边界)

CGO 基本语法 #

启用 CGO #

要在 Go 代码中使用 CGO,需要导入特殊的 "C" 包:

package main

import "C"

func main() {
    // CGO 代码
}

嵌入 C 代码 #

可以通过注释的方式在 Go 文件中嵌入 C 代码:

package main

/*
#include <stdio.h>
#include <stdlib.h>

void hello() {
    printf("Hello from C!\n");
}

int add(int a, int b) {
    return a + b;
}
*/
import "C"
import "fmt"

func main() {
    // 调用 C 函数
    C.hello()

    result := C.add(10, 20)
    fmt.Printf("10 + 20 = %d\n", int(result))
}

外部 C 文件 #

也可以将 C 代码放在单独的文件中:

math.h

#ifndef MATH_H
#define MATH_H

int multiply(int a, int b);
double divide(double a, double b);

#endif

math.c

#include "math.h"

int multiply(int a, int b) {
    return a * b;
}

double divide(double a, double b) {
    if (b == 0) {
        return 0;
    }
    return a / b;
}

main.go

package main

/*
#include "math.h"
*/
import "C"
import "fmt"

func main() {
    result1 := C.multiply(5, 6)
    fmt.Printf("5 * 6 = %d\n", int(result1))

    result2 := C.divide(10.0, 3.0)
    fmt.Printf("10.0 / 3.0 = %f\n", float64(result2))
}

数据类型转换 #

基本类型映射 #

CGO 提供了 Go 和 C 类型之间的自动映射:

package main

/*
#include <stdint.h>

// C 类型示例
char get_char() { return 'A'; }
int get_int() { return 42; }
long get_long() { return 1234567890L; }
float get_float() { return 3.14f; }
double get_double() { return 2.718281828; }
*/
import "C"
import "fmt"

func main() {
    // 基本类型转换
    var c C.char = C.get_char()
    var i C.int = C.get_int()
    var l C.long = C.get_long()
    var f C.float = C.get_float()
    var d C.double = C.get_double()

    fmt.Printf("char: %c\n", byte(c))
    fmt.Printf("int: %d\n", int(i))
    fmt.Printf("long: %d\n", int64(l))
    fmt.Printf("float: %f\n", float32(f))
    fmt.Printf("double: %f\n", float64(d))
}

字符串处理 #

字符串在 Go 和 C 之间的转换需要特别注意:

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

void print_string(char* str) {
    printf("C received: %s\n", str);
}

char* create_string() {
    char* str = malloc(20);
    strcpy(str, "Hello from C");
    return str;
}

int string_length(char* str) {
    return strlen(str);
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

func main() {
    // Go 字符串转 C 字符串
    goStr := "Hello from Go"
    cStr := C.CString(goStr)
    defer C.free(unsafe.Pointer(cStr)) // 必须释放内存

    C.print_string(cStr)

    // C 字符串转 Go 字符串
    cResult := C.create_string()
    defer C.free(unsafe.Pointer(cResult)) // 必须释放内存

    goResult := C.GoString(cResult)
    fmt.Printf("Go received: %s\n", goResult)

    // 获取字符串长度
    length := C.string_length(cStr)
    fmt.Printf("String length: %d\n", int(length))
}

数组和切片 #

处理数组和切片需要使用 unsafe 包:

package main

/*
#include <stdio.h>

void print_array(int* arr, int size) {
    printf("Array: ");
    for (int i = 0; i < size; i++) {
        printf("%d ", arr[i]);
    }
    printf("\n");
}

void fill_array(int* arr, int size) {
    for (int i = 0; i < size; i++) {
        arr[i] = i * i;
    }
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

func main() {
    // Go 切片传递给 C
    goSlice := []int{1, 2, 3, 4, 5}
    C.print_array((*C.int)(unsafe.Pointer(&goSlice[0])), C.int(len(goSlice)))

    // C 修改 Go 切片
    C.fill_array((*C.int)(unsafe.Pointer(&goSlice[0])), C.int(len(goSlice)))
    fmt.Printf("Modified slice: %v\n", goSlice)

    // 创建 C 数组
    size := 10
    cArray := (*C.int)(C.malloc(C.size_t(size * 4))) // int 通常是 4 字节
    defer C.free(unsafe.Pointer(cArray))

    C.fill_array(cArray, C.int(size))

    // 将 C 数组转换为 Go 切片
    goArray := (*[10]C.int)(unsafe.Pointer(cArray))[:size:size]
    fmt.Printf("C array as Go slice: %v\n", goArray)
}

结构体处理 #

简单结构体 #

package main

/*
#include <stdio.h>

typedef struct {
    int x;
    int y;
} Point;

typedef struct {
    char name[50];
    int age;
    double salary;
} Person;

void print_point(Point p) {
    printf("Point: (%d, %d)\n", p.x, p.y);
}

Point create_point(int x, int y) {
    Point p = {x, y};
    return p;
}

void print_person(Person* p) {
    printf("Person: %s, age %d, salary %.2f\n", p->name, p->age, p->salary);
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

func main() {
    // 创建 C 结构体
    var point C.Point
    point.x = 10
    point.y = 20

    C.print_point(point)

    // 从 C 函数获取结构体
    newPoint := C.create_point(30, 40)
    fmt.Printf("New point: (%d, %d)\n", int(newPoint.x), int(newPoint.y))

    // 复杂结构体
    var person C.Person

    // 设置字符串字段
    name := C.CString("John Doe")
    defer C.free(unsafe.Pointer(name))

    // 复制字符串到结构体中的数组
    C.strcpy(&person.name[0], name)
    person.age = 30
    person.salary = 50000.0

    C.print_person(&person)
}

包含指针的结构体 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef struct {
    char* name;
    int* values;
    int count;
} Data;

Data* create_data(char* name, int* values, int count) {
    Data* data = malloc(sizeof(Data));
    data->name = malloc(strlen(name) + 1);
    strcpy(data->name, name);

    data->values = malloc(count * sizeof(int));
    memcpy(data->values, values, count * sizeof(int));
    data->count = count;

    return data;
}

void free_data(Data* data) {
    if (data) {
        free(data->name);
        free(data->values);
        free(data);
    }
}

void print_data(Data* data) {
    printf("Data: %s, values: ", data->name);
    for (int i = 0; i < data->count; i++) {
        printf("%d ", data->values[i]);
    }
    printf("\n");
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

func main() {
    // 准备数据
    name := C.CString("Sample Data")
    defer C.free(unsafe.Pointer(name))

    values := []C.int{1, 2, 3, 4, 5}

    // 创建包含指针的结构体
    data := C.create_data(name, &values[0], C.int(len(values)))
    defer C.free_data(data)

    C.print_data(data)

    // 访问结构体字段
    goName := C.GoString(data.name)
    fmt.Printf("Name from Go: %s\n", goName)
    fmt.Printf("Count: %d\n", int(data.count))
}

函数指针和回调 #

Go 函数作为 C 回调 #

package main

/*
#include <stdio.h>

typedef int (*callback_func)(int);

int call_callback(callback_func cb, int value) {
    printf("Calling callback with value: %d\n", value);
    return cb(value);
}

// 声明外部函数(由 Go 导出)
extern int go_callback(int);

int test_go_callback(int value) {
    return call_callback(go_callback, value);
}
*/
import "C"
import "fmt"

//export go_callback
func go_callback(value C.int) C.int {
    fmt.Printf("Go callback called with: %d\n", int(value))
    return value * 2
}

func main() {
    result := C.test_go_callback(42)
    fmt.Printf("Callback result: %d\n", int(result))
}

更复杂的回调示例 #

package main

/*
#include <stdio.h>
#include <stdlib.h>

typedef struct {
    int (*process)(int);
    void (*notify)(char*);
} Callbacks;

void process_data(int* data, int size, Callbacks* cb) {
    for (int i = 0; i < size; i++) {
        data[i] = cb->process(data[i]);
    }
    cb->notify("Processing completed");
}

// 外部函数声明
extern int go_process(int);
extern void go_notify(char*);

Callbacks* create_callbacks() {
    Callbacks* cb = malloc(sizeof(Callbacks));
    cb->process = go_process;
    cb->notify = go_notify;
    return cb;
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

//export go_process
func go_process(value C.int) C.int {
    // 简单的数据处理:平方
    result := int(value) * int(value)
    return C.int(result)
}

//export go_notify
func go_notify(message *C.char) {
    goMessage := C.GoString(message)
    fmt.Printf("Notification: %s\n", goMessage)
}

func main() {
    // 创建回调结构体
    callbacks := C.create_callbacks()
    defer C.free(unsafe.Pointer(callbacks))

    // 准备数据
    data := []C.int{1, 2, 3, 4, 5}
    fmt.Printf("Original data: %v\n", data)

    // 处理数据
    C.process_data(&data[0], C.int(len(data)), callbacks)

    fmt.Printf("Processed data: %v\n", data)
}

错误处理 #

C 函数的错误处理 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <string.h>

int divide_safe(int a, int b, int* result) {
    if (b == 0) {
        errno = EINVAL;
        return -1;
    }
    *result = a / b;
    return 0;
}

char* read_file(char* filename) {
    FILE* file = fopen(filename, "r");
    if (!file) {
        return NULL;
    }

    fseek(file, 0, SEEK_END);
    long size = ftell(file);
    fseek(file, 0, SEEK_SET);

    char* content = malloc(size + 1);
    if (!content) {
        fclose(file);
        return NULL;
    }

    fread(content, 1, size, file);
    content[size] = '\0';
    fclose(file);

    return content;
}
*/
import "C"
import (
    "errors"
    "fmt"
    "os"
    "unsafe"
)

func safeDivide(a, b int) (int, error) {
    var result C.int
    ret := C.divide_safe(C.int(a), C.int(b), &result)

    if ret != 0 {
        return 0, errors.New("division by zero")
    }

    return int(result), nil
}

func readFileC(filename string) (string, error) {
    cFilename := C.CString(filename)
    defer C.free(unsafe.Pointer(cFilename))

    content := C.read_file(cFilename)
    if content == nil {
        return "", fmt.Errorf("failed to read file: %s", filename)
    }
    defer C.free(unsafe.Pointer(content))

    return C.GoString(content), nil
}

func main() {
    // 测试安全除法
    result, err := safeDivide(10, 2)
    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("10 / 2 = %d\n", result)
    }

    result, err = safeDivide(10, 0)
    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("10 / 0 = %d\n", result)
    }

    // 创建测试文件
    testFile := "test.txt"
    err = os.WriteFile(testFile, []byte("Hello, CGO!"), 0644)
    if err != nil {
        fmt.Printf("Failed to create test file: %v\n", err)
        return
    }
    defer os.Remove(testFile)

    // 测试文件读取
    content, err := readFileC(testFile)
    if err != nil {
        fmt.Printf("Error reading file: %v\n", err)
    } else {
        fmt.Printf("File content: %s\n", content)
    }

    // 测试读取不存在的文件
    _, err = readFileC("nonexistent.txt")
    if err != nil {
        fmt.Printf("Expected error: %v\n", err)
    }
}

内存管理 #

基本内存管理原则 #

package main

/*
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef struct {
    char* data;
    int size;
} Buffer;

Buffer* create_buffer(int size) {
    Buffer* buf = malloc(sizeof(Buffer));
    if (!buf) return NULL;

    buf->data = malloc(size);
    if (!buf->data) {
        free(buf);
        return NULL;
    }

    buf->size = size;
    memset(buf->data, 0, size);
    return buf;
}

void free_buffer(Buffer* buf) {
    if (buf) {
        free(buf->data);
        free(buf);
    }
}

void fill_buffer(Buffer* buf, char c) {
    if (buf && buf->data) {
        memset(buf->data, c, buf->size);
    }
}
*/
import "C"
import (
    "fmt"
    "runtime"
    "unsafe"
)

// BufferWrapper Go 包装器
type BufferWrapper struct {
    cBuffer *C.Buffer
}

// NewBuffer 创建新的缓冲区
func NewBuffer(size int) *BufferWrapper {
    cBuf := C.create_buffer(C.int(size))
    if cBuf == nil {
        return nil
    }

    wrapper := &BufferWrapper{cBuffer: cBuf}

    // 设置终结器,确保 C 内存被释放
    runtime.SetFinalizer(wrapper, (*BufferWrapper).finalize)

    return wrapper
}

// Fill 填充缓冲区
func (b *BufferWrapper) Fill(c byte) {
    if b.cBuffer != nil {
        C.fill_buffer(b.cBuffer, C.char(c))
    }
}

// Size 获取缓冲区大小
func (b *BufferWrapper) Size() int {
    if b.cBuffer != nil {
        return int(b.cBuffer.size)
    }
    return 0
}

// Data 获取缓冲区数据
func (b *BufferWrapper) Data() []byte {
    if b.cBuffer == nil || b.cBuffer.data == nil {
        return nil
    }

    // 将 C 数组转换为 Go 切片
    size := int(b.cBuffer.size)
    return C.GoBytes(unsafe.Pointer(b.cBuffer.data), C.int(size))
}

// Close 手动释放资源
func (b *BufferWrapper) Close() {
    if b.cBuffer != nil {
        C.free_buffer(b.cBuffer)
        b.cBuffer = nil
        runtime.SetFinalizer(b, nil)
    }
}

// finalize 终结器函数
func (b *BufferWrapper) finalize() {
    b.Close()
}

func main() {
    // 创建缓冲区
    buf := NewBuffer(1024)
    if buf == nil {
        fmt.Println("Failed to create buffer")
        return
    }
    defer buf.Close() // 确保资源被释放

    fmt.Printf("Buffer size: %d\n", buf.Size())

    // 填充缓冲区
    buf.Fill('A')

    // 获取数据
    data := buf.Data()
    fmt.Printf("First 10 bytes: %v\n", data[:10])

    // 演示内存管理
    for i := 0; i < 100; i++ {
        tempBuf := NewBuffer(1024)
        if tempBuf != nil {
            tempBuf.Fill(byte(i % 256))
            // 不调用 Close(),依赖终结器
        }
    }

    // 强制垃圾回收
    runtime.GC()
    runtime.GC()

    fmt.Println("Memory management test completed")
}

编译和构建 #

编译指令 #

可以使用编译指令来控制 CGO 的行为:

package main

/*
#cgo CFLAGS: -Wall -O2
#cgo LDFLAGS: -lm
#cgo pkg-config: sqlite3

#include <math.h>
#include <sqlite3.h>

double calculate_sqrt(double x) {
    return sqrt(x);
}
*/
import "C"
import "fmt"

func main() {
    result := C.calculate_sqrt(16.0)
    fmt.Printf("sqrt(16) = %f\n", float64(result))
}

条件编译 #

// +build cgo

package main

/*
#include <stdio.h>

void cgo_function() {
    printf("CGO is enabled\n");
}
*/
import "C"

func main() {
    C.cgo_function()
}

对应的非 CGO 版本:

// +build !cgo

package main

import "fmt"

func main() {
    fmt.Println("CGO is disabled")
}

性能考虑 #

函数调用开销测试 #

package main

/*
#include <stdio.h>

int c_add(int a, int b) {
    return a + b;
}

void c_noop() {
    // 空函数
}
*/
import "C"
import (
    "fmt"
    "time"
)

func goAdd(a, b int) int {
    return a + b
}

func goNoop() {
    // 空函数
}

func main() {
    const iterations = 10000000

    // 测试 Go 函数调用
    start := time.Now()
    for i := 0; i < iterations; i++ {
        goAdd(i, i+1)
    }
    goDuration := time.Since(start)

    // 测试 CGO 函数调用
    start = time.Now()
    for i := 0; i < iterations; i++ {
        C.c_add(C.int(i), C.int(i+1))
    }
    cgoDuration := time.Since(start)

    fmt.Printf("Go function calls: %v\n", goDuration)
    fmt.Printf("CGO function calls: %v\n", cgoDuration)
    fmt.Printf("CGO overhead: %.2fx\n", float64(cgoDuration)/float64(goDuration))

    // 测试空函数调用开销
    start = time.Now()
    for i := 0; i < iterations; i++ {
        goNoop()
    }
    goNoopDuration := time.Since(start)

    start = time.Now()
    for i := 0; i < iterations; i++ {
        C.c_noop()
    }
    cgoNoopDuration := time.Since(start)

    fmt.Printf("Go noop calls: %v\n", goNoopDuration)
    fmt.Printf("CGO noop calls: %v\n", cgoNoopDuration)
    fmt.Printf("CGO noop overhead: %.2fx\n", float64(cgoNoopDuration)/float64(goNoopDuration))
}

实际应用示例 #

使用 C 库进行图像处理 #

package main

/*
#cgo pkg-config: libpng
#include <png.h>
#include <stdio.h>
#include <stdlib.h>

typedef struct {
    int width;
    int height;
    png_byte color_type;
    png_byte bit_depth;
    png_bytep *row_pointers;
} image_data;

image_data* read_png_file(char* filename) {
    FILE *fp = fopen(filename, "rb");
    if (!fp) return NULL;

    png_structp png = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
    if (!png) {
        fclose(fp);
        return NULL;
    }

    png_infop info = png_create_info_struct(png);
    if (!info) {
        png_destroy_read_struct(&png, NULL, NULL);
        fclose(fp);
        return NULL;
    }

    if (setjmp(png_jmpbuf(png))) {
        png_destroy_read_struct(&png, &info, NULL);
        fclose(fp);
        return NULL;
    }

    png_init_io(png, fp);
    png_read_info(png, info);

    image_data* img = malloc(sizeof(image_data));
    img->width = png_get_image_width(png, info);
    img->height = png_get_image_height(png, info);
    img->color_type = png_get_color_type(png, info);
    img->bit_depth = png_get_bit_depth(png, info);

    if (img->bit_depth == 16)
        png_set_strip_16(png);

    if (img->color_type == PNG_COLOR_TYPE_PALETTE)
        png_set_palette_to_rgb(png);

    if (img->color_type == PNG_COLOR_TYPE_GRAY && img->bit_depth < 8)
        png_set_expand_gray_1_2_4_to_8(png);

    if (png_get_valid(png, info, PNG_INFO_tRNS))
        png_set_tRNS_to_alpha(png);

    if (img->color_type == PNG_COLOR_TYPE_RGB ||
        img->color_type == PNG_COLOR_TYPE_GRAY ||
        img->color_type == PNG_COLOR_TYPE_PALETTE)
        png_set_filler(png, 0xFF, PNG_FILLER_AFTER);

    if (img->color_type == PNG_COLOR_TYPE_GRAY ||
        img->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
        png_set_gray_to_rgb(png);

    png_read_update_info(png, info);

    img->row_pointers = malloc(sizeof(png_bytep) * img->height);
    for (int y = 0; y < img->height; y++) {
        img->row_pointers[y] = malloc(png_get_rowbytes(png, info));
    }

    png_read_image(png, img->row_pointers);

    fclose(fp);
    png_destroy_read_struct(&png, &info, NULL);

    return img;
}

void free_image_data(image_data* img) {
    if (img) {
        for (int y = 0; y < img->height; y++) {
            free(img->row_pointers[y]);
        }
        free(img->row_pointers);
        free(img);
    }
}

void apply_grayscale(image_data* img) {
    for (int y = 0; y < img->height; y++) {
        png_bytep row = img->row_pointers[y];
        for (int x = 0; x < img->width; x++) {
            png_bytep px = &(row[x * 4]);
            // 计算灰度值
            int gray = (int)(0.299 * px[0] + 0.587 * px[1] + 0.114 * px[2]);
            px[0] = px[1] = px[2] = gray;
        }
    }
}
*/
import "C"
import (
    "fmt"
    "unsafe"
)

// Image Go 包装器
type Image struct {
    cImage *C.image_data
}

// LoadPNG 加载 PNG 图像
func LoadPNG(filename string) (*Image, error) {
    cFilename := C.CString(filename)
    defer C.free(unsafe.Pointer(cFilename))

    cImg := C.read_png_file(cFilename)
    if cImg == nil {
        return nil, fmt.Errorf("failed to load PNG file: %s", filename)
    }

    return &Image{cImage: cImg}, nil
}

// Close 释放图像资源
func (img *Image) Close() {
    if img.cImage != nil {
        C.free_image_data(img.cImage)
        img.cImage = nil
    }
}

// Width 获取图像宽度
func (img *Image) Width() int {
    if img.cImage != nil {
        return int(img.cImage.width)
    }
    return 0
}

// Height 获取图像高度
func (img *Image) Height() int {
    if img.cImage != nil {
        return int(img.cImage.height)
    }
    return 0
}

// ApplyGrayscale 应用灰度滤镜
func (img *Image) ApplyGrayscale() {
    if img.cImage != nil {
        C.apply_grayscale(img.cImage)
    }
}

func main() {
    // 注意:这个示例需要 libpng 库和一个 PNG 文件
    fmt.Println("PNG processing example")
    fmt.Println("Note: This example requires libpng and a PNG file to work")

    // 示例用法(需要实际的 PNG 文件)
    /*
    img, err := LoadPNG("example.png")
    if err != nil {
        fmt.Printf("Error: %v\n", err)
        return
    }
    defer img.Close()

    fmt.Printf("Image size: %dx%d\n", img.Width(), img.Height())

    // 应用灰度滤镜
    img.ApplyGrayscale()
    fmt.Println("Grayscale filter applied")
    */
}

小结 #

本节详细介绍了 CGO 的基础知识和使用方法,包括:

  1. CGO 概述:了解 CGO 的作用和优缺点
  2. 基本语法:掌握 CGO 的基本语法和代码组织方式
  3. 数据类型转换:学会 Go 和 C 类型之间的转换
  4. 结构体处理:处理复杂的数据结构
  5. 函数指针和回调:实现 Go 和 C 之间的双向调用
  6. 错误处理:正确处理 C 函数的错误
  7. 内存管理:避免内存泄漏和悬空指针
  8. 性能考虑:了解 CGO 的性能特点

CGO 是一个强大的工具,但使用时需要特别注意内存管理和类型安全。在下一节中,我们将深入探讨 CGO 的高级特性和常见陷阱。