3.4.2 GORM 基础操作

3.4.2 GORM 基础操作 #

GORM 是 Go 语言最受欢迎的 ORM 框架,提供了丰富的功能和优雅的 API 设计。本节将深入介绍 GORM 的基础概念、模型定义、关联关系以及基础的 CRUD 操作。

GORM 框架介绍 #

GORM 特性 #

GORM 提供了以下核心特性:

  • 全功能 ORM:支持关联、钩子、事务、迁移等
  • 链式 API:直观的查询构建器
  • 自动迁移:根据模型自动创建和更新表结构
  • 多数据库支持:MySQL、PostgreSQL、SQLite、SQL Server
  • 插件系统:可扩展的插件架构
  • 开发者友好:详细的错误信息和日志

安装和初始化 #

// 安装 GORM 和数据库驱动
// go get -u gorm.io/gorm
// go get -u gorm.io/driver/mysql
// go get -u gorm.io/driver/postgres
// go get -u gorm.io/driver/sqlite

package main

import (
    "fmt"
    "log"
    "time"

    "gorm.io/driver/mysql"
    "gorm.io/driver/postgres"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"
    "gorm.io/gorm/logger"
)

// GORM 配置结构
type GormConfig struct {
    LogLevel                  logger.LogLevel
    SlowThreshold            time.Duration
    IgnoreRecordNotFoundError bool
    DisableForeignKeyConstraintWhenMigrating bool
}

// 默认 GORM 配置
func DefaultGormConfig() *gorm.Config {
    return &gorm.Config{
        Logger: logger.New(
            log.New(os.Stdout, "\r\n", log.LstdFlags),
            logger.Config{
                SlowThreshold:             time.Second,
                LogLevel:                  logger.Info,
                IgnoreRecordNotFoundError: false,
                Colorful:                  true,
            },
        ),
        DisableForeignKeyConstraintWhenMigrating: true,
    }
}

// 初始化 MySQL 连接
func InitMySQL(dsn string) (*gorm.DB, error) {
    db, err := gorm.Open(mysql.Open(dsn), DefaultGormConfig())
    if err != nil {
        return nil, fmt.Errorf("failed to connect to MySQL: %w", err)
    }

    // 获取底层 sql.DB 进行连接池配置
    sqlDB, err := db.DB()
    if err != nil {
        return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
    }

    // 配置连接池
    sqlDB.SetMaxIdleConns(10)
    sqlDB.SetMaxOpenConns(100)
    sqlDB.SetConnMaxLifetime(time.Hour)

    return db, nil
}

// 初始化 PostgreSQL 连接
func InitPostgreSQL(dsn string) (*gorm.DB, error) {
    db, err := gorm.Open(postgres.Open(dsn), DefaultGormConfig())
    if err != nil {
        return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
    }

    sqlDB, err := db.DB()
    if err != nil {
        return nil, err
    }

    sqlDB.SetMaxIdleConns(10)
    sqlDB.SetMaxOpenConns(100)
    sqlDB.SetConnMaxLifetime(time.Hour)

    return db, nil
}

// 初始化 SQLite 连接
func InitSQLite(path string) (*gorm.DB, error) {
    db, err := gorm.Open(sqlite.Open(path), DefaultGormConfig())
    if err != nil {
        return nil, fmt.Errorf("failed to connect to SQLite: %w", err)
    }

    return db, nil
}

模型定义与关联 #

基础模型定义 #

import (
    "time"
    "gorm.io/gorm"
)

// 基础模型(包含常用字段)
type BaseModel struct {
    ID        uint           `gorm:"primaryKey" json:"id"`
    CreatedAt time.Time      `json:"created_at"`
    UpdatedAt time.Time      `json:"updated_at"`
    DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}

