5.4.3 客户端负载均衡

5.4.3 客户端负载均衡 #

客户端负载均衡是一种将负载均衡逻辑嵌入到客户端应用程序中的架构模式。与传统的服务端负载均衡不同,客户端负载均衡让客户端直接负责选择合适的服务实例,从而减少了网络跳转,提高了性能,并增强了系统的弹性。

客户端负载均衡原理 #

架构对比 #

传统服务端负载均衡

Client → Load Balancer → Service Instance

客户端负载均衡

Client (with LB logic) → Service Instance

优势与挑战 #

优势

  • 减少网络延迟(少一跳)
  • 避免负载均衡器单点故障
  • 更好的性能和可扩展性
  • 客户端可以实现更智能的路由策略

挑战

  • 客户端复杂性增加
  • 服务发现和健康检查的复杂性
  • 多语言环境下的一致性问题
  • 配置管理的复杂性

客户端负载均衡器设计 #

核心组件架构 #

// client/load_balancer.go
package client

import (
    "context"
    "fmt"
    "sync"
    "time"

    "github.com/example/client-lb/balancer"
    "github.com/example/client-lb/discovery"
    "github.com/example/client-lb/health"
)

// ClientLoadBalancer 客户端负载均衡器
type ClientLoadBalancer struct {
    serviceName string

    // 核心组件
    discoveryClient discovery.ServiceDiscovery
    balancer        balancer.LoadBalancer
    healthChecker   health.HealthChecker

    // 实例管理
    instances map[string]*balancer.Instance
    mutex     sync.RWMutex

    // 配置
    config ClientLBConfig

    // 控制
    ctx    context.Context
    cancel context.CancelFunc

    // 统计信息
    stats *ClientLBStats
}

// ClientLBConfig 客户端负载均衡配置
type ClientLBConfig struct {
    ServiceName           string
    BalancerType         balancer.BalancerType
    HealthCheckInterval  time.Duration
    DiscoveryInterval    time.Duration
    MaxRetries           int
    RetryTimeout         time.Duration
    CircuitBreakerConfig CircuitBreakerConfig
}

// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
    Enabled              bool
    FailureThreshold     int
    RecoveryTimeout      time.Duration
    HalfOpenMaxRequests  int
}

// ClientLBStats 客户端负载均衡统计
type ClientLBStats struct {
    TotalRequests    int64
    SuccessRequests  int64
    FailedRequests   int64
    RetryRequests    int64
    AvgResponseTime  time.Duration

    mutex sync.RWMutex
}

// NewClientLoadBalancer 创建客户端负载均衡器
func NewClientLoadBalancer(
    discoveryClient discovery.ServiceDiscovery,
    config ClientLBConfig,
) (*ClientLoadBalancer, error) {

    // 创建负载均衡器
    factory := balancer.NewBalancerFactory()
    lb, err := factory.CreateBalancer(balancer.BalancerConfig{
        Type: config.BalancerType,
    })
    if err != nil {
        return nil, fmt.Errorf("failed to create balancer: %v", err)
    }

    // 创建健康检查器
    healthChecker := health.NewHTTPHealthChecker(time.Second * 5)

    ctx, cancel := context.WithCancel(context.Background())

    clb := &ClientLoadBalancer{
        serviceName:     config.ServiceName,
        discoveryClient: discoveryClient,
        balancer:        lb,
        healthChecker:   healthChecker,
        instances:       make(map[string]*balancer.Instance),
        config:          config,
        ctx:             ctx,
        cancel:          cancel,
        stats:           &ClientLBStats{},
    }

    // 启动后台任务
    go clb.startDiscovery()
    go clb.startHealthCheck()

    return clb, nil
}

// Select 选择服务实例
func (c *ClientLoadBalancer) Select(ctx context.Context) (*balancer.Instance, error) {
    return c.balancer.Select(ctx)
}

