3.7.5 gRPC 拦截器与中间件 #
gRPC 拦截器(Interceptor)是一种强大的机制,允许在 RPC 调用的执行过程中插入自定义逻辑。它们类似于 HTTP 中间件,可以用于实现认证、日志记录、监控、限流等横切关注点功能。
拦截器基础 #
拦截器类型 #
gRPC 提供了四种类型的拦截器:
- 一元服务端拦截器(Unary Server Interceptor)
- 流式服务端拦截器(Stream Server Interceptor)
- 一元客户端拦截器(Unary Client Interceptor)
- 流式客户端拦截器(Stream Client Interceptor)
拦截器签名 #
// 一元服务端拦截器
type UnaryServerInterceptor func(
ctx context.Context,
req interface{},
info *UnaryServerInfo,
handler UnaryHandler,
) (resp interface{}, err error)
// 流式服务端拦截器
type StreamServerInterceptor func(
srv interface{},
ss ServerStream,
info *StreamServerInfo,
handler StreamHandler,
) error
// 一元客户端拦截器
type UnaryClientInterceptor func(
ctx context.Context,
method string,
req, reply interface{},
cc *ClientConn,
invoker UnaryInvoker,
opts ...CallOption,
) error
// 流式客户端拦截器
type StreamClientInterceptor func(
ctx context.Context,
desc *StreamDesc,
cc *ClientConn,
method string,
streamer Streamer,
opts ...CallOption,
) (ClientStream, error)
服务端拦截器 #
1. 日志拦截器 #
// internal/interceptor/logging.go
package interceptor
import (
"context"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// 一元 RPC 日志拦截器
func LoggingUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
// 记录请求开始
log.Printf("[%s] Request started - Method: %s",
time.Now().Format("2006-01-02 15:04:05"),
info.FullMethod)
// 调用实际的处理函数
resp, err := handler(ctx, req)
// 计算耗时
duration := time.Since(start)
// 获取状态码
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
// 记录请求完成
log.Printf("[%s] Request completed - Method: %s, Duration: %v, Code: %s, Error: %v",
time.Now().Format("2006-01-02 15:04:05"),
info.FullMethod,
duration,
code,
err)
return resp, err
}
// 流式 RPC 日志拦截器
func LoggingStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
log.Printf("[%s] Stream started - Method: %s, Type: %s",
time.Now().Format("2006-01-02 15:04:05"),
info.FullMethod,
getStreamType(info))
// 包装流以记录消息
wrappedStream := &loggingServerStream{
ServerStream: stream,
method: info.FullMethod,
}
err := handler(srv, wrappedStream)
duration := time.Since(start)
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
log.Printf("[%s] Stream completed - Method: %s, Duration: %v, Code: %s, Error: %v",
time.Now().Format("2006-01-02 15:04:05"),
info.FullMethod,
duration,
code,
err)
return err
}
// 流包装器
type loggingServerStream struct {
grpc.ServerStream
method string
}
func (s *loggingServerStream) SendMsg(m interface{}) error {
log.Printf("[%s] Stream send - Method: %s",
time.Now().Format("2006-01-02 15:04:05"),
s.method)
return s.ServerStream.SendMsg(m)
}
func (s *loggingServerStream) RecvMsg(m interface{}) error {
err := s.ServerStream.RecvMsg(m)
if err == nil {
log.Printf("[%s] Stream receive - Method: %s",
time.Now().Format("2006-01-02 15:04:05"),
s.method)
}
return err
}
func getStreamType(info *grpc.StreamServerInfo) string {
if info.IsClientStream && info.IsServerStream {
return "bidirectional"
} else if info.IsClientStream {
return "client-stream"
} else if info.IsServerStream {
return "server-stream"
}
return "unknown"
}
2. 认证拦截器 #
// internal/interceptor/auth.go
package interceptor
import (
"context"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type AuthInterceptor struct {
jwtManager *JWTManager
publicMethods map[string]bool
}
func NewAuthInterceptor(jwtManager *JWTManager, publicMethods []string) *AuthInterceptor {
publicMethodsMap := make(map[string]bool)
for _, method := range publicMethods {
publicMethodsMap[method] = true
}
return &AuthInterceptor{
jwtManager: jwtManager,
publicMethods: publicMethodsMap,
}
}
func (interceptor *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 检查是否为公开方法
if interceptor.publicMethods[info.FullMethod] {
return handler(ctx, req)
}
// 验证访问令牌
err := interceptor.authorize(ctx)
if err != nil {
return nil, err
}
return handler(ctx, req)
}
}
func (interceptor *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
// 检查是否为公开方法
if interceptor.publicMethods[info.FullMethod] {
return handler(srv, stream)
}
// 验证访问令牌
err := interceptor.authorize(stream.Context())
if err != nil {
return err
}
return handler(srv, stream)
}
}
func (interceptor *AuthInterceptor) authorize(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Errorf(codes.Unauthenticated, "metadata is not provided")
}
values := md["authorization"]
if len(values) == 0 {
return status.Errorf(codes.Unauthenticated, "authorization token is not provided")
}
accessToken := values[0]
if !strings.HasPrefix(accessToken, "Bearer ") {
return status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
accessToken = strings.TrimPrefix(accessToken, "Bearer ")
claims, err := interceptor.jwtManager.Verify(accessToken)
if err != nil {
return status.Errorf(codes.Unauthenticated, "access token is invalid: %v", err)
}
// 将用户信息添加到上下文
ctx = context.WithValue(ctx, "user_id", claims.UserID)
ctx = context.WithValue(ctx, "username", claims.Username)
return nil
}
// JWT 管理器示例
type JWTManager struct {
secretKey string
}
type UserClaims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
}
func NewJWTManager(secretKey string) *JWTManager {
return &JWTManager{secretKey: secretKey}
}
func (manager *JWTManager) Verify(accessToken string) (*UserClaims, error) {
// 这里应该实现实际的 JWT 验证逻辑
// 为了示例,我们简化处理
if accessToken == "valid-token" {
return &UserClaims{
UserID: 1,
Username: "testuser",
}, nil
}
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
3. 限流拦截器 #
// internal/interceptor/ratelimit.go
package interceptor
import (
"context"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"golang.org/x/time/rate"
)
type RateLimiter struct {
mu sync.RWMutex
limiters map[string]*rate.Limiter
rate rate.Limit
burst int
}
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: r,
burst: b,
}
}
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[key]
rl.mu.RUnlock()
if !exists {
rl.mu.Lock()
// 双重检查
if limiter, exists = rl.limiters[key]; !exists {
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[key] = limiter
}
rl.mu.Unlock()
}
return limiter
}
func (rl *RateLimiter) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 使用方法名作为限流键
limiter := rl.getLimiter(info.FullMethod)
if !limiter.Allow() {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for method %s", info.FullMethod)
}
return handler(ctx, req)
}
}
func (rl *RateLimiter) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
limiter := rl.getLimiter(info.FullMethod)
if !limiter.Allow() {
return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for method %s", info.FullMethod)
}
return handler(srv, stream)
}
}
// 基于 IP 的限流器
type IPRateLimiter struct {
mu sync.RWMutex
limiters map[string]*rate.Limiter
rate rate.Limit
burst int
}
func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
limiter := &IPRateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: r,
burst: b,
}
// 启动清理协程
go limiter.cleanup()
return limiter
}
func (rl *IPRateLimiter) getClientIP(ctx context.Context) string {
// 从上下文中获取客户端 IP
// 这里简化处理,实际应该从 peer 信息中获取
return "127.0.0.1"
}
func (rl *IPRateLimiter) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
clientIP := rl.getClientIP(ctx)
limiter := rl.getLimiter(clientIP)
if !limiter.Allow() {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for IP %s", clientIP)
}
return handler(ctx, req)
}
}
func (rl *IPRateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[ip]
rl.mu.RUnlock()
if !exists {
rl.mu.Lock()
if limiter, exists = rl.limiters[ip]; !exists {
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[ip] = limiter
}
rl.mu.Unlock()
}
return limiter
}
func (rl *IPRateLimiter) cleanup() {
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
for ip, limiter := range rl.limiters {
// 如果限流器长时间未使用,则删除
if limiter.Tokens() == float64(rl.burst) {
delete(rl.limiters, ip)
}
}
rl.mu.Unlock()
}
}
4. 监控拦截器 #
// internal/interceptor/metrics.go
package interceptor
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
// 请求总数
grpcRequestsTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "grpc_requests_total",
Help: "Total number of gRPC requests",
},
[]string{"method", "code"},
)
// 请求持续时间
grpcRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "grpc_request_duration_seconds",
Help: "Duration of gRPC requests",
Buckets: prometheus.DefBuckets,
},
[]string{"method"},
)
// 活跃连接数
grpcActiveConnections = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "grpc_active_connections",
Help: "Number of active gRPC connections",
},
)
// 消息大小
grpcMessageSize = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "grpc_message_size_bytes",
Help: "Size of gRPC messages",
Buckets: []float64{64, 256, 1024, 4096, 16384, 65536, 262144, 1048576},
},
[]string{"method", "type"},
)
)
func MetricsUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
// 记录请求大小
if reqSize := getMessageSize(req); reqSize > 0 {
grpcMessageSize.WithLabelValues(info.FullMethod, "request").Observe(float64(reqSize))
}
resp, err := handler(ctx, req)
// 计算持续时间
duration := time.Since(start)
grpcRequestDuration.WithLabelValues(info.FullMethod).Observe(duration.Seconds())
// 获取状态码
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
// 记录请求总数
grpcRequestsTotal.WithLabelValues(info.FullMethod, code.String()).Inc()
// 记录响应大小
if resp != nil {
if respSize := getMessageSize(resp); respSize > 0 {
grpcMessageSize.WithLabelValues(info.FullMethod, "response").Observe(float64(respSize))
}
}
return resp, err
}
func MetricsStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
// 包装流以监控消息
wrappedStream := &metricsServerStream{
ServerStream: stream,
method: info.FullMethod,
}
err := handler(srv, wrappedStream)
duration := time.Since(start)
grpcRequestDuration.WithLabelValues(info.FullMethod).Observe(duration.Seconds())
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
grpcRequestsTotal.WithLabelValues(info.FullMethod, code.String()).Inc()
return err
}
type metricsServerStream struct {
grpc.ServerStream
method string
}
func (s *metricsServerStream) SendMsg(m interface{}) error {
if size := getMessageSize(m); size > 0 {
grpcMessageSize.WithLabelValues(s.method, "response").Observe(float64(size))
}
return s.ServerStream.SendMsg(m)
}
func (s *metricsServerStream) RecvMsg(m interface{}) error {
err := s.ServerStream.RecvMsg(m)
if err == nil {
if size := getMessageSize(m); size > 0 {
grpcMessageSize.WithLabelValues(s.method, "request").Observe(float64(size))
}
}
return err
}
// 获取消息大小的辅助函数
func getMessageSize(msg interface{}) int {
// 这里应该实现实际的消息大小计算
// 可以使用 proto.Size() 或其他方法
return 0
}
5. 恢复拦截器 #
// internal/interceptor/recovery.go
package interceptor
import (
"context"
"log"
"runtime/debug"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func RecoveryUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic recovered in %s: %v\n%s", info.FullMethod, r, debug.Stack())
err = status.Errorf(codes.Internal, "internal server error")
}
}()
return handler(ctx, req)
}
func RecoveryStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic recovered in stream %s: %v\n%s", info.FullMethod, r, debug.Stack())
err = status.Errorf(codes.Internal, "internal server error")
}
}()
return handler(srv, stream)
}
// 自定义恢复处理器
type RecoveryHandler func(p interface{}) error
func RecoveryUnaryInterceptorWithHandler(recoveryHandler RecoveryHandler) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic recovered in %s: %v\n%s", info.FullMethod, r, debug.Stack())
err = recoveryHandler(r)
}
}()
return handler(ctx, req)
}
}
// 默认恢复处理器
func DefaultRecoveryHandler(p interface{}) error {
return status.Errorf(codes.Internal, "internal server error")
}
// 详细恢复处理器
func DetailedRecoveryHandler(p interface{}) error {
log.Printf("Detailed panic info: %+v", p)
return status.Errorf(codes.Internal, "internal server error: %v", p)
}
客户端拦截器 #
1. 客户端日志拦截器 #
// internal/interceptor/client_logging.go
package interceptor
import (
"context"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func ClientLoggingUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
start := time.Now()
log.Printf("[Client] Request started - Method: %s", method)
err := invoker(ctx, method, req, reply, cc, opts...)
duration := time.Since(start)
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
log.Printf("[Client] Request completed - Method: %s, Duration: %v, Code: %s, Error: %v",
method, duration, code, err)
return err
}
func ClientLoggingStreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
start := time.Now()
log.Printf("[Client] Stream started - Method: %s", method)
stream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
log.Printf("[Client] Stream creation failed - Method: %s, Error: %v", method, err)
return nil, err
}
return &loggingClientStream{
ClientStream: stream,
method: method,
startTime: start,
}, nil
}
type loggingClientStream struct {
grpc.ClientStream
method string
startTime time.Time
}
func (s *loggingClientStream) SendMsg(m interface{}) error {
log.Printf("[Client] Stream send - Method: %s", s.method)
return s.ClientStream.SendMsg(m)
}
func (s *loggingClientStream) RecvMsg(m interface{}) error {
err := s.ClientStream.RecvMsg(m)
if err == nil {
log.Printf("[Client] Stream receive - Method: %s", s.method)
} else {
log.Printf("[Client] Stream receive error - Method: %s, Error: %v", s.method, err)
}
return err
}
func (s *loggingClientStream) CloseSend() error {
duration := time.Since(s.startTime)
log.Printf("[Client] Stream closed - Method: %s, Duration: %v", s.method, duration)
return s.ClientStream.CloseSend()
}
2. 客户端认证拦截器 #
// internal/interceptor/client_auth.go
package interceptor
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type ClientAuthInterceptor struct {
tokenProvider TokenProvider
}
type TokenProvider interface {
GetToken() (string, error)
}
func NewClientAuthInterceptor(tokenProvider TokenProvider) *ClientAuthInterceptor {
return &ClientAuthInterceptor{
tokenProvider: tokenProvider,
}
}
func (interceptor *ClientAuthInterceptor) Unary() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
token, err := interceptor.tokenProvider.GetToken()
if err != nil {
return err
}
// 添加认证头
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func (interceptor *ClientAuthInterceptor) Stream() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
token, err := interceptor.tokenProvider.GetToken()
if err != nil {
return nil, err
}
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
return streamer(ctx, desc, cc, method, opts...)
}
}
// 简单的令牌提供者实现
type StaticTokenProvider struct {
token string
}
func NewStaticTokenProvider(token string) *StaticTokenProvider {
return &StaticTokenProvider{token: token}
}
func (p *StaticTokenProvider) GetToken() (string, error) {
return p.token, nil
}
// 动态令牌提供者
type DynamicTokenProvider struct {
getTokenFunc func() (string, error)
}
func NewDynamicTokenProvider(getTokenFunc func() (string, error)) *DynamicTokenProvider {
return &DynamicTokenProvider{
getTokenFunc: getTokenFunc,
}
}
func (p *DynamicTokenProvider) GetToken() (string, error) {
return p.getTokenFunc()
}
拦截器链 #
拦截器组合 #
// internal/interceptor/chain.go
package interceptor
import (
"google.golang.org/grpc"
)
// 服务端拦截器链
func ChainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return grpc.ChainUnaryInterceptor(interceptors...)
}
func ChainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return grpc.ChainStreamInterceptor(interceptors...)
}
// 客户端拦截器链
func ChainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
return grpc.ChainUnaryInterceptor(interceptors...)
}
func ChainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
return grpc.ChainStreamInterceptor(interceptors...)
}
服务器配置示例 #
// cmd/server/main.go
package main
import (
"log"
"net"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
"github.com/example/grpc-server/internal/interceptor"
pb "github.com/example/grpc-server/proto/user/v1"
)
func main() {
// 创建监听器
lis, err := net.Listen("tcp", ":50051")
if err != nil {
log.Fatalf("Failed to listen: %v", err)
}
// 创建 JWT 管理器
jwtManager := interceptor.NewJWTManager("secret-key")
// 创建认证拦截器
authInterceptor := interceptor.NewAuthInterceptor(jwtManager, []string{
"/user.v1.UserService/GetUser", // 公开方法
})
// 创建限流器
rateLimiter := interceptor.NewRateLimiter(10, 20) // 每秒10个请求,突发20个
// 创建 gRPC 服务器
s := grpc.NewServer(
grpc.ChainUnaryInterceptor(
interceptor.RecoveryUnaryInterceptor, // 恢复拦截器(最外层)
interceptor.LoggingUnaryInterceptor, // 日志拦截器
interceptor.MetricsUnaryInterceptor, // 监控拦截器
rateLimiter.UnaryServerInterceptor(), // 限流拦截器
authInterceptor.Unary(), // 认证拦截器
),
grpc.ChainStreamInterceptor(
interceptor.RecoveryStreamInterceptor, // 恢复拦截器(最外层)
interceptor.LoggingStreamInterceptor, // 日志拦截器
interceptor.MetricsStreamInterceptor, // 监控拦截器
rateLimiter.StreamServerInterceptor(), // 限流拦截器
authInterceptor.Stream(), // 认证拦截器
),
)
// 注册服务
pb.RegisterUserServiceServer(s, &userService{})
// 启用反射
reflection.Register(s)
log.Println("gRPC server listening on :50051")
if err := s.Serve(lis); err != nil {
log.Fatalf("Failed to serve: %v", err)
}
}
客户端配置示例 #
// cmd/client/main.go
package main
import (
"context"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/example/grpc-client/internal/interceptor"
pb "github.com/example/grpc-client/proto/user/v1"
)
func main() {
// 创建令牌提供者
tokenProvider := interceptor.NewStaticTokenProvider("valid-token")
// 创建认证拦截器
authInterceptor := interceptor.NewClientAuthInterceptor(tokenProvider)
// 建立连接
conn, err := grpc.Dial(
"localhost:50051",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithChainUnaryInterceptor(
interceptor.ClientLoggingUnaryInterceptor, // 日志拦截器
authInterceptor.Unary(), // 认证拦截器
),
grpc.WithChainStreamInterceptor(
interceptor.ClientLoggingStreamInterceptor, // 日志拦截器
authInterceptor.Stream(), // 认证拦截器
),
)
if err != nil {
log.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// 创建客户端
client := pb.NewUserServiceClient(conn)
// 使用客户端
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
resp, err := client.GetUser(ctx, &pb.GetUserRequest{UserId: 1})
if err != nil {
log.Printf("GetUser failed: %v", err)
} else {
log.Printf("GetUser success: %+v", resp.User)
}
}
高级拦截器模式 #
1. 条件拦截器 #
// internal/interceptor/conditional.go
package interceptor
import (
"context"
"strings"
"google.golang.org/grpc"
)
type ConditionalInterceptor struct {
condition func(method string) bool
interceptor grpc.UnaryServerInterceptor
}
func NewConditionalInterceptor(condition func(string) bool, interceptor grpc.UnaryServerInterceptor) *ConditionalInterceptor {
return &ConditionalInterceptor{
condition: condition,
interceptor: interceptor,
}
}
func (ci *ConditionalInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if ci.condition(info.FullMethod) {
return ci.interceptor(ctx, req, info, handler)
}
return handler(ctx, req)
}
}
// 使用示例
func main() {
// 只对特定方法应用认证
authCondition := func(method string) bool {
return !strings.HasSuffix(method, "/Health")
}
conditionalAuth := NewConditionalInterceptor(authCondition, authInterceptor.Unary())
s := grpc.NewServer(
grpc.UnaryInterceptor(conditionalAuth.Unary()),
)
}
2. 配置驱动的拦截器 #
// internal/interceptor/configurable.go
package interceptor
import (
"context"
"time"
"google.golang.org/grpc"
)
type InterceptorConfig struct {
EnableLogging bool `yaml:"enable_logging"`
EnableMetrics bool `yaml:"enable_metrics"`
EnableAuth bool `yaml:"enable_auth"`
EnableRateLimit bool `yaml:"enable_rate_limit"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
LogLevel string `yaml:"log_level"`
}
type RateLimitConfig struct {
RequestsPerSecond int `yaml:"requests_per_second"`
BurstSize int `yaml:"burst_size"`
CleanupInterval time.Duration `yaml:"cleanup_interval"`
}
func BuildInterceptorChain(config InterceptorConfig) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) {
var unaryInterceptors []grpc.UnaryServerInterceptor
var streamInterceptors []grpc.StreamServerInterceptor
// 恢复拦截器总是启用
unaryInterceptors = append(unaryInterceptors, RecoveryUnaryInterceptor)
streamInterceptors = append(streamInterceptors, RecoveryStreamInterceptor)
// 条件性添加拦截器
if config.EnableLogging {
unaryInterceptors = append(unaryInterceptors, LoggingUnaryInterceptor)
streamInterceptors = append(streamInterceptors, LoggingStreamInterceptor)
}
if config.EnableMetrics {
unaryInterceptors = append(unaryInterceptors, MetricsUnaryInterceptor)
streamInterceptors = append(streamInterceptors, MetricsStreamInterceptor)
}
if config.EnableRateLimit {
rateLimiter := NewRateLimiter(
rate.Limit(config.RateLimit.RequestsPerSecond),
config.RateLimit.BurstSize,
)
unaryInterceptors = append(unaryInterceptors, rateLimiter.UnaryServerInterceptor())
streamInterceptors = append(streamInterceptors, rateLimiter.StreamServerInterceptor())
}
if config.EnableAuth {
// 这里需要根据配置创建认证拦截器
// authInterceptor := createAuthInterceptor(config)
// unaryInterceptors = append(unaryInterceptors, authInterceptor.Unary())
// streamInterceptors = append(streamInterceptors, authInterceptor.Stream())
}
return unaryInterceptors, streamInterceptors
}
最佳实践 #
1. 拦截器顺序 #
// 推荐的拦截器顺序(从外到内)
var recommendedOrder = []string{
"Recovery", // 最外层,捕获 panic
"Logging", // 记录请求日志
"Metrics", // 收集指标
"Tracing", // 分布式追踪
"RateLimit", // 限流控制
"Auth", // 认证授权
"Validation", // 参数验证
"Business", // 业务逻辑
}
2. 错误处理 #
func SafeInterceptor(next grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Interceptor panic: %v", r)
}
}()
return next(ctx, req, info, handler)
}
}
3. 性能考虑 #
// 避免在拦截器中进行重复的昂贵操作
func OptimizedInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// 使用缓存避免重复计算
if cached := getFromCache(info.FullMethod); cached != nil {
return cached, nil
}
// 异步处理非关键路径
go func() {
// 异步日志记录
logAsync(info.FullMethod, req)
}()
return handler(ctx, req)
}
小结 #
本节详细介绍了 gRPC 拦截器和中间件的开发:
- 拦截器类型:一元和流式、服务端和客户端拦截器
- 常用拦截器:日志、认证、限流、监控、恢复
- 拦截器链:组合多个拦截器的方法
- 高级模式:条件拦截器、配置驱动的拦截器
- 最佳实践:拦截器顺序、错误处理、性能优化
通过合理使用拦截器,可以实现横切关注点的分离,提高代码的可维护性和可扩展性。拦截器是构建企业级 gRPC 应用的重要工具。