// 用户模型
type User struct {
    BaseModel
    Username  string    `gorm:"uniqueIndex;size:50;not null" json:"username"`
    Email     string    `gorm:"uniqueIndex;size:100;not null" json:"email"`
    Password  string    `gorm:"size:255;not null" json:"-"` // 不在 JSON 中显示
    FirstName string    `gorm:"size:50" json:"first_name"`
    LastName  string    `gorm:"size:50" json:"last_name"`
    Avatar    string    `gorm:"size:255" json:"avatar"`
    Status    string    `gorm:"size:20;default:active" json:"status"`
    LastLogin *time.Time `json:"last_login,omitempty"`

    // 关联字段
    Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"`
    Posts   []Post       `gorm:"foreignKey:AuthorID" json:"posts,omitempty"`
    Comments []Comment   `gorm:"foreignKey:UserID" json:"comments,omitempty"`
}

// 用户资料模型
type UserProfile struct {
    ID       uint   `gorm:"primaryKey" json:"id"`
    UserID   uint   `gorm:"uniqueIndex;not null" json:"user_id"`
    Bio      string `gorm:"type:text" json:"bio"`
    Website  string `gorm:"size:255" json:"website"`
    Location string `gorm:"size:100" json:"location"`
    Birthday *time.Time `json:"birthday,omitempty"`

    CreatedAt time.Time      `json:"created_at"`
    UpdatedAt time.Time      `json:"updated_at"`
    DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`

    // 反向关联
    User User `gorm:"foreignKey:UserID" json:"user,omitempty"`
}

// 文章模型
type Post struct {
    BaseModel
    Title     string `gorm:"size:200;not null" json:"title"`
    Slug      string `gorm:"uniqueIndex;size:200;not null" json:"slug"`
    Content   string `gorm:"type:longtext" json:"content"`
    Summary   string `gorm:"type:text" json:"summary"`
    Status    string `gorm:"size:20;default:draft" json:"status"`
    ViewCount int    `gorm:"default:0" json:"view_count"`
    AuthorID  uint   `gorm:"not null;index" json:"author_id"`

    // 关联字段
    Author   User      `gorm:"foreignKey:AuthorID" json:"author,omitempty"`
    Comments []Comment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
    Tags     []Tag     `gorm:"many2many:post_tags;" json:"tags,omitempty"`
}

// 评论模型
type Comment struct {
    BaseModel
    Content  string `gorm:"type:text;not null" json:"content"`
    PostID   uint   `gorm:"not null;index" json:"post_id"`
    UserID   uint   `gorm:"not null;index" json:"user_id"`
    ParentID *uint  `gorm:"index" json:"parent_id,omitempty"`

    // 关联字段
    Post     Post      `gorm:"foreignKey:PostID" json:"post,omitempty"`
    User     User      `gorm:"foreignKey:UserID" json:"user,omitempty"`
    Parent   *Comment  `gorm:"foreignKey:ParentID" json:"parent,omitempty"`
    Children []Comment `gorm:"foreignKey:ParentID" json:"children,omitempty"`
}

// 标签模型
type Tag struct {
    BaseModel
    Name        string `gorm:"uniqueIndex;size:50;not null" json:"name"`
    Description string `gorm:"type:text" json:"description"`
    Color       string `gorm:"size:7;default:#007bff" json:"color"`

    // 多对多关联
    Posts []Post `gorm:"many2many:post_tags;" json:"posts,omitempty"`
}

// 文章标签关联表(可选,GORM 会自动创建)
type PostTag struct {
    PostID uint `gorm:"primaryKey"`
    TagID  uint `gorm:"primaryKey"`

    CreatedAt time.Time
}

模型标签详解 #

// GORM 标签示例
type Product struct {
    ID          uint      `gorm:"primaryKey;autoIncrement" json:"id"`
    Code        string    `gorm:"uniqueIndex;size:50;not null" json:"code"`
    Name        string    `gorm:"size:100;not null" json:"name"`
    Description string    `gorm:"type:text" json:"description"`
    Price       float64   `gorm:"type:decimal(10,2);not null" json:"price"`
    Stock       int       `gorm:"default:0;check:stock >= 0" json:"stock"`
    CategoryID  uint      `gorm:"index;not null" json:"category_id"`
    IsActive    bool      `gorm:"default:true" json:"is_active"`
    Tags        string    `gorm:"type:json" json:"tags"` // JSON 字段

    CreatedAt time.Time      `gorm:"autoCreateTime" json:"created_at"`
    UpdatedAt time.Time      `gorm:"autoUpdateTime" json:"updated_at"`
    DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`

    // 关联
    Category Category `gorm:"foreignKey:CategoryID;constraint:OnUpdate:CASCADE,OnDelete:SET NULL" json:"category,omitempty"`
}