// Call 调用服务实例
func (c *ClientLoadBalancer) Call(ctx context.Context, request CallRequest) (*CallResponse, error) {
    startTime := time.Now()

    // 更新统计
    defer func() {
        c.stats.mutex.Lock()
        c.stats.TotalRequests++
        responseTime := time.Since(startTime)

        // 更新平均响应时间
        if c.stats.AvgResponseTime == 0 {
            c.stats.AvgResponseTime = responseTime
        } else {
            c.stats.AvgResponseTime = (c.stats.AvgResponseTime + responseTime) / 2
        }
        c.stats.mutex.Unlock()
    }()

    var lastErr error

    // 重试逻辑
    for attempt := 0; attempt <= c.config.MaxRetries; attempt++ {
        if attempt > 0 {
            c.stats.mutex.Lock()
            c.stats.RetryRequests++
            c.stats.mutex.Unlock()

            // 重试延迟
            select {
            case <-ctx.Done():
                return nil, ctx.Err()
            case <-time.After(time.Millisecond * time.Duration(attempt*100)):
            }
        }

        // 选择实例
        instance, err := c.Select(ctx)
        if err != nil {
            lastErr = err
            continue
        }

        // 执行调用
        response, err := c.executeCall(ctx, instance, request)
        if err != nil {
            lastErr = err
            c.balancer.MarkFailure(instance.ID)
            continue
        }

        // 调用成功
        c.balancer.MarkSuccess(instance.ID)
        c.stats.mutex.Lock()
        c.stats.SuccessRequests++
        c.stats.mutex.Unlock()

        return response, nil
    }

    // 所有重试都失败
    c.stats.mutex.Lock()
    c.stats.FailedRequests++
    c.stats.mutex.Unlock()

    return nil, fmt.Errorf("all retries failed, last error: %v", lastErr)
}

// executeCall 执行具体的调用
func (c *ClientLoadBalancer) executeCall(
    ctx context.Context,
    instance *balancer.Instance,
    request CallRequest,
) (*CallResponse, error) {

    // 创建带超时的上下文
    callCtx, cancel := context.WithTimeout(ctx, c.config.RetryTimeout)
    defer cancel()

    // 构建请求URL
    url := fmt.Sprintf("http://%s:%d%s", instance.Address, instance.Port, request.Path)

    // 执行HTTP请求(这里简化实现)
    response, err := c.makeHTTPRequest(callCtx, request.Method, url, request.Body, request.Headers)
    if err != nil {
        return nil, err
    }

    return response, nil
}

// startDiscovery 启动服务发现
func (c *ClientLoadBalancer) startDiscovery() {
    ticker := time.NewTicker(c.config.DiscoveryInterval)
    defer ticker.Stop()

    // 立即执行一次发现
    c.updateInstances()

    for {
        select {
        case <-c.ctx.Done():
            return
        case <-ticker.C:
            c.updateInstances()
        }
    }
}

// updateInstances 更新实例列表
func (c *ClientLoadBalancer) updateInstances() {
    instances, err := c.discoveryClient.Discover(c.ctx, c.serviceName)
    if err != nil {
        // 记录错误但不中断服务
        return
    }

    // 转换为负载均衡器实例
    lbInstances := make([]*balancer.Instance, 0, len(instances))
    newInstanceMap := make(map[string]*balancer.Instance)

    for _, inst := range instances {
        lbInstance := &balancer.Instance{
            ID:       inst.ID,
            Address:  inst.Address,
            Port:     inst.Port,
            Weight:   1, // 默认权重
            Active:   true,
            Metadata: inst.Metadata,
        }

        // 如果是已存在的实例,保留其状态
        c.mutex.RLock()
        if existing, exists := c.instances[inst.ID]; exists {
            lbInstance.Active = existing.Active
            lbInstance.Connections = existing.Connections
            lbInstance.ErrorCount = existing.ErrorCount
            lbInstance.ResponseTime = existing.ResponseTime
        }
        c.mutex.RUnlock()

        lbInstances = append(lbInstances, lbInstance)
        newInstanceMap[inst.ID] = lbInstance
    }

    // 更新实例映射
    c.mutex.Lock()
    c.instances = newInstanceMap
    c.mutex.Unlock()

    // 更新负载均衡器
    c.balancer.UpdateInstances(lbInstances)
}

