diff --git a/httpcl/doc.go b/httpcl/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..5c75511977ceedc67a24a8e7965f48542e69409d --- /dev/null +++ b/httpcl/doc.go @@ -0,0 +1,9 @@ +// Package httpcl provides a drop-in replacement for the standard library's net/http +// package. +// +// The intention fo this package is that we'll be able to enrich the standard +// client experience with a host of GitLab-specific standards that developers +// will benefit from without necessarily having to think about these things +// themselves. + +package httpcl diff --git a/httpcl/http.go b/httpcl/http.go new file mode 100644 index 0000000000000000000000000000000000000000..203aa769d82d875c7d86225ce192d25cd512cd31 --- /dev/null +++ b/httpcl/http.go @@ -0,0 +1,64 @@ +package httpcl + +import ( + "net/http" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/tracing" +) + +const ( + RequestID = "X-Request-Id" + CorrelationID = "correlation_id" +) + +type Client struct { + HttpClient *http.Client +} + +type Option func(c *Client) + +// NewClient - a simple constructor that takes in +// optional functions that allow you to override +// various aspects of the client. +func NewClient(opts ...Option) *Client { + client := &Client{ + &http.Client{ + Transport: http.DefaultTransport, + }, + } + + // Once we've instantiated the client, we then + // loop through the variadic opts argument and + // ensure we're applying them. + for _, o := range opts { + o(client) + } + + // We need to ensure that the instrumented RoundTripper is always + // applied so that the use of a custom WithTransport setup will not + // break our expectations. + client.HttpClient.Transport = correlation.NewInstrumentedRoundTripper( + tracing.NewRoundTripper(client.HttpClient.Transport), + ) + + return client +} + +// WithTransport - an optional function that allows you +// to override the default http.Transport used within the +// client. +func WithTransport(t *http.Transport) Option { + return func(c *Client) { + c.HttpClient.Transport = t + } +} + +// WithTimeout - an optional function that allows you to +// override the default timeout used within the client. +func WithTimeout(timeout time.Duration) Option { + return func(c *Client) { + c.HttpClient.Timeout = timeout + } +} diff --git a/httpcl/http_test.go b/httpcl/http_test.go new file mode 100644 index 0000000000000000000000000000000000000000..58778b24b0bf8a4179e77b4a161a15ee9e4346ec --- /dev/null +++ b/httpcl/http_test.go @@ -0,0 +1,78 @@ +package httpcl_test + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/labkit/httpcl" +) + +func TestCorrelationIDPropagationViaRoundTripper(t *testing.T) { + client := httpcl.NewClient() + tests := []struct { + name string + correlationID string + }{ + { + name: "the correlation_id is successfully retrieved from ctx and appended to outbound request", + correlationID: "some-correlation-id", + }, + { + name: "in situations where the correlation_id is empty, the client still operates", + correlationID: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + require.Equal(t, tt.correlationID, r.Header.Get("correlation_id")) + })) + + req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil) + require.Nil(t, err) + req.Header.Set("correlation_id", tt.correlationID) + + _, err = client.HttpClient.Do(req) + require.Nil(t, err) + }) + } +} + +func TestWithTransport(t *testing.T) { + t.Run("ensure that instrumented round tripper is always applied with a custom http.Transport", + func(t *testing.T) { + client := httpcl.NewClient(httpcl.WithTransport(&http.Transport{})) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + require.Equal(t, "some-correlation-id", r.Header.Get("correlation_id")) + })) + + req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil) + require.Nil(t, err) + req.Header.Set("correlation_id", "some-correlation-id") + + _, err = client.HttpClient.Do(req) + require.Nil(t, err) + }) +} + +func TestTimeout(t *testing.T) { + client := httpcl.NewClient(httpcl.WithTimeout(20 * time.Millisecond)) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) + })) + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + _, err = client.HttpClient.Do(req) + require.Error(t, err) + require.IsType(t, &url.Error{}, err) +}