3.7.5 gRPC 拦截器与中间件

3.7.5 gRPC 拦截器与中间件 #

gRPC 拦截器(Interceptor)是一种强大的机制,允许在 RPC 调用的执行过程中插入自定义逻辑。它们类似于 HTTP 中间件,可以用于实现认证、日志记录、监控、限流等横切关注点功能。

拦截器基础 #

拦截器类型 #

gRPC 提供了四种类型的拦截器:

  1. 一元服务端拦截器(Unary Server Interceptor)
  2. 流式服务端拦截器(Stream Server Interceptor)
  3. 一元客户端拦截器(Unary Client Interceptor)
  4. 流式客户端拦截器(Stream Client Interceptor)

拦截器签名 #

// 一元服务端拦截器
type UnaryServerInterceptor func(
    ctx context.Context,
    req interface{},
    info *UnaryServerInfo,
    handler UnaryHandler,
) (resp interface{}, err error)

// 流式服务端拦截器
type StreamServerInterceptor func(
    srv interface{},
    ss ServerStream,
    info *StreamServerInfo,
    handler StreamHandler,
) error

// 一元客户端拦截器
type UnaryClientInterceptor func(
    ctx context.Context,
    method string,
    req, reply interface{},
    cc *ClientConn,
    invoker UnaryInvoker,
    opts ...CallOption,
) error

// 流式客户端拦截器
type StreamClientInterceptor func(
    ctx context.Context,
    desc *StreamDesc,
    cc *ClientConn,
    method string,
    streamer Streamer,
    opts ...CallOption,
) (ClientStream, error)

服务端拦截器 #

1. 日志拦截器 #

// internal/interceptor/logging.go
package interceptor

import (
    "context"
    "log"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

// 一元 RPC 日志拦截器
func LoggingUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    start := time.Now()

    // 记录请求开始
    log.Printf("[%s] Request started - Method: %s",
        time.Now().Format("2006-01-02 15:04:05"),
        info.FullMethod)

    // 调用实际的处理函数
    resp, err := handler(ctx, req)

    // 计算耗时
    duration := time.Since(start)

    // 获取状态码
    code := codes.OK
    if err != nil {
        if st, ok := status.FromError(err); ok {
            code = st.Code()
        } else {
            code = codes.Unknown
        }
    }

    // 记录请求完成
    log.Printf("[%s] Request completed - Method: %s, Duration: %v, Code: %s, Error: %v",
        time.Now().Format("2006-01-02 15:04:05"),
        info.FullMethod,
        duration,
        code,
        err)

    return resp, err
}

// 流式 RPC 日志拦截器
func LoggingStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    start := time.Now()

    log.Printf("[%s] Stream started - Method: %s, Type: %s",
        time.Now().Format("2006-01-02 15:04:05"),
        info.FullMethod,
        getStreamType(info))

    // 包装流以记录消息
    wrappedStream := &loggingServerStream{
        ServerStream: stream,
        method:       info.FullMethod,
    }

    err := handler(srv, wrappedStream)

    duration := time.Since(start)
    code := codes.OK
    if err != nil {
        if st, ok := status.FromError(err); ok {
            code = st.Code()
        } else {
            code = codes.Unknown
        }
    }

    log.Printf("[%s] Stream completed - Method: %s, Duration: %v, Code: %s, Error: %v",
        time.Now().Format("2006-01-02 15:04:05"),
        info.FullMethod,
        duration,
        code,
        err)

    return err
}

// 流包装器
type loggingServerStream struct {
    grpc.ServerStream
    method string
}

func (s *loggingServerStream) SendMsg(m interface{}) error {
    log.Printf("[%s] Stream send - Method: %s",
        time.Now().Format("2006-01-02 15:04:05"),
        s.method)
    return s.ServerStream.SendMsg(m)
}

func (s *loggingServerStream) RecvMsg(m interface{}) error {
    err := s.ServerStream.RecvMsg(m)
    if err == nil {
        log.Printf("[%s] Stream receive - Method: %s",
            time.Now().Format("2006-01-02 15:04:05"),
            s.method)
    }
    return err
}