// startHealthCheck 启动健康检查
func (c *ClientLoadBalancer) startHealthCheck() {
    ticker := time.NewTicker(c.config.HealthCheckInterval)
    defer ticker.Stop()

    for {
        select {
        case <-c.ctx.Done():
            return
        case <-ticker.C:
            c.performHealthCheck()
        }
    }
}

// performHealthCheck 执行健康检查
func (c *ClientLoadBalancer) performHealthCheck() {
    c.mutex.RLock()
    instances := make([]*balancer.Instance, 0, len(c.instances))
    for _, instance := range c.instances {
        instances = append(instances, instance)
    }
    c.mutex.RUnlock()

    // 并发检查所有实例
    var wg sync.WaitGroup
    for _, instance := range instances {
        wg.Add(1)
        go func(inst *balancer.Instance) {
            defer wg.Done()
            c.checkInstanceHealth(inst)
        }(instance)
    }
    wg.Wait()
}

// checkInstanceHealth 检查单个实例健康状态
func (c *ClientLoadBalancer) checkInstanceHealth(instance *balancer.Instance) {
    ctx, cancel := context.WithTimeout(c.ctx, time.Second*5)
    defer cancel()

    // 获取健康检查路径
    healthPath := instance.Metadata["health_check_path"]
    if healthPath == "" {
        healthPath = "/health"
    }

    err := c.healthChecker.Check(ctx, instance.Address, instance.Port, healthPath)

    c.mutex.Lock()
    defer c.mutex.Unlock()

    if err != nil {
        instance.Active = false
        instance.ErrorCount++
    } else {
        instance.Active = true
        instance.ErrorCount = 0
    }
}

// Stop 停止客户端负载均衡器
func (c *ClientLoadBalancer) Stop() {
    c.cancel()
}

// GetStats 获取统计信息
func (c *ClientLoadBalancer) GetStats() map[string]interface{} {
    c.stats.mutex.RLock()
    defer c.stats.mutex.RUnlock()

    c.mutex.RLock()
    activeInstances := 0
    for _, instance := range c.instances {
        if instance.Active {
            activeInstances++
        }
    }
    c.mutex.RUnlock()

    successRate := float64(0)
    if c.stats.TotalRequests > 0 {
        successRate = float64(c.stats.SuccessRequests) / float64(c.stats.TotalRequests) * 100
    }

    return map[string]interface{}{
        "service_name":       c.serviceName,
        "total_instances":    len(c.instances),
        "active_instances":   activeInstances,
        "total_requests":     c.stats.TotalRequests,
        "success_requests":   c.stats.SuccessRequests,
        "failed_requests":    c.stats.FailedRequests,
        "retry_requests":     c.stats.RetryRequests,
        "success_rate":       successRate,
        "avg_response_time":  c.stats.AvgResponseTime.Milliseconds(),
        "balancer_stats":     c.balancer.GetStats(),
    }
}

请求和响应定义 #

// client/types.go
package client

import (
    "io"
    "net/http"
)

// CallRequest 调用请求
type CallRequest struct {
    Method  string
    Path    string
    Body    io.Reader
    Headers map[string]string
    Timeout time.Duration
}

// CallResponse 调用响应
type CallResponse struct {
    StatusCode int
    Headers    map[string][]string
    Body       []byte
    Duration   time.Duration
}

// makeHTTPRequest 执行HTTP请求
func (c *ClientLoadBalancer) makeHTTPRequest(
    ctx context.Context,
    method, url string,
    body io.Reader,
    headers map[string]string,
) (*CallResponse, error) {

    startTime := time.Now()

    req, err := http.NewRequestWithContext(ctx, method, url, body)
    if err != nil {
        return nil, err
    }

    // 设置请求头
    for key, value := range headers {
        req.Header.Set(key, value)
    }

    // 执行请求
    client := &http.Client{}
    resp, err := client.Do(req)
    if err != nil {
        return nil, err
    }
    defer resp.Body.Close()

    // 读取响应体
    respBody, err := io.ReadAll(resp.Body)
    if err != nil {
        return nil, err
    }

    return &CallResponse{
        StatusCode: resp.StatusCode,
        Headers:    resp.Header,
        Body:       respBody,
        Duration:   time.Since(startTime),
    }, nil
}