// 常用 GORM 标签说明:
// primaryKey: 主键
// autoIncrement: 自增
// uniqueIndex: 唯一索引
// index: 普通索引
// size: 字段大小
// type: 数据库字段类型
// not null: 非空约束
// default: 默认值
// check: 检查约束
// foreignKey: 外键字段
// constraint: 外键约束
// autoCreateTime: 自动设置创建时间
// autoUpdateTime: 自动设置更新时间

关联关系详解 #

// 一对一关联 (Has One)
type User struct {
    ID      uint
    Name    string
    Profile UserProfile `gorm:"foreignKey:UserID"`
}

type UserProfile struct {
    ID     uint
    UserID uint
    Bio    string
}

// 一对多关联 (Has Many)
type User struct {
    ID    uint
    Name  string
    Posts []Post `gorm:"foreignKey:UserID"`
}

type Post struct {
    ID     uint
    UserID uint
    Title  string
}

// 多对多关联 (Many to Many)
type User struct {
    ID    uint
    Name  string
    Roles []Role `gorm:"many2many:user_roles;"`
}

type Role struct {
    ID    uint
    Name  string
    Users []User `gorm:"many2many:user_roles;"`
}

// 自定义关联表
type UserRole struct {
    UserID    uint      `gorm:"primaryKey"`
    RoleID    uint      `gorm:"primaryKey"`
    CreatedAt time.Time
    CreatedBy uint
}

// 使用自定义关联表
type User struct {
    ID        uint
    Name      string
    UserRoles []UserRole `gorm:"foreignKey:UserID"`
}

基础 CRUD 操作 #

创建记录 #

// 数据访问层
type UserRepository struct {
    db *gorm.DB
}

func NewUserRepository(db *gorm.DB) *UserRepository {
    return &UserRepository{db: db}
}

// 创建单个用户
func (r *UserRepository) Create(user *User) error {
    result := r.db.Create(user)
    if result.Error != nil {
        return fmt.Errorf("failed to create user: %w", result.Error)
    }

    log.Printf("Created user with ID: %d, affected rows: %d", user.ID, result.RowsAffected)
    return nil
}

// 批量创建用户
func (r *UserRepository) CreateInBatches(users []User, batchSize int) error {
    result := r.db.CreateInBatches(users, batchSize)
    if result.Error != nil {
        return fmt.Errorf("failed to create users in batches: %w", result.Error)
    }

    log.Printf("Created %d users in batches", result.RowsAffected)
    return nil
}

// 创建或更新(Upsert)
func (r *UserRepository) CreateOrUpdate(user *User) error {
    // 使用 Clauses 进行 Upsert
    result := r.db.Clauses(clause.OnConflict{
        Columns:   []clause.Column{{Name: "email"}},
        DoUpdates: clause.AssignmentColumns([]string{"username", "first_name", "last_name", "updated_at"}),
    }).Create(user)

    if result.Error != nil {
        return fmt.Errorf("failed to create or update user: %w", result.Error)
    }

    return nil
}

// 使用 Map 创建
func (r *UserRepository) CreateFromMap(userData map[string]interface{}) error {
    result := r.db.Model(&User{}).Create(userData)
    if result.Error != nil {
        return fmt.Errorf("failed to create user from map: %w", result.Error)
    }

    return nil
}

查询记录 #

// 根据 ID 查询
func (r *UserRepository) GetByID(id uint) (*User, error) {
    var user User
    result := r.db.First(&user, id)

    if result.Error != nil {
        if errors.Is(result.Error, gorm.ErrRecordNotFound) {
            return nil, fmt.Errorf("user with ID %d not found", id)
        }
        return nil, fmt.Errorf("failed to get user: %w", result.Error)
    }

    return &user, nil
}