func getStreamType(info *grpc.StreamServerInfo) string {
    if info.IsClientStream && info.IsServerStream {
        return "bidirectional"
    } else if info.IsClientStream {
        return "client-stream"
    } else if info.IsServerStream {
        return "server-stream"
    }
    return "unknown"
}

2. 认证拦截器 #

// internal/interceptor/auth.go
package interceptor

import (
    "context"
    "strings"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/metadata"
    "google.golang.org/grpc/status"
)

type AuthInterceptor struct {
    jwtManager *JWTManager
    publicMethods map[string]bool
}

func NewAuthInterceptor(jwtManager *JWTManager, publicMethods []string) *AuthInterceptor {
    publicMethodsMap := make(map[string]bool)
    for _, method := range publicMethods {
        publicMethodsMap[method] = true
    }

    return &AuthInterceptor{
        jwtManager:    jwtManager,
        publicMethods: publicMethodsMap,
    }
}

func (interceptor *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        // 检查是否为公开方法
        if interceptor.publicMethods[info.FullMethod] {
            return handler(ctx, req)
        }

        // 验证访问令牌
        err := interceptor.authorize(ctx)
        if err != nil {
            return nil, err
        }

        return handler(ctx, req)
    }
}

func (interceptor *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
    return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
        // 检查是否为公开方法
        if interceptor.publicMethods[info.FullMethod] {
            return handler(srv, stream)
        }

        // 验证访问令牌
        err := interceptor.authorize(stream.Context())
        if err != nil {
            return err
        }

        return handler(srv, stream)
    }
}

func (interceptor *AuthInterceptor) authorize(ctx context.Context) error {
    md, ok := metadata.FromIncomingContext(ctx)
    if !ok {
        return status.Errorf(codes.Unauthenticated, "metadata is not provided")
    }

    values := md["authorization"]
    if len(values) == 0 {
        return status.Errorf(codes.Unauthenticated, "authorization token is not provided")
    }

    accessToken := values[0]
    if !strings.HasPrefix(accessToken, "Bearer ") {
        return status.Errorf(codes.Unauthenticated, "invalid authorization format")
    }

    accessToken = strings.TrimPrefix(accessToken, "Bearer ")
    claims, err := interceptor.jwtManager.Verify(accessToken)
    if err != nil {
        return status.Errorf(codes.Unauthenticated, "access token is invalid: %v", err)
    }

    // 将用户信息添加到上下文
    ctx = context.WithValue(ctx, "user_id", claims.UserID)
    ctx = context.WithValue(ctx, "username", claims.Username)

    return nil
}

// JWT 管理器示例
type JWTManager struct {
    secretKey string
}

type UserClaims struct {
    UserID   int64  `json:"user_id"`
    Username string `json:"username"`
}

func NewJWTManager(secretKey string) *JWTManager {
    return &JWTManager{secretKey: secretKey}
}

func (manager *JWTManager) Verify(accessToken string) (*UserClaims, error) {
    // 这里应该实现实际的 JWT 验证逻辑
    // 为了示例,我们简化处理
    if accessToken == "valid-token" {
        return &UserClaims{
            UserID:   1,
            Username: "testuser",
        }, nil
    }
    return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}

3. 限流拦截器 #

// internal/interceptor/ratelimit.go
package interceptor

import (
    "context"
    "sync"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
    "golang.org/x/time/rate"
)

type RateLimiter struct {
    mu       sync.RWMutex
    limiters map[string]*rate.Limiter
    rate     rate.Limit
    burst    int
}

func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
    return &RateLimiter{
        limiters: make(map[string]*rate.Limiter),
        rate:     r,
        burst:    b,
    }
}

func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
    rl.mu.RLock()
    limiter, exists := rl.limiters[key]
    rl.mu.RUnlock()

    if !exists {
        rl.mu.Lock()
        // 双重检查
        if limiter, exists = rl.limiters[key]; !exists {
            limiter = rate.NewLimiter(rl.rate, rl.burst)
            rl.limiters[key] = limiter
        }
        rl.mu.Unlock()
    }

    return limiter
}