熔断器实现 #

熔断器状态机 #

// circuit/breaker.go
package circuit

import (
    "context"
    "errors"
    "sync"
    "time"
)

var (
    ErrCircuitOpen     = errors.New("circuit breaker is open")
    ErrTooManyRequests = errors.New("too many requests")
)

// State 熔断器状态
type State int

const (
    StateClosed State = iota
    StateHalfOpen
    StateOpen
)

func (s State) String() string {
    switch s {
    case StateClosed:
        return "closed"
    case StateHalfOpen:
        return "half-open"
    case StateOpen:
        return "open"
    default:
        return "unknown"
    }
}

// CircuitBreaker 熔断器
type CircuitBreaker struct {
    config Config

    // 状态管理
    state         State
    generation    uint64
    expiry        time.Time

    // 计数器
    counts        Counts

    mutex sync.RWMutex
}

// Config 熔断器配置
type Config struct {
    MaxRequests         uint32        // 半开状态下的最大请求数
    Interval            time.Duration // 统计间隔
    Timeout             time.Duration // 开启状态的超时时间
    ReadyToTrip         func(counts Counts) bool // 判断是否应该开启熔断器
    OnStateChange       func(name string, from State, to State) // 状态变化回调
    IsSuccessful        func(err error) bool // 判断请求是否成功
}

// Counts 计数器
type Counts struct {
    Requests             uint32
    TotalSuccesses       uint32
    TotalFailures        uint32
    ConsecutiveSuccesses uint32
    ConsecutiveFailures  uint32
}

// NewCircuitBreaker 创建熔断器
func NewCircuitBreaker(config Config) *CircuitBreaker {
    cb := &CircuitBreaker{
        config: config,
        state:  StateClosed,
    }

    // 设置默认配置
    if cb.config.Interval == 0 {
        cb.config.Interval = time.Minute
    }
    if cb.config.Timeout == 0 {
        cb.config.Timeout = time.Minute
    }
    if cb.config.ReadyToTrip == nil {
        cb.config.ReadyToTrip = func(counts Counts) bool {
            return counts.Requests >= 5 && counts.TotalFailures >= 3
        }
    }
    if cb.config.IsSuccessful == nil {
        cb.config.IsSuccessful = func(err error) bool {
            return err == nil
        }
    }

    return cb
}

// Execute 执行函数
func (cb *CircuitBreaker) Execute(fn func() error) error {
    generation, err := cb.beforeRequest()
    if err != nil {
        return err
    }

    defer func() {
        cb.afterRequest(generation, cb.config.IsSuccessful(err))
    }()

    err = fn()
    return err
}

// beforeRequest 请求前检查
func (cb *CircuitBreaker) beforeRequest() (uint64, error) {
    cb.mutex.Lock()
    defer cb.mutex.Unlock()

    now := time.Now()
    state, generation := cb.currentState(now)

    if state == StateOpen {
        return generation, ErrCircuitOpen
    } else if state == StateHalfOpen && cb.counts.Requests >= cb.config.MaxRequests {
        return generation, ErrTooManyRequests
    }

    cb.counts.onRequest()
    return generation, nil
}

// afterRequest 请求后处理
func (cb *CircuitBreaker) afterRequest(before uint64, success bool) {
    cb.mutex.Lock()
    defer cb.mutex.Unlock()

    now := time.Now()
    state, generation := cb.currentState(now)
    if generation != before {
        return
    }

    if success {
        cb.onSuccess(state, now)
    } else {
        cb.onFailure(state, now)
    }
}

// currentState 获取当前状态
func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) {
    switch cb.state {
    case StateClosed:
        if !cb.expiry.IsZero() && cb.expiry.Before(now) {
            cb.toNewGeneration(now)
        }
    case StateOpen:
        if cb.expiry.Before(now) {
            cb.setState(StateHalfOpen, now)
        }
    }
    return cb.state, cb.generation
}

