diff --git a/correlation/context.go b/correlation/context.go index 34559747b2af6643c6f47c6ee854a3f5233a7666..cf9cf6d6842d1e58d134930b505d95038ef882fc 100644 --- a/correlation/context.go +++ b/correlation/context.go @@ -9,6 +9,9 @@ type ctxKey int const ( keyCorrelationID ctxKey = iota keyClientName + keyUserID + keyUsername + keyRemoteIP ) func extractFromContextByKey(ctx context.Context, key ctxKey) string { @@ -55,3 +58,33 @@ func ExtractClientNameFromContext(ctx context.Context) string { func ContextWithClientName(ctx context.Context, clientName string) context.Context { return context.WithValue(ctx, keyClientName, clientName) } + +// ExtractUserIDFromContext extracts user ID from incoming context. +func ExtractUserIDFromContext(ctx context.Context) string { + return extractFromContextByKey(ctx, keyUserID) +} + +// ContextWithUserID will create a new context containing user ID metadata. +func ContextWithUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, keyUserID, userID) +} + +// ExtractUsernameFromContext extracts username from incoming context. +func ExtractUsernameFromContext(ctx context.Context) string { + return extractFromContextByKey(ctx, keyUsername) +} + +// ContextWithUsername will create a new context containing username metadata. +func ContextWithUsername(ctx context.Context, username string) context.Context { + return context.WithValue(ctx, keyUsername, username) +} + +// ExtractRemoteIPFromContext extracts remote IP from incoming context. +func ExtractRemoteIPFromContext(ctx context.Context) string { + return extractFromContextByKey(ctx, keyRemoteIP) +} + +// ContextWithRemoteIP will create a new context containing remote IP metadata. +func ContextWithRemoteIP(ctx context.Context, remoteIP string) context.Context { + return context.WithValue(ctx, keyRemoteIP, remoteIP) +} diff --git a/correlation/grpc/client_interceptors.go b/correlation/grpc/client_interceptors.go index aab8f210ad23a2f9febcf737c476a168a23050dc..f5c1f979f23234095db5ba151d3439d77b441c78 100644 --- a/correlation/grpc/client_interceptors.go +++ b/correlation/grpc/client_interceptors.go @@ -8,13 +8,22 @@ import ( "google.golang.org/grpc/metadata" ) -func appendToOutgoingContext(ctx context.Context, clientName string) context.Context { +func appendToOutgoingContext(ctx context.Context, config clientInterceptConfig) context.Context { correlationID := correlation.ExtractFromContext(ctx) if correlationID != "" { ctx = metadata.AppendToOutgoingContext(ctx, metadataCorrelatorKey, correlationID) } - if clientName != "" { - ctx = metadata.AppendToOutgoingContext(ctx, metadataClientNameKey, clientName) + if config.clientName != "" { + ctx = metadata.AppendToOutgoingContext(ctx, metadataClientNameKey, config.clientName) + } + if config.userID != "" { + ctx = metadata.AppendToOutgoingContext(ctx, metadataUserIDKey, config.userID) + } + if config.username != "" { + ctx = metadata.AppendToOutgoingContext(ctx, metadataUsernameKey, config.username) + } + if config.remoteIP != "" { + ctx = metadata.AppendToOutgoingContext(ctx, metadataUsernameKey, config.remoteIP) } return ctx @@ -25,7 +34,7 @@ func UnaryClientCorrelationInterceptor(opts ...ClientCorrelationInterceptorOptio config := applyClientCorrelationInterceptorOptions(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = appendToOutgoingContext(ctx, config.clientName) + ctx = appendToOutgoingContext(ctx, config) return invoker(ctx, method, req, reply, cc, opts...) } } @@ -35,7 +44,7 @@ func StreamClientCorrelationInterceptor(opts ...ClientCorrelationInterceptorOpti config := applyClientCorrelationInterceptorOptions(opts) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - ctx = appendToOutgoingContext(ctx, config.clientName) + ctx = appendToOutgoingContext(ctx, config) return streamer(ctx, desc, cc, method, opts...) } } diff --git a/correlation/grpc/client_interceptors_options.go b/correlation/grpc/client_interceptors_options.go index 1d618935049c41327e0031464a2b289b52a65fac..0ee9e9fe9d45922ef73965a6f6f76337e5c7c905 100644 --- a/correlation/grpc/client_interceptors_options.go +++ b/correlation/grpc/client_interceptors_options.go @@ -3,6 +3,9 @@ package grpccorrelation // The configuration for InjectCorrelationID. type clientInterceptConfig struct { clientName string + userID string + username string + remoteIP string } // ClientCorrelationInterceptorOption configures client correlation interceptors. @@ -24,3 +27,24 @@ func WithClientName(clientName string) ClientCorrelationInterceptorOption { config.clientName = clientName } } + +// WithUserID propagates the user ID +func WithUserID(userID string) ClientCorrelationInterceptorOption { + return func(config *clientInterceptConfig) { + config.userID = userID + } +} + +// WithUsername propagates the username +func WithUsername(username string) ClientCorrelationInterceptorOption { + return func(config *clientInterceptConfig) { + config.username = username + } +} + +// WithRemoteIP propagates the remote IP +func WithRemoteIP(remoteIP string) ClientCorrelationInterceptorOption { + return func(config *clientInterceptConfig) { + config.remoteIP = remoteIP + } +} diff --git a/correlation/grpc/key.go b/correlation/grpc/key.go index 3af234b5bcadb9d28bc2381cbcfcc3bcbea1ddb6..e22e46e23b98c1920baf9b9879427aac110cce9f 100644 --- a/correlation/grpc/key.go +++ b/correlation/grpc/key.go @@ -3,4 +3,7 @@ package grpccorrelation const ( metadataCorrelatorKey = "X-GitLab-Correlation-ID" metadataClientNameKey = "X-GitLab-Client-Name" + metadataUserIDKey = "X-GitLab-User-ID" + metadataUsernameKey = "X-GitLab-Username" + metadataRemoteIPKey = "X-GitLab-Remote-IP" ) diff --git a/correlation/grpc/server_interceptors.go b/correlation/grpc/server_interceptors.go index 32edecd842a4d85f375b845747105f6c39cadef3..a215809854734f1089fdfc946b1f03540889f0aa 100644 --- a/correlation/grpc/server_interceptors.go +++ b/correlation/grpc/server_interceptors.go @@ -27,6 +27,21 @@ func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool if len(clientNames) > 0 { ctx = correlation.ContextWithClientName(ctx, clientNames[0]) } + + userIDs := md.Get(metadataUserIDKey) + if len(userIDs) > 0 { + ctx = correlation.ContextWithUserID(ctx, userIDs[0]) + } + + usernames := md.Get(metadataUsernameKey) + if len(usernames) > 0 { + ctx = correlation.ContextWithUsername(ctx, usernames[0]) + } + + remoteIP := md.Get(metadataRemoteIPKey) + if len(remoteIP) > 0 { + ctx = correlation.ContextWithRemoteIP(ctx, remoteIP[0]) + } } if generateCorrelationID { ctx = correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()) diff --git a/correlation/inbound_http.go b/correlation/inbound_http.go index e031a8fc716f8b2371fd9c292ce0ec3f38b62d2e..a5c76e93e574f838b8553a4c0004ad4af667924e 100644 --- a/correlation/inbound_http.go +++ b/correlation/inbound_http.go @@ -17,19 +17,31 @@ func InjectCorrelationID(h http.Handler, opts ...InboundHandlerOption) http.Hand parent := r.Context() correlationID := "" - clientName := "" if config.propagation { - correlationID, clientName = extractFromRequest(r) + correlationID = r.Header.Get(propagationHeader) } - if correlationID == "" { correlationID = SafeRandomID() } - ctx := ContextWithCorrelation(parent, correlationID) - if clientName != "" { - ctx = ContextWithClientName(ctx, clientName) + if config.propagation { + clientName := r.Header.Get(clientNameHeader) + if clientName != "" { + ctx = ContextWithClientName(ctx, clientName) + } + userID := r.Header.Get(userIDHeader) + if userID != "" { + ctx = ContextWithUserID(ctx, userID) + } + username := r.Header.Get(usernameHeader) + if username != "" { + ctx = ContextWithUsername(ctx, username) + } + remoteIP := r.Header.Get(remoteIPHeader) + if remoteIP != "" { + ctx = ContextWithRemoteIP(ctx, remoteIP) + } } h.ServeHTTP(w, r.WithContext(ctx)) @@ -40,10 +52,6 @@ func InjectCorrelationID(h http.Handler, opts ...InboundHandlerOption) http.Hand }) } -func extractFromRequest(r *http.Request) (string, string) { - return r.Header.Get(propagationHeader), r.Header.Get(clientNameHeader) -} - // setResponseHeader will set the response header, if it has not already // been set by an downstream response. func setResponseHeader(w http.ResponseWriter, correlationID string) { diff --git a/correlation/outbound_http.go b/correlation/outbound_http.go index 98f4d9ffffe919ed0874e95ae92d8643b512ce4a..5aa66422b750e758fe97682c4aac82fcb9225ed0 100644 --- a/correlation/outbound_http.go +++ b/correlation/outbound_http.go @@ -7,6 +7,9 @@ import ( const ( propagationHeader = "X-Request-ID" clientNameHeader = "X-GitLab-Client-Name" + userIDHeader = "X-GitLab-User-ID" + usernameHeader = "X-GitLab-Username" + remoteIPHeader = "X-Forwarded-For" ) type instrumentedRoundTripper struct { diff --git a/correlation/outbound_http_options.go b/correlation/outbound_http_options.go index 36573e0c0ed10ce09004e3eeae7a9d571f4edbae..996086e812fd9353202f07f743c252f4b28faac2 100644 --- a/correlation/outbound_http_options.go +++ b/correlation/outbound_http_options.go @@ -3,6 +3,9 @@ package correlation // The configuration for InjectCorrelationID. type instrumentedRoundTripperConfig struct { clientName string + userID string + username string + remoteIP string } // InstrumentedRoundTripperOption will configure a correlation handler @@ -26,3 +29,24 @@ func WithClientName(clientName string) InstrumentedRoundTripperOption { config.clientName = clientName } } + +// WithUserID will configure user ID propagation +func WithUserID(userID string) InstrumentedRoundTripperOption { + return func(config *instrumentedRoundTripperConfig) { + config.userID = userID + } +} + +// WithUsername will configure username propagation +func WithUsername(username string) InstrumentedRoundTripperOption { + return func(config *instrumentedRoundTripperConfig) { + config.username = username + } +} + +// WithUsername will configure remote IP propagation +func WithRemoteIP(remoteIP string) InstrumentedRoundTripperOption { + return func(config *instrumentedRoundTripperConfig) { + config.remoteIP = remoteIP + } +} diff --git a/log/access_logger.go b/log/access_logger.go index 0bbe876d1712e32586dc4a8812797c4f173fd991..e58fdd98c4eef84a5cf8b72a78d46dbe5c3ef9d8 100644 --- a/log/access_logger.go +++ b/log/access_logger.go @@ -165,6 +165,14 @@ func (l *loggingResponseWriter) accessLogFields(r *http.Request) logrus.Fields { fields[httpResponseContentTypeField] = l.contentType } + if fieldsBitMask&HTTPUserID != 0 { + fields[httpUserIDField] = correlation.ExtractUserIDFromContext(r.Context()) + } + + if fieldsBitMask&HTTPUsername != 0 { + fields[httpUsernameField] = correlation.ExtractUsernameFromContext(r.Context()) + } + return fields } @@ -185,11 +193,13 @@ func (l *hijackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) } func getRemoteIP(r *http.Request) string { - remoteAddr := xff.GetRemoteAddr(r) - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return r.RemoteAddr + if xffh := correlation.ExtractRemoteIPFromContext(r.Context()); xffh != "" { + return xff.Parse(xffh) + } + + if xffh := r.Header.Get("X-Forwarded-For"); xffh != "" { + return xff.Parse(xffh) } - return host + return r.RemoteAddr } diff --git a/log/access_logger_fields.go b/log/access_logger_fields.go index f542353d26e2b38f4788d9c94335a7516d979052..64611e2acb16b365137fad30dea5cba8754e0f89 100644 --- a/log/access_logger_fields.go +++ b/log/access_logger_fields.go @@ -1,7 +1,7 @@ package log // AccessLogField is used to select which fields are recorded in the access log. See WithoutFields. -type AccessLogField uint16 +type AccessLogField uint32 const ( // CorrelationID field will record the Correlation-ID in the access log. @@ -50,6 +50,12 @@ const ( // in the access log. Time is recorded before an actual Write happens to ensure that this metric // is not affected by a slow client receiving data. RequestTTFB + + // UserID will record the ID of the user making the request + HTTPUserID + + // Username will record the name of the user making the request + HTTPUsername ) const defaultEnabledFields = ^AccessLogField(0) // By default, all fields are enabled @@ -71,4 +77,6 @@ const ( requestTTFBField = "ttfb_ms" // ESC: no mapping systemField = "system" // ESC: no mapping httpResponseContentTypeField = "content_type" // ESC: no mapping + httpUserIDField = "user_id" // ESC: user.id + httpUsernameField = "username" // ESC: user.name )