func (rl *RateLimiter) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        // 使用方法名作为限流键
        limiter := rl.getLimiter(info.FullMethod)

        if !limiter.Allow() {
            return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for method %s", info.FullMethod)
        }

        return handler(ctx, req)
    }
}

func (rl *RateLimiter) StreamServerInterceptor() grpc.StreamServerInterceptor {
    return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
        limiter := rl.getLimiter(info.FullMethod)

        if !limiter.Allow() {
            return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for method %s", info.FullMethod)
        }

        return handler(srv, stream)
    }
}

// 基于 IP 的限流器
type IPRateLimiter struct {
    mu       sync.RWMutex
    limiters map[string]*rate.Limiter
    rate     rate.Limit
    burst    int
}

func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
    limiter := &IPRateLimiter{
        limiters: make(map[string]*rate.Limiter),
        rate:     r,
        burst:    b,
    }

    // 启动清理协程
    go limiter.cleanup()

    return limiter
}

func (rl *IPRateLimiter) getClientIP(ctx context.Context) string {
    // 从上下文中获取客户端 IP
    // 这里简化处理,实际应该从 peer 信息中获取
    return "127.0.0.1"
}

func (rl *IPRateLimiter) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        clientIP := rl.getClientIP(ctx)
        limiter := rl.getLimiter(clientIP)

        if !limiter.Allow() {
            return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for IP %s", clientIP)
        }

        return handler(ctx, req)
    }
}

func (rl *IPRateLimiter) getLimiter(ip string) *rate.Limiter {
    rl.mu.RLock()
    limiter, exists := rl.limiters[ip]
    rl.mu.RUnlock()

    if !exists {
        rl.mu.Lock()
        if limiter, exists = rl.limiters[ip]; !exists {
            limiter = rate.NewLimiter(rl.rate, rl.burst)
            rl.limiters[ip] = limiter
        }
        rl.mu.Unlock()
    }

    return limiter
}

func (rl *IPRateLimiter) cleanup() {
    ticker := time.NewTicker(time.Hour)
    defer ticker.Stop()

    for range ticker.C {
        rl.mu.Lock()
        for ip, limiter := range rl.limiters {
            // 如果限流器长时间未使用,则删除
            if limiter.Tokens() == float64(rl.burst) {
                delete(rl.limiters, ip)
            }
        }
        rl.mu.Unlock()
    }
}

4. 监控拦截器 #

// internal/interceptor/metrics.go
package interceptor

import (
    "context"
    "time"

    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

var (
    // 请求总数
    grpcRequestsTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "grpc_requests_total",
            Help: "Total number of gRPC requests",
        },
        []string{"method", "code"},
    )

    // 请求持续时间
    grpcRequestDuration = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "grpc_request_duration_seconds",
            Help:    "Duration of gRPC requests",
            Buckets: prometheus.DefBuckets,
        },
        []string{"method"},
    )

    // 活跃连接数
    grpcActiveConnections = promauto.NewGauge(
        prometheus.GaugeOpts{
            Name: "grpc_active_connections",
            Help: "Number of active gRPC connections",
        },
    )

    // 消息大小
    grpcMessageSize = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "grpc_message_size_bytes",
            Help:    "Size of gRPC messages",
            Buckets: []float64{64, 256, 1024, 4096, 16384, 65536, 262144, 1048576},
        },
        []string{"method", "type"},
    )
)

func MetricsUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    start := time.Now()

    // 记录请求大小
    if reqSize := getMessageSize(req); reqSize > 0 {
        grpcMessageSize.WithLabelValues(info.FullMethod, "request").Observe(float64(reqSize))
    }

    resp, err := handler(ctx, req)

    // 计算持续时间
    duration := time.Since(start)
    grpcRequestDuration.WithLabelValues(info.FullMethod).Observe(duration.Seconds())

    // 获取状态码
    code := codes.OK
    if err != nil {
        if st, ok := status.FromError(err); ok {
            code = st.Code()
        } else {
            code = codes.Unknown
        }
    }

    // 记录请求总数
    grpcRequestsTotal.WithLabelValues(info.FullMethod, code.String()).Inc()

    // 记录响应大小
    if resp != nil {
        if respSize := getMessageSize(resp); respSize > 0 {
            grpcMessageSize.WithLabelValues(info.FullMethod, "response").Observe(float64(respSize))
        }
    }

    return resp, err
}

func MetricsStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    start := time.Now()

    // 包装流以监控消息
    wrappedStream := &metricsServerStream{
        ServerStream: stream,
        method:       info.FullMethod,
    }

    err := handler(srv, wrappedStream)

    duration := time.Since(start)
    grpcRequestDuration.WithLabelValues(info.FullMethod).Observe(duration.Seconds())

    code := codes.OK
    if err != nil {
        if st, ok := status.FromError(err); ok {
            code = st.Code()
        } else {
            code = codes.Unknown
        }
    }

    grpcRequestsTotal.WithLabelValues(info.FullMethod, code.String()).Inc()

    return err
}

type metricsServerStream struct {
    grpc.ServerStream
    method string
}

func (s *metricsServerStream) SendMsg(m interface{}) error {
    if size := getMessageSize(m); size > 0 {
        grpcMessageSize.WithLabelValues(s.method, "response").Observe(float64(size))
    }
    return s.ServerStream.SendMsg(m)
}

func (s *metricsServerStream) RecvMsg(m interface{}) error {
    err := s.ServerStream.RecvMsg(m)
    if err == nil {
        if size := getMessageSize(m); size > 0 {
            grpcMessageSize.WithLabelValues(s.method, "request").Observe(float64(size))
        }
    }
    return err
}

// 获取消息大小的辅助函数
func getMessageSize(msg interface{}) int {
    // 这里应该实现实际的消息大小计算
    // 可以使用 proto.Size() 或其他方法
    return 0
}

5. 恢复拦截器 #

// internal/interceptor/recovery.go
package interceptor

import (
    "context"
    "log"
    "runtime/debug"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

func RecoveryUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    defer func() {
        if r := recover(); r != nil {
            log.Printf("Panic recovered in %s: %v\n%s", info.FullMethod, r, debug.Stack())
            err = status.Errorf(codes.Internal, "internal server error")
        }
    }()

    return handler(ctx, req)
}

func RecoveryStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
    defer func() {
        if r := recover(); r != nil {
            log.Printf("Panic recovered in stream %s: %v\n%s", info.FullMethod, r, debug.Stack())
            err = status.Errorf(codes.Internal, "internal server error")
        }
    }()

    return handler(srv, stream)
}

// 自定义恢复处理器
type RecoveryHandler func(p interface{}) error

func RecoveryUnaryInterceptorWithHandler(recoveryHandler RecoveryHandler) grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
        defer func() {
            if r := recover(); r != nil {
                log.Printf("Panic recovered in %s: %v\n%s", info.FullMethod, r, debug.Stack())
                err = recoveryHandler(r)
            }
        }()

        return handler(ctx, req)
    }
}

// 默认恢复处理器
func DefaultRecoveryHandler(p interface{}) error {
    return status.Errorf(codes.Internal, "internal server error")
}

// 详细恢复处理器
func DetailedRecoveryHandler(p interface{}) error {
    log.Printf("Detailed panic info: %+v", p)
    return status.Errorf(codes.Internal, "internal server error: %v", p)
}

客户端拦截器 #

1. 客户端日志拦截器 #

// internal/interceptor/client_logging.go
package interceptor

import (
    "context"
    "log"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

func ClientLoggingUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    start := time.Now()

    log.Printf("[Client] Request started - Method: %s", method)

    err := invoker(ctx, method, req, reply, cc, opts...)

    duration := time.Since(start)
    code := codes.OK
    if err != nil {
        if st, ok := status.FromError(err); ok {
            code = st.Code()
        } else {
            code = codes.Unknown
        }
    }

    log.Printf("[Client] Request completed - Method: %s, Duration: %v, Code: %s, Error: %v",
        method, duration, code, err)

    return err
}

func ClientLoggingStreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    start := time.Now()

    log.Printf("[Client] Stream started - Method: %s", method)

    stream, err := streamer(ctx, desc, cc, method, opts...)
    if err != nil {
        log.Printf("[Client] Stream creation failed - Method: %s, Error: %v", method, err)
        return nil, err
    }

    return &loggingClientStream{
        ClientStream: stream,
        method:       method,
        startTime:    start,
    }, nil
}

type loggingClientStream struct {
    grpc.ClientStream
    method    string
    startTime time.Time
}

func (s *loggingClientStream) SendMsg(m interface{}) error {
    log.Printf("[Client] Stream send - Method: %s", s.method)
    return s.ClientStream.SendMsg(m)
}

func (s *loggingClientStream) RecvMsg(m interface{}) error {
    err := s.ClientStream.RecvMsg(m)
    if err == nil {
        log.Printf("[Client] Stream receive - Method: %s", s.method)
    } else {
        log.Printf("[Client] Stream receive error - Method: %s, Error: %v", s.method, err)
    }
    return err
}

func (s *loggingClientStream) CloseSend() error {
    duration := time.Since(s.startTime)
    log.Printf("[Client] Stream closed - Method: %s, Duration: %v", s.method, duration)
    return s.ClientStream.CloseSend()
}

2. 客户端认证拦截器 #

// internal/interceptor/client_auth.go
package interceptor

import (
    "context"

    "google.golang.org/grpc"
    "google.golang.org/grpc/metadata"
)

type ClientAuthInterceptor struct {
    tokenProvider TokenProvider
}

type TokenProvider interface {
    GetToken() (string, error)
}

func NewClientAuthInterceptor(tokenProvider TokenProvider) *ClientAuthInterceptor {
    return &ClientAuthInterceptor{
        tokenProvider: tokenProvider,
    }
}

func (interceptor *ClientAuthInterceptor) Unary() grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        token, err := interceptor.tokenProvider.GetToken()
        if err != nil {
            return err
        }

        // 添加认证头
        ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)

        return invoker(ctx, method, req, reply, cc, opts...)
    }
}

func (interceptor *ClientAuthInterceptor) Stream() grpc.StreamClientInterceptor {
    return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
        token, err := interceptor.tokenProvider.GetToken()
        if err != nil {
            return nil, err
        }

        ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)

        return streamer(ctx, desc, cc, method, opts...)
    }
}

// 简单的令牌提供者实现
type StaticTokenProvider struct {
    token string
}

func NewStaticTokenProvider(token string) *StaticTokenProvider {
    return &StaticTokenProvider{token: token}
}

func (p *StaticTokenProvider) GetToken() (string, error) {
    return p.token, nil
}

// 动态令牌提供者
type DynamicTokenProvider struct {
    getTokenFunc func() (string, error)
}

func NewDynamicTokenProvider(getTokenFunc func() (string, error)) *DynamicTokenProvider {
    return &DynamicTokenProvider{
        getTokenFunc: getTokenFunc,
    }
}

func (p *DynamicTokenProvider) GetToken() (string, error) {
    return p.getTokenFunc()
}

拦截器链 #

拦截器组合 #

// internal/interceptor/chain.go
package interceptor

import (
    "google.golang.org/grpc"
)

// 服务端拦截器链
func ChainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
    return grpc.ChainUnaryInterceptor(interceptors...)
}

func ChainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
    return grpc.ChainStreamInterceptor(interceptors...)
}

// 客户端拦截器链
func ChainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
    return grpc.ChainUnaryInterceptor(interceptors...)
}

func ChainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
    return grpc.ChainStreamInterceptor(interceptors...)
}

服务器配置示例 #

// cmd/server/main.go
package main

import (
    "log"
    "net"

    "google.golang.org/grpc"
    "google.golang.org/grpc/reflection"

    "github.com/example/grpc-server/internal/interceptor"
    pb "github.com/example/grpc-server/proto/user/v1"
)