// onSuccess 成功处理
func (cb *CircuitBreaker) onSuccess(state State, now time.Time) {
    switch state {
    case StateClosed:
        cb.counts.onSuccess()
    case StateHalfOpen:
        cb.counts.onSuccess()
        if cb.counts.ConsecutiveSuccesses >= cb.config.MaxRequests {
            cb.setState(StateClosed, now)
        }
    }
}

// onFailure 失败处理
func (cb *CircuitBreaker) onFailure(state State, now time.Time) {
    switch state {
    case StateClosed:
        cb.counts.onFailure()
        if cb.config.ReadyToTrip(cb.counts) {
            cb.setState(StateOpen, now)
        }
    case StateHalfOpen:
        cb.setState(StateOpen, now)
    }
}

// setState 设置状态
func (cb *CircuitBreaker) setState(state State, now time.Time) {
    if cb.state == state {
        return
    }

    prev := cb.state
    cb.state = state

    cb.toNewGeneration(now)

    if cb.config.OnStateChange != nil {
        cb.config.OnStateChange("circuit-breaker", prev, state)
    }
}

// toNewGeneration 进入新的周期
func (cb *CircuitBreaker) toNewGeneration(now time.Time) {
    cb.generation++
    cb.counts.clear()

    var zero time.Time
    switch cb.state {
    case StateClosed:
        if cb.config.Interval == 0 {
            cb.expiry = zero
        } else {
            cb.expiry = now.Add(cb.config.Interval)
        }
    case StateOpen:
        cb.expiry = now.Add(cb.config.Timeout)
    default: // StateHalfOpen
        cb.expiry = zero
    }
}

// State 获取当前状态
func (cb *CircuitBreaker) State() State {
    cb.mutex.RLock()
    defer cb.mutex.RUnlock()

    state, _ := cb.currentState(time.Now())
    return state
}

// Counts 获取计数器
func (cb *CircuitBreaker) Counts() Counts {
    cb.mutex.RLock()
    defer cb.mutex.RUnlock()

    return cb.counts
}

// 计数器方法
func (c *Counts) onRequest() {
    c.Requests++
}

func (c *Counts) onSuccess() {
    c.TotalSuccesses++
    c.ConsecutiveSuccesses++
    c.ConsecutiveFailures = 0
}

func (c *Counts) onFailure() {
    c.TotalFailures++
    c.ConsecutiveFailures++
    c.ConsecutiveSuccesses = 0
}

func (c *Counts) clear() {
    c.Requests = 0
    c.TotalSuccesses = 0
    c.TotalFailures = 0
    c.ConsecutiveSuccesses = 0
    c.ConsecutiveFailures = 0
}

HTTP 客户端实现 #

带负载均衡的 HTTP 客户端 #

// http/client.go
package http

import (
    "bytes"
    "context"
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "time"

    "github.com/example/client-lb/client"
    "github.com/example/client-lb/circuit"
)

// LBHTTPClient 带负载均衡的HTTP客户端
type LBHTTPClient struct {
    loadBalancer   *client.ClientLoadBalancer
    circuitBreaker *circuit.CircuitBreaker
    httpClient     *http.Client
}

// NewLBHTTPClient 创建HTTP客户端
func NewLBHTTPClient(
    loadBalancer *client.ClientLoadBalancer,
    timeout time.Duration,
) *LBHTTPClient {

    // 创建熔断器
    cb := circuit.NewCircuitBreaker(circuit.Config{
        MaxRequests: 3,
        Interval:    time.Minute,
        Timeout:     time.Minute,
        ReadyToTrip: func(counts circuit.Counts) bool {
            failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
            return counts.Requests >= 3 && failureRatio >= 0.6
        },
    })

    return &LBHTTPClient{
        loadBalancer:   loadBalancer,
        circuitBreaker: cb,
        httpClient: &http.Client{
            Timeout: timeout,
        },
    }
}

