diff --git a/http_server.go b/http_server.go new file mode 100644 index 0000000000000000000000000000000000000000..8fadb97ba442efaafaa821ac4dbf56c83b3e1d0c --- /dev/null +++ b/http_server.go @@ -0,0 +1,389 @@ +package labkit + +import ( + "context" + "errors" + "net" + "net/http" + "os" + "runtime/debug" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/metrics" + "gitlab.com/gitlab-org/labkit/monitoring" + "gitlab.com/gitlab-org/labkit/tracing" +) + +type httpHandlerOptions struct { + withoutMetrics bool + withoutTracing bool + withoutCorrelationID bool +} + +// HTTPServer represents an HTTP server with built-in monitoring, metrics, tracing, +// and correlation ID support. It embeds the standard http.Server and provides +// additional middleware capabilities for observability. +type HTTPServer struct { + *http.Server + + // Server options + serviceName string + monitoringAddress string + metricRegisterer prometheus.Registerer + httpListener net.Listener + monitoringListener net.Listener + + // Handler options + handlerOptions httpHandlerOptions + + metricHandler metrics.HandlerFactory +} + +// NewHTTPServer creates a new HTTPServer with the provided options. +// By default, it listens on :8080 (or the PORT environment variable if set), +// with monitoring on :8082. The service name defaults to the RUNWAY_SERVICE_ID +// environment variable. Options can be used to customize these defaults. +func NewHTTPServer(opts ...HTTPServerOption) *HTTPServer { + // 1. Create HTTPServer with constant default values. + s := &HTTPServer{ + Server: &http.Server{ + Addr: ":8080", + Handler: http.NewServeMux(), + ReadHeaderTimeout: 10 * time.Second, + }, + + serviceName: "unnamed-service", + monitoringAddress: ":8082", + metricRegisterer: prometheus.DefaultRegisterer, + } + + // 2. Populate fields based on environment variables. + s.initFromEnv() + + // 3. Apply explicitly provided options. + for _, opt := range opts { + opt.applyHTTP(s) + } + + var metricFactoryOptions []metrics.HandlerFactoryOption + + if s.metricRegisterer != nil { + metricFactoryOptions = append(metricFactoryOptions, metrics.WithPrometheusRegisterer(s.metricRegisterer)) + } + + s.metricHandler = metrics.NewHandlerFactory(metricFactoryOptions...) + + return s +} + +// Handle registers the handler for the given pattern. The handler will be wrapped +// with middleware for metrics, tracing, and correlation IDs unless explicitly +// disabled via HTTPHandlerOption. The pattern follows the same rules as http.ServeMux.Handle. +func (s *HTTPServer) Handle(pattern string, handler Handler, opts ...HTTPHandlerOption) { + handlerOptions := s.handlerOptions + for _, opt := range opts { + opt(&handlerOptions) + } + + var httpHandler http.Handler = contextHandler{hndl: handler} + + if !handlerOptions.withoutMetrics { + httpHandler = s.metricHandler(httpHandler) + } + + if !handlerOptions.withoutTracing { + httpHandler = tracing.Handler(httpHandler) + } + + if !handlerOptions.withoutCorrelationID { + httpHandler = correlation.InjectCorrelationID(httpHandler, + correlation.WithPropagation(), + correlation.WithSetResponseHeader(), + ) + } + + mux := s.Server.Handler.(*http.ServeMux) + mux.Handle(pattern, httpHandler) +} + +// HandleFunc registers the handler function for the given pattern. +// This is a convenience method that wraps the function as a Handler. +// The pattern follows the same rules as http.ServeMux.HandleFunc. +func (s *HTTPServer) HandleFunc(pattern string, handler func(context.Context, http.ResponseWriter, *http.Request) error, opts ...HTTPHandlerOption) { + s.Handle(pattern, HandlerFunc(handler), opts...) +} + +// Start starts both the HTTP server and the monitoring endpoint. +// It returns when both servers have stopped. If either server fails to start +// or encounters an error during operation, both servers are shut down and +// the errors are returned as a joined error. +func (s *HTTPServer) Start(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + if err := s.startHTTP(ctx); err != nil { + log.WithError(err).Error("startHTTP failed") + errCh <- err + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + if err := s.startMonitoring(ctx); err != nil { + log.WithError(err).Error("startMonitoring failed") + errCh <- err + } + }() + + go func() { + wg.Wait() + close(errCh) + }() + + var errs error + + for err := range errCh { + if err == nil { + continue + } + + // If any Goroutine fails, signal all other Goroutines to also return. + cancel() + + errs = errors.Join(errs, err) + } + + return errs +} + +func (s *HTTPServer) startHTTP(ctx context.Context) error { + if s.httpListener == nil { + var err error + s.httpListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", s.Addr) + if err != nil { + return err + } + } + + go func() { + <-ctx.Done() + + if err := s.Server.Shutdown(ctx); err != nil { + log.WithError(err).Error("failed to shut down HTTP server") + } + }() + + err := s.Server.Serve(s.httpListener) + if !errors.Is(err, http.ErrServerClosed) { + return err + } + + log.Info("HTTP server gracefully shut down") + return nil +} + +func (s *HTTPServer) startMonitoring(ctx context.Context) error { + var opts []monitoring.Option + + if s.monitoringListener == nil { + var err error + s.monitoringListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", s.monitoringAddress) + if err != nil { + return err + } + } + + opts = append(opts, monitoring.WithListener(s.monitoringListener)) + + if s.metricRegisterer != nil { + opts = append(opts, monitoring.WithPrometheusRegisterer(s.metricRegisterer)) + } + + if s.handlerOptions.withoutMetrics { + opts = append(opts, monitoring.WithoutMetrics()) + } + + if buildInfo, ok := debug.ReadBuildInfo(); ok { + opts = append(opts, monitoring.WithGoBuildInformation(buildInfo)) + } + + server := monitoring.NewServer(opts...) + + go func() { + <-ctx.Done() + + if err := server.Shutdown(ctx); err != nil { + log.WithError(err).Error("failed to shut down monitoring endpoint") + } + }() + + err := server.Start() + if !errors.Is(err, http.ErrServerClosed) { + return err + } + + log.Info("monitoring endpoint gracefully shut down") + return nil +} + +// initFromEnv sets HTTPServer options from environment variables. +func (s *HTTPServer) initFromEnv() { + if port, ok := os.LookupEnv("PORT"); ok { + s.Server.Addr = ":" + port + } + + if port, ok := os.LookupEnv("OTEL_EXPORTER_PROMETHEUS_PORT"); ok { + s.monitoringAddress = os.Getenv("OTEL_EXPORTER_PROMETHEUS_HOST") + ":" + port + } + + var ( + gitlabServiceName = os.Getenv("GITLAB_SERVICE_NAME") + otelServiceName = os.Getenv("OTEL_SERVICE_NAME") + ) + + switch { + case gitlabServiceName != "": + s.serviceName = gitlabServiceName + case otelServiceName != "": + s.serviceName = otelServiceName + } +} + +// Handler is an interface for HTTP handlers that accept a context and return an error. +// This allows handlers to propagate errors up to the server, which will handle them +// appropriately by logging and returning an HTTP 500 status. +type Handler interface { + ServeHTTP(ctx context.Context, w http.ResponseWriter, req *http.Request) error +} + +// HandlerFunc is an adapter to allow the use of ordinary functions as Handler instances. +// If f is a function with the appropriate signature, HandlerFunc(f) is a Handler that calls f. +type HandlerFunc func(context.Context, http.ResponseWriter, *http.Request) error + +// ServeHTTP calls f(ctx, w, req). +func (f HandlerFunc) ServeHTTP(ctx context.Context, w http.ResponseWriter, req *http.Request) error { + return f(ctx, w, req) +} + +type contextHandler struct { + hndl Handler +} + +func (c contextHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if err := c.hndl.ServeHTTP(req.Context(), w, req); err != nil { + log.WithError(err).Error("handler returned an error") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// HTTPServerOption is an interface for options that can be applied to an HTTPServer. +type HTTPServerOption interface { + applyHTTP(srv *HTTPServer) +} + +// HTTPServerOptionFunc is a function type that implements HTTPServerOption. +type HTTPServerOptionFunc func(*HTTPServer) + +func (f HTTPServerOptionFunc) applyHTTP(srv *HTTPServer) { + f(srv) +} + +// ServerOption represents an option that can be applied to both HTTP and gRPC servers. +// Currently only HTTPServerOption is implemented. +type ServerOption struct { + HTTPServerOption + // GRPCServerOption +} + +// WithAddress sets the listening address for the server. +// The address should be in the format "host:port" or ":port". +func WithAddress(addr string) ServerOption { + return ServerOption{ + HTTPServerOption: HTTPServerOptionFunc(func(s *HTTPServer) { + s.Addr = addr + }), + } +} + +// WithListener sets the listener for the server. +func WithListener(l net.Listener) ServerOption { + return ServerOption{ + HTTPServerOption: HTTPServerOptionFunc(func(s *HTTPServer) { + s.httpListener = l + }), + } +} + +// WithMonitoringAddress sets the listening address for the monitoring endpoint. +// The monitoring endpoint exposes metrics and health checks. +// The address should be in the format "host:port" or ":port". +func WithMonitoringAddress(addr string) HTTPServerOption { + return HTTPServerOptionFunc(func(s *HTTPServer) { + s.monitoringAddress = addr + }) +} + +// WithServiceName sets the service name for the server. +// This name is used in metrics and tracing to identify the service. +func WithServiceName(name string) ServerOption { + return ServerOption{ + HTTPServerOption: HTTPServerOptionFunc(func(s *HTTPServer) { + s.serviceName = name + }), + } +} + +// WithPrometheusRegisterer sets the prometheus registerer to use when creating the metrics endpoint. +// If not provided, prometheus.DefaultRegisterer will be used. +func WithPrometheusRegisterer(r prometheus.Registerer) ServerOption { + return ServerOption{ + HTTPServerOption: HTTPServerOptionFunc(func(s *HTTPServer) { + s.metricRegisterer = r + }), + } +} + +// HTTPHandlerOption is a function type for options that can be applied to individual handlers. +// Each HTTPHandlerOption is also an HTTPServerOption, which sets the default for HTTP handlers of the server. +type HTTPHandlerOption func(opts *httpHandlerOptions) + +func (o HTTPHandlerOption) applyHTTP(srv *HTTPServer) { + o(&srv.handlerOptions) +} + +// WithMetrics enables or disables metrics collection for handlers. +// When enabled (default), handlers will have request metrics collected automatically. +func WithMetrics(v bool) HTTPHandlerOption { + return func(opts *httpHandlerOptions) { + opts.withoutMetrics = !v + } +} + +// WithTracing enables or disables distributed tracing for handlers. +// When enabled (default), handlers will participate in distributed traces. +func WithTracing(v bool) HTTPHandlerOption { + return func(opts *httpHandlerOptions) { + opts.withoutTracing = !v + } +} + +// WithCorrelationID enables or disables correlation ID injection for handlers. +// When enabled (default), handlers will propagate and set correlation IDs in requests and responses. +func WithCorrelationID(v bool) HTTPHandlerOption { + return func(opts *httpHandlerOptions) { + opts.withoutCorrelationID = !v + } +} diff --git a/http_server_test.go b/http_server_test.go new file mode 100644 index 0000000000000000000000000000000000000000..732552607e5073af7fbc09691a587a1ca2dd40dd --- /dev/null +++ b/http_server_test.go @@ -0,0 +1,217 @@ +package labkit_test + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/labkit" +) + +func ExampleHTTPServer() { + ctx := context.Background() + + srv := labkit.NewHTTPServer( + labkit.WithAddress(":8080"), + labkit.WithMetrics(true), + ) + + srv.HandleFunc("/", func(_ context.Context, w http.ResponseWriter, _ *http.Request) error { + w.WriteHeader(http.StatusOK) + + fmt.Fprintln(w, "Hello, World!") + + return nil + }, labkit.WithCorrelationID(false)) + + err := srv.Start(ctx) + if !errors.Is(err, http.ErrServerClosed) { + panic(err) + } +} + +func TestHTTPServer(t *testing.T) { + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // Create an HTTP Server with a dummy endpoint + server := labkit.NewHTTPServer( + // Use :0 to get a random available port. This allows tests to run in parallel. + labkit.WithListener(listener), + labkit.WithMonitoringAddress("localhost:0"), + labkit.WithPrometheusRegisterer(prometheus.NewRegistry()), + ) + + // Add a dummy endpoint that serves "hello world" + server.HandleFunc("/hello", func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + _, err := w.Write([]byte("hello world")) + return err + }) + + // Channel to capture any error from the server + serverErrCh := make(chan error, 1) + + // Start the HTTP Server in a Goroutine + go func() { + serverErrCh <- server.Start(ctx) + }() + + // Give the server a moment to start up + time.Sleep(10 * time.Millisecond) + + // Try calling the dummy endpoint until it succeeds + // The test will automatically abort when its timeout is reached + client := &http.Client{Timeout: 1 * time.Second} + + var lastErr error + maxAttempts := 50 // Maximum attempts to prevent infinite loop + for i := range maxAttempts { + resp, err := client.Get(fmt.Sprintf("http://%s/hello", listener.Addr())) + if err != nil { + lastErr = err + time.Sleep(100 * time.Millisecond) + continue + } + + // Read the response body + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + // Verify the response + if string(body) != "hello world" { + t.Fatalf("Unexpected response: got %q, want %q", string(body), "hello world") + } + + // Verify status code + if resp.StatusCode != http.StatusOK { + t.Fatalf("Unexpected status code: got %d, want %d", resp.StatusCode, http.StatusOK) + } + + // Endpoint successfully reached + t.Logf("Successfully reached the dummy endpoint on attempt %d", i+1) + lastErr = nil + break + } + + if lastErr != nil { + t.Fatalf("Failed to reach the endpoint after %d attempts: %v", maxAttempts, lastErr) + } + + // Once the endpoint was reached, cancel the context + cancel() + + // Wait for the server goroutine to shut down and check for errors + select { + case err := <-serverErrCh: + // Ensure that the HTTP Server Goroutine shuts down without error + if err != nil { + t.Fatalf("Server returned an error on shutdown: %v", err) + } + t.Log("Server shut down gracefully without error") + case <-time.After(5 * time.Second): + t.Fatal("Server did not shut down within the expected time") + } +} + +func TestHTTPServerWithHandlerError(t *testing.T) { + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // Create an HTTP Server + server := labkit.NewHTTPServer( + labkit.WithListener(listener), + labkit.WithMonitoringAddress("localhost:0"), + labkit.WithPrometheusRegisterer(prometheus.NewRegistry()), + ) + + // Add an endpoint that returns an error + server.HandleFunc("/error", func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + return http.ErrBodyNotAllowed + }) + + // Start the server + serverErrCh := make(chan error, 1) + go func() { + err := server.Start(ctx) + serverErrCh <- err + }() + + // Give the server a moment to start up + time.Sleep(100 * time.Millisecond) + + // Try calling the error endpoint + client := &http.Client{Timeout: 1 * time.Second} + + var connected bool + maxAttempts := 50 + for range maxAttempts { + resp, err := client.Get(fmt.Sprintf("http://%s/error", listener.Addr())) + if err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + + // Should get a 500 Internal Server Error + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("Expected status code 500, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + // The error message should be in the body + if len(body) == 0 { + t.Fatal("Expected error message in response body") + } + + connected = true + t.Log("Successfully verified error handling") + break + } + + if !connected { + t.Fatal("Failed to connect to the server") + } + + // Shutdown the server + cancel() + + // Wait for graceful shutdown + select { + case err := <-serverErrCh: + if err != nil { + t.Fatalf("Server returned an error on shutdown: %v", err) + } + t.Log("Server shut down gracefully after handling errors") + case <-time.After(5 * time.Second): + t.Fatal("Server did not shut down within the expected time") + } +} diff --git a/metrics/handler.go b/metrics/handler.go index 5af354330cc24efc03ae91ce8d646e34e7131021..0184ba15adc7559f6fc0880f416ea85da87d4343 100644 --- a/metrics/handler.go +++ b/metrics/handler.go @@ -92,12 +92,17 @@ func NewHandlerFactory(opts ...HandlerFactoryOption) HandlerFactory { ) ) - prometheus.MustRegister(httpInFlightRequests) - prometheus.MustRegister(httpRequestsTotal) - prometheus.MustRegister(httpRequestDurationSeconds) - prometheus.MustRegister(httpRequestSizeBytes) - prometheus.MustRegister(httpResponseSizeBytes) - prometheus.MustRegister(httpTimeToWriteHeaderSeconds) + r := prometheus.DefaultRegisterer + if factoryConfig.metricRegisterer != nil { + r = factoryConfig.metricRegisterer + } + + r.MustRegister(httpInFlightRequests) + r.MustRegister(httpRequestsTotal) + r.MustRegister(httpRequestDurationSeconds) + r.MustRegister(httpRequestSizeBytes) + r.MustRegister(httpResponseSizeBytes) + r.MustRegister(httpTimeToWriteHeaderSeconds) return func(next http.Handler, handlerOpts ...HandlerOption) http.Handler { handlerConfig, promOpts := applyHandlerOptions(handlerOpts) diff --git a/metrics/handler_factory_options.go b/metrics/handler_factory_options.go index 726d272d8a379ae3d5ed0d4b8945acaed7f7e7cc..972a004c093dd684ceb2b24a15d20f50c4bf2b42 100644 --- a/metrics/handler_factory_options.go +++ b/metrics/handler_factory_options.go @@ -1,5 +1,7 @@ package metrics +import "github.com/prometheus/client_golang/prometheus" + type handlerFactoryConfig struct { namespace string subsystem string @@ -7,6 +9,7 @@ type handlerFactoryConfig struct { timeToWriteHeaderDurationBuckets []float64 byteSizeBuckets []float64 labels []string + metricRegisterer prometheus.Registerer } // HandlerFactoryOption is used to pass options in NewHandlerFactory. @@ -91,3 +94,11 @@ func WithByteSizeBuckets(buckets []float64) HandlerFactoryOption { config.byteSizeBuckets = buckets } } + +// WithPrometheusRegisterer will configure the prometheus registerer to use when creating +// the metrics handler. If not provided, prometheus.DefaultRegisterer will be used. +func WithPrometheusRegisterer(r prometheus.Registerer) HandlerFactoryOption { + return func(config *handlerFactoryConfig) { + config.metricRegisterer = r + } +} diff --git a/monitoring/start.go b/monitoring/start.go index 2a572474ed9d314f31a2c153396408c47d9e58fc..156421f0994121305b27289496a98f9637fab2f9 100644 --- a/monitoring/start.go +++ b/monitoring/start.go @@ -1,32 +1,58 @@ package monitoring import ( + "context" "log" "net/http/pprof" "github.com/prometheus/client_golang/prometheus/promhttp" ) -// Start will start a new monitoring service listening on the address +// Start starts a new monitoring service. +// +// Deprecated: use `NewServer(options).Start()` instead. +func Start(options ...Option) error { + s := NewServer(options...) + + return s.Start() +} + +// Server struct wraps the monitoring service and its configuration. +type Server struct { + config optionsConfig +} + +// NewServer initializes a new Server struct with the given options. +func NewServer(opts ...Option) *Server { + s := Server{ + config: defaultOptions(), + } + + for _, opt := range opts { + opt(&s.config) + } + + return &s +} + +// Start starts a new monitoring service listening on the address // configured through the option arguments. Additionally, it'll start // a Continuous Profiler configured through environment variables // (see more at https://gitlab.com/gitlab-org/labkit/-/blob/master/monitoring/doc.go). // // If `WithListenerAddress` option is provided, Start will block or return a non-nil error, // similar to `http.ListenAndServe` (for instance). -func Start(options ...Option) error { - config := applyOptions(options) - - listener, err := config.listenerFactory() +func (s *Server) Start() error { + listener, err := s.config.listenerFactory() if err != nil { return err } // Initialize the Continuous Profiler. - if !config.continuousProfilingDisabled { + if !s.config.continuousProfilingDisabled { profOpts := profilerOpts{ - ServiceVersion: config.version, - CredentialsFile: config.profilerCredentialsFile, + ServiceVersion: s.config.version, + CredentialsFile: s.config.profilerCredentialsFile, } initProfiler(profOpts) } @@ -36,12 +62,16 @@ func Start(options ...Option) error { return nil } - metricsHandler(config) - pprofHandlers(config) + metricsHandler(s.config) + pprofHandlers(s.config) - config.server.Handler = config.serveMux + s.config.server.Handler = s.config.serveMux + + return s.config.server.Serve(listener) +} - return config.server.Serve(listener) +func (s *Server) Shutdown(ctx context.Context) error { + return s.config.server.Shutdown(ctx) } func metricsHandler(cfg optionsConfig) {