func main() {
    // 创建监听器
    lis, err := net.Listen("tcp", ":50051")
    if err != nil {
        log.Fatalf("Failed to listen: %v", err)
    }

    // 创建 JWT 管理器
    jwtManager := interceptor.NewJWTManager("secret-key")

    // 创建认证拦截器
    authInterceptor := interceptor.NewAuthInterceptor(jwtManager, []string{
        "/user.v1.UserService/GetUser", // 公开方法
    })

    // 创建限流器
    rateLimiter := interceptor.NewRateLimiter(10, 20) // 每秒10个请求,突发20个

    // 创建 gRPC 服务器
    s := grpc.NewServer(
        grpc.ChainUnaryInterceptor(
            interceptor.RecoveryUnaryInterceptor,           // 恢复拦截器(最外层)
            interceptor.LoggingUnaryInterceptor,            // 日志拦截器
            interceptor.MetricsUnaryInterceptor,            // 监控拦截器
            rateLimiter.UnaryServerInterceptor(),           // 限流拦截器
            authInterceptor.Unary(),                        // 认证拦截器
        ),
        grpc.ChainStreamInterceptor(
            interceptor.RecoveryStreamInterceptor,          // 恢复拦截器(最外层)
            interceptor.LoggingStreamInterceptor,           // 日志拦截器
            interceptor.MetricsStreamInterceptor,           // 监控拦截器
            rateLimiter.StreamServerInterceptor(),          // 限流拦截器
            authInterceptor.Stream(),                       // 认证拦截器
        ),
    )

    // 注册服务
    pb.RegisterUserServiceServer(s, &userService{})

    // 启用反射
    reflection.Register(s)

    log.Println("gRPC server listening on :50051")
    if err := s.Serve(lis); err != nil {
        log.Fatalf("Failed to serve: %v", err)
    }
}

客户端配置示例 #

// cmd/client/main.go
package main

import (
    "context"
    "log"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/credentials/insecure"

    "github.com/example/grpc-client/internal/interceptor"
    pb "github.com/example/grpc-client/proto/user/v1"
)

func main() {
    // 创建令牌提供者
    tokenProvider := interceptor.NewStaticTokenProvider("valid-token")

    // 创建认证拦截器
    authInterceptor := interceptor.NewClientAuthInterceptor(tokenProvider)

    // 建立连接
    conn, err := grpc.Dial(
        "localhost:50051",
        grpc.WithTransportCredentials(insecure.NewCredentials()),
        grpc.WithChainUnaryInterceptor(
            interceptor.ClientLoggingUnaryInterceptor,  // 日志拦截器
            authInterceptor.Unary(),                    // 认证拦截器
        ),
        grpc.WithChainStreamInterceptor(
            interceptor.ClientLoggingStreamInterceptor, // 日志拦截器
            authInterceptor.Stream(),                   // 认证拦截器
        ),
    )
    if err != nil {
        log.Fatalf("Failed to connect: %v", err)
    }
    defer conn.Close()

    // 创建客户端
    client := pb.NewUserServiceClient(conn)

    // 使用客户端
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    resp, err := client.GetUser(ctx, &pb.GetUserRequest{UserId: 1})
    if err != nil {
        log.Printf("GetUser failed: %v", err)
    } else {
        log.Printf("GetUser success: %+v", resp.User)
    }
}

高级拦截器模式 #

1. 条件拦截器 #

// internal/interceptor/conditional.go
package interceptor

import (
    "context"
    "strings"

    "google.golang.org/grpc"
)

type ConditionalInterceptor struct {
    condition   func(method string) bool
    interceptor grpc.UnaryServerInterceptor
}

func NewConditionalInterceptor(condition func(string) bool, interceptor grpc.UnaryServerInterceptor) *ConditionalInterceptor {
    return &ConditionalInterceptor{
        condition:   condition,
        interceptor: interceptor,
    }
}

func (ci *ConditionalInterceptor) Unary() grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        if ci.condition(info.FullMethod) {
            return ci.interceptor(ctx, req, info, handler)
        }
        return handler(ctx, req)
    }
}

// 使用示例
func main() {
    // 只对特定方法应用认证
    authCondition := func(method string) bool {
        return !strings.HasSuffix(method, "/Health")
    }

    conditionalAuth := NewConditionalInterceptor(authCondition, authInterceptor.Unary())

    s := grpc.NewServer(
        grpc.UnaryInterceptor(conditionalAuth.Unary()),
    )
}

2. 配置驱动的拦截器 #