// Get 执行GET请求
func (c *LBHTTPClient) Get(ctx context.Context, path string, headers map[string]string) (*client.CallResponse, error) {
    return c.Request(ctx, "GET", path, nil, headers)
}

// Post 执行POST请求
func (c *LBHTTPClient) Post(ctx context.Context, path string, body interface{}, headers map[string]string) (*client.CallResponse, error) {
    var bodyReader io.Reader

    if body != nil {
        jsonData, err := json.Marshal(body)
        if err != nil {
            return nil, fmt.Errorf("failed to marshal request body: %v", err)
        }
        bodyReader = bytes.NewReader(jsonData)

        if headers == nil {
            headers = make(map[string]string)
        }
        headers["Content-Type"] = "application/json"
    }

    return c.Request(ctx, "POST", path, bodyReader, headers)
}

// Put 执行PUT请求
func (c *LBHTTPClient) Put(ctx context.Context, path string, body interface{}, headers map[string]string) (*client.CallResponse, error) {
    var bodyReader io.Reader

    if body != nil {
        jsonData, err := json.Marshal(body)
        if err != nil {
            return nil, fmt.Errorf("failed to marshal request body: %v", err)
        }
        bodyReader = bytes.NewReader(jsonData)

        if headers == nil {
            headers = make(map[string]string)
        }
        headers["Content-Type"] = "application/json"
    }

    return c.Request(ctx, "PUT", path, bodyReader, headers)
}

// Delete 执行DELETE请求
func (c *LBHTTPClient) Delete(ctx context.Context, path string, headers map[string]string) (*client.CallResponse, error) {
    return c.Request(ctx, "DELETE", path, nil, headers)
}

// Request 执行通用请求
func (c *LBHTTPClient) Request(
    ctx context.Context,
    method, path string,
    body io.Reader,
    headers map[string]string,
) (*client.CallResponse, error) {

    var response *client.CallResponse
    var err error

    // 使用熔断器保护
    cbErr := c.circuitBreaker.Execute(func() error {
        request := client.CallRequest{
            Method:  method,
            Path:    path,
            Body:    body,
            Headers: headers,
        }

        response, err = c.loadBalancer.Call(ctx, request)
        return err
    })

    if cbErr != nil {
        return nil, cbErr
    }

    return response, err
}

// GetStats 获取统计信息
func (c *LBHTTPClient) GetStats() map[string]interface{} {
    lbStats := c.loadBalancer.GetStats()
    cbCounts := c.circuitBreaker.Counts()

    return map[string]interface{}{
        "load_balancer": lbStats,
        "circuit_breaker": map[string]interface{}{
            "state":                 c.circuitBreaker.State().String(),
            "requests":              cbCounts.Requests,
            "total_successes":       cbCounts.TotalSuccesses,
            "total_failures":        cbCounts.TotalFailures,
            "consecutive_successes": cbCounts.ConsecutiveSuccesses,
            "consecutive_failures":  cbCounts.ConsecutiveFailures,
        },
    }
}

gRPC 客户端实现 #

带负载均衡的 gRPC 客户端 #

// grpc/client.go
package grpc

import (
    "context"
    "fmt"
    "sync"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/connectivity"
    "google.golang.org/grpc/credentials/insecure"

    "github.com/example/client-lb/balancer"
    "github.com/example/client-lb/discovery"
)

// LBGRPCClient 带负载均衡的gRPC客户端
type LBGRPCClient struct {
    serviceName     string
    discoveryClient discovery.ServiceDiscovery
    loadBalancer    balancer.LoadBalancer

    // 连接池
    connections map[string]*grpc.ClientConn
    connMutex   sync.RWMutex

    // 配置
    dialOptions []grpc.DialOption

    // 控制
    ctx    context.Context
    cancel context.CancelFunc
}

