diff --git a/log/logger.go b/log/logger.go index 7781095775c48da72a5742c3f06bfd048a9fbee1..c2c15b052f69ee7f15b66f7134a511516eb080b6 100644 --- a/log/logger.go +++ b/log/logger.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/request_context" ) var logger = logrus.StandardLogger() @@ -32,5 +33,13 @@ func WithContextFields(ctx context.Context, fields Fields) *logrus.Entry { func ContextFields(ctx context.Context) Fields { correlationID := correlation.ExtractFromContext(ctx) - return logrus.Fields{correlation.FieldName: correlationID} + fields := logrus.Fields{ + correlation.FieldName: correlationID, + } + + if path := request_context.ExtractFromContext(ctx, request_context.RequestPath); path != "" { + fields[request_context.RequestPath] = path + } + + return fields } diff --git a/log/logger_test.go b/log/logger_test.go index 3c72af6bfee71044597fbef2355dfc8e70d10d91..a5df50b821a43f656897e70afb164297f2b3b1ee 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/request_context" ) func TestContextLogger(t *testing.T) { @@ -49,26 +50,30 @@ func TestWithContextFields(t *testing.T) { tests := []struct { name string correlationID string + requestPath string fields Fields matchRegExp string }{ { name: "none", correlationID: "", + requestPath: "", fields: nil, matchRegExp: `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*level=info msg=Hello correlation_id=\n$`, }, { name: "single", correlationID: "123456", + requestPath: "", fields: Fields{"field": "value"}, matchRegExp: `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*level=info msg=Hello correlation_id=123456 field=value\n$`, }, { name: "multiple", correlationID: "123456", + requestPath: "/path/to/resource", fields: Fields{"field": "value", "field2": "value2"}, - matchRegExp: `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*level=info msg=Hello correlation_id=123456 field=value field2=value2\n$`, + matchRegExp: `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*level=info msg=Hello correlation_id=123456 field=value field2=value2 path=/path/to/resource\n$`, }, } @@ -83,6 +88,9 @@ func TestWithContextFields(t *testing.T) { if tt.correlationID != "" { ctx = correlation.ContextWithCorrelation(ctx, tt.correlationID) } + if tt.requestPath != "" { + ctx = request_context.WithContextField(ctx, request_context.RequestPath, tt.requestPath) + } WithContextFields(ctx, tt.fields).Info("Hello") require.Regexp(t, tt.matchRegExp, buf.String()) diff --git a/request_context/context_fields.go b/request_context/context_fields.go new file mode 100644 index 0000000000000000000000000000000000000000..25e0156efcf96a51673dfbd9d2447b4fe305a8f8 --- /dev/null +++ b/request_context/context_fields.go @@ -0,0 +1,46 @@ +package request_context + +import ( + "context" + "net/http" +) + +type keyType string + +type ContextField struct { + Key string + Value string +} + +type FieldProvider func(r *http.Request) []ContextField + +const ( + RequestPath = "path" +) + +func WithContextFieldsMiddleware(next http.Handler, provideFields FieldProvider) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + fields := provideFields(r) + for _, f := range fields { + ctx = WithContextField(ctx, f.Key, f.Value) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func WithContextField(ctx context.Context, key string, value string) context.Context { + return context.WithValue(ctx, keyType(key), value) +} + +func ExtractFromContext(ctx context.Context, key string) string { + rawValue := ctx.Value(keyType(key)) + strValue, ok := rawValue.(string) + if !ok { + return "" + } + + return strValue +} diff --git a/request_context/context_fields_test.go b/request_context/context_fields_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a3cb5ec88e779fb66426e617667950a55371877c --- /dev/null +++ b/request_context/context_fields_test.go @@ -0,0 +1,30 @@ +package request_context + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExtractFromContextReturnsFieldIfExists(t *testing.T) { + ctx := WithContextField(context.TODO(), "a", "42") + + value := ExtractFromContext(ctx, "a") + + require.Equal(t, "42", value) +} + +func TestExtractFromContextReturnsEmptyStringIfNotString(t *testing.T) { + ctx := context.WithValue(context.TODO(), keyType("a"), 42) + + value := ExtractFromContext(ctx, "a") + + require.Equal(t, "", value) +} + +func TestExtractFromContextReturnsEmptyStringIfNotExists(t *testing.T) { + value := ExtractFromContext(context.TODO(), "a") + + require.Equal(t, "", value) +} diff --git a/request_context/examples_test.go b/request_context/examples_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d609366c50002eff6e53a4386c242bb128725fbd --- /dev/null +++ b/request_context/examples_test.go @@ -0,0 +1,45 @@ +package request_context_test + +import ( + "fmt" + "log" + "net/http" + + "gitlab.com/gitlab-org/labkit/request_context" +) + +func ExampleWithContextFieldsMiddleware() { + go func() { + http.ListenAndServe(":8080", request_context.WithContextFieldsMiddleware( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := request_context.ExtractFromContext(r.Context(), "a") + b := request_context.ExtractFromContext(r.Context(), "b") + fmt.Println("a:", a) + fmt.Println("b:", b) + }), func(r *http.Request) []request_context.ContextField { + return []request_context.ContextField{ + {Key: "a", Value: "1"}, + {Key: "b", Value: "2"}, + } + }, + )) + }() + + httpClient := &http.Client{ + Transport: http.DefaultTransport, + } + + request, err := http.NewRequest("GET", "http://localhost:8080", nil) + if err != nil { + log.Fatalf("unable to send request: %v", err) + } + + _, err = httpClient.Do(request) + if err != nil { + log.Fatalf("unable to read response: %v", err) + } + + // Output: + // a: 1 + // b: 2 +}