// internal/interceptor/configurable.go
package interceptor

import (
    "context"
    "time"

    "google.golang.org/grpc"
)

type InterceptorConfig struct {
    EnableLogging   bool          `yaml:"enable_logging"`
    EnableMetrics   bool          `yaml:"enable_metrics"`
    EnableAuth      bool          `yaml:"enable_auth"`
    EnableRateLimit bool          `yaml:"enable_rate_limit"`
    RateLimit       RateLimitConfig `yaml:"rate_limit"`
    LogLevel        string        `yaml:"log_level"`
}

type RateLimitConfig struct {
    RequestsPerSecond int           `yaml:"requests_per_second"`
    BurstSize         int           `yaml:"burst_size"`
    CleanupInterval   time.Duration `yaml:"cleanup_interval"`
}

func BuildInterceptorChain(config InterceptorConfig) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) {
    var unaryInterceptors []grpc.UnaryServerInterceptor
    var streamInterceptors []grpc.StreamServerInterceptor

    // 恢复拦截器总是启用
    unaryInterceptors = append(unaryInterceptors, RecoveryUnaryInterceptor)
    streamInterceptors = append(streamInterceptors, RecoveryStreamInterceptor)

    // 条件性添加拦截器
    if config.EnableLogging {
        unaryInterceptors = append(unaryInterceptors, LoggingUnaryInterceptor)
        streamInterceptors = append(streamInterceptors, LoggingStreamInterceptor)
    }

    if config.EnableMetrics {
        unaryInterceptors = append(unaryInterceptors, MetricsUnaryInterceptor)
        streamInterceptors = append(streamInterceptors, MetricsStreamInterceptor)
    }

    if config.EnableRateLimit {
        rateLimiter := NewRateLimiter(
            rate.Limit(config.RateLimit.RequestsPerSecond),
            config.RateLimit.BurstSize,
        )
        unaryInterceptors = append(unaryInterceptors, rateLimiter.UnaryServerInterceptor())
        streamInterceptors = append(streamInterceptors, rateLimiter.StreamServerInterceptor())
    }

    if config.EnableAuth {
        // 这里需要根据配置创建认证拦截器
        // authInterceptor := createAuthInterceptor(config)
        // unaryInterceptors = append(unaryInterceptors, authInterceptor.Unary())
        // streamInterceptors = append(streamInterceptors, authInterceptor.Stream())
    }

    return unaryInterceptors, streamInterceptors
}

最佳实践 #

1. 拦截器顺序 #

// 推荐的拦截器顺序(从外到内)
var recommendedOrder = []string{
    "Recovery",      // 最外层,捕获 panic
    "Logging",       // 记录请求日志
    "Metrics",       // 收集指标
    "Tracing",       // 分布式追踪
    "RateLimit",     // 限流控制
    "Auth",          // 认证授权
    "Validation",    // 参数验证
    "Business",      // 业务逻辑
}

2. 错误处理 #

func SafeInterceptor(next grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        defer func() {
            if r := recover(); r != nil {
                log.Printf("Interceptor panic: %v", r)
            }
        }()

        return next(ctx, req, info, handler)
    }
}

3. 性能考虑 #

// 避免在拦截器中进行重复的昂贵操作
func OptimizedInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    // 使用缓存避免重复计算
    if cached := getFromCache(info.FullMethod); cached != nil {
        return cached, nil
    }

    // 异步处理非关键路径
    go func() {
        // 异步日志记录
        logAsync(info.FullMethod, req)
    }()

    return handler(ctx, req)
}

小结 #

本节详细介绍了 gRPC 拦截器和中间件的开发:

  1. 拦截器类型:一元和流式、服务端和客户端拦截器
  2. 常用拦截器:日志、认证、限流、监控、恢复
  3. 拦截器链:组合多个拦截器的方法
  4. 高级模式:条件拦截器、配置驱动的拦截器
  5. 最佳实践:拦截器顺序、错误处理、性能优化

通过合理使用拦截器,可以实现横切关注点的分离,提高代码的可维护性和可扩展性。拦截器是构建企业级 gRPC 应用的重要工具。