// NewLBGRPCClient 创建gRPC客户端
func NewLBGRPCClient(
    serviceName string,
    discoveryClient discovery.ServiceDiscovery,
    loadBalancer balancer.LoadBalancer,
    dialOptions ...grpc.DialOption,
) *LBGRPCClient {

    ctx, cancel := context.WithCancel(context.Background())

    // 默认拨号选项
    if len(dialOptions) == 0 {
        dialOptions = []grpc.DialOption{
            grpc.WithTransportCredentials(insecure.NewCredentials()),
            grpc.WithBlock(),
        }
    }

    client := &LBGRPCClient{
        serviceName:     serviceName,
        discoveryClient: discoveryClient,
        loadBalancer:    loadBalancer,
        connections:     make(map[string]*grpc.ClientConn),
        dialOptions:     dialOptions,
        ctx:             ctx,
        cancel:          cancel,
    }

    // 启动服务发现
    go client.startDiscovery()

    return client
}

// GetConnection 获取连接
func (c *LBGRPCClient) GetConnection(ctx context.Context) (*grpc.ClientConn, error) {
    // 选择实例
    instance, err := c.loadBalancer.Select(ctx)
    if err != nil {
        return nil, err
    }

    // 获取或创建连接
    return c.getOrCreateConnection(instance)
}

// getOrCreateConnection 获取或创建连接
func (c *LBGRPCClient) getOrCreateConnection(instance *balancer.Instance) (*grpc.ClientConn, error) {
    address := fmt.Sprintf("%s:%d", instance.Address, instance.Port)

    c.connMutex.RLock()
    if conn, exists := c.connections[address]; exists {
        // 检查连接状态
        if conn.GetState() != connectivity.Shutdown {
            c.connMutex.RUnlock()
            return conn, nil
        }
    }
    c.connMutex.RUnlock()

    // 创建新连接
    c.connMutex.Lock()
    defer c.connMutex.Unlock()

    // 双重检查
    if conn, exists := c.connections[address]; exists {
        if conn.GetState() != connectivity.Shutdown {
            return conn, nil
        }
        // 关闭旧连接
        conn.Close()
    }

    // 创建连接
    ctx, cancel := context.WithTimeout(c.ctx, time.Second*10)
    defer cancel()

    conn, err := grpc.DialContext(ctx, address, c.dialOptions...)
    if err != nil {
        return nil, fmt.Errorf("failed to dial %s: %v", address, err)
    }

    c.connections[address] = conn

    // 监控连接状态
    go c.monitorConnection(address, conn)

    return conn, nil
}

// monitorConnection 监控连接状态
func (c *LBGRPCClient) monitorConnection(address string, conn *grpc.ClientConn) {
    for {
        if !conn.WaitForStateChange(c.ctx, conn.GetState()) {
            // 上下文取消
            return
        }

        state := conn.GetState()
        if state == connectivity.Shutdown {
            // 连接已关闭,从池中移除
            c.connMutex.Lock()
            if c.connections[address] == conn {
                delete(c.connections, address)
            }
            c.connMutex.Unlock()
            return
        }
    }
}

// startDiscovery 启动服务发现
func (c *LBGRPCClient) startDiscovery() {
    ticker := time.NewTicker(time.Second * 30)
    defer ticker.Stop()

    // 立即执行一次
    c.updateInstances()

    for {
        select {
        case <-c.ctx.Done():
            return
        case <-ticker.C:
            c.updateInstances()
        }
    }
}

// updateInstances 更新实例列表
func (c *LBGRPCClient) updateInstances() {
    instances, err := c.discoveryClient.Discover(c.ctx, c.serviceName)
    if err != nil {
        return
    }

    // 转换为负载均衡器实例
    lbInstances := make([]*balancer.Instance, 0, len(instances))
    for _, inst := range instances {
        lbInstances = append(lbInstances, &balancer.Instance{
            ID:       inst.ID,
            Address:  inst.Address,
            Port:     inst.Port,
            Weight:   1,
            Active:   true,
            Metadata: inst.Metadata,
        })
    }

    c.loadBalancer.UpdateInstances(lbInstances)
}

// Close 关闭客户端
func (c *LBGRPCClient) Close() error {
    c.cancel()

    c.connMutex.Lock()
    defer c.connMutex.Unlock()

    for _, conn := range c.connections {
        conn.Close()
    }
    c.connections = make(map[string]*grpc.ClientConn)

    return nil
}