// 根据条件查询单个记录
func (r *UserRepository) GetByEmail(email string) (*User, error) {
    var user User
    result := r.db.Where("email = ?", email).First(&user)

    if result.Error != nil {
        if errors.Is(result.Error, gorm.ErrRecordNotFound) {
            return nil, fmt.Errorf("user with email %s not found", email)
        }
        return nil, fmt.Errorf("failed to get user by email: %w", result.Error)
    }

    return &user, nil
}

// 查询多个记录
func (r *UserRepository) GetAll() ([]User, error) {
    var users []User
    result := r.db.Find(&users)

    if result.Error != nil {
        return nil, fmt.Errorf("failed to get all users: %w", result.Error)
    }

    return users, nil
}

// 条件查询
func (r *UserRepository) GetByStatus(status string) ([]User, error) {
    var users []User
    result := r.db.Where("status = ?", status).Find(&users)

    if result.Error != nil {
        return nil, fmt.Errorf("failed to get users by status: %w", result.Error)
    }

    return users, nil
}

// 复杂条件查询
func (r *UserRepository) Search(keyword string, status string, limit, offset int) ([]User, int64, error) {
    var users []User
    var total int64

    query := r.db.Model(&User{})

    // 添加搜索条件
    if keyword != "" {
        query = query.Where("username LIKE ? OR email LIKE ? OR first_name LIKE ? OR last_name LIKE ?",
            "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
    }

    if status != "" {
        query = query.Where("status = ?", status)
    }

    // 获取总数
    if err := query.Count(&total).Error; err != nil {
        return nil, 0, fmt.Errorf("failed to count users: %w", err)
    }

    // 获取分页数据
    result := query.Limit(limit).Offset(offset).Find(&users)
    if result.Error != nil {
        return nil, 0, fmt.Errorf("failed to search users: %w", result.Error)
    }

    return users, total, nil
}

// 预加载关联数据
func (r *UserRepository) GetWithProfile(id uint) (*User, error) {
    var user User
    result := r.db.Preload("Profile").First(&user, id)

    if result.Error != nil {
        if errors.Is(result.Error, gorm.ErrRecordNotFound) {
            return nil, fmt.Errorf("user with ID %d not found", id)
        }
        return nil, fmt.Errorf("failed to get user with profile: %w", result.Error)
    }

    return &user, nil
}

// 预加载多个关联
func (r *UserRepository) GetWithAllRelations(id uint) (*User, error) {
    var user User
    result := r.db.Preload("Profile").
        Preload("Posts").
        Preload("Comments").
        First(&user, id)

    if result.Error != nil {
        return nil, fmt.Errorf("failed to get user with relations: %w", result.Error)
    }

    return &user, nil
}

更新记录 #

// 更新单个字段
func (r *UserRepository) UpdateStatus(id uint, status string) error {
    result := r.db.Model(&User{}).Where("id = ?", id).Update("status", status)

    if result.Error != nil {
        return fmt.Errorf("failed to update user status: %w", result.Error)
    }

    if result.RowsAffected == 0 {
        return fmt.Errorf("user with ID %d not found", id)
    }

    return nil
}

// 更新多个字段
func (r *UserRepository) UpdateFields(id uint, updates map[string]interface{}) error {
    result := r.db.Model(&User{}).Where("id = ?", id).Updates(updates)

    if result.Error != nil {
        return fmt.Errorf("failed to update user fields: %w", result.Error)
    }

    if result.RowsAffected == 0 {
        return fmt.Errorf("user with ID %d not found", id)
    }

    return nil
}

// 更新整个结构体
func (r *UserRepository) Update(user *User) error {
    result := r.db.Save(user)

    if result.Error != nil {
        return fmt.Errorf("failed to update user: %w", result.Error)
    }

    return nil
}

// 批量更新
func (r *UserRepository) UpdateBatch(status string, ids []uint) error {
    result := r.db.Model(&User{}).Where("id IN ?", ids).Update("status", status)

    if result.Error != nil {
        return fmt.Errorf("failed to batch update users: %w", result.Error)
    }

    log.Printf("Updated %d users", result.RowsAffected)
    return nil
}

// 条件更新
func (r *UserRepository) UpdateByCondition(condition map[string]interface{}, updates map[string]interface{}) error {
    query := r.db.Model(&User{})

    for key, value := range condition {
        query = query.Where(key+" = ?", value)
    }

    result := query.Updates(updates)
    if result.Error != nil {
        return fmt.Errorf("failed to update users by condition: %w", result.Error)
    }

    return nil
}

删除记录 #

// 软删除(推荐)
func (r *UserRepository) Delete(id uint) error {
    result := r.db.Delete(&User{}, id)

    if result.Error != nil {
        return fmt.Errorf("failed to delete user: %w", result.Error)
    }

    if result.RowsAffected == 0 {
        return fmt.Errorf("user with ID %d not found", id)
    }

    return nil
}

// 批量软删除
func (r *UserRepository) DeleteBatch(ids []uint) error {
    result := r.db.Delete(&User{}, ids)

    if result.Error != nil {
        return fmt.Errorf("failed to batch delete users: %w", result.Error)
    }

    log.Printf("Deleted %d users", result.RowsAffected)
    return nil
}

// 条件删除
func (r *UserRepository) DeleteByStatus(status string) error {
    result := r.db.Where("status = ?", status).Delete(&User{})

    if result.Error != nil {
        return fmt.Errorf("failed to delete users by status: %w", result.Error)
    }

    return nil
}

// 永久删除
func (r *UserRepository) HardDelete(id uint) error {
    result := r.db.Unscoped().Delete(&User{}, id)

    if result.Error != nil {
        return fmt.Errorf("failed to hard delete user: %w", result.Error)
    }

    return nil
}

// 恢复软删除的记录
func (r *UserRepository) Restore(id uint) error {
    result := r.db.Unscoped().Model(&User{}).Where("id = ?", id).Update("deleted_at", nil)

    if result.Error != nil {
        return fmt.Errorf("failed to restore user: %w", result.Error)
    }

    return nil
}

查询构建器使用 #

链式查询 #

// 复杂查询构建器
type UserQueryBuilder struct {
    db *gorm.DB
}

func NewUserQueryBuilder(db *gorm.DB) *UserQueryBuilder {
    return &UserQueryBuilder{db: db}
}

func (qb *UserQueryBuilder) Query() *gorm.DB {
    return qb.db.Model(&User{})
}

// 按状态筛选
func (qb *UserQueryBuilder) WithStatus(status string) *UserQueryBuilder {
    qb.db = qb.db.Where("status = ?", status)
    return qb
}

// 按创建时间范围筛选
func (qb *UserQueryBuilder) CreatedBetween(start, end time.Time) *UserQueryBuilder {
    qb.db = qb.db.Where("created_at BETWEEN ? AND ?", start, end)
    return qb
}

// 按关键词搜索
func (qb *UserQueryBuilder) Search(keyword string) *UserQueryBuilder {
    if keyword != "" {
        qb.db = qb.db.Where("username LIKE ? OR email LIKE ? OR first_name LIKE ? OR last_name LIKE ?",
            "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
    }
    return qb
}

// 排序
func (qb *UserQueryBuilder) OrderBy(field, direction string) *UserQueryBuilder {
    qb.db = qb.db.Order(field + " " + direction)
    return qb
}

// 分页
func (qb *UserQueryBuilder) Paginate(page, pageSize int) *UserQueryBuilder {
    offset := (page - 1) * pageSize
    qb.db = qb.db.Limit(pageSize).Offset(offset)
    return qb
}

// 预加载关联
func (qb *UserQueryBuilder) WithProfile() *UserQueryBuilder {
    qb.db = qb.db.Preload("Profile")
    return qb
}

func (qb *UserQueryBuilder) WithPosts() *UserQueryBuilder {
    qb.db = qb.db.Preload("Posts")
    return qb
}

// 执行查询
func (qb *UserQueryBuilder) Find() ([]User, error) {
    var users []User
    result := qb.db.Find(&users)
    return users, result.Error
}

func (qb *UserQueryBuilder) First() (*User, error) {
    var user User
    result := qb.db.First(&user)
    if result.Error != nil {
        return nil, result.Error
    }
    return &user, nil
}

func (qb *UserQueryBuilder) Count() (int64, error) {
    var count int64
    result := qb.db.Count(&count)
    return count, result.Error
}

// 使用示例
func ExampleQueryBuilder(db *gorm.DB) {
    qb := NewUserQueryBuilder(db)

    // 链式查询
    users, err := qb.
        WithStatus("active").
        Search("john").
        CreatedBetween(time.Now().AddDate(0, -1, 0), time.Now()).
        WithProfile().
        OrderBy("created_at", "DESC").
        Paginate(1, 10).
        Find()

    if err != nil {
        log.Printf("Query failed: %v", err)
        return
    }

    log.Printf("Found %d users", len(users))
}

完整的服务层示例 #

// 用户服务
type UserService struct {
    repo *UserRepository
}

func NewUserService(db *gorm.DB) *UserService {
    return &UserService{
        repo: NewUserRepository(db),
    }
}

// 创建用户
func (s *UserService) CreateUser(req CreateUserRequest) (*User, error) {
    // 检查邮箱是否已存在
    if existingUser, _ := s.repo.GetByEmail(req.Email); existingUser != nil {
        return nil, fmt.Errorf("email %s already exists", req.Email)
    }

    // 创建用户
    user := &User{
        Username:  req.Username,
        Email:     req.Email,
        Password:  hashPassword(req.Password), // 假设有密码哈希函数
        FirstName: req.FirstName,
        LastName:  req.LastName,
        Status:    "active",
    }

    if err := s.repo.Create(user); err != nil {
        return nil, fmt.Errorf("failed to create user: %w", err)
    }

    return user, nil
}

// 获取用户列表
func (s *UserService) GetUsers(page, pageSize int, status, keyword string) (*UserListResponse, error) {
    users, total, err := s.repo.Search(keyword, status, pageSize, (page-1)*pageSize)
    if err != nil {
        return nil, fmt.Errorf("failed to get users: %w", err)
    }

    return &UserListResponse{
        Users: users,
        Pagination: PaginationInfo{
            Page:      page,
            PageSize:  pageSize,
            Total:     total,
            TotalPage: (total + int64(pageSize) - 1) / int64(pageSize),
        },
    }, nil
}

// 更新用户
func (s *UserService) UpdateUser(id uint, req UpdateUserRequest) (*User, error) {
    user, err := s.repo.GetByID(id)
    if err != nil {
        return nil, err
    }

    // 更新字段
    updates := make(map[string]interface{})
    if req.FirstName != "" {
        updates["first_name"] = req.FirstName
    }
    if req.LastName != "" {
        updates["last_name"] = req.LastName
    }
    if req.Avatar != "" {
        updates["avatar"] = req.Avatar
    }

    if err := s.repo.UpdateFields(id, updates); err != nil {
        return nil, fmt.Errorf("failed to update user: %w", err)
    }

    // 返回更新后的用户
    return s.repo.GetByID(id)
}

// 删除用户
func (s *UserService) DeleteUser(id uint) error {
    if _, err := s.repo.GetByID(id); err != nil {
        return err
    }

    return s.repo.Delete(id)
}

// 请求和响应结构
type CreateUserRequest struct {
    Username  string `json:"username" validate:"required,min=3,max=50"`
    Email     string `json:"email" validate:"required,email"`
    Password  string `json:"password" validate:"required,min=8"`
    FirstName string `json:"first_name" validate:"required"`
    LastName  string `json:"last_name" validate:"required"`
}

type UpdateUserRequest struct {
    FirstName string `json:"first_name"`
    LastName  string `json:"last_name"`
    Avatar    string `json:"avatar"`
}

type UserListResponse struct {
    Users      []User         `json:"users"`
    Pagination PaginationInfo `json:"pagination"`
}

type PaginationInfo struct {
    Page      int   `json:"page"`
    PageSize  int   `json:"page_size"`
    Total     int64 `json:"total"`
    TotalPage int64 `json:"total_page"`
}

func hashPassword(password string) string {
    // 实现密码哈希逻辑
    return password // 简化示例
}

通过本节的学习,你已经掌握了 GORM 框架的基础使用方法,包括模型定义、关联关系以及基础的 CRUD 操作。这些知识为构建数据访问层提供了坚实的基础。在下一节中,我们将学习 GORM 的高级查询功能。