// Copyright 2022 Harness Inc. All rights reserved. // Use of this source code is governed by the Polyform Free Trial License // that can be found in the LICENSE.md file for this repository. package middleware import ( "context" "strings" "time" "github.com/harness/gitness/gitrpc/rpc" "github.com/rs/xid" "github.com/rs/zerolog/log" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) const ( RequestIDNone string = "gitrpc_none" ) // requestIDKey is context key for storing and retrieving the request ID to and from a context. type requestIDKey struct{} // LogInterceptor injects a zerolog logger with common grpc related annotations and logs the completion of the call. // If the metadata contains a request id, the logger is annotated with the same request ID, otherwise with a new one. type LogInterceptor struct { } func NewLogInterceptor() LogInterceptor { return LogInterceptor{} } func (i LogInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ctx = injectLogging(ctx, info.FullMethod) // measure execution time start := time.Now() value, err := handler(ctx, req) logCompletion(ctx, start, err) return value, err } } func (i LogInterceptor) StreamInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx := injectLogging(stream.Context(), info.FullMethod) // wrap stream with updated context stream = &logServerStream{ ServerStream: stream, ctx: ctx, } // measure execution time start := time.Now() err := handler(srv, stream) logCompletion(ctx, start, err) return err } } // WithRequestID returns a copy of parent in which the request id value is set. func WithRequestID(parent context.Context, v string) context.Context { return context.WithValue(parent, requestIDKey{}, v) } // RequestIDFrom retrieves the request id from the context. // If no request id exists, RequestIDNone is returned. func RequestIDFrom(ctx context.Context) string { if v, ok := ctx.Value(requestIDKey{}).(string); ok { return v } return RequestIDNone } func injectLogging(ctx context.Context, fullMethod string) context.Context { // split fullMethod into service and method (expected format: "/package.service/method...") // If it doesn't match the expected format, the full string is put into method. service, method := "", fullMethod if len(fullMethod) > 0 && fullMethod[0] == '/' { if s, m, ok := strings.Cut(fullMethod[1:], "/"); ok { service, method = s, m } } // get request id (or create a new one) and inject it for later usage (git env variables) requestID := getOrCreateRequestID(ctx) ctx = WithRequestID(ctx, requestID) // create new logCtx with injected info logCtx := log.Logger.With(). Str("grpc.service", service). Str("grpc.method", method). Str("request_id", requestID) // add peer information if available if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { logCtx = logCtx.Str("grpc.peer", p.Addr.String()) } // inject logger in context logger := logCtx.Logger() return logger.WithContext(ctx) } func logCompletion(ctx context.Context, start time.Time, err error) { logCtx := log.Ctx(ctx).Info(). Dur("grpc.elapsed_ms", time.Since(start)) // try to get grpc status code if status, ok := status.FromError(err); ok { logCtx.Str("grpc.status_code", status.Code().String()) } logCtx.Msg("grpc request completed.") } func getOrCreateRequestID(ctx context.Context) string { // check if request id was passed as part of grpc metadata if md, ok := metadata.FromIncomingContext(ctx); ok { if ids := md.Get(rpc.MetadataKeyRequestID); len(ids) > 0 { return ids[0] } } // use same type of request IDs as used by zerolog return xid.New().String() } // logServerStream is used to modify the stream context. // In order to modify the stream context we have to create a new struct and overshadow the `Context()` method. type logServerStream struct { grpc.ServerStream ctx context.Context } func (s *logServerStream) Context() context.Context { return s.ctx }