使用示例 #

HTTP 客户端使用示例 #

// example/http_client_example.go
package main

import (
    "context"
    "log"
    "time"

    "github.com/example/client-lb/balancer"
    "github.com/example/client-lb/client"
    "github.com/example/client-lb/discovery"
    "github.com/example/client-lb/http"
)

func main() {
    // 创建服务发现客户端(这里使用模拟实现)
    discoveryClient := discovery.NewMockDiscovery()

    // 配置客户端负载均衡器
    config := client.ClientLBConfig{
        ServiceName:          "user-service",
        BalancerType:         balancer.WeightedRoundRobin,
        HealthCheckInterval:  time.Second * 10,
        DiscoveryInterval:    time.Second * 30,
        MaxRetries:           3,
        RetryTimeout:         time.Second * 5,
    }

    // 创建客户端负载均衡器
    clb, err := client.NewClientLoadBalancer(discoveryClient, config)
    if err != nil {
        log.Fatal("Failed to create client load balancer:", err)
    }
    defer clb.Stop()

    // 创建HTTP客户端
    httpClient := http.NewLBHTTPClient(clb, time.Second*10)

    // 执行请求
    ctx := context.Background()

    // GET 请求
    response, err := httpClient.Get(ctx, "/users/123", nil)
    if err != nil {
        log.Printf("GET request failed: %v", err)
    } else {
        log.Printf("GET response: %s", string(response.Body))
    }

    // POST 请求
    userData := map[string]interface{}{
        "name":  "John Doe",
        "email": "[email protected]",
    }

    response, err = httpClient.Post(ctx, "/users", userData, nil)
    if err != nil {
        log.Printf("POST request failed: %v", err)
    } else {
        log.Printf("POST response: %s", string(response.Body))
    }

    // 打印统计信息
    stats := httpClient.GetStats()
    log.Printf("Client stats: %+v", stats)
}

gRPC 客户端使用示例 #

// example/grpc_client_example.go
package main

import (
    "context"
    "log"
    "time"

    "google.golang.org/grpc"

    "github.com/example/client-lb/balancer"
    "github.com/example/client-lb/discovery"
    grpcclient "github.com/example/client-lb/grpc"
    pb "github.com/example/proto/user"
)

func main() {
    // 创建服务发现客户端
    discoveryClient := discovery.NewMockDiscovery()

    // 创建负载均衡器
    factory := balancer.NewBalancerFactory()
    lb, err := factory.CreateBalancer(balancer.BalancerConfig{
        Type: balancer.RoundRobin,
    })
    if err != nil {
        log.Fatal("Failed to create load balancer:", err)
    }

    // 创建gRPC客户端
    grpcClient := grpcclient.NewLBGRPCClient(
        "user-service",
        discoveryClient,
        lb,
        grpc.WithInsecure(),
        grpc.WithBlock(),
    )
    defer grpcClient.Close()

    // 获取连接并创建服务客户端
    ctx := context.Background()
    conn, err := grpcClient.GetConnection(ctx)
    if err != nil {
        log.Fatal("Failed to get connection:", err)
    }

    userClient := pb.NewUserServiceClient(conn)

    // 调用服务
    response, err := userClient.GetUser(ctx, &pb.GetUserRequest{
        UserId: "123",
    })
    if err != nil {
        log.Printf("gRPC call failed: %v", err)
    } else {
        log.Printf("User: %+v", response.User)
    }

    // 创建用户
    createResponse, err := userClient.CreateUser(ctx, &pb.CreateUserRequest{
        User: &pb.User{
            Name:  "Jane Doe",
            Email: "[email protected]",
        },
    })
    if err != nil {
        log.Printf("Create user failed: %v", err)
    } else {
        log.Printf("Created user ID: %s", createResponse.UserId)
    }
}

通过实现这个完整的客户端负载均衡系统,我们为微服务架构提供了高效、可靠的服务调用能力。系统包含了负载均衡、熔断保护、健康检查、重试机制等关键特性,能够有效提高系统的可用性和性能。