diff --git a/log/examples_test.go b/log/examples_test.go index 48d5430288c7483ef5b3a7d6f0a04f80f7116cf7..eaf4eeccd8b9dc83b1287e542ff329dbeb0df6d3 100644 --- a/log/examples_test.go +++ b/log/examples_test.go @@ -3,7 +3,9 @@ package log_test import ( "context" "fmt" + "log/slog" "net/http" + "os" "gitlab.com/gitlab-org/labkit/log" "google.golang.org/grpc" @@ -17,6 +19,7 @@ func ExampleInitialize() { log.WithLogLevel("info"), // Use info level log.WithOutputName("stderr"), // Output to stderr ) + //nolint defer closer.Close() log.WithError(err).Info("This has been logged") @@ -29,6 +32,7 @@ func ExampleInitialize() { log.WithLogLevel("debug"), // Use debug level log.WithOutputName("/var/log/labkit.log"), // Output to `/var/log/labkit.log` ) + //nolint defer closer2.Close() log.WithError(err).Info("This has been logged") @@ -52,6 +56,26 @@ func ExampleAccessLogger() { ) } +func ExampleAccessLogger_slog() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello world") + }) + + // This func is used by WithExtraFields to add additional fields to the logger + extraFieldGenerator := func(r *http.Request) log.Fields { + return log.Fields{"header": r.Header.Get("X-Magical-Header")} + } + + slogger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + http.ListenAndServe(":8080", + log.AccessLogger(handler, + log.WithSlogAccessLogger(slogger), + log.WithExtraFields(extraFieldGenerator), // Include custom fields into the logs + log.WithFieldsExcluded(log.HTTPRequestReferrer|log.HTTPUserAgent), // Exclude user-agent and referrer fields from the logs + ), + ) +} + func ExampleUnaryServerInterceptor() { // This func is used by WithGrpcExtraFields to add additional fields to the logger extraFieldGenerator := func(ctx context.Context) log.Fields { diff --git a/log/http_access_logger.go b/log/http_access_logger.go index e054ba030e900f4ebac2365fd1767ef1cac168a2..305aba1942104a05b64350285ab37dbe3c4c57cb 100644 --- a/log/http_access_logger.go +++ b/log/http_access_logger.go @@ -2,6 +2,7 @@ package log import ( "bufio" + "log/slog" "net" "net/http" "sync/atomic" @@ -191,7 +192,24 @@ func (l *loggingResponseWriter) accessLogFields(r *http.Request) logrus.Fields { return fields } +// toSlogAttrs - a helper function that allows us to lean on the original +// logrus-based implementation and easily convert those fields into +// slog.Attr values. +// This helps to ensure consistency across both logger setups. +func (l *loggingResponseWriter) toSlogAttrs(fields logrus.Fields) []any { + slogFields := []any{} + + for k, v := range fields { + slogFields = append(slogFields, slog.Attr{Key: k, Value: slog.AnyValue(v)}) + } + + return slogFields +} + func (l *loggingResponseWriter) requestFinished(r *http.Request) { + if l.config.slogger != nil { + l.config.slogger.Info("access", l.toSlogAttrs(l.accessLogFields(r))...) + } l.config.logger.WithFields(l.accessLogFields(r)).Info("access") } diff --git a/log/http_access_logger_go1_20_test.go b/log/http_access_logger_go1_20_test.go deleted file mode 100644 index 4f8702ed12a9f15bd9d72fb58c79270808bc88fb..0000000000000000000000000000000000000000 --- a/log/http_access_logger_go1_20_test.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build go1.20 - -package log - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestAccessLoggerFlushable(t *testing.T) { - rw := httptest.NewRecorder() - lrw := &loggingResponseWriter{rw: rw} - rc := http.NewResponseController(lrw) - - err := rc.Flush() - require.NoError(t, err, "the underlying response writer is not flushable") -} diff --git a/log/http_access_logger_options.go b/log/http_access_logger_options.go index f703db118e849d0a037b96a7e0e9ff1aaabe9bce..890ffd6159a5750812aef632bbfe92df7f495bc5 100644 --- a/log/http_access_logger_options.go +++ b/log/http_access_logger_options.go @@ -1,6 +1,7 @@ package log import ( + "log/slog" "net/http" "github.com/sirupsen/logrus" @@ -15,6 +16,7 @@ type XFFAllowedFunc func(ip string) bool // The configuration for an access logger. type accessLoggerConfig struct { logger *logrus.Logger + slogger *slog.Logger extraFields ExtraFieldsGeneratorFunc fields AccessLogField xffAllowed XFFAllowedFunc @@ -63,6 +65,12 @@ func WithAccessLogger(logger *logrus.Logger) AccessLoggerOption { } } +func WithSlogAccessLogger(logger *slog.Logger) AccessLoggerOption { + return func(config *accessLoggerConfig) { + config.slogger = logger + } +} + // WithXFFAllowed decides whether to trust X-Forwarded-For headers. func WithXFFAllowed(xffAllowed XFFAllowedFunc) AccessLoggerOption { return func(config *accessLoggerConfig) { diff --git a/log/http_access_logger_slog_test.go b/log/http_access_logger_slog_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c46e7927327486ecdd09645208ff1994ed8691eb --- /dev/null +++ b/log/http_access_logger_slog_test.go @@ -0,0 +1,332 @@ +package log + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSlogAccessLogger(t *testing.T) { + tests := []struct { + name string + urlSuffix string + body string + logMatchers []string + options []AccessLoggerOption + requestHeaders map[string]string + responseHeaders map[string]string + handler http.Handler + }{ + { + name: "trivial", + body: "hello", + logMatchers: []string{ + `\btime=\"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}`, + `\blevel=info`, + `\bmsg=access`, + `\bcorrelation_id=\s+`, + `\bduration_ms=\d+`, + `\bhost="127.0.0.1:\d+"`, + `\bmethod=GET`, + `\bproto=HTTP/1.1`, + `\breferrer=\s+`, + `\bremote_addr="127.0.0.1:\d+"`, + `\bremote_ip=127.0.0.1`, + `\bstatus=200`, + `\bsystem=http`, + `\buri=/`, + `\buser_agent=Go`, + `\bwritten_bytes=5`, + `\bcontent_type=\s+`, + `\bread_bytes=96\b`, + }, + }, + { + name: "senstitive_params", + urlSuffix: "?password=123456", + logMatchers: []string{ + `\buri=\"/\?password=\[FILTERED\]\"`, + }, + }, + { + name: "extra_fields", + options: []AccessLoggerOption{ + WithExtraFields(func(r *http.Request) Fields { + return Fields{"testfield": "testvalue"} + }), + }, + logMatchers: []string{ + `\btestfield=testvalue\b`, + }, + }, + { + name: "excluded_fields", + options: []AccessLoggerOption{ + WithFieldsExcluded(defaultEnabledFields), + }, + logMatchers: []string{ + `^time=\"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.*level=info msg=access\n$`, + }, + }, + { + name: "x_forwarded_for", + requestHeaders: map[string]string{ + "X-Forwarded-For": "196.7.0.238", + }, + logMatchers: []string{ + `\bremote_ip=196.7.0.238\b`, + }, + }, + { + name: "x_forwarded_for_with_trusted_proxies", + requestHeaders: map[string]string{ + "X-Forwarded-For": "196.7.0.238, 197.7.8.9", + }, + options: []AccessLoggerOption{ + WithTrustedProxies([]string{"196.7.8.0/24"}), + }, + logMatchers: []string{ + `\bremote_ip=196.7.0.238\b`, + }, + }, + { + name: "x_forwarded_for_incorrect", + requestHeaders: map[string]string{ + "X-Forwarded-For": "gitlab.com", + }, + logMatchers: []string{ + `\bremote_ip=127.0.0.1\b`, + }, + }, + { + name: "x_forwarded_for_incorrect", + requestHeaders: map[string]string{ + "X-Forwarded-For": "196.7.238, 197.7.8.9", + }, + logMatchers: []string{ + `\bremote_ip=197.7.8.9\b`, + }, + }, + { + name: "x_forwarded_for_not_allowed", + options: []AccessLoggerOption{ + WithXFFAllowed(func(sip string) bool { return false }), + }, + requestHeaders: map[string]string{ + "X-Forwarded-For": "196.7.0.238", + }, + logMatchers: []string{ + `\bremote_ip=127.0.0.1\b`, + }, + }, + { + name: "empty body", + logMatchers: []string{ + `\bstatus=200\b`, + }, + }, + { + name: "content type", + body: "hello", + responseHeaders: map[string]string{ + "Content-Type": "text/plain", + }, + logMatchers: []string{ + `\bcontent_type=text/plain`, + }, + }, + { + name: "time to the first byte", + body: "ok", + logMatchers: []string{ + `\bttfb_ms=\d+`, + }, + }, + { + name: "time to the first byte, with long delay", + body: "yo", + logMatchers: []string{ + // we expect the delay to be around `10ms` + `\bttfb_ms=1\d\b`, + }, + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Millisecond) + fmt.Fprint(w, "yo") + }), + }, + { + name: "time to the first byte, with a slow data transfer", + body: "yo", + logMatchers: []string{ + // we expect the delay to be lower than `10ms` + `\bttfb_ms=\d\b`, + }, + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + time.Sleep(10 * time.Millisecond) + fmt.Fprint(w, "yo") + }), + }, + { + name: "time to the first byte, with a long processing and slow data transfer", + body: "yo", + logMatchers: []string{ + // we expect the delay to be around `10ms` + `\bttfb_ms=1\d\b`, + }, + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Millisecond) + w.WriteHeader(http.StatusInternalServerError) + time.Sleep(20 * time.Millisecond) + fmt.Fprint(w, "yo") + }), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + + slogger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + _, err := Initialize(WithSlogLogger(slogger), WithWriter(buf)) + require.NoError(t, err) + + handler := tt.handler + + if handler == nil { + handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range tt.responseHeaders { + w.Header().Add(k, v) + } + // This if-statement provides test coverage for the case where the + // handler never calls Write or WriteHeader. + if len(tt.body) > 0 { + fmt.Fprint(w, tt.body) + } + }) + } + + opts := []AccessLoggerOption{WithSlogAccessLogger(slogger)} + opts = append(opts, tt.options...) + handler = AccessLogger(handler, opts...) + + ts := httptest.NewTLSServer(handler) + defer ts.Close() + + client := ts.Client() + req, err := http.NewRequest(http.MethodGet, ts.URL+tt.urlSuffix, nil) + require.NoError(t, err) + + for k, v := range tt.requestHeaders { + req.Header.Add(k, v) + } + + res, err := client.Do(req) + require.NoError(t, err) + + gotBody, err := io.ReadAll(res.Body) + res.Body.Close() + + require.NoError(t, err) + require.Equal(t, tt.body, string(gotBody)) + + logString := buf.String() + for _, v := range tt.logMatchers { + require.Regexp(t, v, logString) + } + }) + } +} + +func TestSlogAccessLoggerPanic(t *testing.T) { + buf := &bytes.Buffer{} + + slogger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + _, err := Initialize(WithSlogLogger(slogger), WithWriter(buf)) + require.NoError(t, err) + + var handler http.Handler + handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("see how the logger handles a panic") + }) + + opts := []AccessLoggerOption{WithSlogAccessLogger(slogger)} + handler = AccessLogger(handler, opts...) + + ts := httptest.NewTLSServer(handler) + defer ts.Close() + + client := ts.Client() + req, err := http.NewRequest(http.MethodGet, ts.URL+"/", nil) + require.NoError(t, err) + + _, err = client.Do(req) + require.Error(t, err, "panic should cause the request to fail with a closed connection") + + require.Regexp(t, `\bstatus=0\b`, buf.String(), "if the handler panics before writing a response header, the status code is undefined, so we expect code 0") +} + +func TestSlogHandlerWithChunkedTransfer(t *testing.T) { + buf := &bytes.Buffer{} + + slogger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + _, err := Initialize(WithSlogLogger(slogger), WithWriter(buf)) + require.NoError(t, err) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo back the received body + io.Copy(w, r.Body) + }) + + // Create a test server + server := httptest.NewServer(AccessLogger(handler, WithSlogAccessLogger(slogger))) + defer server.Close() + + // Prepare the chunked request body + chunks := []string{ + "First chunk of data\n", + "Second chunk of data\n", + "Third and final chunk", + } + + // Create a pipe to write our chunks + pr, pw := io.Pipe() + + // Start a goroutine to write chunks + go func() { + defer pw.Close() + + for _, chunk := range chunks { + _, err := pw.Write([]byte(chunk)) + require.NoError(t, err) + } + }() + + // Create the request + req, err := http.NewRequest(http.MethodPost, server.URL, pr) + require.NoError(t, err) + + // Set necessary headers for chunked transfer + req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Regexp(t, `\bwritten_bytes=62\b`, buf.String()) + require.Regexp(t, `\bread_bytes=185\b`, buf.String()) +} diff --git a/log/logger_options.go b/log/logger_options.go index aee67b25a586e3b3f8194b6a14a7c22d00e9f0d8..2a366e83641018f9a2b28a6535790207900cc488 100644 --- a/log/logger_options.go +++ b/log/logger_options.go @@ -3,6 +3,7 @@ package log import ( "fmt" "io" + "log/slog" "os" "time" @@ -15,6 +16,7 @@ const ( ) type loggerConfig struct { + slogger *slog.Logger logger *logrus.Logger level logrus.Level formatter logrus.Formatter @@ -160,3 +162,9 @@ func WithLogger(logger *logrus.Logger) LoggerOption { conf.logger = logger } } + +func WithSlogLogger(slogger *slog.Logger) LoggerOption { + return func(conf *loggerConfig) { + conf.slogger = slogger + } +}