From 0966171cb49b57aa0290b6a838b8831052289e5c Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 11 Jan 2021 11:35:45 +0000 Subject: [PATCH 01/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- internal/auth/auth.go | 117 +++++++---- internal/auth/auth_code.go | 147 ++++++++++++++ internal/auth/auth_code_test.go | 99 ++++++++++ internal/auth/auth_test.go | 207 +++++++++++--------- internal/rejectmethods/middleware.go | 31 +++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 12 +- metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 +-- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 109 +++++++++-- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 +++ 16 files changed, 722 insertions(+), 161 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..be22d82df 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -56,7 +56,17 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + timer := time.NewTimer(r.maxRetrievalInterval) + select { + case <-timer.C: + // retry to GetLookup + continue + case <-ctx.Done(): + // when the context is done we stop the timer + timer.Stop() + log.WithError(ctx.Err()).Debug("resolveWithBackoff context done") + break + } } else { break } diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -234,11 +234,10 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -336,11 +335,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +410,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +644,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From 7f5da7e937de0e995d5c3f4b665957834bdd487d Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 11 Jan 2021 11:35:45 +0000 Subject: [PATCH 02/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- internal/source/gitlab/cache/retriever.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..be22d82df 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -56,7 +56,17 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + timer := time.NewTimer(r.maxRetrievalInterval) + select { + case <-timer.C: + // retry to GetLookup + continue + case <-ctx.Done(): + // when the context is done we stop the timer + timer.Stop() + log.WithError(ctx.Err()).Debug("resolveWithBackoff context done") + break + } } else { break } -- GitLab From 6e8e78699fe3e0897ff4f73070f5e7f5b69d6412 Mon Sep 17 00:00:00 2001 From: Dishon Date: Tue, 12 Jan 2021 09:28:39 +0000 Subject: [PATCH 03/17] Apply 2 suggestion(s) to 1 file(s) --- internal/source/gitlab/cache/retriever.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index be22d82df..2aa532e63 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -62,9 +62,9 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha // retry to GetLookup continue case <-ctx.Done(): - // when the context is done we stop the timer + // when the retrieval context is done we stop the timer timer.Stop() - log.WithError(ctx.Err()).Debug("resolveWithBackoff context done") + log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") break } } else { -- GitLab From da862b83f50786a9ad7f071a50a7e6711bb65632 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 08:51:27 +0000 Subject: [PATCH 04/17] Add tests --- internal/source/gitlab/cache/retriever.go | 51 ++++++++++++++++--- .../source/gitlab/cache/retriever_test.go | 27 ++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 internal/source/gitlab/cache/retriever_test.go diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index 2aa532e63..bfddaadba 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -4,6 +4,7 @@ import ( "context" "errors" "time" + "sync" log "github.com/sirupsen/logrus" @@ -13,15 +14,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,20 +63,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - timer := time.NewTimer(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) select { - case <-timer.C: + case <-r.timer.timer.C: // retry to GetLookup - continue + continue Retry case <-ctx.Done(): // when the retrieval context is done we stop the timer - timer.Stop() - log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") - break + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break @@ -78,3 +93,27 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + + + func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) + } + + func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() + } + + func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped + } + \ No newline at end of file diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..4db4a639f --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string, 0), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} -- GitLab From 2978471b9e7955daeff2840f9c6a74fc3a6eec08 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 10:01:46 +0000 Subject: [PATCH 05/17] Rebase --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- internal/auth/auth.go | 117 +++++++--- internal/auth/auth_code.go | 147 +++++++++++++ internal/auth/auth_code_test.go | 99 +++++++++ internal/auth/auth_test.go | 207 ++++++++++-------- internal/rejectmethods/middleware.go | 31 +++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 54 ++++- .../source/gitlab/cache/retriever_test.go | 27 +++ metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 ++- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 109 ++++++++- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 ++ 17 files changed, 789 insertions(+), 163 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..f38351ce4 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,10 +1,11 @@ package cache import ( + "time" "context" "errors" - "time" - + "sync" + log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -13,15 +14,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +63,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +92,27 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + + + func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) + } + + func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() + } + + func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped + } + \ No newline at end of file diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..4db4a639f --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string, 0), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -234,11 +234,10 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -336,11 +335,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +410,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +644,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From f6038c39528b9a5728ab22a1763fab4e8134b0bf Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 10:19:07 +0000 Subject: [PATCH 06/17] fix failing ci pipeline --- internal/source/gitlab/cache/retriever.go | 56 ++++++++----------- .../source/gitlab/cache/retriever_test.go | 2 +- 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index 8e5be44b5..b8c08cede 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,17 +1,11 @@ package cache import ( - "time" "context" "errors" -<<<<<<< HEAD "sync" - -======= "time" - "sync" ->>>>>>> da862b83f50786a9ad7f071a50a7e6711bb65632 log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -69,7 +63,7 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup - Retry: + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) @@ -86,10 +80,6 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha // when the retrieval context is done we stop the timerFunc r.timer.stop() break Retry -<<<<<<< HEAD -======= - ->>>>>>> da862b83f50786a9ad7f071a50a7e6711bb65632 } } else { break @@ -103,26 +93,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } - - func (t *timer) start(d time.Duration) { - t.mu.Lock() - defer t.mu.Unlock() - - t.stopped = false - t.timer = time.NewTimer(d) - } - - func (t *timer) stop() { - t.mu.Lock() - defer t.mu.Unlock() - - t.stopped = t.timer.Stop() - } - - func (t *timer) hasStopped() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return t.stopped - } - \ No newline at end of file +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go index 4db4a639f..774e9779e 100644 --- a/internal/source/gitlab/cache/retriever_test.go +++ b/internal/source/gitlab/cache/retriever_test.go @@ -13,7 +13,7 @@ func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff resolver := &client{ - domain: make(chan string, 0), + domain: make(chan string), lookups: make(chan uint64, 1), failure: errors.New("500 error"), } -- GitLab From bd862b1a80582368c705b7c8487c00e2775753d3 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 10:01:46 +0000 Subject: [PATCH 07/17] Rebase Encrypt and sign OAuth code Add AES GCM encryption/decryption to auth Add signing key to Auth Abstract key generation and Auth init to their own funcs. Cleanup and DRY unit tests. Use same code parameter in auth redirect Cleanup auth and add tests for enc/dec oauth code Add acceptance test for fix Apply suggestion from review Add missing test and apply feedback Fix unit test Simplify acceptance test Reject all unknown http methods Release 1.33.0 Allow DELETE HTTP method For some reason I forgot it the last time Disable IPv6 listeners for acceptance tests Replace time.Sleep with a cancelable timer inside the cache retriever Replace time.Sleep with a cancelable timer inside the cache retriever Add tests fix failing ci pipeline --- internal/rejectmethods/middleware.go | 3 +++ internal/source/gitlab/cache/retriever.go | 31 +++++++++++++++++++++++ test/acceptance/auth_test.go | 13 ++++++++++ 3 files changed, 47 insertions(+) diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go index e78a0ce59..235c37f13 100644 --- a/internal/rejectmethods/middleware.go +++ b/internal/rejectmethods/middleware.go @@ -12,7 +12,10 @@ var acceptedMethods = map[string]bool{ http.MethodPost: true, http.MethodPut: true, http.MethodPatch: true, +<<<<<<< HEAD http.MethodDelete: true, +======= +>>>>>>> 263db26 (Rebase) http.MethodConnect: true, http.MethodOptions: true, http.MethodTrace: true, diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index b8c08cede..2e288309c 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -5,6 +5,7 @@ import ( "errors" "sync" "time" + "sync" log "github.com/sirupsen/logrus" @@ -63,7 +64,11 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup +<<<<<<< HEAD Retry: +======= + Retry: +>>>>>>> 263db26 (Rebase) for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) @@ -93,6 +98,7 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } +<<<<<<< HEAD func (t *timer) start(d time.Duration) { t.mu.Lock() defer t.mu.Unlock() @@ -114,3 +120,28 @@ func (t *timer) hasStopped() bool { return t.stopped } +======= + + func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) + } + + func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() + } + + func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped + } + +>>>>>>> 263db26 (Rebase) diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index fa2d768d8..958c12b7e 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -232,6 +232,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { url, err = url.Parse(authrsp.Header.Get("Location")) require.NoError(t, err) +<<<<<<< HEAD // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) code := url.Query().Get("code") @@ -239,6 +240,18 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) +======= + // Will redirect to custom domain + require.Equal(t, "private.domain.com", url.Host) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) + require.Equal(t, state, url.Query().Get("state")) + + // Run auth callback in custom domain + authrsp, err = GetRedirectPageWithCookie(t, httpListener, "private.domain.com", "/auth?code="+code+"&state="+ + state, cookie) +>>>>>>> cd780a1 (Encrypt and sign OAuth code) require.NoError(t, err) defer authrsp.Body.Close() -- GitLab From 72121cb1f78addaea49060fe84686e5111960570 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 10:01:46 +0000 Subject: [PATCH 08/17] Replace time.Sleep with a cancelable timer inside the cache retrieverebase Encrypt and sign OAuth code Add AES GCM encryption/decryption to auth Add signing key to Auth Abstract key generation and Auth init to their own funcs. Cleanup and DRY unit tests. Use same code parameter in auth redirect Cleanup auth and add tests for enc/dec oauth code Add acceptance test for fix Apply suggestion from review Add missing test and apply feedback Fix unit test Simplify acceptance test Reject all unknown http methods Release 1.33.0 Allow DELETE HTTP method For some reason I forgot it the last time Disable IPv6 listeners for acceptance tests Replace time.Sleep with a cancelable timer inside the cache retriever Replace time.Sleep with a cancelable timer inside the cache retriever Apply 2 suggestion(s) to 1 file(s) Add tests fix failing ci pipeline Rebase Encrypt and sign OAuth code Add AES GCM encryption/decryption to auth Add signing key to Auth Abstract key generation and Auth init to their own funcs. Cleanup and DRY unit tests. Use same code parameter in auth redirect Cleanup auth and add tests for enc/dec oauth code Add acceptance test for fix Apply suggestion from review Add missing test and apply feedback Fix unit test Simplify acceptance test Reject all unknown http methods Release 1.33.0 Allow DELETE HTTP method For some reason I forgot it the last time Disable IPv6 listeners for acceptance tests Replace time.Sleep with a cancelable timer inside the cache retriever Replace time.Sleep with a cancelable timer inside the cache retriever Add tests fix failing ci pipeline --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- internal/auth/auth.go | 117 +++++++--- internal/auth/auth_code.go | 147 +++++++++++++ internal/auth/auth_code_test.go | 99 +++++++++ internal/auth/auth_test.go | 207 ++++++++++-------- internal/rejectmethods/middleware.go | 40 ++++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 118 +++++++++- .../source/gitlab/cache/retriever_test.go | 27 +++ metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 ++- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 122 ++++++++++- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 ++ 17 files changed, 877 insertions(+), 161 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..1240094e6 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,40 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, +<<<<<<< HEAD +<<<<<<< HEAD +======= + http.MethodDelete: true, +>>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) +======= + http.MethodDelete: true, +======= +>>>>>>> 263db26 (Rebase) +>>>>>>> bd862b1 (Rebase) + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..0cc61c6fa 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -3,8 +3,14 @@ package cache import ( "context" "errors" + "sync" "time" + "sync" +<<<<<<< HEAD +>>>>>>> da862b8 (Add tests) +======= +>>>>>>> f6038c3 (fix failing ci pipeline) log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -13,15 +19,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +68,63 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup +<<<<<<< HEAD + Retry: +======= + Retry: +>>>>>>> 263db26 (Rebase) for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry +<<<<<<< HEAD +======= +======= +>>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) + timer := time.NewTimer(r.maxRetrievalInterval) +======= + r.timer.start(r.maxRetrievalInterval) +>>>>>>> da862b8 (Add tests) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer +<<<<<<< HEAD + timer.Stop() + log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + break +<<<<<<< HEAD +>>>>>>> 7f5da7e (Replace time.Sleep with a cancelable timer inside the cache retriever) +======= +>>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) +======= + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + +>>>>>>> da862b8 (Add tests) +======= +>>>>>>> f6038c3 (fix failing ci pipeline) + } } else { break } @@ -68,3 +136,51 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +<<<<<<< HEAD +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} +======= + + func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) + } + + func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() + } + + func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped + } + +>>>>>>> 263db26 (Rebase) diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..774e9779e --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..958c12b7e 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -232,14 +232,26 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { url, err = url.Parse(authrsp.Header.Get("Location")) require.NoError(t, err) +<<<<<<< HEAD // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) +======= + // Will redirect to custom domain + require.Equal(t, "private.domain.com", url.Host) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) + require.Equal(t, state, url.Query().Get("state")) + + // Run auth callback in custom domain + authrsp, err = GetRedirectPageWithCookie(t, httpListener, "private.domain.com", "/auth?code="+code+"&state="+ + state, cookie) +>>>>>>> cd780a1 (Encrypt and sign OAuth code) require.NoError(t, err) defer authrsp.Body.Close() @@ -336,11 +348,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +423,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +657,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From a3b9df94ab1cc789947c72f4fd881d0521b41030 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 12:56:24 +0000 Subject: [PATCH 09/17] Update middleware.go --- internal/rejectmethods/middleware.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go index 7d4aa51d2..e78a0ce59 100644 --- a/internal/rejectmethods/middleware.go +++ b/internal/rejectmethods/middleware.go @@ -12,22 +12,7 @@ var acceptedMethods = map[string]bool{ http.MethodPost: true, http.MethodPut: true, http.MethodPatch: true, -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -======= http.MethodDelete: true, ->>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) -======= - http.MethodDelete: true, -======= ->>>>>>> 263db26 (Rebase) ->>>>>>> bd862b1 (Rebase) -======= - http.MethodDelete: true, -======= ->>>>>>> 263db26 (Rebase) ->>>>>>> bd862b1a80582368c705b7c8487c00e2775753d3 http.MethodConnect: true, http.MethodOptions: true, http.MethodTrace: true, -- GitLab From 247018d27a0addd0d224bf69249ec5c894daee49 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 13:01:00 +0000 Subject: [PATCH 10/17] add fixes --- internal/rejectmethods/middleware.go | 15 ----- internal/source/gitlab/cache/retriever.go | 81 +---------------------- 2 files changed, 3 insertions(+), 93 deletions(-) diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go index 7d4aa51d2..e78a0ce59 100644 --- a/internal/rejectmethods/middleware.go +++ b/internal/rejectmethods/middleware.go @@ -12,22 +12,7 @@ var acceptedMethods = map[string]bool{ http.MethodPost: true, http.MethodPut: true, http.MethodPatch: true, -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -======= http.MethodDelete: true, ->>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) -======= - http.MethodDelete: true, -======= ->>>>>>> 263db26 (Rebase) ->>>>>>> bd862b1 (Rebase) -======= - http.MethodDelete: true, -======= ->>>>>>> 263db26 (Rebase) ->>>>>>> bd862b1a80582368c705b7c8487c00e2775753d3 http.MethodConnect: true, http.MethodOptions: true, http.MethodTrace: true, diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index a5af90f91..43ef2e523 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,16 +1,11 @@ package cache import ( + "time" "context" "errors" "sync" - "time" - "sync" - -<<<<<<< HEAD ->>>>>>> da862b8 (Add tests) -======= ->>>>>>> f6038c3 (fix failing ci pipeline) + log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -68,21 +63,11 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup -<<<<<<< HEAD - Retry: -======= Retry: ->>>>>>> 263db26 (Rebase) for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> bd862b1a80582368c705b7c8487c00e2775753d3 r.timer.start(r.maxRetrievalInterval) select { case <-r.timer.timer.C: @@ -95,41 +80,6 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha // when the retrieval context is done we stop the timerFunc r.timer.stop() break Retry -<<<<<<< HEAD -<<<<<<< HEAD -======= -======= ->>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) - timer := time.NewTimer(r.maxRetrievalInterval) -======= - r.timer.start(r.maxRetrievalInterval) ->>>>>>> da862b8 (Add tests) - select { - case <-r.timer.timer.C: - // retry to GetLookup - continue Retry - case <-ctx.Done(): - // when the retrieval context is done we stop the timer -<<<<<<< HEAD - timer.Stop() - log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") - break -<<<<<<< HEAD ->>>>>>> 7f5da7e (Replace time.Sleep with a cancelable timer inside the cache retriever) -======= ->>>>>>> 0966171 (Replace time.Sleep with a cancelable timer inside the cache retriever) -======= - // timer.Stop() - // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") - // when the retrieval context is done we stop the timerFunc - r.timer.stop() - break Retry - ->>>>>>> da862b8 (Add tests) -======= ->>>>>>> f6038c3 (fix failing ci pipeline) -======= ->>>>>>> bd862b1a80582368c705b7c8487c00e2775753d3 } } else { break @@ -143,7 +93,6 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } -<<<<<<< HEAD func (t *timer) start(d time.Duration) { t.mu.Lock() defer t.mu.Unlock() @@ -165,28 +114,4 @@ func (t *timer) hasStopped() bool { return t.stopped } -======= - - func (t *timer) start(d time.Duration) { - t.mu.Lock() - defer t.mu.Unlock() - - t.stopped = false - t.timer = time.NewTimer(d) - } - - func (t *timer) stop() { - t.mu.Lock() - defer t.mu.Unlock() - - t.stopped = t.timer.Stop() - } - - func (t *timer) hasStopped() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return t.stopped - } - ->>>>>>> 263db26 (Rebase) + \ No newline at end of file -- GitLab From fb3f8727a1d4eee9f19a5719ea0e51b23004852e Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 13:05:30 +0000 Subject: [PATCH 11/17] Update auth_test.go --- test/acceptance/auth_test.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index 958c12b7e..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -232,7 +232,6 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { url, err = url.Parse(authrsp.Header.Get("Location")) require.NoError(t, err) -<<<<<<< HEAD // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) code := url.Query().Get("code") @@ -240,18 +239,6 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) -======= - // Will redirect to custom domain - require.Equal(t, "private.domain.com", url.Host) - // code must have changed since it's encrypted - code := url.Query().Get("code") - require.NotEqual(t, "1", code) - require.Equal(t, state, url.Query().Get("state")) - - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, "private.domain.com", "/auth?code="+code+"&state="+ - state, cookie) ->>>>>>> cd780a1 (Encrypt and sign OAuth code) require.NoError(t, err) defer authrsp.Body.Close() -- GitLab From 923e4b22b33b45be4dd4f8c5228e701d57d1509c Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 13:32:55 +0000 Subject: [PATCH 12/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- internal/auth/auth.go | 117 +++++++--- internal/auth/auth_code.go | 147 +++++++++++++ internal/auth/auth_code_test.go | 99 +++++++++ internal/auth/auth_test.go | 207 ++++++++++-------- internal/rejectmethods/middleware.go | 31 +++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 53 ++++- .../source/gitlab/cache/retriever_test.go | 27 +++ metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 ++- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 109 ++++++++- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 ++ 17 files changed, 788 insertions(+), 163 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..43ef2e523 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,10 +1,11 @@ package cache import ( + "time" "context" "errors" - "time" - + "sync" + log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -13,15 +14,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +63,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +92,26 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} + \ No newline at end of file diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..774e9779e --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -234,11 +234,10 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -336,11 +335,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +410,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +644,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From 3bc1621d3b501e8849ed15dab7ff9d9be279267b Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 13:44:27 +0000 Subject: [PATCH 13/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- .gitignore | 7 +- .gitlab/ci/prepare.yml | 5 +- .gitlab/ci/test.yml | 54 +- .tool-versions | 2 +- CHANGELOG | 52 + Makefile.build.mk | 1 + Makefile.internal.mk | 3 + Makefile.util.mk | 8 +- README.md | 28 + VERSION | 2 +- acceptance_test.go | 2068 ----------------- app.go | 90 +- app_config.go | 20 +- daemon.go | 1 + go.mod | 8 +- go.sum | 29 +- internal/auth/auth.go | 117 +- internal/auth/auth_code.go | 147 ++ internal/auth/auth_code_test.go | 99 + internal/auth/auth_test.go | 207 +- internal/config/config.go | 34 +- internal/httperrors/httperrors.go | 14 + internal/httprange/http_reader.go | 29 +- internal/httprange/http_reader_test.go | 57 +- internal/httprange/resource.go | 56 +- internal/httprange/resource_test.go | 34 +- internal/httptransport/transport.go | 44 +- internal/httptransport/transport_test.go | 64 +- internal/jail/jail.go | 22 +- internal/jail/mount_linux.go | 4 +- internal/jail/mount_not_supported.go | 4 +- internal/middleware/headers.go | 31 + .../headers_test.go} | 2 +- internal/rejectmethods/middleware.go | 31 + internal/rejectmethods/middleware_test.go | 43 + internal/serving/disk/local/serving_test.go | 84 +- internal/serving/disk/reader.go | 104 +- internal/serving/disk/reader_test.go | 68 + internal/serving/disk/serving.go | 12 +- internal/serving/disk/zip/serving.go | 3 +- internal/serving/disk/zip/serving_test.go | 45 +- internal/serving/serverless/serverless.go | 6 + internal/serving/serving.go | 3 + internal/source/domains.go | 34 +- internal/source/domains_test.go | 1 - internal/source/gitlab/cache/retriever.go | 53 +- .../source/gitlab/cache/retriever_test.go | 27 + internal/source/gitlab/client/client.go | 4 +- internal/source/gitlab/client/client_test.go | 2 +- internal/source/gitlab/factory.go | 22 +- internal/source/gitlab/factory_test.go | 5 +- internal/tlsconfig/tlsconfig.go | 6 +- internal/tlsconfig/tlsconfig_test.go | 8 +- internal/vfs/errors.go | 18 + internal/vfs/local/vfs.go | 14 +- internal/vfs/local/vfs_test.go | 4 +- internal/vfs/vfs.go | 6 + internal/vfs/zip/archive.go | 54 +- internal/vfs/zip/archive_test.go | 157 +- internal/vfs/zip/lru_cache.go | 7 +- internal/vfs/zip/vfs.go | 123 +- internal/vfs/zip/vfs_test.go | 143 +- main.go | 43 +- metrics/metrics.go | 7 + server.go | 26 +- shared/lookups/zip-malformed.gitlab.io.json | 16 + shared/lookups/zip-not-found.gitlab.io.json | 16 + shared/lookups/zip.gitlab.io.json | 2 +- test/acceptance/acceptance_test.go | 81 + test/acceptance/acme_test.go | 73 + test/acceptance/artifacts_test.go | 299 +++ test/acceptance/auth_test.go | 730 ++++++ test/acceptance/config_test.go | 66 + test/acceptance/encodings_test.go | 78 + .../acceptance/helpers_test.go | 265 ++- test/acceptance/metrics_test.go | 62 + test/acceptance/proxyv2_test.go | 52 + test/acceptance/redirects_test.go | 116 + test/acceptance/serving_test.go | 574 +++++ test/acceptance/status_test.go | 44 + test/acceptance/stub_test.go | 72 + test/acceptance/tls_test.go | 130 ++ test/acceptance/unknown_http_method_test.go | 23 + test/acceptance/zip_test.go | 161 ++ tools.go | 1 + 85 files changed, 4666 insertions(+), 2701 deletions(-) delete mode 100644 acceptance_test.go create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/middleware/headers.go rename internal/{config/config_test.go => middleware/headers_test.go} (99%) create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/serving/disk/reader_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 internal/vfs/errors.go create mode 100644 shared/lookups/zip-malformed.gitlab.io.json create mode 100644 shared/lookups/zip-not-found.gitlab.io.json create mode 100644 test/acceptance/acceptance_test.go create mode 100644 test/acceptance/acme_test.go create mode 100644 test/acceptance/artifacts_test.go create mode 100644 test/acceptance/auth_test.go create mode 100644 test/acceptance/config_test.go create mode 100644 test/acceptance/encodings_test.go rename helpers_test.go => test/acceptance/helpers_test.go (63%) create mode 100644 test/acceptance/metrics_test.go create mode 100644 test/acceptance/proxyv2_test.go create mode 100644 test/acceptance/redirects_test.go create mode 100644 test/acceptance/serving_test.go create mode 100644 test/acceptance/status_test.go create mode 100644 test/acceptance/stub_test.go create mode 100644 test/acceptance/tls_test.go create mode 100644 test/acceptance/unknown_http_method_test.go create mode 100644 test/acceptance/zip_test.go diff --git a/.gitignore b/.gitignore index 7357bc4b0..e3e689d12 100644 --- a/.gitignore +++ b/.gitignore @@ -4,12 +4,11 @@ shared/pages/.update /vendor /gitlab-pages.conf /gl-code-quality-report.json +/gl-license-scanning-report.json /coverage.html +/junit-test-report.xml +/tests.out # Used by the makefile /.GOPATH /bin - -# reports -gl-license-scanning-report.json -gl-code-quality-report.json diff --git a/.gitlab/ci/prepare.yml b/.gitlab/ci/prepare.yml index bae8d1e74..1b1347ea9 100644 --- a/.gitlab/ci/prepare.yml +++ b/.gitlab/ci/prepare.yml @@ -28,7 +28,10 @@ gemnasium-dependency_scanning: <<: *rules-for-scanners secret_detection: - <<: *rules-for-scanners + stage: prepare + rules: + # For merge requests, create a pipeline. + - if: '$CI_MERGE_REQUEST_IID' gosec-sast: <<: *rules-for-scanners diff --git a/.gitlab/ci/test.yml b/.gitlab/ci/test.yml index 3218d8ee0..74d49ee6b 100644 --- a/.gitlab/ci/test.yml +++ b/.gitlab/ci/test.yml @@ -1,46 +1,62 @@ -.tests: +.tests-common: extends: .go-mod-cache stage: test tags: - gitlab-org-docker needs: ['download deps'] + artifacts: + reports: + junit: junit-test-report.xml + +.tests-unit: + extends: .tests-common script: - echo "Running all tests without daemonizing..." + - make setup - make test + - make junit-report + +.tests-acceptance-deamon: + extends: .tests-common + script: + - make setup - echo "Running just the acceptance tests daemonized (tmpdir)...." - TEST_DAEMONIZE=tmpdir make acceptance - echo "Running just the acceptance tests daemonized (inplace)...." - TEST_DAEMONIZE=inplace make acceptance - artifacts: - paths: - - bin/gitlab-pages + - make junit-report test:1.13: - extends: .tests + extends: .tests-unit + image: golang:1.13 + +test-acceptance:1.13: + extends: .tests-acceptance-deamon image: golang:1.13 test:1.14: - extends: .tests + extends: .tests-unit + image: golang:1.14 + +test-acceptance:1.14: + extends: .tests-acceptance-deamon image: golang:1.14 test:1.15: - extends: .tests + extends: .tests-unit + image: golang:1.15 +test-acceptance:1.15: + extends: .tests-acceptance-deamon image: golang:1.15 race: - extends: .go-mod-cache - stage: test - needs: ['download deps'] - tags: - - gitlab-org-docker + extends: .tests-common script: - echo "Running race detector" - make race cover: - stage: test - extends: .go-mod-cache - needs: ['download deps'] + extends: .tests-common script: - make setup - make generate-mocks @@ -51,9 +67,7 @@ cover: - coverage.html code_quality: - stage: test - extends: .go-mod-cache - needs: ['download deps'] + extends: .tests-common image: golangci/golangci-lint:v1.27.0 variables: REPORT_FILE: gl-code-quality-report.json @@ -69,8 +83,6 @@ code_quality: - ${REPORT_FILE} check deps: - stage: test - extends: .go-mod-cache - needs: ['download deps'] + extends: .tests-common script: - make deps-check diff --git a/.tool-versions b/.tool-versions index b1dbcb109..63d9ded15 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1 +1 @@ -golang 1.15.1 +golang 1.15.5 diff --git a/CHANGELOG b/CHANGELOG index f39ebae99..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,55 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + +v 1.32.0 + +- Try to automatically use gitlab API as a source for domain information !402 +- Fix https redirect loop for PROXYv2 protocol !405 + +v 1.31.0 + +- Support for HTTPS over PROXYv2 protocol !278 +- Update LabKit library to v1.0.0 !397 +- Add zip serving configuration flags !392 +- Disable deprecated serverless serving and proxy !400 + +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + +v 1.30.0 + +- Allow to refresh an existing cached archive when accessed !375 + +v 1.29.0 + +- Fix LRU cache metrics !379 +- Upgrade go-mimedb to support new types including avif images !353 +- Return 5xx instead of 404 if pages zip serving is unavailable !381 +- Make timeouts for ZIP VFS configurable !385 +- Improve httprange timeouts !382 +- Fix caching for errored ZIP VFS archives !384 + +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/Makefile.build.mk b/Makefile.build.mk index 2c9e91591..2656e62e5 100644 --- a/Makefile.build.mk +++ b/Makefile.build.mk @@ -10,6 +10,7 @@ setup: clean .GOPATH/.ok go get github.com/wadey/gocovmerge@v0.0.0-20160331181800-b5bfa59ec0ad go get github.com/golang/mock/mockgen@v1.3.1 go get github.com/golangci/golangci-lint/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION) + go get github.com/jstemmer/go-junit-report generate-mocks: .GOPATH/.ok $Q bin/mockgen -source=internal/interface.go -destination=internal/mocks/mocks.go -package=mocks diff --git a/Makefile.internal.mk b/Makefile.internal.mk index a33634fd3..54e7f7da3 100644 --- a/Makefile.internal.mk +++ b/Makefile.internal.mk @@ -43,3 +43,6 @@ bin/golangci-lint: .GOPATH/.ok @test -x $@ || \ { echo "Vendored golangci-lint not found, try running 'make setup'..."; exit 1; } +bin/go-junit-report: .GOPATH/.ok + @test -x $@ || \ + { echo "Vendored go-junit-report not found, try running 'make setup'..."; exit 1; } diff --git a/Makefile.util.mk b/Makefile.util.mk index ea465fbf2..4f190ea45 100644 --- a/Makefile.util.mk +++ b/Makefile.util.mk @@ -10,13 +10,14 @@ lint: .GOPATH/.ok bin/golangci-lint $Q ./bin/golangci-lint run ./... --out-format $(OUT_FORMAT) $(LINT_FLAGS) | tee ${REPORT_FILE} test: .GOPATH/.ok gitlab-pages - go test $(if $V,-v) $(allpackages) + rm tests.out || true + go test $(if $V,-v) $(allpackages) 2>&1 | tee tests.out race: .GOPATH/.ok gitlab-pages CGO_ENABLED=1 go test -race $(if $V,-v) $(allpackages) acceptance: .GOPATH/.ok gitlab-pages - go test $(if $V,-v) $(IMPORT_PATH) + go test $(if $V,-v) ./test/acceptance 2>&1 | tee tests.out bench: .GOPATH/.ok gitlab-pages go test -bench=. -run=^$$ $(allpackages) @@ -55,3 +56,6 @@ deps-check: .GOPATH/.ok deps-download: .GOPATH/.ok go mod download + +junit-report: .GOPATH/.ok bin/go-junit-report + cat tests.out | ./bin/go-junit-report -set-exit-code > junit-test-report.xml diff --git a/README.md b/README.md index d8016eaf1..ab2ac6329 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,34 @@ We use `gorilla/handlers.ProxyHeaders` middleware. For more information please r > NOTE: This middleware should only be used when behind a reverse proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are configured not to) strip these headers from client requests, or where these headers are accepted "as is" from a remote client (e.g. when Go is not behind a proxy), can manifest as a vulnerability if your application uses these headers for validating the 'trustworthiness' of a request. +### PROXY protocol for HTTPS + +The above `listen-proxy` option only works for plaintext HTTP, where the reverse +proxy was already able to parse the incoming HTTP traffic and inject a header for +the remote client IP. + +This does not work for HTTPS which is generally proxied at the TCP level. In +order to propagate the remote client IP in this case, you can use the +[PROXY protocol](https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt). +This is supported by HAProxy and some third party services such as Cloudflare. + +To configure PROXY protocol support, run `gitlab-pages` with the +`listen-https-proxyv2` flag. + +If you are using HAProxy as your TCP load balancer, you can configure the backend +with the `send-proxy-v2` option, like so: + +``` +frontend fe + bind 127.0.0.1:12340 + mode tcp + default_backend be + +backend be + mode tcp + server app1 127.0.0.1:1234 send-proxy-v2 +``` + ### GitLab access control GitLab access control is configured with properties `auth-client-id`, `auth-client-secret`, `auth-redirect-uri`, `auth-server` and `auth-secret`. Client ID, secret and redirect uri are configured in the GitLab and should match. `auth-server` points to a GitLab instance used for authentication. `auth-redirect-uri` should be `http(s)://pages-domain/auth`. Note that if the pages-domain is not handled by GitLab pages, then the `auth-redirect-uri` should use some reserved namespace prefix (such as `http(s)://projects.pages-domain/auth`). Using HTTPS is _strongly_ encouraged. `auth-secret` is used to encrypt the session cookie, and it should be strong enough. diff --git a/VERSION b/VERSION index cfc730712..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.28.0 +1.34.0 diff --git a/acceptance_test.go b/acceptance_test.go deleted file mode 100644 index 69ec8742e..000000000 --- a/acceptance_test.go +++ /dev/null @@ -1,2068 +0,0 @@ -package main - -import ( - "crypto/tls" - "fmt" - "io/ioutil" - "mime" - "net" - "net/http" - "net/http/httptest" - "net/url" - "os" - "path" - "regexp" - "testing" - "time" - - "github.com/namsral/flag" - "github.com/stretchr/testify/require" -) - -var pagesBinary = flag.String("gitlab-pages-binary", "./gitlab-pages", "Path to the gitlab-pages binary") - -const ( - objectStorageMockServer = "127.0.0.1:37003" -) - -// TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output -// the actual port (and type of listener) for us to read in place of the -// hardcoded values below. -var listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, -} - -var ( - httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] -) - -func skipUnlessEnabled(t *testing.T, conditions ...string) { - t.Helper() - - if testing.Short() { - t.Log("Acceptance tests disabled") - t.SkipNow() - } - - if _, err := os.Stat(*pagesBinary); os.IsNotExist(err) { - t.Errorf("Couldn't find gitlab-pages binary at %s", *pagesBinary) - t.FailNow() - } - - for _, condition := range conditions { - switch condition { - case "not-inplace-chroot": - if os.Getenv("TEST_DAEMONIZE") == "inplace" { - t.Log("Not supported with -daemon-inplace-chroot") - t.SkipNow() - } - default: - t.Error("Unknown condition:", condition) - t.FailNow() - } - } -} - -func TestUnknownHostReturnsNotFound(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, "invalid.invalid", "") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusNotFound, rsp.StatusCode) - } -} - -func TestUnknownProjectReturnsNotFound(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/nonexistent/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusNotFound, rsp.StatusCode) -} - -func TestGroupDomainReturns200(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestKnownHostReturns200(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - tests := []struct { - name string - host string - path string - }{ - { - name: "lower case", - host: "group.gitlab-example.com", - path: "project/", - }, - { - name: "capital project", - host: "group.gitlab-example.com", - path: "CapitalProject/", - }, - { - name: "capital group", - host: "CapitalGroup.gitlab-example.com", - path: "project/", - }, - { - name: "capital group and project", - host: "CapitalGroup.gitlab-example.com", - path: "CapitalProject/", - }, - { - name: "subgroup", - host: "group.gitlab-example.com", - path: "subgroup/project/", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, tt.host, tt.path) - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) - } - }) - } -} - -func TestNestedSubgroups(t *testing.T) { - skipUnlessEnabled(t) - - maxNestedSubgroup := 21 - - pagesRoot, err := ioutil.TempDir("", "pages-root") - require.NoError(t, err) - defer os.RemoveAll(pagesRoot) - - makeProjectIndex := func(subGroupPath string) { - projectPath := path.Join(pagesRoot, "nested", subGroupPath, "project", "public") - require.NoError(t, os.MkdirAll(projectPath, 0755)) - - projectIndex := path.Join(projectPath, "index.html") - require.NoError(t, ioutil.WriteFile(projectIndex, []byte("index"), 0644)) - } - makeProjectIndex("") - - paths := []string{""} - for i := 1; i < maxNestedSubgroup*2; i++ { - subGroupPath := fmt.Sprintf("%ssub%d/", paths[i-1], i) - paths = append(paths, subGroupPath) - - makeProjectIndex(subGroupPath) - } - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-pages-root", pagesRoot) - defer teardown() - - for nestingLevel, path := range paths { - t.Run(fmt.Sprintf("nested level %d", nestingLevel), func(t *testing.T) { - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, "nested.gitlab-example.com", path+"project/") - - require.NoError(t, err) - rsp.Body.Close() - if nestingLevel <= maxNestedSubgroup { - require.Equal(t, http.StatusOK, rsp.StatusCode) - } else { - require.Equal(t, http.StatusNotFound, rsp.StatusCode) - } - } - }) - } -} - -func TestCustom404(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - tests := []struct { - host string - path string - content string - }{ - { - host: "group.404.gitlab-example.com", - path: "project.404/not/existing-file", - content: "Custom 404 project page", - }, - { - host: "group.404.gitlab-example.com", - path: "project.404/", - content: "Custom 404 project page", - }, - { - host: "group.404.gitlab-example.com", - path: "not/existing-file", - content: "Custom 404 group page", - }, - { - host: "group.404.gitlab-example.com", - path: "not-existing-file", - content: "Custom 404 group page", - }, - { - host: "group.404.gitlab-example.com", - content: "Custom 404 group page", - }, - { - host: "domain.404.com", - content: "Custom domain.404 page", - }, - { - host: "group.404.gitlab-example.com", - path: "project.no.404/not/existing-file", - content: "The page you're looking for could not be found.", - }, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("%s/%s", test.host, test.path), func(t *testing.T) { - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, test.host, test.path) - - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusNotFound, rsp.StatusCode) - - page, err := ioutil.ReadAll(rsp.Body) - require.NoError(t, err) - require.Contains(t, string(page), test.content) - } - }) - } -} - -func TestCORSWhenDisabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-disable-cross-origin-requests") - defer teardown() - - for _, spec := range listeners { - for _, method := range []string{"GET", "OPTIONS"} { - rsp := doCrossOriginRequest(t, method, method, spec.URL("project/")) - - require.Equal(t, http.StatusOK, rsp.StatusCode) - require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) - } - } -} - -func TestCORSAllowsGET(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, spec := range listeners { - for _, method := range []string{"GET", "OPTIONS"} { - rsp := doCrossOriginRequest(t, method, method, spec.URL("project/")) - - require.Equal(t, http.StatusOK, rsp.StatusCode) - require.Equal(t, "*", rsp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) - } - } -} - -func TestCORSForbidsPOST(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, spec := range listeners { - rsp := doCrossOriginRequest(t, "OPTIONS", "POST", spec.URL("project/")) - - require.Equal(t, http.StatusOK, rsp.StatusCode) - require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Origin")) - require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) - } -} - -func TestCustomHeaders(t *testing.T) { - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-header", "X-Test1:Testing1", "-header", "X-Test2:Testing2") - defer teardown() - - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:", "project/") - require.NoError(t, err) - require.Equal(t, http.StatusOK, rsp.StatusCode) - require.Equal(t, "Testing1", rsp.Header.Get("X-Test1")) - require.Equal(t, "Testing2", rsp.Header.Get("X-Test2")) - } -} - -func doCrossOriginRequest(t *testing.T, method, reqMethod, url string) *http.Response { - req, err := http.NewRequest(method, url, nil) - require.NoError(t, err) - - req.Host = "group.gitlab-example.com" - req.Header.Add("Origin", "example.com") - req.Header.Add("Access-Control-Request-Method", reqMethod) - - var rsp *http.Response - err = fmt.Errorf("no request was made") - for start := time.Now(); time.Since(start) < 1*time.Second; { - rsp, err = DoPagesRequest(t, req) - if err == nil { - break - } - time.Sleep(100 * time.Millisecond) - } - require.NoError(t, err) - - rsp.Body.Close() - return rsp -} - -func TestKnownHostWithPortReturns200(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, spec := range listeners { - rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:"+spec.Port, "project/") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) - } -} - -func TestHttpToHttpsRedirectDisabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) - - rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestHttpToHttpsRedirectEnabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-redirect-http=true") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusTemporaryRedirect, rsp.StatusCode) - require.Equal(t, 1, len(rsp.Header["Location"])) - require.Equal(t, "https://group.gitlab-example.com/project/", rsp.Header.Get("Location")) - - rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestHttpsOnlyGroupEnabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "group.https-only.gitlab-example.com", "project1/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusMovedPermanently, rsp.StatusCode) -} - -func TestHttpsOnlyGroupDisabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.https-only.gitlab-example.com", "project2/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestHttpsOnlyProjectEnabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "test.my-domain.com", "/index.html") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusMovedPermanently, rsp.StatusCode) -} - -func TestHttpsOnlyProjectDisabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "test2.my-domain.com", "/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestHttpsOnlyDomainDisabled(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "no.cert.com", "/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestPrometheusMetricsCanBeScraped(t *testing.T) { - skipUnlessEnabled(t) - - _, cleanup := newZipFileServerURL(t, "shared/pages/group/zip.gitlab.io/public.zip") - defer cleanup() - - teardown := RunPagesProcessWithStubGitLabServer(t, true, *pagesBinary, listeners, ":42345", []string{}) - defer teardown() - - // need to call an actual resource to populate certain metrics e.g. gitlab_pages_domains_source_api_requests_total - res, err := GetPageFromListener(t, httpListener, "zip.gitlab.io", - "/symlink.html") - require.NoError(t, err) - require.Equal(t, http.StatusOK, res.StatusCode) - - resp, err := http.Get("http://localhost:42345/metrics") - require.NoError(t, err) - - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.Contains(t, string(body), "gitlab_pages_http_in_flight_requests 0") - // TODO: remove metrics for disk source https://gitlab.com/gitlab-org/gitlab-pages/-/issues/382 - require.Contains(t, string(body), "gitlab_pages_served_domains 0") - require.Contains(t, string(body), "gitlab_pages_domains_failed_total 0") - require.Contains(t, string(body), "gitlab_pages_domains_updated_total 0") - require.Contains(t, string(body), "gitlab_pages_last_domain_update_seconds gauge") - require.Contains(t, string(body), "gitlab_pages_domains_configuration_update_duration gauge") - // end TODO - require.Contains(t, string(body), "gitlab_pages_domains_source_cache_hit") - require.Contains(t, string(body), "gitlab_pages_domains_source_cache_miss") - require.Contains(t, string(body), "gitlab_pages_domains_source_failures_total") - require.Contains(t, string(body), "gitlab_pages_serverless_requests 0") - require.Contains(t, string(body), "gitlab_pages_serverless_latency_sum 0") - require.Contains(t, string(body), "gitlab_pages_disk_serving_file_size_bytes_sum") - require.Contains(t, string(body), "gitlab_pages_serving_time_seconds_sum") - require.Contains(t, string(body), `gitlab_pages_domains_source_api_requests_total{status_code="200"}`) - require.Contains(t, string(body), `gitlab_pages_domains_source_api_call_duration_bucket`) - require.Contains(t, string(body), `gitlab_pages_domains_source_api_trace_duration`) - // httprange - require.Contains(t, string(body), `gitlab_pages_httprange_requests_total{status_code="206"}`) - require.Contains(t, string(body), "gitlab_pages_httprange_requests_duration_bucket") - require.Contains(t, string(body), "gitlab_pages_httprange_trace_duration") - require.Contains(t, string(body), "gitlab_pages_httprange_open_requests") - // zip archives - require.Contains(t, string(body), "gitlab_pages_zip_opened") - require.Contains(t, string(body), "gitlab_pages_zip_cache_requests") - require.Contains(t, string(body), "gitlab_pages_zip_cached_entries") - require.Contains(t, string(body), "gitlab_pages_zip_archive_entries_cached") - require.Contains(t, string(body), "gitlab_pages_zip_opened_entries_count") -} - -func TestDisabledRedirects(t *testing.T) { - skipUnlessEnabled(t) - - teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{"FF_ENABLE_REDIRECTS=false"}) - defer teardown() - - // Test that redirects status page is forbidden - rsp, err := GetPageFromListener(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/_redirects") - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusForbidden, rsp.StatusCode) - - // Test that redirects are disabled - rsp, err = GetRedirectPage(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/redirect-portal.html") - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusNotFound, rsp.StatusCode) -} - -func TestRedirectStatusPage(t *testing.T) { - skipUnlessEnabled(t) - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/_redirects") - require.NoError(t, err) - - body, err := ioutil.ReadAll(rsp.Body) - require.NoError(t, err) - defer rsp.Body.Close() - - require.Contains(t, string(body), "11 rules") - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestRedirect(t *testing.T) { - skipUnlessEnabled(t) - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - // Test that serving a file still works with redirects enabled - rsp, err := GetRedirectPage(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/index.html") - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusOK, rsp.StatusCode) - - tests := []struct { - host string - path string - expectedStatus int - expectedLocation string - }{ - // Project domain - { - host: "group.redirects.gitlab-example.com", - path: "/project-redirects/redirect-portal.html", - expectedStatus: http.StatusFound, - expectedLocation: "/project-redirects/magic-land.html", - }, - // Make sure invalid rule does not redirect - { - host: "group.redirects.gitlab-example.com", - path: "/project-redirects/goto-domain.html", - expectedStatus: http.StatusNotFound, - expectedLocation: "", - }, - // Actual file on disk should override any redirects that match - { - host: "group.redirects.gitlab-example.com", - path: "/project-redirects/file-override.html", - expectedStatus: http.StatusOK, - expectedLocation: "", - }, - // Group-level domain - { - host: "group.redirects.gitlab-example.com", - path: "/redirect-portal.html", - expectedStatus: http.StatusFound, - expectedLocation: "/magic-land.html", - }, - // Custom domain - { - host: "redirects.custom-domain.com", - path: "/redirect-portal.html", - expectedStatus: http.StatusFound, - expectedLocation: "/magic-land.html", - }, - } - - for _, tt := range tests { - t.Run(fmt.Sprintf("%s%s -> %s (%d)", tt.host, tt.path, tt.expectedLocation, tt.expectedStatus), func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, tt.expectedLocation, rsp.Header.Get("Location")) - require.Equal(t, tt.expectedStatus, rsp.StatusCode) - }) - } -} - -func TestStatusPage(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-pages-status=/@statuscheck") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestStatusNotYetReady(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-status=/@statuscheck", "-pages-root=shared/invalid-pages") - defer teardown() - - waitForRoundtrips(t, listeners, 5*time.Second) - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) -} - -func TestPageNotAvailableIfNotLoaded(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-root=shared/invalid-pages") - defer teardown() - waitForRoundtrips(t, listeners, 5*time.Second) - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "index.html") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) -} - -func TestObscureMIMEType(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "") - defer teardown() - - require.NoError(t, httpListener.WaitUntilRequestSucceeds(nil)) - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/file.webmanifest") - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusOK, rsp.StatusCode) - mt, _, err := mime.ParseMediaType(rsp.Header.Get("Content-Type")) - require.NoError(t, err) - require.Equal(t, "application/manifest+json", mt) -} - -func TestCompressedEncoding(t *testing.T) { - skipUnlessEnabled(t) - - tests := []struct { - name string - host string - path string - encoding string - }{ - { - "gzip encoding", - "group.gitlab-example.com", - "index.html", - "gzip", - }, - { - "brotli encoding", - "group.gitlab-example.com", - "index.html", - "br", - }, - } - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rsp, err := GetCompressedPageFromListener(t, httpListener, "group.gitlab-example.com", "index.html", tt.encoding) - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusOK, rsp.StatusCode) - require.Equal(t, tt.encoding, rsp.Header.Get("Content-Encoding")) - }) - } -} - -func TestArtifactProxyRequest(t *testing.T) { - skipUnlessEnabled(t, "not-inplace-chroot") - - transport := (TestHTTPSClient.Transport).(*http.Transport) - defer func(t time.Duration) { - transport.ResponseHeaderTimeout = t - }(transport.ResponseHeaderTimeout) - transport.ResponseHeaderTimeout = 5 * time.Second - - content := "Title of the document" - contentLength := int64(len(content)) - testServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.RawPath { - case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/delayed_200.html": - time.Sleep(2 * time.Second) - fallthrough - case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/200.html", - "/api/v4/projects/group%2Fsubgroup%2Fproject/jobs/1/artifacts/200.html": - w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprint(w, content) - case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/500.html": - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprint(w, content) - default: - t.Logf("Unexpected r.URL.RawPath: %q", r.URL.RawPath) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusNotFound) - fmt.Fprint(w, content) - } - })) - - keyFile, certFile := CreateHTTPSFixtureFiles(t) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - require.NoError(t, err) - defer os.Remove(keyFile) - defer os.Remove(certFile) - - testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} - testServer.StartTLS() - defer testServer.Close() - - tests := []struct { - name string - host string - path string - status int - binaryOption string - content string - length int64 - cacheControl string - contentType string - }{ - { - name: "basic proxied request", - host: "group.gitlab-example.com", - path: "/-/project/-/jobs/1/artifacts/200.html", - status: http.StatusOK, - binaryOption: "", - content: content, - length: contentLength, - cacheControl: "max-age=3600", - contentType: "text/html; charset=utf-8", - }, - { - name: "basic proxied request for subgroup", - host: "group.gitlab-example.com", - path: "/-/subgroup/project/-/jobs/1/artifacts/200.html", - status: http.StatusOK, - binaryOption: "", - content: content, - length: contentLength, - cacheControl: "max-age=3600", - contentType: "text/html; charset=utf-8", - }, - { - name: "502 error while attempting to proxy", - host: "group.gitlab-example.com", - path: "/-/project/-/jobs/1/artifacts/delayed_200.html", - status: http.StatusBadGateway, - binaryOption: "-artifacts-server-timeout=1", - content: "", - length: 0, - cacheControl: "", - contentType: "text/html; charset=utf-8", - }, - { - name: "Proxying 404 from server", - host: "group.gitlab-example.com", - path: "/-/project/-/jobs/1/artifacts/404.html", - status: http.StatusNotFound, - binaryOption: "", - content: "", - length: 0, - cacheControl: "", - contentType: "text/html; charset=utf-8", - }, - { - name: "Proxying 500 from server", - host: "group.gitlab-example.com", - path: "/-/project/-/jobs/1/artifacts/500.html", - status: http.StatusInternalServerError, - binaryOption: "", - content: "", - length: 0, - cacheControl: "", - contentType: "text/html; charset=utf-8", - }, - } - - // Ensure the IP address is used in the URL, as we're relying on IP SANs to - // validate - artifactServerURL := testServer.URL + "/api/v4" - t.Log("Artifact server URL", artifactServerURL) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - teardown := RunPagesProcessWithSSLCertFile( - t, - *pagesBinary, - listeners, - "", - certFile, - "-artifacts-server="+artifactServerURL, - tt.binaryOption, - ) - defer teardown() - - resp, err := GetPageFromListener(t, httpListener, tt.host, tt.path) - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, tt.status, resp.StatusCode) - require.Equal(t, tt.contentType, resp.Header.Get("Content-Type")) - - if !((tt.status == http.StatusBadGateway) || (tt.status == http.StatusNotFound) || (tt.status == http.StatusInternalServerError)) { - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, tt.content, string(body)) - require.Equal(t, tt.length, resp.ContentLength) - require.Equal(t, tt.cacheControl, resp.Header.Get("Cache-Control")) - } - }) - } -} - -func TestPrivateArtifactProxyRequest(t *testing.T) { - skipUnlessEnabled(t, "not-inplace-chroot") - - setupTransport(t) - - testServer := makeGitLabPagesAccessStub(t) - - keyFile, certFile := CreateHTTPSFixtureFiles(t) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - require.NoError(t, err) - defer os.Remove(keyFile) - defer os.Remove(certFile) - - testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} - testServer.StartTLS() - defer testServer.Close() - - tests := []struct { - name string - host string - path string - status int - binaryOption string - }{ - { - name: "basic proxied request for private project", - host: "group.gitlab-example.com", - path: "/-/private/-/jobs/1/artifacts/200.html", - status: http.StatusOK, - binaryOption: "", - }, - { - name: "basic proxied request for subgroup", - host: "group.gitlab-example.com", - path: "/-/subgroup/private/-/jobs/1/artifacts/200.html", - status: http.StatusOK, - binaryOption: "", - }, - { - name: "502 error while attempting to proxy", - host: "group.gitlab-example.com", - path: "/-/private/-/jobs/1/artifacts/delayed_200.html", - status: http.StatusBadGateway, - binaryOption: "artifacts-server-timeout=1", - }, - { - name: "Proxying 404 from server", - host: "group.gitlab-example.com", - path: "/-/private/-/jobs/1/artifacts/404.html", - status: http.StatusNotFound, - binaryOption: "", - }, - { - name: "Proxying 500 from server", - host: "group.gitlab-example.com", - path: "/-/private/-/jobs/1/artifacts/500.html", - status: http.StatusInternalServerError, - binaryOption: "", - }, - } - - // Ensure the IP address is used in the URL, as we're relying on IP SANs to - // validate - artifactServerURL := testServer.URL + "/api/v4" - t.Log("Artifact server URL", artifactServerURL) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - configFile, cleanup := defaultConfigFileWith(t, - "artifacts-server="+artifactServerURL, - "auth-server="+testServer.URL, - "auth-redirect-uri=https://projects.gitlab-example.com/auth", - tt.binaryOption) - defer cleanup() - - teardown := RunPagesProcessWithSSLCertFile( - t, - *pagesBinary, - listeners, - "", - certFile, - "-config="+configFile, - ) - defer teardown() - - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, http.StatusFound, resp.StatusCode) - - cookie := resp.Header.Get("Set-Cookie") - - // Redirects to the projects under gitlab pages domain for authentication flow - url, err := url.Parse(resp.Header.Get("Location")) - require.NoError(t, err) - require.Equal(t, "projects.gitlab-example.com", url.Host) - require.Equal(t, "/auth", url.Path) - state := url.Query().Get("state") - - resp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) - - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, http.StatusFound, resp.StatusCode) - pagesDomainCookie := resp.Header.Get("Set-Cookie") - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetRedirectPageWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ - state, pagesDomainCookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will redirect auth callback to correct host - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - require.Equal(t, tt.host, url.Host) - require.Equal(t, "/auth", url.Path) - - // Request auth callback in project domain - authrsp, err = GetRedirectPageWithCookie(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery, cookie) - require.NoError(t, err) - - // server returns the ticket, user will be redirected to the project page - require.Equal(t, http.StatusFound, authrsp.StatusCode) - cookie = authrsp.Header.Get("Set-Cookie") - resp, err = GetRedirectPageWithCookie(t, httpsListener, tt.host, tt.path, cookie) - - require.Equal(t, tt.status, resp.StatusCode) - - require.NoError(t, err) - defer resp.Body.Close() - }) - } -} - -func TestEnvironmentVariablesConfig(t *testing.T) { - skipUnlessEnabled(t) - os.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) - defer func() { os.Unsetenv("LISTEN_HTTP") }() - - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, []ListenSpec{}, "") - defer teardown() - require.NoError(t, httpListener.WaitUntilRequestSucceeds(nil)) - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com:", "project/") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestMixedConfigSources(t *testing.T) { - skipUnlessEnabled(t) - os.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) - defer func() { os.Unsetenv("LISTEN_HTTP") }() - - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, []ListenSpec{httpsListener}, "") - defer teardown() - - for _, listener := range []ListenSpec{httpListener, httpsListener} { - require.NoError(t, listener.WaitUntilRequestSucceeds(nil)) - rsp, err := GetPageFromListener(t, listener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - rsp.Body.Close() - - require.Equal(t, http.StatusOK, rsp.StatusCode) - } -} - -func TestMultiFlagEnvironmentVariables(t *testing.T) { - skipUnlessEnabled(t) - listenSpecs := []ListenSpec{{"http", "127.0.0.1", "37001"}, {"http", "127.0.0.1", "37002"}} - envVarValue := fmt.Sprintf("%s,%s", net.JoinHostPort("127.0.0.1", "37001"), net.JoinHostPort("127.0.0.1", "37002")) - - os.Setenv("LISTEN_HTTP", envVarValue) - defer func() { os.Unsetenv("LISTEN_HTTP") }() - - teardown := RunPagesProcess(t, *pagesBinary, []ListenSpec{}, "") - defer teardown() - - for _, listener := range listenSpecs { - require.NoError(t, listener.WaitUntilRequestSucceeds(nil)) - rsp, err := GetPageFromListener(t, listener, "group.gitlab-example.com", "project/") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) - } -} - -func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { - skipUnlessEnabled(t) - - var listeners = []ListenSpec{ - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - } - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - for _, spec := range listeners { - rsp, err := GetProxiedPageFromListener(t, spec, "localhost", "group.gitlab-example.com", "project/") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusOK, rsp.StatusCode) - } -} - -func TestWhenAuthIsDisabledPrivateIsNotAccessible(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpListener, "group.auth.gitlab-example.com", "private.project/") - - require.NoError(t, err) - rsp.Body.Close() - require.Equal(t, http.StatusInternalServerError, rsp.StatusCode) -} - -func TestWhenAuthIsEnabledPrivateWillRedirectToAuthorize(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") - - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusFound, rsp.StatusCode) - require.Equal(t, 1, len(rsp.Header["Location"])) - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - rsp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) - require.NoError(t, err) - - require.Equal(t, http.StatusFound, rsp.StatusCode) - require.Equal(t, 1, len(rsp.Header["Location"])) - - url, err = url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - require.Equal(t, "https", url.Scheme) - require.Equal(t, "gitlab-auth.com", url.Host) - require.Equal(t, "/oauth/authorize", url.Path) - require.Equal(t, "clientID", url.Query().Get("client_id")) - require.Equal(t, "https://projects.gitlab-example.com/auth", url.Query().Get("redirect_uri")) - require.NotEqual(t, "", url.Query().Get("state")) -} - -func TestWhenAuthDeniedWillCauseUnauthorized(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?error=access_denied") - - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusUnauthorized, rsp.StatusCode) -} -func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") - - require.NoError(t, err) - defer rsp.Body.Close() - - // Go to auth page with wrong state will cause failure - authrsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?code=0&state=0") - - require.NoError(t, err) - defer authrsp.Body.Close() - - require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) -} - -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") - - require.NoError(t, err) - defer rsp.Body.Close() - - cookie := rsp.Header.Get("Set-Cookie") - - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetPageFromListenerWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ - url.Query().Get("state"), cookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) -} - -// makeGitLabPagesAccessStub provides a stub *httptest.Server to check pages_access API call. -// the result is based on the project id. -// -// Project IDs must be 4 digit long and the following rules applies: -// 1000-1999: Ok -// 2000-2999: Unauthorized -// 3000-3999: Invalid token -func makeGitLabPagesAccessStub(t *testing.T) *httptest.Server { - return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/oauth/token": - require.Equal(t, "POST", r.Method) - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, "{\"access_token\":\"abc\"}") - case "/api/v4/user": - require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) - w.WriteHeader(http.StatusOK) - default: - if handleAccessControlArtifactRequests(t, w, r) { - return - } - handleAccessControlRequests(t, w, r) - } - })) -} - -var existingAcmeTokenPath = "/.well-known/acme-challenge/existingtoken" -var notexistingAcmeTokenPath = "/.well-known/acme-challenge/notexistingtoken" - -func TestAcmeChallengesWhenItIsConfigured(t *testing.T) { - skipUnlessEnabled(t) - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-gitlab-server=https://gitlab-acme.com") - defer teardown() - - t.Run("When domain folder contains requested acme challenge it responds with it", func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", - existingAcmeTokenPath) - - defer rsp.Body.Close() - require.NoError(t, err) - require.Equal(t, http.StatusOK, rsp.StatusCode) - body, _ := ioutil.ReadAll(rsp.Body) - require.Equal(t, "this is token\n", string(body)) - }) - - t.Run("When domain folder doesn't contains requested acme challenge it redirects to GitLab", - func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", - notexistingAcmeTokenPath) - - defer rsp.Body.Close() - require.NoError(t, err) - require.Equal(t, http.StatusTemporaryRedirect, rsp.StatusCode) - - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - require.Equal(t, url.String(), "https://gitlab-acme.com/-/acme-challenge?domain=withacmechallenge.domain.com&token=notexistingtoken") - }, - ) -} - -func TestAcmeChallengesWhenItIsNotConfigured(t *testing.T) { - skipUnlessEnabled(t) - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "") - defer teardown() - - t.Run("When domain folder contains requested acme challenge it responds with it", func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", - existingAcmeTokenPath) - - defer rsp.Body.Close() - require.NoError(t, err) - require.Equal(t, http.StatusOK, rsp.StatusCode) - body, _ := ioutil.ReadAll(rsp.Body) - require.Equal(t, "this is token\n", string(body)) - }) - - t.Run("When domain folder doesn't contains requested acme challenge it returns 404", - func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", - notexistingAcmeTokenPath) - - defer rsp.Body.Close() - require.NoError(t, err) - require.Equal(t, http.StatusNotFound, rsp.StatusCode) - }, - ) -} - -func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { - authorization := r.Header.Get("Authorization") - - switch { - case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/delayed_200.html`).MatchString(r.URL.Path): - sleepIfAuthorized(t, authorization, w) - return true - case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/404.html`).MatchString(r.URL.Path): - w.WriteHeader(http.StatusNotFound) - return true - case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/500.html`).MatchString(r.URL.Path): - returnIfAuthorized(t, authorization, w, http.StatusInternalServerError) - return true - case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/200.html`).MatchString(r.URL.Path): - returnIfAuthorized(t, authorization, w, http.StatusOK) - return true - case regexp.MustCompile(`/api/v4/projects/group/subgroup/private/jobs/\d+/artifacts/200.html`).MatchString(r.URL.Path): - returnIfAuthorized(t, authorization, w, http.StatusOK) - return true - default: - return false - } -} - -func handleAccessControlRequests(t *testing.T, w http.ResponseWriter, r *http.Request) { - allowedProjects := regexp.MustCompile(`/api/v4/projects/1\d{3}/pages_access`) - deniedProjects := regexp.MustCompile(`/api/v4/projects/2\d{3}/pages_access`) - invalidTokenProjects := regexp.MustCompile(`/api/v4/projects/3\d{3}/pages_access`) - - switch { - case allowedProjects.MatchString(r.URL.Path): - require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) - w.WriteHeader(http.StatusOK) - case deniedProjects.MatchString(r.URL.Path): - require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) - w.WriteHeader(http.StatusUnauthorized) - case invalidTokenProjects.MatchString(r.URL.Path): - require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) - w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, "{\"error\":\"invalid_token\"}") - default: - t.Logf("Unexpected r.URL.RawPath: %q", r.URL.Path) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusNotFound) - } -} - -func returnIfAuthorized(t *testing.T, authorization string, w http.ResponseWriter, status int) { - if authorization != "" { - require.Equal(t, "Bearer abc", authorization) - w.WriteHeader(status) - } else { - w.WriteHeader(http.StatusNotFound) - } -} - -func sleepIfAuthorized(t *testing.T, authorization string, w http.ResponseWriter) { - if authorization != "" { - require.Equal(t, "Bearer abc", authorization) - time.Sleep(2 * time.Second) - } else { - w.WriteHeader(http.StatusNotFound) - } -} - -func TestAccessControlUnderCustomDomain(t *testing.T) { - skipUnlessEnabled(t, "not-inplace-chroot") - - testServer := makeGitLabPagesAccessStub(t) - testServer.Start() - defer testServer.Close() - - teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "private.domain.com", "/") - require.NoError(t, err) - defer rsp.Body.Close() - - cookie := rsp.Header.Get("Set-Cookie") - - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - state := url.Query().Get("state") - require.Equal(t, url.Query().Get("domain"), "http://private.domain.com") - - pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) - require.NoError(t, err) - defer pagesrsp.Body.Close() - - pagescookie := pagesrsp.Header.Get("Set-Cookie") - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetRedirectPageWithCookie(t, httpListener, "projects.gitlab-example.com", "/auth?code=1&state="+ - state, pagescookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain - require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) - - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, "private.domain.com", "/auth?code=1&state="+ - state, cookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will redirect to the page - cookie = authrsp.Header.Get("Set-Cookie") - require.Equal(t, http.StatusFound, authrsp.StatusCode) - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain - require.Equal(t, "http://private.domain.com/", url.String()) - - // Fetch page in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, "private.domain.com", "/", cookie) - require.NoError(t, err) - require.Equal(t, http.StatusOK, authrsp.StatusCode) -} - -func TestCustomErrorPageWithAuth(t *testing.T) { - skipUnlessEnabled(t, "not-inplace-chroot") - testServer := makeGitLabPagesAccessStub(t) - testServer.Start() - defer testServer.Close() - - teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) - defer teardown() - - tests := []struct { - name string - domain string - path string - expectedErrorPage string - }{ - { - name: "private_project_authorized", - domain: "group.404.gitlab-example.com", - path: "/private_project/unknown", - expectedErrorPage: "Private custom 404 error page", - }, - { - name: "public_namespace_with_private_unauthorized_project", - domain: "group.404.gitlab-example.com", - // /private_unauthorized/config.json resolves project ID to 2000 which will cause a 401 from the mock GitLab testServer - path: "/private_unauthorized/unknown", - expectedErrorPage: "Custom 404 group page", - }, - { - name: "private_namespace_authorized", - domain: "group.auth.gitlab-example.com", - path: "/unknown", - expectedErrorPage: "group.auth.gitlab-example.com namespace custom 404", - }, - { - name: "private_namespace_with_private_project_auth_failed", - domain: "group.auth.gitlab-example.com", - // project ID is 2000 - path: "/private.project.1/unknown", - expectedErrorPage: "The page you're looking for could not be found.", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path) - require.NoError(t, err) - defer rsp.Body.Close() - - cookie := rsp.Header.Get("Set-Cookie") - - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - state := url.Query().Get("state") - require.Equal(t, "http://"+tt.domain, url.Query().Get("domain")) - - pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) - require.NoError(t, err) - defer pagesrsp.Body.Close() - - pagescookie := pagesrsp.Header.Get("Set-Cookie") - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetRedirectPageWithCookie(t, httpListener, "projects.gitlab-example.com", "/auth?code=1&state="+ - state, pagescookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain - require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) - - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ - state, cookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will redirect to the page - groupCookie := authrsp.Header.Get("Set-Cookie") - require.Equal(t, http.StatusFound, authrsp.StatusCode) - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain error page - require.Equal(t, "http://"+tt.domain+tt.path, url.String()) - - // Fetch page in custom domain - anotherResp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, groupCookie) - require.NoError(t, err) - - require.Equal(t, http.StatusNotFound, anotherResp.StatusCode) - - page, err := ioutil.ReadAll(anotherResp.Body) - require.NoError(t, err) - require.Contains(t, string(page), tt.expectedErrorPage) - }) - } -} - -func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { - skipUnlessEnabled(t, "not-inplace-chroot") - - testServer := makeGitLabPagesAccessStub(t) - testServer.Start() - defer testServer.Close() - - teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) - defer teardown() - - rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/", "", true) - require.NoError(t, err) - defer rsp.Body.Close() - - cookie := rsp.Header.Get("Set-Cookie") - - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - state := url.Query().Get("state") - require.Equal(t, url.Query().Get("domain"), "https://private.domain.com") - pagesrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, url.Host, url.Path+"?"+url.RawQuery, "", true) - require.NoError(t, err) - defer pagesrsp.Body.Close() - - pagescookie := pagesrsp.Header.Get("Set-Cookie") - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, - "projects.gitlab-example.com", "/auth?code=1&state="+state, - pagescookie, true) - - require.NoError(t, err) - defer authrsp.Body.Close() - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain - require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) - - // Run auth callback in custom domain - authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will redirect to the page - cookie = authrsp.Header.Get("Set-Cookie") - require.Equal(t, http.StatusFound, authrsp.StatusCode) - - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - - // Will redirect to custom domain - require.Equal(t, "https://private.domain.com/", url.String()) - // Fetch page in custom domain - authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/", - cookie, true) - require.NoError(t, err) - require.Equal(t, http.StatusOK, authrsp.StatusCode) -} - -func TestAccessControlGroupDomain404RedirectsAuth(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/nonexistent/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusFound, rsp.StatusCode) - // Redirects to the projects under gitlab pages domain for authentication flow - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - require.Equal(t, "projects.gitlab-example.com", url.Host) - require.Equal(t, "/auth", url.Path) -} -func TestAccessControlProject404DoesNotRedirect(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") - defer teardown() - - rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/project/nonexistent/") - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusNotFound, rsp.StatusCode) -} - -func setupTransport(t *testing.T) { - transport := (TestHTTPSClient.Transport).(*http.Transport) - defer func(t time.Duration) { - transport.ResponseHeaderTimeout = t - }(transport.ResponseHeaderTimeout) - transport.ResponseHeaderTimeout = 5 * time.Second -} - -type runPagesFunc func(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() - -func testAccessControl(t *testing.T, runPages runPagesFunc) { - skipUnlessEnabled(t, "not-inplace-chroot") - - setupTransport(t) - - keyFile, certFile := CreateHTTPSFixtureFiles(t) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - require.NoError(t, err) - defer os.Remove(keyFile) - defer os.Remove(certFile) - - testServer := makeGitLabPagesAccessStub(t) - testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} - testServer.StartTLS() - defer testServer.Close() - - tests := []struct { - host string - path string - status int - redirectBack bool - name string - }{ - { - name: "project with access", - host: "group.auth.gitlab-example.com", - path: "/private.project/", - status: http.StatusOK, - redirectBack: false, - }, - { - name: "project without access", - host: "group.auth.gitlab-example.com", - path: "/private.project.1/", - status: http.StatusNotFound, // Do not expose project existed - redirectBack: false, - }, - { - name: "invalid token test should redirect back", - host: "group.auth.gitlab-example.com", - path: "/private.project.2/", - status: http.StatusFound, - redirectBack: true, - }, - { - name: "no project should redirect to login and then return 404", - host: "group.auth.gitlab-example.com", - path: "/nonexistent/", - status: http.StatusNotFound, - redirectBack: false, - }, - { - name: "no project should redirect to login and then return 404", - host: "nonexistent.gitlab-example.com", - path: "/nonexistent/", - status: http.StatusNotFound, - redirectBack: false, - }, // subgroups - { - name: "[subgroup] project with access", - host: "group.auth.gitlab-example.com", - path: "/subgroup/private.project/", - status: http.StatusOK, - redirectBack: false, - }, - { - name: "[subgroup] project without access", - host: "group.auth.gitlab-example.com", - path: "/subgroup/private.project.1/", - status: http.StatusNotFound, // Do not expose project existed - redirectBack: false, - }, - { - name: "[subgroup] invalid token test should redirect back", - host: "group.auth.gitlab-example.com", - path: "/subgroup/private.project.2/", - status: http.StatusFound, - redirectBack: true, - }, - { - name: "[subgroup] no project should redirect to login and then return 404", - host: "group.auth.gitlab-example.com", - path: "/subgroup/nonexistent/", - status: http.StatusNotFound, - redirectBack: false, - }, - { - name: "[subgroup] no project should redirect to login and then return 404", - host: "nonexistent.gitlab-example.com", - path: "/subgroup/nonexistent/", - status: http.StatusNotFound, - redirectBack: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - teardown := runPages(t, *pagesBinary, listeners, "", certFile, testServer.URL) - defer teardown() - - rsp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) - - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusFound, rsp.StatusCode) - cookie := rsp.Header.Get("Set-Cookie") - - // Redirects to the projects under gitlab pages domain for authentication flow - url, err := url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - require.Equal(t, "projects.gitlab-example.com", url.Host) - require.Equal(t, "/auth", url.Path) - state := url.Query().Get("state") - - rsp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) - - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, http.StatusFound, rsp.StatusCode) - pagesDomainCookie := rsp.Header.Get("Set-Cookie") - - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetRedirectPageWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ - state, pagesDomainCookie) - - require.NoError(t, err) - defer authrsp.Body.Close() - - // Will redirect auth callback to correct host - url, err = url.Parse(authrsp.Header.Get("Location")) - require.NoError(t, err) - require.Equal(t, tt.host, url.Host) - require.Equal(t, "/auth", url.Path) - - // Request auth callback in project domain - authrsp, err = GetRedirectPageWithCookie(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery, cookie) - require.NoError(t, err) - - // server returns the ticket, user will be redirected to the project page - require.Equal(t, http.StatusFound, authrsp.StatusCode) - cookie = authrsp.Header.Get("Set-Cookie") - rsp, err = GetRedirectPageWithCookie(t, httpsListener, tt.host, tt.path, cookie) - - require.NoError(t, err) - defer rsp.Body.Close() - - require.Equal(t, tt.status, rsp.StatusCode) - require.Equal(t, "", rsp.Header.Get("Cache-Control")) - - if tt.redirectBack { - url, err = url.Parse(rsp.Header.Get("Location")) - require.NoError(t, err) - - require.Equal(t, "https", url.Scheme) - require.Equal(t, tt.host, url.Host) - require.Equal(t, tt.path, url.Path) - } - }) - } -} - -func TestAccessControlWithSSLCertFile(t *testing.T) { - testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertFile) -} - -func TestAccessControlWithSSLCertDir(t *testing.T) { - testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) -} - -func TestAcceptsSupportedCiphers(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - tlsConfig := &tls.Config{ - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - }, - } - client, cleanup := ClientWithConfig(tlsConfig) - defer cleanup() - - rsp, err := client.Get(httpsListener.URL("/")) - - if rsp != nil { - rsp.Body.Close() - } - - require.NoError(t, err) -} - -func tlsConfigWithInsecureCiphersOnly() *tls.Config { - return &tls.Config{ - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - }, - MaxVersion: tls.VersionTLS12, // ciphers for TLS1.3 are not configurable and will work if enabled - } -} - -func TestRejectsUnsupportedCiphers(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "") - defer teardown() - - client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) - defer cleanup() - - rsp, err := client.Get(httpsListener.URL("/")) - - if rsp != nil { - rsp.Body.Close() - } - - require.Error(t, err) - require.Nil(t, rsp) -} - -func TestEnableInsecureCiphers(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-insecure-ciphers") - defer teardown() - - client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) - defer cleanup() - - rsp, err := client.Get(httpsListener.URL("/")) - - if rsp != nil { - rsp.Body.Close() - } - - require.NoError(t, err) -} - -func TestTLSVersions(t *testing.T) { - skipUnlessEnabled(t) - - tests := map[string]struct { - tlsMin string - tlsMax string - tlsClient uint16 - expectError bool - }{ - "client version not supported": {tlsMin: "tls1.1", tlsMax: "tls1.2", tlsClient: tls.VersionTLS10, expectError: true}, - "client version supported": {tlsMin: "tls1.1", tlsMax: "tls1.2", tlsClient: tls.VersionTLS12, expectError: false}, - "client and server using default settings": {tlsMin: "", tlsMax: "", tlsClient: 0, expectError: false}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - args := []string{} - if tc.tlsMin != "" { - args = append(args, "-tls-min-version", tc.tlsMin) - } - if tc.tlsMax != "" { - args = append(args, "-tls-max-version", tc.tlsMax) - } - - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", args...) - defer teardown() - - tlsConfig := &tls.Config{} - if tc.tlsClient != 0 { - tlsConfig.MinVersion = tc.tlsClient - tlsConfig.MaxVersion = tc.tlsClient - } - client, cleanup := ClientWithConfig(tlsConfig) - defer cleanup() - - rsp, err := client.Get(httpsListener.URL("/")) - - if rsp != nil { - rsp.Body.Close() - } - - if tc.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestDomainsSource(t *testing.T) { - skipUnlessEnabled(t) - - type args struct { - configSource string - domain string - urlSuffix string - } - type want struct { - statusCode int - content string - apiCalled bool - } - tests := []struct { - name string - args args - want want - }{ - { - name: "gitlab_source_domain_exists", - args: args{ - configSource: "gitlab", - domain: "new-source-test.gitlab.io", - urlSuffix: "/my/pages/project/", - }, - want: want{ - statusCode: http.StatusOK, - content: "New Pages GitLab Source TEST OK\n", - apiCalled: true, - }, - }, - { - name: "gitlab_source_domain_does_not_exist", - args: args{ - configSource: "gitlab", - domain: "non-existent-domain.gitlab.io", - }, - want: want{ - statusCode: http.StatusNotFound, - apiCalled: true, - }, - }, - { - name: "disk_source_domain_exists", - args: args{ - configSource: "disk", - // test.domain.com sourced from disk configuration - domain: "test.domain.com", - urlSuffix: "/", - }, - want: want{ - statusCode: http.StatusOK, - content: "main-dir\n", - apiCalled: false, - }, - }, - { - name: "disk_source_domain_does_not_exist", - args: args{ - configSource: "disk", - domain: "non-existent-domain.gitlab.io", - }, - want: want{ - statusCode: http.StatusNotFound, - apiCalled: false, - }, - }, - { - name: "disk_source_domain_should_not_exist_under_hashed_dir", - args: args{ - configSource: "disk", - domain: "hashed.com", - }, - want: want{ - statusCode: http.StatusNotFound, - apiCalled: false, - }, - }, - // TODO: modify mock so we can test domain-config-source=auto when API/disk is not ready https://gitlab.com/gitlab-org/gitlab/-/issues/218358 - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var apiCalled bool - source := NewGitlabDomainsSourceStub(t, &apiCalled) - defer source.Close() - - gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) - - pagesArgs := []string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", tt.args.configSource} - teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{}, pagesArgs...) - defer teardown() - - response, err := GetPageFromListener(t, httpListener, tt.args.domain, tt.args.urlSuffix) - require.NoError(t, err) - - require.Equal(t, tt.want.statusCode, response.StatusCode) - if tt.want.statusCode == http.StatusOK { - defer response.Body.Close() - body, err := ioutil.ReadAll(response.Body) - require.NoError(t, err) - - require.Equal(t, tt.want.content, string(body), "content mismatch") - } - - require.Equal(t, tt.want.apiCalled, apiCalled, "api called mismatch") - }) - } -} - -func TestZipServing(t *testing.T) { - skipUnlessEnabled(t) - - var apiCalled bool - source := NewGitlabDomainsSourceStub(t, &apiCalled) - defer source.Close() - - gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) - - pagesArgs := []string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", "gitlab"} - teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{}, pagesArgs...) - defer teardown() - - _, cleanup := newZipFileServerURL(t, "shared/pages/group/zip.gitlab.io/public.zip") - defer cleanup() - - tests := map[string]struct { - urlSuffix string - expectedStatusCode int - expectedContent string - }{ - "base_domain_no_suffix": { - urlSuffix: "/", - expectedStatusCode: http.StatusOK, - expectedContent: "zip.gitlab.io/project/index.html\n", - }, - "file_exists": { - urlSuffix: "/index.html", - expectedStatusCode: http.StatusOK, - expectedContent: "zip.gitlab.io/project/index.html\n", - }, - "file_exists_in_subdir": { - urlSuffix: "/subdir/hello.html", - expectedStatusCode: http.StatusOK, - expectedContent: "zip.gitlab.io/project/subdir/hello.html\n", - }, - "file_exists_symlink": { - urlSuffix: "/symlink.html", - expectedStatusCode: http.StatusOK, - expectedContent: "symlink.html->subdir/linked.html\n", - }, - "dir": { - urlSuffix: "/subdir/", - expectedStatusCode: http.StatusNotFound, - expectedContent: "zip.gitlab.io/project/404.html\n", - }, - "file_does_not_exist": { - urlSuffix: "/unknown.html", - expectedStatusCode: http.StatusNotFound, - expectedContent: "zip.gitlab.io/project/404.html\n", - }, - "bad_symlink": { - urlSuffix: "/bad-symlink.html", - expectedStatusCode: http.StatusNotFound, - expectedContent: "zip.gitlab.io/project/404.html\n", - }, - } - - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - response, err := GetPageFromListener(t, httpListener, "zip.gitlab.io", tt.urlSuffix) - require.NoError(t, err) - defer response.Body.Close() - - require.Equal(t, tt.expectedStatusCode, response.StatusCode) - if tt.expectedStatusCode == http.StatusOK || tt.expectedStatusCode == http.StatusNotFound { - body, err := ioutil.ReadAll(response.Body) - require.NoError(t, err) - - require.Equal(t, tt.expectedContent, string(body), "content mismatch") - } - }) - } -} diff --git a/app.go b/app.go index 5a1953962..1352b630b 100644 --- a/app.go +++ b/app.go @@ -12,22 +12,27 @@ import ( ghandlers "github.com/gorilla/handlers" "github.com/rs/cors" log "github.com/sirupsen/logrus" + + "gitlab.com/gitlab-org/go-mimedb" "gitlab.com/gitlab-org/labkit/errortracking" labmetrics "gitlab.com/gitlab-org/labkit/metrics" "gitlab.com/gitlab-org/labkit/monitoring" - "gitlab.com/lupine/go-mimedb" "gitlab.com/gitlab-org/gitlab-pages/internal/acme" "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" - headerConfig "gitlab.com/gitlab-org/gitlab-pages/internal/config" + cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" + "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" + "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" + "gitlab.com/gitlab-org/gitlab-pages/internal/tlsconfig" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -184,7 +189,7 @@ func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, erro // customHeadersMiddleware will inject custom headers into the response func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - headerConfig.AddCustomHeaders(w, a.CustomHeaders) + middleware.AddCustomHeaders(w, a.CustomHeaders) handler.ServeHTTP(w, r) }) @@ -333,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -367,6 +378,11 @@ func (a *theApp) Run() { a.listenProxyFD(&wg, fd, proxyHandler, limiter) } + // Listen for HTTPS PROXYv2 requests + for _, fd := range a.ListenHTTPSProxyv2 { + a.ListenHTTPSProxyv2FD(&wg, fd, httpHandler, limiter) + } + // Serve metrics for Prometheus if a.ListenMetrics != 0 { a.listenMetricsFD(&wg, a.ListenMetrics) @@ -381,7 +397,7 @@ func (a *theApp) listenHTTPFD(wg *sync.WaitGroup, fd uintptr, httpHandler http.H wg.Add(1) go func() { defer wg.Done() - err := listenAndServe(fd, httpHandler, a.HTTP2, nil, limiter) + err := listenAndServe(fd, httpHandler, a.HTTP2, nil, limiter, false) if err != nil { capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTP)) } @@ -392,7 +408,12 @@ func (a *theApp) listenHTTPSFD(wg *sync.WaitGroup, fd uintptr, httpHandler http. wg.Add(1) go func() { defer wg.Done() - err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter) + tlsConfig, err := a.TLSConfig() + if err != nil { + capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTPS)) + } + + err = listenAndServe(fd, httpHandler, a.HTTP2, tlsConfig, limiter, false) if err != nil { capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTPS)) } @@ -405,7 +426,7 @@ func (a *theApp) listenProxyFD(wg *sync.WaitGroup, fd uintptr, proxyHandler http wg.Add(1) go func(fd uintptr) { defer wg.Done() - err := listenAndServe(fd, proxyHandler, a.HTTP2, nil, limiter) + err := listenAndServe(fd, proxyHandler, a.HTTP2, nil, limiter, false) if err != nil { capturingFatal(err, errortracking.WithField("listener", "http proxy")) } @@ -413,6 +434,23 @@ func (a *theApp) listenProxyFD(wg *sync.WaitGroup, fd uintptr, proxyHandler http }() } +// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt +func (a *theApp) ListenHTTPSProxyv2FD(wg *sync.WaitGroup, fd uintptr, httpHandler http.Handler, limiter *netutil.Limiter) { + wg.Add(1) + go func() { + defer wg.Done() + tlsConfig, err := a.TLSConfig() + if err != nil { + capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTPS)) + } + + err = listenAndServe(fd, httpHandler, a.HTTP2, tlsConfig, limiter, true) + if err != nil { + capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTPS)) + } + }() +} + func (a *theApp) listenMetricsFD(wg *sync.WaitGroup, fd uintptr) { wg.Add(1) go func() { @@ -452,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -464,7 +499,7 @@ func runApp(config appConfig) { } if len(config.CustomHeaders) != 0 { - customHeaders, err := headerConfig.ParseHeaderString(config.CustomHeaders) + customHeaders, err := middleware.ParseHeaderString(config.CustomHeaders) if err != nil { log.WithError(err).Fatal("Unable to parse header string") } @@ -475,10 +510,43 @@ func runApp(config appConfig) { log.WithError(err).Warn("Loading extended MIME database failed") } + c := &cfg.Config{ + Zip: &cfg.ZipServing{ + ExpirationInterval: config.ZipCacheExpiry, + CleanupInterval: config.ZipCacheCleanup, + RefreshInterval: config.ZipCacheRefresh, + OpenTimeout: config.ZipeOpenTimeout, + }, + } + + // TODO: reconfigure all VFS' + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/512 + if err := zip.Instance().Reconfigure(c); err != nil { + fatal(err, "failed to reconfigure zip VFS") + } + a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) } + +func (a *theApp) TLSConfig() (*tls.Config, error) { + return tlsconfig.Create(a.RootCertificate, a.RootKey, a.ServeTLS, + a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion) +} diff --git a/app_config.go b/app_config.go index 3bc2197b1..0dd192d5d 100644 --- a/app_config.go +++ b/app_config.go @@ -10,13 +10,14 @@ type appConfig struct { RootKey []byte MaxConns int - ListenHTTP []uintptr - ListenHTTPS []uintptr - ListenProxy []uintptr - ListenMetrics uintptr - InsecureCiphers bool - TLSMinVersion uint16 - TLSMaxVersion uint16 + ListenHTTP []uintptr + ListenHTTPS []uintptr + ListenProxy []uintptr + ListenHTTPSProxyv2 []uintptr + ListenMetrics uintptr + InsecureCiphers bool + TLSMinVersion uint16 + TLSMaxVersion uint16 HTTP2 bool RedirectHTTP bool @@ -40,6 +41,11 @@ type appConfig struct { SentryDSN string SentryEnvironment string CustomHeaders []string + + ZipCacheExpiry time.Duration + ZipCacheRefresh time.Duration + ZipCacheCleanup time.Duration + ZipeOpenTimeout time.Duration } // InternalGitLabServerURL returns URL to a GitLab instance. diff --git a/daemon.go b/daemon.go index 11fa3e9e0..c2404e05b 100644 --- a/daemon.go +++ b/daemon.go @@ -330,6 +330,7 @@ func updateFds(config *appConfig, cmd *exec.Cmd) { config.ListenHTTP, config.ListenHTTPS, config.ListenProxy, + config.ListenHTTPSProxyv2, } { daemonUpdateFds(cmd, fds) } diff --git a/go.mod b/go.mod index 9e021ca5f..76d45a9c9 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gorilla/handlers v1.4.2 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.0 + github.com/jstemmer/go-junit-report v0.9.1 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 github.com/karlseguin/ccache/v2 v2.0.6 github.com/karrick/godirwalk v1.10.12 @@ -19,17 +20,18 @@ require ( github.com/namsral/flag v1.7.4-pre github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pires/go-proxyproto v0.2.0 github.com/prometheus/client_golang v1.6.0 github.com/rs/cors v1.7.0 - github.com/sirupsen/logrus v1.4.2 + github.com/sirupsen/logrus v1.7.0 github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.6.1 github.com/tj/assert v0.0.3 // indirect github.com/tj/go-redirects v0.0.0-20180508180010-5c02ead0bbc5 github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce github.com/wadey/gocovmerge v0.0.0-20160331181800-b5bfa59ec0ad - gitlab.com/gitlab-org/labkit v0.0.0-20201014124351-eb1fe6499318 - gitlab.com/lupine/go-mimedb v0.0.0-20180307000149-e8af1d659877 + gitlab.com/gitlab-org/go-mimedb v1.45.0 + gitlab.com/gitlab-org/labkit v1.0.0 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f golang.org/x/net v0.0.0-20200226121028-0de0cce0169b diff --git a/go.sum b/go.sum index e9cf1de54..945b05ce7 100644 --- a/go.sum +++ b/go.sum @@ -76,10 +76,9 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fzipp/gocyclo v0.0.0-20150627053110-6acd4345c835 h1:roDmqJ4Qes7hrDOsWsMCce0vQHz3xiMPjJ9m4c2eeNs= github.com/fzipp/gocyclo v0.0.0-20150627053110-6acd4345c835/go.mod h1:BjL/N0+C+j9uNX+1xcNuM9vdSIcXCZrQZUYbXOFbgN8= github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= -github.com/getsentry/raven-go v0.1.0 h1:lc5jnN9D+q3panDpihwShgaOVvP6esoMEKbID2yhLoQ= -github.com/getsentry/raven-go v0.1.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= -github.com/getsentry/sentry-go v0.5.1 h1:MIPe7ScHADsrK2vznqmhksIUFxq7m0JfTh+ZIMkI+VQ= -github.com/getsentry/sentry-go v0.5.1/go.mod h1:B8H7x8TYDPkeWPRzGpIiFO97LZP6rL8A3hEt8lUItMw= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/getsentry/sentry-go v0.7.0 h1:MR2yfR4vFfv/2+iBuSnkdQwVg7N9cJzihZ6KJu7srwQ= +github.com/getsentry/sentry-go v0.7.0/go.mod h1:pLFpD2Y5RHIKF9Bw3KH6/68DeN2K/XBJd8awjdPnUwg= github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= @@ -192,8 +191,6 @@ github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -234,17 +231,22 @@ github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7 github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= +github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pires/go-proxyproto v0.2.0 h1:WyYKlv9pkt77b+LjMvPfwrsAxviaGCFhG4KDIy1ofLY= +github.com/pires/go-proxyproto v0.2.0/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= @@ -282,10 +284,10 @@ github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBa github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= -github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= @@ -342,10 +344,10 @@ github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FB github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -gitlab.com/gitlab-org/labkit v0.0.0-20201014124351-eb1fe6499318 h1:3xX/pl8dQjEtBZzHPCkex4Bwr7SGmVea/Zu4JdbZrKs= -gitlab.com/gitlab-org/labkit v0.0.0-20201014124351-eb1fe6499318/go.mod h1:SNfxkfUwVNECgtmluVayv0GWFgEjjBs5AzgsowPQuo0= -gitlab.com/lupine/go-mimedb v0.0.0-20180307000149-e8af1d659877 h1:k5N2m0IPaMuwWmFTO9fyTK4IEnSm35GC/p1S7VRgUyM= -gitlab.com/lupine/go-mimedb v0.0.0-20180307000149-e8af1d659877/go.mod h1:Es0wDVbtgNqhpEXMb+yct6JKnGMrNsUSh9oio0bqqdU= +gitlab.com/gitlab-org/go-mimedb v1.45.0 h1:PO8dx6HEWzPYU6MQTYnCbpQEJzhJLW/Bh43+2VUHTgc= +gitlab.com/gitlab-org/go-mimedb v1.45.0/go.mod h1:wa9y/zOSFKmTXLyBs4clz2FNVhZQmmEQM9TxslPAjZ0= +gitlab.com/gitlab-org/labkit v1.0.0 h1:t2Wr8ygtvHfXAMlCkoEdk5pdb5Gy1IYdr41H7t4kAYw= +gitlab.com/gitlab-org/labkit v1.0.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= @@ -436,6 +438,7 @@ golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1 h1:gZpLHxUX5BdYLA08Lj4YCJNN/jk7KtquiArPoeX0WvA= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/config/config.go b/internal/config/config.go index d415b21d3..c52beef82 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,32 +1,18 @@ package config import ( - "errors" - "net/http" - "strings" + "time" ) -var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") - -// AddCustomHeaders adds a map of Headers to a Response -func AddCustomHeaders(w http.ResponseWriter, headers http.Header) error { - for k, v := range headers { - for _, value := range v { - w.Header().Add(k, value) - } - } - return nil +type Config struct { + Zip *ZipServing } -// ParseHeaderString parses a string of key values into a map -func ParseHeaderString(customHeaders []string) (http.Header, error) { - headers := http.Header{} - for _, keyValueString := range customHeaders { - keyValue := strings.SplitN(keyValueString, ":", 2) - if len(keyValue) != 2 { - return nil, errInvalidHeaderParameter - } - headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) - } - return headers, nil +// ZipServing stores all configuration values to be used by the zip VFS opening and +// caching +type ZipServing struct { + ExpirationInterval time.Duration + CleanupInterval time.Duration + RefreshInterval time.Duration + OpenTimeout time.Duration } diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go index 1ae5224b0..476d270c8 100644 --- a/internal/httperrors/httperrors.go +++ b/internal/httperrors/httperrors.go @@ -3,6 +3,10 @@ package httperrors import ( "fmt" "net/http" + + log "github.com/sirupsen/logrus" + + "gitlab.com/gitlab-org/labkit/errortracking" ) type content struct { @@ -177,6 +181,16 @@ func Serve500(w http.ResponseWriter) { serveErrorPage(w, content500) } +// Serve500WithRequest returns a 500 error response / HTML page to the http.ResponseWriter +func Serve500WithRequest(w http.ResponseWriter, r *http.Request, reason string, err error) { + log.WithFields(log.Fields{ + "host": r.Host, + "path": r.URL.Path, + }).WithError(err).Error(reason) + errortracking.Capture(err, errortracking.WithRequest(r)) + serveErrorPage(w, content500) +} + // Serve502 returns a 502 error response / HTML page to the http.ResponseWriter func Serve502(w http.ResponseWriter) { serveErrorPage(w, content502) diff --git a/internal/httprange/http_reader.go b/internal/httprange/http_reader.go index 44694e858..589351fa1 100644 --- a/internal/httprange/http_reader.go +++ b/internal/httprange/http_reader.go @@ -14,16 +14,16 @@ import ( ) var ( + // ErrNotFound is returned when servers responds with 404 + ErrNotFound = errors.New("resource not found") + // ErrRangeRequestsNotSupported is returned by Seek and Read - // when the remote server does not allow range requests (Accept-Ranges was not set) - ErrRangeRequestsNotSupported = errors.New("range requests are not supported by the remote server") + // when the remote server does not allow range requests for a given request parameters + ErrRangeRequestsNotSupported = errors.New("requests range is not supported by the remote server") // ErrInvalidRange is returned by Read when trying to read past the end of the file ErrInvalidRange = errors.New("invalid range") - // ErrContentHasChanged is returned by Read when the content has changed since the first request - ErrContentHasChanged = errors.New("content has changed since first request") - // seek errors no need to export them errSeekInvalidWhence = errors.New("invalid whence") errSeekOutsideRange = errors.New("outside of range") @@ -59,6 +59,7 @@ var httpClient = &http.Client{ metrics.HTTPRangeTraceDuration, metrics.HTTPRangeRequestDuration, metrics.HTTPRangeRequestsTotal, + httptransport.DefaultTTFBTimeout, ), } @@ -102,21 +103,12 @@ func (r *Reader) prepareRequest() (*http.Request, error) { return nil, ErrInvalidRange } - req, err := http.NewRequest("GET", r.Resource.URL, nil) + req, err := r.Resource.Request() if err != nil { return nil, err } req = req.WithContext(r.ctx) - - if r.Resource.ETag != "" { - req.Header.Set("ETag", r.Resource.ETag) - } else if r.Resource.LastModified != "" { - // Last-Modified should be a fallback mechanism in case ETag is not present - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified - req.Header.Set("If-Range", r.Resource.LastModified) - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", r.offset, r.rangeStart+r.rangeSize-1)) return req, nil @@ -129,12 +121,17 @@ func (r *Reader) setResponse(res *http.Response) error { // some servers return 200 OK for bytes=0- // TODO: should we handle r.Resource.Last-Modified as well? if r.offset > 0 || r.Resource.ETag != "" && r.Resource.ETag != res.Header.Get("ETag") { - return ErrContentHasChanged + r.Resource.setError(ErrRangeRequestsNotSupported) + return ErrRangeRequestsNotSupported } + case http.StatusNotFound: + r.Resource.setError(ErrNotFound) + return ErrNotFound case http.StatusPartialContent: // Requested `Range` request succeeded https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/206 break case http.StatusRequestedRangeNotSatisfiable: + r.Resource.setError(ErrRangeRequestsNotSupported) return ErrRangeRequestsNotSupported default: return fmt.Errorf("httprange: read response %d: %q", res.StatusCode, res.Status) diff --git a/internal/httprange/http_reader_test.go b/internal/httprange/http_reader_test.go index 5e9715750..97bfbf24a 100644 --- a/internal/httprange/http_reader_test.go +++ b/internal/httprange/http_reader_test.go @@ -199,45 +199,61 @@ func TestSeekAndRead(t *testing.T) { func TestReaderSetResponse(t *testing.T) { tests := map[string]struct { - status int - offset int64 - prevETag string - resEtag string - expectedErrMsg string + status int + offset int64 + prevETag string + resEtag string + expectedErrMsg string + expectedIsValid bool }{ "partial_content_success": { - status: http.StatusPartialContent, + status: http.StatusPartialContent, + expectedIsValid: true, }, "status_ok_success": { - status: http.StatusOK, + status: http.StatusOK, + expectedIsValid: true, }, "status_ok_previous_response_invalid_offset": { - status: http.StatusOK, - offset: 1, - expectedErrMsg: ErrContentHasChanged.Error(), + status: http.StatusOK, + offset: 1, + expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + expectedIsValid: false, }, "status_ok_previous_response_different_etag": { - status: http.StatusOK, - prevETag: "old", - resEtag: "new", - expectedErrMsg: ErrContentHasChanged.Error(), + status: http.StatusOK, + prevETag: "old", + resEtag: "new", + expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + expectedIsValid: false, }, "requested_range_not_satisfiable": { - status: http.StatusRequestedRangeNotSatisfiable, - expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + status: http.StatusRequestedRangeNotSatisfiable, + expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + expectedIsValid: false, + }, + "not_found": { + status: http.StatusNotFound, + expectedErrMsg: ErrNotFound.Error(), + expectedIsValid: false, }, "unhandled_status_code": { - status: http.StatusNotFound, - expectedErrMsg: "httprange: read response 404:", + status: http.StatusInternalServerError, + expectedErrMsg: "httprange: read response 500:", + expectedIsValid: true, }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { - r := NewReader(context.Background(), &Resource{ETag: tt.prevETag}, tt.offset, 0) + resource := &Resource{ETag: tt.prevETag} + reader := NewReader(context.Background(), resource, tt.offset, 0) res := &http.Response{StatusCode: tt.status, Header: map[string][]string{}} res.Header.Set("ETag", tt.resEtag) - err := r.setResponse(res) + err := reader.setResponse(res) + + require.Equal(t, tt.expectedIsValid, resource.Valid()) + if tt.expectedErrMsg != "" { require.Error(t, err) require.Contains(t, err.Error(), tt.expectedErrMsg) @@ -245,7 +261,6 @@ func TestReaderSetResponse(t *testing.T) { } require.NoError(t, err) - require.Equal(t, r.res, res) }) } } diff --git a/internal/httprange/resource.go b/internal/httprange/resource.go index 7e21ef292..8b908fe85 100644 --- a/internal/httprange/resource.go +++ b/internal/httprange/resource.go @@ -8,15 +8,63 @@ import ( "net/http" "strconv" "strings" + "sync/atomic" ) // Resource represents any HTTP resource that can be read by a GET operation. // It holds the resource's URL and metadata about it. type Resource struct { - URL string ETag string LastModified string Size int64 + + url atomic.Value + err atomic.Value +} + +func (r *Resource) URL() string { + url, _ := r.url.Load().(string) + return url +} + +func (r *Resource) SetURL(url string) { + if r.URL() == url { + // We want to avoid cache lines invalidation + // on CPU due to value change + return + } + + r.url.Store(url) +} + +func (r *Resource) Err() error { + err, _ := r.err.Load().(error) + return err +} + +func (r *Resource) Valid() bool { + return r.Err() == nil +} + +func (r *Resource) setError(err error) { + r.err.Store(err) +} + +func (r *Resource) Request() (*http.Request, error) { + req, err := http.NewRequest("GET", r.URL(), nil) + if err != nil { + return nil, err + } + + if r.ETag != "" { + req.Header.Set("ETag", r.ETag) + } else if r.LastModified != "" { + // Last-Modified should be a fallback mechanism in case ETag is not present + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified + req.Header.Set("If-Range", r.LastModified) + } + + return req, nil } func NewResource(ctx context.Context, url string) (*Resource, error) { @@ -44,11 +92,12 @@ func NewResource(ctx context.Context, url string) (*Resource, error) { }() resource := &Resource{ - URL: url, ETag: res.Header.Get("ETag"), LastModified: res.Header.Get("Last-Modified"), } + resource.SetURL(url) + switch res.StatusCode { case http.StatusOK: resource.Size = res.ContentLength @@ -71,6 +120,9 @@ func NewResource(ctx context.Context, url string) (*Resource, error) { case http.StatusRequestedRangeNotSatisfiable: return nil, ErrRangeRequestsNotSupported + case http.StatusNotFound: + return nil, ErrNotFound + default: return nil, fmt.Errorf("httprange: new resource %d: %q", res.StatusCode, res.Status) } diff --git a/internal/httprange/resource_test.go b/internal/httprange/resource_test.go index 89d15a217..1d6481fca 100644 --- a/internal/httprange/resource_test.go +++ b/internal/httprange/resource_test.go @@ -2,17 +2,23 @@ package httprange import ( "context" - "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "github.com/stretchr/testify/require" ) +func urlValue(url string) atomic.Value { + v := atomic.Value{} + v.Store(url) + return v +} + func TestNewResource(t *testing.T) { - resource := Resource{ - URL: "/some/resource", + resource := &Resource{ + url: urlValue("/some/resource"), ETag: "etag", LastModified: "Wed, 21 Oct 2015 07:28:00 GMT", Size: 1, @@ -22,7 +28,7 @@ func TestNewResource(t *testing.T) { url string status int contentRange string - want Resource + want *Resource expectedErrMsg string }{ "status_ok": { @@ -34,37 +40,43 @@ func TestNewResource(t *testing.T) { url: "/some/resource", status: http.StatusPartialContent, contentRange: "bytes 200-1000/67589", - want: func() Resource { - r := resource - r.Size = 67589 - return r - }(), + want: &Resource{ + url: urlValue("/some/resource"), + ETag: "etag", + LastModified: "Wed, 21 Oct 2015 07:28:00 GMT", + Size: 67589, + }, }, "status_partial_content_invalid_content_range": { url: "/some/resource", status: http.StatusPartialContent, contentRange: "invalid", expectedErrMsg: "invalid `Content-Range`:", + want: resource, }, "status_partial_content_content_range_not_a_number": { url: "/some/resource", status: http.StatusPartialContent, contentRange: "bytes 200-1000/notanumber", expectedErrMsg: "invalid `Content-Range`:", + want: resource, }, "StatusRequestedRangeNotSatisfiable": { url: "/some/resource", status: http.StatusRequestedRangeNotSatisfiable, expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + want: resource, }, "not_found": { url: "/some/resource", status: http.StatusNotFound, - expectedErrMsg: fmt.Sprintf("httprange: new resource %d: %q", http.StatusNotFound, "404 Not Found"), + expectedErrMsg: ErrNotFound.Error(), + want: resource, }, "invalid_url": { url: "/%", expectedErrMsg: "invalid URL escape", + want: resource, }, } @@ -87,7 +99,7 @@ func TestNewResource(t *testing.T) { } require.NoError(t, err) - require.Contains(t, got.URL, tt.want.URL) + require.Contains(t, got.URL(), tt.want.URL()) require.Equal(t, tt.want.LastModified, got.LastModified) require.Equal(t, tt.want.ETag, got.ETag) require.Equal(t, tt.want.Size, got.Size) diff --git a/internal/httptransport/transport.go b/internal/httptransport/transport.go index bc871ea7b..d8e6a3fe3 100644 --- a/internal/httptransport/transport.go +++ b/internal/httptransport/transport.go @@ -1,6 +1,7 @@ package httptransport import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -15,6 +16,13 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // DefaultTTFBTimeout is the timeout used in the meteredRoundTripper + // when calling http.Transport.RoundTrip. The request will be cancelled + // if the response takes longer than this. + DefaultTTFBTimeout = 15 * time.Second +) + var ( sysPoolOnce = &sync.Once{} sysPool *x509.CertPool @@ -26,11 +34,12 @@ var ( ) type meteredRoundTripper struct { - next http.RoundTripper - name string - tracer *prometheus.HistogramVec - durations *prometheus.HistogramVec - counter *prometheus.CounterVec + next http.RoundTripper + name string + tracer *prometheus.HistogramVec + durations *prometheus.HistogramVec + counter *prometheus.CounterVec + ttfbTimeout time.Duration } func newInternalTransport() *http.Transport { @@ -43,19 +52,24 @@ func newInternalTransport() *http.Transport { MaxIdleConns: 100, MaxIdleConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, + // Set more timeouts https://gitlab.com/gitlab-org/gitlab-pages/-/issues/495 + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 15 * time.Second, + ExpectContinueTimeout: 15 * time.Second, } } // NewTransportWithMetrics will create a custom http.RoundTripper that can be used with an http.Client. // The RoundTripper will report metrics based on the collectors passed. func NewTransportWithMetrics(name string, tracerVec, durationsVec *prometheus. - HistogramVec, counterVec *prometheus.CounterVec) http.RoundTripper { + HistogramVec, counterVec *prometheus.CounterVec, ttfbTimeout time.Duration) http.RoundTripper { return &meteredRoundTripper{ - next: InternalTransport, - name: name, - tracer: tracerVec, - durations: durationsVec, - counter: counterVec, + next: InternalTransport, + name: name, + tracer: tracerVec, + durations: durationsVec, + counter: counterVec, + ttfbTimeout: ttfbTimeout, } } @@ -88,7 +102,13 @@ func loadPool() { func (mrt *meteredRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { start := time.Now() - r = r.WithContext(httptrace.WithClientTrace(r.Context(), mrt.newTracer(start))) + ctx := httptrace.WithClientTrace(r.Context(), mrt.newTracer(start)) + ctx, cancel := context.WithCancel(ctx) + + timer := time.AfterFunc(mrt.ttfbTimeout, cancel) + defer timer.Stop() + + r = r.WithContext(ctx) resp, err := mrt.next.RoundTrip(r) if err != nil { diff --git a/internal/httptransport/transport_test.go b/internal/httptransport/transport_test.go index a4105bef4..9059ea153 100644 --- a/internal/httptransport/transport_test.go +++ b/internal/httptransport/transport_test.go @@ -1,6 +1,8 @@ package httptransport import ( + "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -43,22 +45,17 @@ func Test_withRoundTripper(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - histVec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: t.Name(), - }, []string{"status_code"}) - - counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: t.Name(), - }, []string{"status_code"}) + histVec, counterVec := newTestMetrics(t) next := &mockRoundTripper{ res: &http.Response{ StatusCode: tt.statusCode, }, - err: tt.err, + err: tt.err, + timeout: time.Nanosecond, } - mtr := &meteredRoundTripper{next: next, durations: histVec, counter: counterVec} + mtr := &meteredRoundTripper{next: next, durations: histVec, counter: counterVec, ttfbTimeout: DefaultTTFBTimeout} r := httptest.NewRequest("GET", "/", nil) res, err := mtr.RoundTrip(r) @@ -78,13 +75,53 @@ func Test_withRoundTripper(t *testing.T) { } } +func TestRoundTripTTFBTimeout(t *testing.T) { + histVec, counterVec := newTestMetrics(t) + + next := &mockRoundTripper{ + res: &http.Response{ + StatusCode: http.StatusOK, + }, + timeout: time.Millisecond, + err: nil, + } + + mtr := &meteredRoundTripper{next: next, durations: histVec, counter: counterVec, ttfbTimeout: time.Nanosecond} + req, err := http.NewRequest("GET", "https://gitlab.com", nil) + require.NoError(t, err) + + res, err := mtr.RoundTrip(req) + require.Nil(t, res) + require.True(t, errors.Is(err, context.Canceled), "context must have been canceled after ttfb timeout") +} + +func newTestMetrics(t *testing.T) (*prometheus.HistogramVec, *prometheus.CounterVec) { + t.Helper() + + histVec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: t.Name(), + }, []string{"status_code"}) + + counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: t.Name(), + }, []string{"status_code"}) + + return histVec, counterVec +} + type mockRoundTripper struct { - res *http.Response - err error + res *http.Response + err error + timeout time.Duration } func (mrt *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - return mrt.res, mrt.err + select { + case <-r.Context().Done(): + return nil, r.Context().Err() + case <-time.After(mrt.timeout): + return mrt.res, mrt.err + } } func TestInternalTransportShouldHaveCustomConnectionPoolSettings(t *testing.T) { @@ -92,4 +129,7 @@ func TestInternalTransportShouldHaveCustomConnectionPoolSettings(t *testing.T) { require.EqualValues(t, 100, InternalTransport.MaxIdleConnsPerHost) require.EqualValues(t, 0, InternalTransport.MaxConnsPerHost) require.EqualValues(t, 90*time.Second, InternalTransport.IdleConnTimeout) + require.EqualValues(t, 10*time.Second, InternalTransport.TLSHandshakeTimeout) + require.EqualValues(t, 15*time.Second, InternalTransport.ResponseHeaderTimeout) + require.EqualValues(t, 15*time.Second, InternalTransport.ExpectContinueTimeout) } diff --git a/internal/jail/jail.go b/internal/jail/jail.go index ec64805b3..13b393745 100644 --- a/internal/jail/jail.go +++ b/internal/jail/jail.go @@ -80,14 +80,14 @@ func (j *Jail) Build() error { for _, dir := range j.directories { if err := os.Mkdir(dir.path, dir.mode); err != nil { j.removeAll() - return fmt.Errorf("Can't create directory %q. %s", dir.path, err) + return fmt.Errorf("can't create directory %q. %s", dir.path, err) } } for dest, src := range j.files { if err := handleFile(dest, src); err != nil { j.removeAll() - return fmt.Errorf("Can't copy %q -> %q. %s", src.path, dest, err) + return fmt.Errorf("can't copy %q -> %q. %s", src.path, dest, err) } } @@ -106,12 +106,12 @@ func (j *Jail) removeAll() error { // to traverse files and directories if j.deleteRoot { if err := os.RemoveAll(j.Path()); err != nil { - return fmt.Errorf("Can't delete jail %q. %s", j.Path(), err) + return fmt.Errorf("can't delete jail %q. %s", j.Path(), err) } } else { for path := range j.files { if err := os.Remove(path); err != nil { - return fmt.Errorf("Can't delete file in jail %q: %s", path, err) + return fmt.Errorf("can't delete file in jail %q: %s", path, err) } } @@ -119,7 +119,7 @@ func (j *Jail) removeAll() error { for i := len(j.directories) - 1; i >= 0; i-- { dest := j.directories[i] if err := os.Remove(dest.path); err != nil { - return fmt.Errorf("Can't delete directory in jail %q: %s", dest.path, err) + return fmt.Errorf("can't delete directory in jail %q: %s", dest.path, err) } } } @@ -134,7 +134,7 @@ func (j *Jail) Dispose() error { } if err := j.removeAll(); err != nil { - return fmt.Errorf("Can't delete jail %q. %s", j.Path(), err) + return fmt.Errorf("can't delete jail %q. %s", j.Path(), err) } return nil @@ -150,17 +150,17 @@ func (j *Jail) MkDir(path string, perm os.FileMode) { func (j *Jail) CharDev(path string) error { fi, err := os.Stat(path) if err != nil { - return fmt.Errorf("Can't stat %q: %s", path, err) + return fmt.Errorf("can't stat %q: %s", path, err) } if (fi.Mode() & os.ModeCharDevice) == 0 { - return fmt.Errorf("Can't mknod %q: not a character device", path) + return fmt.Errorf("can't mknod %q: not a character device", path) } // Read the device number from the underlying unix implementation of stat() sys, ok := fi.Sys().(*syscall.Stat_t) if !ok { - return fmt.Errorf("Couldn't determine rdev for %q", path) + return fmt.Errorf("couldn't determine rdev for %q", path) } jailedDest := j.ExternalPath(path) @@ -177,11 +177,11 @@ func (j *Jail) CharDev(path string) error { func (j *Jail) CopyTo(dest, src string) error { fi, err := os.Stat(src) if err != nil { - return fmt.Errorf("Can't stat %q. %s", src, err) + return fmt.Errorf("can't stat %q. %s", src, err) } if fi.IsDir() { - return fmt.Errorf("Can't copy directories. %s", src) + return fmt.Errorf("can't copy directories. %s", src) } jailedDest := j.ExternalPath(dest) diff --git a/internal/jail/mount_linux.go b/internal/jail/mount_linux.go index 7b3eb56ef..54093c401 100644 --- a/internal/jail/mount_linux.go +++ b/internal/jail/mount_linux.go @@ -33,7 +33,7 @@ func (j *Jail) mount() error { for dest, src := range j.bindMounts { var opts uintptr = unix.MS_BIND | unix.MS_REC if err := unix.Mount(src, dest, "none", opts, ""); err != nil { - return fmt.Errorf("Failed to bind mount %s on %s. %s", src, dest, err) + return fmt.Errorf("failed to bind mount %s on %s. %s", src, dest, err) } } @@ -46,7 +46,7 @@ func (j *Jail) unmount() error { // A second invocation on unmount with MNT_DETACH flag will return EINVAL // there's no need to abort with an error if bind mountpoint is already unmounted if err != unix.EINVAL { - return fmt.Errorf("Failed to unmount %s. %s", dest, err) + return fmt.Errorf("failed to unmount %s. %s", dest, err) } } } diff --git a/internal/jail/mount_not_supported.go b/internal/jail/mount_not_supported.go index 0ab0f5f0e..b4d3e3488 100644 --- a/internal/jail/mount_not_supported.go +++ b/internal/jail/mount_not_supported.go @@ -8,12 +8,12 @@ import ( ) func (j *Jail) Unshare() error { - return fmt.Errorf("Unshare not supported on %s", runtime.GOOS) + return fmt.Errorf("unshare not supported on %s", runtime.GOOS) } func (j *Jail) notSupported() error { if len(j.bindMounts) > 0 { - return fmt.Errorf("Bind mount not supported on %s", runtime.GOOS) + return fmt.Errorf("bind mount not supported on %s", runtime.GOOS) } return nil diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go new file mode 100644 index 000000000..77b008f39 --- /dev/null +++ b/internal/middleware/headers.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "errors" + "net/http" + "strings" +) + +var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + +// AddCustomHeaders adds a map of Headers to a Response +func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { + for k, v := range headers { + for _, value := range v { + w.Header().Add(k, value) + } + } +} + +// ParseHeaderString parses a string of key values into a map +func ParseHeaderString(customHeaders []string) (http.Header, error) { + headers := http.Header{} + for _, keyValueString := range customHeaders { + keyValue := strings.SplitN(keyValueString, ":", 2) + if len(keyValue) != 2 { + return nil, errInvalidHeaderParameter + } + headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) + } + return headers, nil +} diff --git a/internal/config/config_test.go b/internal/middleware/headers_test.go similarity index 99% rename from internal/config/config_test.go rename to internal/middleware/headers_test.go index 44afd470d..17d31b50d 100644 --- a/internal/config/config_test.go +++ b/internal/middleware/headers_test.go @@ -1,4 +1,4 @@ -package config +package middleware import ( "net/http/httptest" diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/serving/disk/local/serving_test.go b/internal/serving/disk/local/serving_test.go index 60f01acd6..2602451f1 100644 --- a/internal/serving/disk/local/serving_test.go +++ b/internal/serving/disk/local/serving_test.go @@ -15,29 +15,79 @@ import ( func TestDisk_ServeFileHTTP(t *testing.T) { defer setUpTests(t)() - s := Instance() - w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "http://group.gitlab-example.com/serving/index.html", nil) - handler := serving.Handler{ - Writer: w, - Request: r, - LookupPath: &serving.LookupPath{ - Prefix: "/serving", - Path: "group/serving/public", + tests := map[string]struct { + vfsPath string + path string + expectedStatus int + expectedBody string + }{ + "accessing /index.html": { + vfsPath: "group/serving/public", + path: "/index.html", + expectedStatus: http.StatusOK, + expectedBody: "HTML Document", + }, + "accessing /": { + vfsPath: "group/serving/public", + path: "/", + expectedStatus: http.StatusOK, + expectedBody: "HTML Document", + }, + "accessing without /": { + vfsPath: "group/serving/public", + path: "", + expectedStatus: http.StatusFound, + expectedBody: `Found.`, + }, + "accessing vfs path that is missing": { + vfsPath: "group/serving/public-missing", + path: "/index.html", + // we expect the status to not be set + expectedStatus: 0, + }, + "accessing vfs path that is forbidden (like file)": { + vfsPath: "group/serving/public/index.html", + path: "/index.html", + expectedStatus: http.StatusInternalServerError, }, - SubPath: "/index.html", } - require.True(t, s.ServeFileHTTP(handler)) + s := Instance() - resp := w.Result() - defer resp.Body.Close() + for name, test := range tests { + t.Run(name, func(t *testing.T) { + w := httptest.NewRecorder() + w.Code = 0 // ensure that code is not set, and it is being set by handler + r := httptest.NewRequest("GET", "http://group.gitlab-example.com/serving"+test.path, nil) - require.Equal(t, http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) + handler := serving.Handler{ + Writer: w, + Request: r, + LookupPath: &serving.LookupPath{ + Prefix: "/serving/", + Path: test.vfsPath, + }, + SubPath: test.path, + } - require.Contains(t, string(body), "HTML Document") + if test.expectedStatus == 0 { + require.False(t, s.ServeFileHTTP(handler)) + require.Zero(t, w.Code, "we expect status to not be set") + return + } + + require.True(t, s.ServeFileHTTP(handler)) + + resp := w.Result() + defer resp.Body.Close() + + require.Equal(t, test.expectedStatus, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.Contains(t, string(body), test.expectedBody) + }) + } } var chdirSet = false diff --git a/internal/serving/disk/reader.go b/internal/serving/disk/reader.go index e7f15a1e3..12223bad4 100644 --- a/internal/serving/disk/reader.go +++ b/internal/serving/disk/reader.go @@ -11,7 +11,9 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/labkit/errortracking" + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/redirects" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/symlink" @@ -25,62 +27,64 @@ type Reader struct { } // Show the user some validation messages for their _redirects file -func (reader *Reader) serveRedirectsStatus(h serving.Handler, redirects *redirects.Redirects) error { +func (reader *Reader) serveRedirectsStatus(h serving.Handler, redirects *redirects.Redirects) { h.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") h.Writer.Header().Set("X-Content-Type-Options", "nosniff") h.Writer.WriteHeader(http.StatusOK) - _, err := fmt.Fprintln(h.Writer, redirects.Status()) - return err + fmt.Fprintln(h.Writer, redirects.Status()) } -func (reader *Reader) tryRedirects(h serving.Handler) error { +// tryRedirects returns true if it successfully handled request +func (reader *Reader) tryRedirects(h serving.Handler) bool { ctx := h.Request.Context() root, err := reader.vfs.Root(ctx, h.LookupPath.Path) - if err != nil { - return err + if vfs.IsNotExist(err) { + return false + } else if err != nil { + httperrors.Serve500WithRequest(h.Writer, h.Request, "vfs.Root", err) + return true } r := redirects.ParseRedirects(ctx, root) rewrittenURL, status, err := r.Rewrite(h.Request.URL) if err != nil { - return err + if err != redirects.ErrNoRedirect { + // We assume that rewrite failure is not fatal + // and we only capture the error + errortracking.Capture(err, errortracking.WithRequest(h.Request)) + } + return false } http.Redirect(h.Writer, h.Request, rewrittenURL.Path, status) - - return nil + return true } -func (reader *Reader) tryFile(h serving.Handler) error { +// tryFile returns true if it successfully handled request +func (reader *Reader) tryFile(h serving.Handler) bool { ctx := h.Request.Context() root, err := reader.vfs.Root(ctx, h.LookupPath.Path) - if err != nil { - return err + if vfs.IsNotExist(err) { + return false + } else if err != nil { + httperrors.Serve500WithRequest(h.Writer, h.Request, + "vfs.Root", err) + return true } fullPath, err := reader.resolvePath(ctx, root, h.SubPath) request := h.Request - host := request.Host urlPath := request.URL.Path if locationError, _ := err.(*locationDirectoryError); locationError != nil { if endsWithSlash(urlPath) { fullPath, err = reader.resolvePath(ctx, root, h.SubPath, "index.html") } else { - // TODO why are we doing that? In tests it redirects to HTTPS. This seems wrong, - // issue about this: https://gitlab.com/gitlab-org/gitlab-pages/issues/273 - - // Concat Host with URL.Path - redirectPath := "//" + host + "/" - redirectPath += strings.TrimPrefix(urlPath, "/") - - // Ensure that there's always "/" at end - redirectPath = strings.TrimSuffix(redirectPath, "/") + "/" - http.Redirect(h.Writer, h.Request, redirectPath, 302) - return nil + http.Redirect(h.Writer, h.Request, redirectPath(h.Request), 302) + return true } } @@ -89,7 +93,9 @@ func (reader *Reader) tryFile(h serving.Handler) error { } if err != nil { - return err + // We assume that this is mostly missing file type of the error + // and additional handlers should try to process the request + return false } // Serve status of `_redirects` under `_redirects` @@ -97,34 +103,53 @@ func (reader *Reader) tryFile(h serving.Handler) error { if fullPath == redirects.ConfigFile { if os.Getenv("FF_ENABLE_REDIRECTS") != "false" { r := redirects.ParseRedirects(ctx, root) - return reader.serveRedirectsStatus(h, r) + reader.serveRedirectsStatus(h, r) + return true } h.Writer.WriteHeader(http.StatusForbidden) - return nil + return true } return reader.serveFile(ctx, h.Writer, h.Request, root, fullPath, h.LookupPath.HasAccessControl) } -func (reader *Reader) tryNotFound(h serving.Handler) error { +func redirectPath(request *http.Request) string { + url := *request.URL + + // This ensures that path starts with `///` + url.Scheme = "" + url.Host = request.Host + url.Path = strings.TrimPrefix(url.Path, "/") + "/" + + return strings.TrimSuffix(url.String(), "?") +} + +func (reader *Reader) tryNotFound(h serving.Handler) bool { ctx := h.Request.Context() root, err := reader.vfs.Root(ctx, h.LookupPath.Path) - if err != nil { - return err + if vfs.IsNotExist(err) { + return false + } else if err != nil { + httperrors.Serve500WithRequest(h.Writer, h.Request, "vfs.Root", err) + return true } page404, err := reader.resolvePath(ctx, root, "404.html") if err != nil { - return err + // We assume that this is mostly missing file type of the error + // and additional handlers should try to process the request + return false } err = reader.serveCustomFile(ctx, h.Writer, h.Request, http.StatusNotFound, root, page404) if err != nil { - return err + httperrors.Serve500WithRequest(h.Writer, h.Request, "serveCustomFile", err) + return true } - return nil + + return true } // Resolve the HTTP request to a path on disk, converting requests for @@ -168,19 +193,21 @@ func (reader *Reader) resolvePath(ctx context.Context, root vfs.Root, subPath .. return fullPath, nil } -func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, origPath string, accessControl bool) error { +func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, origPath string, accessControl bool) bool { fullPath := reader.handleContentEncoding(ctx, w, r, root, origPath) file, err := root.Open(ctx, fullPath) if err != nil { - return err + httperrors.Serve500WithRequest(w, r, "root.Open", err) + return true } defer file.Close() fi, err := root.Lstat(ctx, fullPath) if err != nil { - return err + httperrors.Serve500WithRequest(w, r, "root.Lstat", err) + return true } if !accessControl { @@ -191,7 +218,8 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h contentType, err := reader.detectContentType(ctx, root, origPath) if err != nil { - return err + httperrors.Serve500WithRequest(w, r, "detectContentType", err) + return true } w.Header().Set("Content-Type", contentType) @@ -208,7 +236,7 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h io.Copy(w, file) } - return nil + return true } func (reader *Reader) serveCustomFile(ctx context.Context, w http.ResponseWriter, r *http.Request, code int, root vfs.Root, origPath string) error { diff --git a/internal/serving/disk/reader_test.go b/internal/serving/disk/reader_test.go new file mode 100644 index 000000000..53ea3d9a5 --- /dev/null +++ b/internal/serving/disk/reader_test.go @@ -0,0 +1,68 @@ +package disk + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_redirectPath(t *testing.T) { + tests := map[string]struct { + request *http.Request + expectedPath string + }{ + "simple_url_no_path": { + request: newRequest(t, "https://domain.gitlab.io"), + expectedPath: "//domain.gitlab.io/", + }, + "path_only": { + request: newRequest(t, "https://domain.gitlab.io/index.html"), + expectedPath: "//domain.gitlab.io/index.html/", + }, + "query_only": { + request: newRequest(t, "https://domain.gitlab.io?query=test"), + expectedPath: "//domain.gitlab.io/?query=test", + }, + "empty_query": { + request: newRequest(t, "https://domain.gitlab.io?"), + expectedPath: "//domain.gitlab.io/", + }, + "fragment_only": { + request: newRequest(t, "https://domain.gitlab.io#fragment"), + expectedPath: "//domain.gitlab.io/#fragment", + }, + "path_and_query": { + request: newRequest(t, "https://domain.gitlab.io/index.html?query=test"), + expectedPath: "//domain.gitlab.io/index.html/?query=test", + }, + "path_and_fragment": { + request: newRequest(t, "https://domain.gitlab.io/index.html#fragment"), + expectedPath: "//domain.gitlab.io/index.html/#fragment", + }, + "query_and_fragment": { + request: newRequest(t, "https://domain.gitlab.io?query=test#fragment"), + expectedPath: "//domain.gitlab.io/?query=test#fragment", + }, + "path_query_and_fragment": { + request: newRequest(t, "https://domain.gitlab.io/index.html?query=test#fragment"), + expectedPath: "//domain.gitlab.io/index.html/?query=test#fragment", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got := redirectPath(test.request) + require.Equal(t, test.expectedPath, got) + }) + } +} + +func newRequest(t *testing.T, url string) *http.Request { + t.Helper() + + r, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + return r +} diff --git a/internal/serving/disk/serving.go b/internal/serving/disk/serving.go index 11b1689e3..fbcdf9f2d 100644 --- a/internal/serving/disk/serving.go +++ b/internal/serving/disk/serving.go @@ -3,6 +3,7 @@ package disk import ( "os" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" @@ -17,12 +18,12 @@ type Disk struct { // ServeFileHTTP serves a file from disk and returns true. It returns false // when a file could not been found. func (s *Disk) ServeFileHTTP(h serving.Handler) bool { - if s.reader.tryFile(h) == nil { + if s.reader.tryFile(h) { return true } if os.Getenv("FF_ENABLE_REDIRECTS") != "false" { - if s.reader.tryRedirects(h) == nil { + if s.reader.tryRedirects(h) { return true } } @@ -32,7 +33,7 @@ func (s *Disk) ServeFileHTTP(h serving.Handler) bool { // ServeNotFoundHTTP tries to read a custom 404 page func (s *Disk) ServeNotFoundHTTP(h serving.Handler) { - if s.reader.tryNotFound(h) == nil { + if s.reader.tryNotFound(h) { return } @@ -40,6 +41,11 @@ func (s *Disk) ServeNotFoundHTTP(h serving.Handler) { httperrors.Serve404(h.Writer) } +// Reconfigure VFS +func (s *Disk) Reconfigure(cfg *config.Config) error { + return s.reader.vfs.Reconfigure(cfg) +} + // New returns a serving instance that is capable of reading files // from the VFS func New(vfs vfs.VFS) serving.Serving { diff --git a/internal/serving/disk/zip/serving.go b/internal/serving/disk/zip/serving.go index 95894fc98..6db0be10d 100644 --- a/internal/serving/disk/zip/serving.go +++ b/internal/serving/disk/zip/serving.go @@ -1,13 +1,14 @@ package zip import ( + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs/zip" ) -var instance = disk.New(vfs.Instrumented(zip.New())) +var instance = disk.New(vfs.Instrumented(zip.New(&config.ZipServing{}))) // Instance returns a serving instance that is capable of reading files // from a zip archives opened from a URL, most likely stored in object storage diff --git a/internal/serving/disk/zip/serving_test.go b/internal/serving/disk/zip/serving_test.go index e95432ae3..e64a761a1 100644 --- a/internal/serving/disk/zip/serving_test.go +++ b/internal/serving/disk/zip/serving_test.go @@ -5,9 +5,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" ) @@ -17,32 +19,59 @@ func TestZip_ServeFileHTTP(t *testing.T) { defer cleanup() tests := map[string]struct { + vfsPath string path string expectedStatus int expectedBody string }{ "accessing /index.html": { + vfsPath: testServerURL + "/public.zip", path: "/index.html", expectedStatus: http.StatusOK, expectedBody: "zip.gitlab.io/project/index.html\n", }, "accessing /": { + vfsPath: testServerURL + "/public.zip", path: "/", expectedStatus: http.StatusOK, expectedBody: "zip.gitlab.io/project/index.html\n", }, "accessing without /": { + vfsPath: testServerURL + "/public.zip", path: "", expectedStatus: http.StatusFound, expectedBody: `Found.`, }, + "accessing archive that is 404": { + vfsPath: testServerURL + "/invalid.zip", + path: "/index.html", + // we expect the status to not be set + expectedStatus: 0, + }, + "accessing archive that is 500": { + vfsPath: testServerURL + "/500", + path: "/index.html", + expectedStatus: http.StatusInternalServerError, + }, + } + + cfg := &config.Config{ + Zip: &config.ZipServing{ + ExpirationInterval: 10 * time.Second, + CleanupInterval: 5 * time.Second, + RefreshInterval: 5 * time.Second, + OpenTimeout: 5 * time.Second, + }, } s := Instance() + err := s.Reconfigure(cfg) + require.NoError(t, err) for name, test := range tests { t.Run(name, func(t *testing.T) { w := httptest.NewRecorder() + w.Code = 0 // ensure that code is not set, and it is being set by handler r := httptest.NewRequest("GET", "http://zip.gitlab.io/zip"+test.path, nil) handler := serving.Handler{ @@ -50,11 +79,17 @@ func TestZip_ServeFileHTTP(t *testing.T) { Request: r, LookupPath: &serving.LookupPath{ Prefix: "/zip/", - Path: testServerURL + "/public.zip", + Path: test.vfsPath, }, SubPath: test.path, } + if test.expectedStatus == 0 { + require.False(t, s.ServeFileHTTP(handler)) + require.Zero(t, w.Code, "we expect status to not be set") + return + } + require.True(t, s.ServeFileHTTP(handler)) resp := w.Result() @@ -76,9 +111,15 @@ func newZipFileServerURL(t *testing.T, zipFilePath string) (string, func()) { chdir := testhelpers.ChdirInPath(t, "../../../../shared/pages", &chdirSet) - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := http.NewServeMux() + m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, zipFilePath) })) + m.HandleFunc("/500", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + testServer := httptest.NewServer(m) return testServer.URL, func() { chdir() diff --git a/internal/serving/serverless/serverless.go b/internal/serving/serverless/serverless.go index e1881362a..f8bd4e87b 100644 --- a/internal/serving/serverless/serverless.go +++ b/internal/serving/serverless/serverless.go @@ -4,6 +4,7 @@ import ( "errors" "net/http/httputil" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -65,3 +66,8 @@ func (s *Serverless) ServeFileHTTP(h serving.Handler) bool { func (s *Serverless) ServeNotFoundHTTP(h serving.Handler) { httperrors.Serve404(h.Writer) } + +// Reconfigure noop +func (s *Serverless) Reconfigure(*config.Config) error { + return nil +} diff --git a/internal/serving/serving.go b/internal/serving/serving.go index 6fde82165..786ee569e 100644 --- a/internal/serving/serving.go +++ b/internal/serving/serving.go @@ -1,7 +1,10 @@ package serving +import "gitlab.com/gitlab-org/gitlab-pages/internal/config" + // Serving is an interface used to define a serving driver type Serving interface { ServeFileHTTP(Handler) bool ServeNotFoundHTTP(Handler) + Reconfigure(config *config.Config) error } diff --git a/internal/source/domains.go b/internal/source/domains.go index 2a7a317c2..cf81fab2d 100644 --- a/internal/source/domains.go +++ b/internal/source/domains.go @@ -51,14 +51,11 @@ func NewDomains(config Config) (*Domains, error) { // returns error if -domain-config-source is not valid // returns error if -domain-config-source=gitlab and init fails func (d *Domains) setConfigSource(config Config) error { - // TODO: Handle domain-config-source=auto https://gitlab.com/gitlab-org/gitlab/-/issues/218358 - // attach gitlab by default when source is not disk (auto, gitlab) switch config.DomainConfigSource() { case "gitlab": d.configSource = sourceGitlab return d.setGitLabClient(config) case "auto": - // TODO: handle DomainConfigSource == "auto" https://gitlab.com/gitlab-org/gitlab/-/issues/218358 d.configSource = sourceAuto // enable disk for auto for now d.disk = disk.New() @@ -122,18 +119,18 @@ func (d *Domains) IsReady() bool { case sourceDisk: return d.disk.IsReady() case sourceAuto: - // TODO: implement auto https://gitlab.com/gitlab-org/gitlab/-/issues/218358, default to disk for now + // if gitlab is configured and is ready + if d.gitlab != nil && d.gitlab.IsReady() { + return true + } + return d.disk.IsReady() + default: + return false } - - return false } func (d *Domains) source(domain string) Source { - if d.gitlab == nil { - return d.disk - } - // This check is only needed until we enable `d.gitlab` source in all // environments (including on-premises installations) followed by removal of // `d.disk` source. This can be safely removed afterwards. @@ -141,17 +138,18 @@ func (d *Domains) source(domain string) Source { return d.gitlab } - if d.configSource == sourceDisk { + switch d.configSource { + case sourceDisk: return d.disk - } - - // TODO: handle sourceAuto https://gitlab.com/gitlab-org/gitlab/-/issues/218358 - // check IsReady for sourceAuto for now - if d.configSource == sourceGitlab || d.gitlab.IsReady() { + case sourceGitlab: return d.gitlab - } + default: + if d.gitlab != nil && d.gitlab.IsReady() { + return d.gitlab + } - return d.disk + return d.disk + } } // IsServerlessDomain checks if a domain requested is a serverless domain we diff --git a/internal/source/domains_test.go b/internal/source/domains_test.go index 36c53e6fc..abc82e423 100644 --- a/internal/source/domains_test.go +++ b/internal/source/domains_test.go @@ -60,7 +60,6 @@ func TestNewDomains(t *testing.T) { expectDiskNil: false, }, { - // TODO: https://gitlab.com/gitlab-org/gitlab/-/issues/218358 name: "auto_without_api_config", sourceConfig: sourceConfig{domainSource: "auto"}, expectGitlabNil: true, diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..43ef2e523 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,10 +1,11 @@ package cache import ( + "time" "context" "errors" - "time" - + "sync" + log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -13,15 +14,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +63,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +92,26 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} + \ No newline at end of file diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..774e9779e --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/internal/source/gitlab/client/client.go b/internal/source/gitlab/client/client.go index 0e8235c0a..b11ea2cbb 100644 --- a/internal/source/gitlab/client/client.go +++ b/internal/source/gitlab/client/client.go @@ -60,7 +60,9 @@ func NewClient(baseURL string, secretKey []byte, connectionTimeout, jwtTokenExpi "gitlab_internal_api", metrics.DomainsSourceAPITraceDuration, metrics.DomainsSourceAPICallDuration, - metrics.DomainsSourceAPIReqTotal), + metrics.DomainsSourceAPIReqTotal, + httptransport.DefaultTTFBTimeout, + ), }, jwtTokenExpiry: jwtTokenExpiry, }, nil diff --git a/internal/source/gitlab/client/client_test.go b/internal/source/gitlab/client/client_test.go index c888a059f..6d4ce8140 100644 --- a/internal/source/gitlab/client/client_test.go +++ b/internal/source/gitlab/client/client_test.go @@ -319,7 +319,7 @@ func validateToken(t *testing.T, tokenString string) { t.Helper() token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return secretKey(t), nil diff --git a/internal/source/gitlab/factory.go b/internal/source/gitlab/factory.go index 41f7ea56b..b033a592a 100644 --- a/internal/source/gitlab/factory.go +++ b/internal/source/gitlab/factory.go @@ -6,7 +6,6 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/serving" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/local" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" - "gitlab.com/gitlab-org/gitlab-pages/internal/serving/serverless" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" ) @@ -35,14 +34,19 @@ func fabricateServing(lookup api.LookupPath) serving.Serving { case "zip": return zip.Instance() case "serverless": - serving, err := serverless.NewFromAPISource(source.Serverless) - if err != nil { - log.WithError(err).Errorf("could not fabricate serving for project %d", lookup.ProjectID) - - break - } - - return serving + log.Errorf("attempted to fabricate serverless serving for project %d", lookup.ProjectID) + + // This feature has been disalbed, for more details see + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/467 + // + // serving, err := serverless.NewFromAPISource(source.Serverless) + // if err != nil { + // log.WithError(err).Errorf("could not fabricate serving for project %d", lookup.ProjectID) + // + // break + // } + // + // return serving } return defaultServing() diff --git a/internal/source/gitlab/factory_test.go b/internal/source/gitlab/factory_test.go index 2f3e19940..46740d354 100644 --- a/internal/source/gitlab/factory_test.go +++ b/internal/source/gitlab/factory_test.go @@ -7,7 +7,6 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk" - "gitlab.com/gitlab-org/gitlab-pages/internal/serving/serverless" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" ) @@ -59,6 +58,8 @@ func TestFabricateServing(t *testing.T) { }, } - require.IsType(t, &serverless.Serverless{}, fabricateServing(lookup)) + // Serverless serving has been deprecated. + // require.IsType(t, &serverless.Serverless{}, fabricateServing(lookup)) + require.IsType(t, &disk.Disk{}, fabricateServing(lookup)) }) } diff --git a/internal/tlsconfig/tlsconfig.go b/internal/tlsconfig/tlsconfig.go index 5d26ed520..9babf3744 100644 --- a/internal/tlsconfig/tlsconfig.go +++ b/internal/tlsconfig/tlsconfig.go @@ -73,13 +73,13 @@ func ValidateTLSVersions(min, max string) error { tlsMax, tlsMaxOk := AllTLSVersions[max] if !tlsMinOk { - return fmt.Errorf("Invalid minimum TLS version: %s", min) + return fmt.Errorf("invalid minimum TLS version: %s", min) } if !tlsMaxOk { - return fmt.Errorf("Invalid maximum TLS version: %s", max) + return fmt.Errorf("invalid maximum TLS version: %s", max) } if tlsMin > tlsMax && tlsMax > 0 { - return fmt.Errorf("Invalid maximum TLS version: %s; Should be at least %s", max, min) + return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) } return nil diff --git a/internal/tlsconfig/tlsconfig_test.go b/internal/tlsconfig/tlsconfig_test.go index e37ab51bf..00a080667 100644 --- a/internal/tlsconfig/tlsconfig_test.go +++ b/internal/tlsconfig/tlsconfig_test.go @@ -35,9 +35,9 @@ func TestValidateTLSVersions(t *testing.T) { tlsMax string err string }{ - "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "Invalid minimum TLS version: tls123"}, - "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "Invalid maximum TLS version: tls123"}, - "TLS versions conflict": {tlsMin: "tls1.2", tlsMax: "tls1.1", err: "Invalid maximum TLS version: tls1.1; Should be at least tls1.2"}, + "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, + "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, + "TLS versions conflict": {tlsMin: "tls1.2", tlsMax: "tls1.1", err: "invalid maximum TLS version: tls1.1; should be at least tls1.2"}, } for name, tc := range tests { @@ -53,7 +53,7 @@ func TestInvalidKeyPair(t *testing.T) { require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") } -func TestInsecureCihers(t *testing.T) { +func TestInsecureCiphers(t *testing.T) { tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12) require.NoError(t, err) require.False(t, tlsConfig.PreferServerCipherSuites) diff --git a/internal/vfs/errors.go b/internal/vfs/errors.go new file mode 100644 index 000000000..32b861925 --- /dev/null +++ b/internal/vfs/errors.go @@ -0,0 +1,18 @@ +package vfs + +import ( + "fmt" +) + +type ErrNotExist struct { + Inner error +} + +func (e ErrNotExist) Error() string { + return fmt.Sprintf("not exist: %q", e.Inner) +} + +func IsNotExist(err error) bool { + _, ok := err.(*ErrNotExist) + return ok +} diff --git a/internal/vfs/local/vfs.go b/internal/vfs/local/vfs.go index ca74dfbe5..ea54e8e83 100644 --- a/internal/vfs/local/vfs.go +++ b/internal/vfs/local/vfs.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" ) @@ -20,12 +21,16 @@ func (fs VFS) Root(ctx context.Context, path string) (vfs.Root, error) { } rootPath, err = filepath.EvalSymlinks(rootPath) - if err != nil { + if os.IsNotExist(err) { + return nil, &vfs.ErrNotExist{Inner: err} + } else if err != nil { return nil, err } fi, err := os.Lstat(rootPath) - if err != nil { + if os.IsNotExist(err) { + return nil, &vfs.ErrNotExist{Inner: err} + } else if err != nil { return nil, err } @@ -39,3 +44,8 @@ func (fs VFS) Root(ctx context.Context, path string) (vfs.Root, error) { func (fs *VFS) Name() string { return "local" } + +func (fs *VFS) Reconfigure(*config.Config) error { + // noop + return nil +} diff --git a/internal/vfs/local/vfs_test.go b/internal/vfs/local/vfs_test.go index ec67d5959..b678cfa7a 100644 --- a/internal/vfs/local/vfs_test.go +++ b/internal/vfs/local/vfs_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" ) var localVFS = &VFS{} @@ -98,7 +100,7 @@ func TestVFSRoot(t *testing.T) { rootVFS, err := localVFS.Root(context.Background(), filepath.Join(tmpDir, test.path)) if test.expectedIsNotExist { - require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err)) + require.Equal(t, test.expectedIsNotExist, vfs.IsNotExist(err)) return } diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 7bd51db2c..2304f9034 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -13,6 +14,7 @@ import ( type VFS interface { Root(ctx context.Context, path string) (Root, error) Name() string + Reconfigure(config *config.Config) error } func Instrumented(fs VFS) VFS { @@ -50,3 +52,7 @@ func (i *instrumentedVFS) Root(ctx context.Context, path string) (Root, error) { func (i *instrumentedVFS) Name() string { return i.fs.Name() } + +func (i *instrumentedVFS) Reconfigure(cfg *config.Config) error { + return i.fs.Reconfigure(cfg) +} diff --git a/internal/vfs/zip/archive.go b/internal/vfs/zip/archive.go index ba15af200..1137f0041 100644 --- a/internal/vfs/zip/archive.go +++ b/internal/vfs/zip/archive.go @@ -24,9 +24,6 @@ import ( const ( dirPrefix = "public/" maxSymlinkSize = 256 - - // DefaultOpenTimeout to request an archive and read its contents the first time - DefaultOpenTimeout = 30 * time.Second ) var ( @@ -35,13 +32,21 @@ var ( errNotFile = errors.New("not a file") ) +type archiveStatus int + +const ( + archiveOpening archiveStatus = iota + archiveOpenError + archiveOpened + archiveCorrupted +) + // zipArchive implements the vfs.Root interface. // It represents a zip archive saving all its files in memory. // It holds an httprange.Resource that can be read with httprange.RangedReader in chunks. type zipArchive struct { fs *zipVFS - path string once sync.Once done chan struct{} openTimeout time.Duration @@ -57,10 +62,9 @@ type zipArchive struct { directories map[string]*zip.FileHeader } -func newArchive(fs *zipVFS, path string, openTimeout time.Duration) *zipArchive { +func newArchive(fs *zipVFS, openTimeout time.Duration) *zipArchive { return &zipArchive{ fs: fs, - path: path, done: make(chan struct{}), files: make(map[string]*zip.File), directories: make(map[string]*zip.FileHeader), @@ -69,13 +73,15 @@ func newArchive(fs *zipVFS, path string, openTimeout time.Duration) *zipArchive } } -func (a *zipArchive) openArchive(parentCtx context.Context) (err error) { - // return early if openArchive was done already in a concurrent request - select { - case <-a.done: - return a.err +func (a *zipArchive) openArchive(parentCtx context.Context, url string) (err error) { + // always try to update URL on resource + if a.resource != nil { + a.resource.SetURL(url) + } - default: + // return early if openArchive was done already in a concurrent request + if status, err := a.openStatus(); status != archiveOpening { + return err } ctx, cancel := context.WithTimeout(parentCtx, a.openTimeout) @@ -84,7 +90,7 @@ func (a *zipArchive) openArchive(parentCtx context.Context) (err error) { a.once.Do(func() { // read archive once in its own routine with its own timeout // if parentCtx is canceled, readArchive will continue regardless and will be cached in memory - go a.readArchive() + go a.readArchive(url) }) // wait for readArchive to be done or return if the parent context is canceled @@ -106,14 +112,14 @@ func (a *zipArchive) openArchive(parentCtx context.Context) (err error) { // readArchive creates an httprange.Resource that can read the archive's contents and stores a slice of *zip.Files // that can be accessed later when calling any of th vfs.VFS operations -func (a *zipArchive) readArchive() { +func (a *zipArchive) readArchive(url string) { defer close(a.done) // readArchive with a timeout separate from openArchive's ctx, cancel := context.WithTimeout(context.Background(), a.openTimeout) defer cancel() - a.resource, a.err = httprange.NewResource(ctx, a.path) + a.resource, a.err = httprange.NewResource(ctx, url) if a.err != nil { metrics.ZipOpened.WithLabelValues("error").Inc() return @@ -286,3 +292,21 @@ func (a *zipArchive) Readlink(ctx context.Context, name string) (string, error) func (a *zipArchive) onEvicted() { metrics.ZipArchiveEntriesCached.Sub(float64(len(a.files))) } + +func (a *zipArchive) openStatus() (archiveStatus, error) { + select { + case <-a.done: + if a.err != nil { + return archiveOpenError, a.err + } + + if a.resource != nil && a.resource.Err() != nil { + return archiveCorrupted, a.resource.Err() + } + + return archiveOpened, nil + + default: + return archiveOpening, nil + } +} diff --git a/internal/vfs/zip/archive_test.go b/internal/vfs/zip/archive_test.go index ef6785b5e..da778e620 100644 --- a/internal/vfs/zip/archive_test.go +++ b/internal/vfs/zip/archive_test.go @@ -12,10 +12,20 @@ import ( "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/httprange" "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" ) -var chdirSet = false +var ( + chdirSet = false + zipCfg = &config.ZipServing{ + ExpirationInterval: 10 * time.Second, + CleanupInterval: 5 * time.Second, + RefreshInterval: 5 * time.Second, + OpenTimeout: 5 * time.Second, + } +) func TestOpen(t *testing.T) { zip, cleanup := openZipArchive(t, nil) @@ -71,30 +81,107 @@ func TestOpen(t *testing.T) { func TestOpenCached(t *testing.T) { var requests int64 - zip, cleanup := openZipArchive(t, &requests) + testServerURL, cleanup := newZipFileServerURL(t, "group/zip.gitlab.io/public-without-dirs.zip", &requests) defer cleanup() - t.Run("open file first time", func(t *testing.T) { - requestsStart := requests - f, err := zip.Open(context.Background(), "index.html") - require.NoError(t, err) - defer f.Close() + fs := New(zipCfg) + + // We use array instead of map to ensure + // predictable ordering of test execution + tests := []struct { + name string + vfsPath string + filePath string + expectedArchiveStatus archiveStatus + expectedOpenErr error + expectedReadErr error + expectedRequests int64 + }{ + { + name: "open file first time", + vfsPath: testServerURL + "/public.zip", + filePath: "index.html", + // we expect five requests to: + // read resource and zip metadata + // read file: data offset and content + expectedRequests: 5, + expectedArchiveStatus: archiveOpened, + }, + { + name: "open file second time", + vfsPath: testServerURL + "/public.zip", + filePath: "index.html", + // we expect one request to read file with cached data offset + expectedRequests: 1, + expectedArchiveStatus: archiveOpened, + }, + { + name: "when the URL changes", + vfsPath: testServerURL + "/public.zip?new-secret", + filePath: "index.html", + expectedRequests: 1, + expectedArchiveStatus: archiveOpened, + }, + { + name: "when opening cached file and content changes", + vfsPath: testServerURL + "/public.zip?changed-content=1", + filePath: "index.html", + expectedRequests: 1, + // we receive an error on `read` as `open` offset is already cached + expectedReadErr: httprange.ErrRangeRequestsNotSupported, + expectedArchiveStatus: archiveCorrupted, + }, + { + name: "after content change archive is reloaded", + vfsPath: testServerURL + "/public.zip?new-secret", + filePath: "index.html", + expectedRequests: 5, + expectedArchiveStatus: archiveOpened, + }, + { + name: "when opening non-cached file and content changes", + vfsPath: testServerURL + "/public.zip?changed-content=1", + filePath: "subdir/hello.html", + expectedRequests: 1, + // we receive an error on `read` as `open` offset is already cached + expectedOpenErr: httprange.ErrRangeRequestsNotSupported, + expectedArchiveStatus: archiveCorrupted, + }, + } - _, err = ioutil.ReadAll(f) - require.NoError(t, err) - require.Equal(t, int64(2), atomic.LoadInt64(&requests)-requestsStart, "we expect two requests to read file: data offset and content") - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + start := atomic.LoadInt64(&requests) + zip, err := fs.Root(context.Background(), test.vfsPath) + require.NoError(t, err) - t.Run("open file second time", func(t *testing.T) { - requestsStart := atomic.LoadInt64(&requests) - f, err := zip.Open(context.Background(), "index.html") - require.NoError(t, err) - defer f.Close() + f, err := zip.Open(context.Background(), test.filePath) + if test.expectedOpenErr != nil { + require.Equal(t, test.expectedOpenErr, err) + status, _ := zip.(*zipArchive).openStatus() + require.Equal(t, test.expectedArchiveStatus, status) + return + } - _, err = ioutil.ReadAll(f) - require.NoError(t, err) - require.Equal(t, int64(1), atomic.LoadInt64(&requests)-requestsStart, "we expect one request to read file with cached data offset") - }) + require.NoError(t, err) + defer f.Close() + + _, err = ioutil.ReadAll(f) + if test.expectedReadErr != nil { + require.Equal(t, test.expectedReadErr, err) + status, _ := zip.(*zipArchive).openStatus() + require.Equal(t, test.expectedArchiveStatus, status) + return + } + + require.NoError(t, err) + status, _ := zip.(*zipArchive).openStatus() + require.Equal(t, test.expectedArchiveStatus, status) + + end := atomic.LoadInt64(&requests) + require.Equal(t, test.expectedRequests, end-start) + }) + } } func TestLstat(t *testing.T) { @@ -244,12 +331,12 @@ func TestArchiveCanBeReadAfterOpenCtxCanceled(t *testing.T) { testServerURL, cleanup := newZipFileServerURL(t, "group/zip.gitlab.io/public.zip", nil) defer cleanup() - fs := New().(*zipVFS) - zip := newArchive(fs, testServerURL+"/public.zip", time.Second) + fs := New(zipCfg).(*zipVFS) + zip := newArchive(fs, time.Second) ctx, cancel := context.WithCancel(context.Background()) cancel() - err := zip.openArchive(ctx) + err := zip.openArchive(ctx, testServerURL+"/public.zip") require.EqualError(t, err, context.Canceled.Error()) <-zip.done @@ -267,12 +354,12 @@ func TestReadArchiveFails(t *testing.T) { testServerURL, cleanup := newZipFileServerURL(t, "group/zip.gitlab.io/public.zip", nil) defer cleanup() - fs := New().(*zipVFS) - zip := newArchive(fs, testServerURL+"/unkown.html", time.Second) + fs := New(zipCfg).(*zipVFS) + zip := newArchive(fs, time.Second) - err := zip.openArchive(context.Background()) + err := zip.openArchive(context.Background(), testServerURL+"/unkown.html") require.Error(t, err) - require.Contains(t, err.Error(), "Not Found") + require.Contains(t, err.Error(), httprange.ErrNotFound.Error()) _, err = zip.Open(context.Background(), "index.html") require.EqualError(t, err, os.ErrNotExist.Error()) @@ -287,10 +374,10 @@ func openZipArchive(t *testing.T, requests *int64) (*zipArchive, func()) { testServerURL, cleanup := newZipFileServerURL(t, "group/zip.gitlab.io/public-without-dirs.zip", requests) - fs := New().(*zipVFS) - zip := newArchive(fs, testServerURL+"/public.zip", time.Second) + fs := New(zipCfg).(*zipVFS) + zip := newArchive(fs, time.Second) - err := zip.openArchive(context.Background()) + err := zip.openArchive(context.Background(), testServerURL+"/public.zip") require.NoError(t, err) // public/ public/index.html public/404.html public/symlink.html @@ -311,10 +398,18 @@ func newZipFileServerURL(t *testing.T, zipFilePath string, requests *int64) (str m := http.NewServeMux() m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, zipFilePath) if requests != nil { atomic.AddInt64(requests, 1) } + + r.ParseForm() + + if changedContent := r.Form.Get("changed-content"); changedContent != "" { + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + + http.ServeFile(w, r, zipFilePath) })) testServer := httptest.NewServer(m) diff --git a/internal/vfs/zip/lru_cache.go b/internal/vfs/zip/lru_cache.go index fed5c3602..9810e2453 100644 --- a/internal/vfs/zip/lru_cache.go +++ b/internal/vfs/zip/lru_cache.go @@ -24,16 +24,17 @@ type lruCache struct { cache *ccache.Cache } -func newLruCache(op string, maxEntries uint32, duration time.Duration) *lruCache { +func newLruCache(op string, maxEntries int64, duration time.Duration) *lruCache { configuration := ccache.Configure() - configuration.MaxSize(int64(maxEntries)) - configuration.ItemsToPrune(maxEntries / lruCacheItemsToPruneDiv) + configuration.MaxSize(maxEntries) + configuration.ItemsToPrune(uint32(maxEntries) / lruCacheItemsToPruneDiv) configuration.GetsPerPromote(lruCacheGetsPerPromote) // if item gets requested frequently promote it configuration.OnDelete(func(*ccache.Item) { metrics.ZipCachedEntries.WithLabelValues(op).Dec() }) return &lruCache{ + op: op, cache: ccache.New(configuration), duration: duration, } diff --git a/internal/vfs/zip/vfs.go b/internal/vfs/zip/vfs.go index 78a77e1cb..b27424c60 100644 --- a/internal/vfs/zip/vfs.go +++ b/internal/vfs/zip/vfs.go @@ -9,16 +9,13 @@ import ( "github.com/patrickmn/go-cache" + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/httprange" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) const ( - // TODO: make these configurable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/464 - defaultCacheExpirationInterval = time.Minute - defaultCacheCleanupInterval = time.Minute / 2 - defaultCacheRefreshInterval = time.Minute / 2 - // we assume that each item costs around 100 bytes // this gives around 5MB of raw memory needed without acceleration structures defaultDataOffsetItems = 50000 @@ -39,6 +36,11 @@ type zipVFS struct { cache *cache.Cache cacheLock sync.Mutex + openTimeout time.Duration + cacheExpirationInterval time.Duration + cacheRefreshInterval time.Duration + cacheCleanupInterval time.Duration + dataOffsetCache *lruCache readlinkCache *lruCache @@ -46,20 +48,60 @@ type zipVFS struct { } // New creates a zipVFS instance that can be used by a serving request -func New() vfs.VFS { +func New(cfg *config.ZipServing) vfs.VFS { zipVFS := &zipVFS{ - cache: cache.New(defaultCacheExpirationInterval, defaultCacheCleanupInterval), - dataOffsetCache: newLruCache("data-offset", defaultDataOffsetItems, defaultDataOffsetExpirationInterval), - readlinkCache: newLruCache("readlink", defaultReadlinkItems, defaultReadlinkExpirationInterval), + cacheExpirationInterval: cfg.ExpirationInterval, + cacheRefreshInterval: cfg.RefreshInterval, + cacheCleanupInterval: cfg.CleanupInterval, + openTimeout: cfg.OpenTimeout, } - zipVFS.cache.OnEvicted(func(s string, i interface{}) { + zipVFS.resetCache() + + // TODO: To be removed with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/480 + zipVFS.dataOffsetCache = newLruCache("data-offset", defaultDataOffsetItems, defaultDataOffsetExpirationInterval) + zipVFS.readlinkCache = newLruCache("readlink", defaultReadlinkItems, defaultReadlinkExpirationInterval) + + return zipVFS +} + +// Reconfigure will update the zipVFS configuration values and will reset the +// cache +func (fs *zipVFS) Reconfigure(cfg *config.Config) error { + fs.cacheLock.Lock() + defer fs.cacheLock.Unlock() + + fs.openTimeout = cfg.Zip.OpenTimeout + fs.cacheExpirationInterval = cfg.Zip.ExpirationInterval + fs.cacheRefreshInterval = cfg.Zip.RefreshInterval + fs.cacheCleanupInterval = cfg.Zip.CleanupInterval + + fs.resetCache() + + return nil +} + +func (fs *zipVFS) resetCache() { + fs.cache = cache.New(fs.cacheExpirationInterval, fs.cacheCleanupInterval) + fs.cache.OnEvicted(func(s string, i interface{}) { metrics.ZipCachedEntries.WithLabelValues("archive").Dec() i.(*zipArchive).onEvicted() }) +} - return zipVFS +func (fs *zipVFS) keyFromPath(path string) (string, error) { + // We assume that our URL is https://.../artifacts.zip?content-sign=aaa + // our caching key is `https://.../artifacts.zip` + // TODO: replace caching key with file_sha256 + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/489 + key, err := url.Parse(path) + if err != nil { + return "", err + } + key.RawQuery = "" + key.Fragment = "" + return key.String(), nil } // Root opens an archive given a URL path and returns an instance of zipArchive @@ -70,18 +112,23 @@ func New() vfs.VFS { // to try and find the cached archive or return if there's an error, for example // if the context is canceled. func (fs *zipVFS) Root(ctx context.Context, path string) (vfs.Root, error) { - urlPath, err := url.Parse(path) + key, err := fs.keyFromPath(path) if err != nil { return nil, err } // we do it in loop to not use any additional locks for { - root, err := fs.findOrOpenArchive(ctx, urlPath.String()) + root, err := fs.findOrOpenArchive(ctx, key, path) if err == errAlreadyCached { continue } + // If archive is not found, return a known `vfs` error + if err == httprange.ErrNotFound { + err = &vfs.ErrNotExist{Inner: err} + } + return root, err } } @@ -94,33 +141,53 @@ func (fs *zipVFS) Name() string { // otherwise creates the archive entry in a cache and try to save it, // if saving fails it's because the archive has already been cached // (e.g. by another concurrent request) -func (fs *zipVFS) findOrCreateArchive(ctx context.Context, path string) (*zipArchive, error) { +func (fs *zipVFS) findOrCreateArchive(ctx context.Context, key string) (*zipArchive, error) { // This needs to happen in lock to ensure that // concurrent access will not remove it // it is needed due to the bug https://github.com/patrickmn/go-cache/issues/48 fs.cacheLock.Lock() defer fs.cacheLock.Unlock() - archive, expiry, found := fs.cache.GetWithExpiration(path) + archive, expiry, found := fs.cache.GetWithExpiration(key) if found { - metrics.ZipCacheRequests.WithLabelValues("archive", "hit").Inc() - - // TODO: do not refreshed errored archives https://gitlab.com/gitlab-org/gitlab-pages/-/merge_requests/351 - if time.Until(expiry) < defaultCacheRefreshInterval { - // refresh item - fs.cache.SetDefault(path, archive) + status, _ := archive.(*zipArchive).openStatus() + switch status { + case archiveOpening: + metrics.ZipCacheRequests.WithLabelValues("archive", "hit-opening").Inc() + + case archiveOpenError: + // this means that archive is likely corrupted + // we keep it for duration of cache entry expiry (negative cache) + metrics.ZipCacheRequests.WithLabelValues("archive", "hit-open-error").Inc() + + case archiveOpened: + if time.Until(expiry) < fs.cacheRefreshInterval { + fs.cache.SetDefault(key, archive) + metrics.ZipCacheRequests.WithLabelValues("archive", "hit-refresh").Inc() + } else { + metrics.ZipCacheRequests.WithLabelValues("archive", "hit").Inc() + } + + case archiveCorrupted: + // this means that archive is likely changed + // we should invalidate it immediately + metrics.ZipCacheRequests.WithLabelValues("archive", "corrupted").Inc() + archive = nil } - } else { - archive = newArchive(fs, path, DefaultOpenTimeout) + } + + if archive == nil { + archive = newArchive(fs, fs.openTimeout) // We call delete to ensure that expired item // is properly evicted as there's a bug in a cache library: // https://github.com/patrickmn/go-cache/issues/48 - fs.cache.Delete(path) + fs.cache.Delete(key) // if adding the archive to the cache fails it means it's already been added before // this is done to find concurrent additions. - if fs.cache.Add(path, archive, cache.DefaultExpiration) != nil { + if fs.cache.Add(key, archive, fs.cacheExpirationInterval) != nil { + metrics.ZipCacheRequests.WithLabelValues("archive", "already-cached").Inc() return nil, errAlreadyCached } @@ -132,13 +199,13 @@ func (fs *zipVFS) findOrCreateArchive(ctx context.Context, path string) (*zipArc } // findOrOpenArchive gets archive from cache and tries to open it -func (fs *zipVFS) findOrOpenArchive(ctx context.Context, path string) (*zipArchive, error) { - zipArchive, err := fs.findOrCreateArchive(ctx, path) +func (fs *zipVFS) findOrOpenArchive(ctx context.Context, key, path string) (*zipArchive, error) { + zipArchive, err := fs.findOrCreateArchive(ctx, key) if err != nil { return nil, err } - err = zipArchive.openArchive(ctx) + err = zipArchive.openArchive(ctx, path) if err != nil { return nil, err } diff --git a/internal/vfs/zip/vfs_test.go b/internal/vfs/zip/vfs_test.go index c12e49cd1..ffda1fb6c 100644 --- a/internal/vfs/zip/vfs_test.go +++ b/internal/vfs/zip/vfs_test.go @@ -9,6 +9,8 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/httprange" + "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -25,7 +27,7 @@ func TestVFSRoot(t *testing.T) { }, "zip_file_does_not_exist": { path: "/unknown", - expectedErrMsg: "404 Not Found", + expectedErrMsg: vfs.ErrNotExist{Inner: httprange.ErrNotFound}.Error(), }, "invalid_url": { path: "/%", @@ -33,7 +35,7 @@ func TestVFSRoot(t *testing.T) { }, } - vfs := New() + vfs := New(zipCfg) for name, tt := range tests { t.Run(name, func(t *testing.T) { @@ -71,7 +73,7 @@ func TestVFSFindOrOpenArchiveConcurrentAccess(t *testing.T) { path := testServerURL + "/public.zip" - vfs := New().(*zipVFS) + vfs := New(zipCfg).(*zipVFS) root, err := vfs.Root(context.Background(), path) require.NoError(t, err) @@ -94,38 +96,131 @@ func TestVFSFindOrOpenArchiveConcurrentAccess(t *testing.T) { }() require.Eventually(t, func() bool { - _, err := vfs.findOrOpenArchive(context.Background(), path) + _, err := vfs.findOrOpenArchive(context.Background(), path, path) return err == errAlreadyCached - }, time.Second, time.Nanosecond) + }, 3*time.Second, time.Nanosecond) } -func TestVFSFindOrCreateArchiveCacheEvict(t *testing.T) { +func TestVFSFindOrOpenArchiveRefresh(t *testing.T) { testServerURL, cleanup := newZipFileServerURL(t, "group/zip.gitlab.io/public.zip", nil) defer cleanup() - path := testServerURL + "/public.zip" + // It should be large enough to not have flaky executions + const expiryInterval = 10 * time.Millisecond - vfs := New().(*zipVFS) + tests := map[string]struct { + path string + expirationInterval time.Duration + refreshInterval time.Duration - archivesMetric := metrics.ZipCachedEntries.WithLabelValues("archive") - archivesCount := testutil.ToFloat64(archivesMetric) + expectNewArchive bool + expectOpenError bool + expectArchiveRefreshed bool + }{ + "after cache expiry of successful open a new archive is returned": { + path: "/public.zip", + expirationInterval: expiryInterval, + expectNewArchive: true, + expectOpenError: false, + }, + "after cache expiry of errored open a new archive is returned": { + path: "/unknown.zip", + expirationInterval: expiryInterval, + expectNewArchive: true, + expectOpenError: true, + }, + "subsequent open during refresh interval does refresh archive": { + path: "/public.zip", + expirationInterval: time.Second, + refreshInterval: time.Second, // refresh always + expectNewArchive: false, + expectOpenError: false, + expectArchiveRefreshed: true, + }, + "subsequent open before refresh interval does not refresh archive": { + path: "/public.zip", + expirationInterval: time.Second, + refreshInterval: time.Millisecond, // very short interval should not refresh + expectNewArchive: false, + expectOpenError: false, + expectArchiveRefreshed: false, + }, + "subsequent open of errored archive during refresh interval does not refresh": { + path: "/unknown.zip", + expirationInterval: time.Second, + refreshInterval: time.Second, // refresh always (if not error) + expectNewArchive: false, + expectOpenError: true, + expectArchiveRefreshed: false, + }, + } - // create a new archive and increase counters - archive, err := vfs.findOrOpenArchive(context.Background(), path) - require.NoError(t, err) - require.NotNil(t, archive) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + withExpectedArchiveCount(t, 1, func(t *testing.T) { + cfg := *zipCfg + cfg.ExpirationInterval = test.expirationInterval + cfg.RefreshInterval = test.refreshInterval + + vfs := New(&cfg).(*zipVFS) + + path := testServerURL + test.path + + // create a new archive and increase counters + archive1, err1 := vfs.findOrOpenArchive(context.Background(), path, path) + if test.expectOpenError { + require.Error(t, err1) + require.Nil(t, archive1) + } else { + require.NoError(t, err1) + } + + item1, exp1, found := vfs.cache.GetWithExpiration(path) + require.True(t, found) + + // give some time to for timeouts to fire + time.Sleep(expiryInterval) + + if test.expectNewArchive { + // should return a new archive + archive2, err2 := vfs.findOrOpenArchive(context.Background(), path, path) + if test.expectOpenError { + require.Error(t, err2) + require.Nil(t, archive2) + } else { + require.NoError(t, err2) + require.NotEqual(t, archive1, archive2, "a new archive should be returned") + } + return + } + + // should return exactly the same archive + archive2, err2 := vfs.findOrOpenArchive(context.Background(), path, path) + require.Equal(t, archive1, archive2, "same archive is returned") + require.Equal(t, err1, err2, "same error for the same archive") + + item2, exp2, found := vfs.cache.GetWithExpiration(path) + require.True(t, found) + require.Equal(t, item1, item2, "same item is returned") + + if test.expectArchiveRefreshed { + require.Greater(t, exp2.UnixNano(), exp1.UnixNano(), "archive should be refreshed") + } else { + require.Equal(t, exp1.UnixNano(), exp2.UnixNano(), "archive has not been refreshed") + } + }) + }) + } +} - // inject into cache to be "expired" - // (we could as well wait `defaultCacheExpirationInterval`) - vfs.cache.Set(path, archive, time.Nanosecond) - time.Sleep(time.Nanosecond) +func withExpectedArchiveCount(t *testing.T, archiveCount int, fn func(t *testing.T)) { + t.Helper() - // a new object is created - archive2, err := vfs.findOrOpenArchive(context.Background(), path) - require.NoError(t, err) - require.NotNil(t, archive2) - require.NotEqual(t, archive, archive2, "a different archive is returned") + archivesMetric := metrics.ZipCachedEntries.WithLabelValues("archive") + archivesCount := testutil.ToFloat64(archivesMetric) + + fn(t) archivesCountEnd := testutil.ToFloat64(archivesMetric) - require.Equal(t, float64(1), archivesCountEnd-archivesCount, "all expired archives are evicted") + require.Equal(t, float64(archiveCount), archivesCountEnd-archivesCount, "exact number of archives is cached") } diff --git a/main.go b/main.go index 1d3979225..802b98fda 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/namsral/flag" log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" @@ -29,13 +30,16 @@ var VERSION = "dev" var REVISION = "HEAD" func init() { + // TODO: move all flags to config pkg https://gitlab.com/gitlab-org/gitlab-pages/-/issues/507 flag.Var(&listenHTTP, "listen-http", "The address(es) to listen on for HTTP requests") flag.Var(&listenHTTPS, "listen-https", "The address(es) to listen on for HTTPS requests") flag.Var(&listenProxy, "listen-proxy", "The address(es) to listen on for proxy requests") + flag.Var(&ListenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)") flag.Var(&header, "header", "The additional http header(s) that should be send to the client") } var ( + // TODO: move all flags to config pkg https://gitlab.com/gitlab-org/gitlab-pages/-/issues/507 pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") @@ -66,7 +70,7 @@ var ( gitlabClientHTTPTimeout = flag.Duration("gitlab-client-http-timeout", 10*time.Second, "GitLab API HTTP client connection timeout in seconds (default: 10s)") gitlabClientJWTExpiry = flag.Duration("gitlab-client-jwt-expiry", 30*time.Second, "JWT Token expiry time in seconds (default: 30s)") // TODO: implement functionality for disk, auto and gitlab https://gitlab.com/gitlab-org/gitlab/-/issues/217912 - domainConfigSource = flag.String("domain-config-source", "disk", "Domain configuration source 'disk', 'auto' or 'gitlab' (default: 'disk')") + domainConfigSource = flag.String("domain-config-source", "auto", "Domain configuration source 'disk', 'auto' or 'gitlab' (default: 'auto')") clientID = flag.String("auth-client-id", "", "GitLab application Client ID") clientSecret = flag.String("auth-client-secret", "", "GitLab application Client Secret") redirectURI = flag.String("auth-redirect-uri", "", "GitLab application redirect URI") @@ -74,13 +78,19 @@ var ( insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4") tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsconfig.FlagUsage("min")) tlsMaxVersion = flag.String("tls-max-version", "", tlsconfig.FlagUsage("max")) + // TODO: move all flags to config pkg https://gitlab.com/gitlab-org/gitlab-pages/-/issues/507 + zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval") + zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval") + zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval") + zipOpenTimeout = flag.Duration("zip-open-timeout", 30*time.Second, "Zip archive open timeout") disableCrossOriginRequests = flag.Bool("disable-cross-origin-requests", false, "Disable cross-origin requests") // See init() - listenHTTP MultiStringFlag - listenHTTPS MultiStringFlag - listenProxy MultiStringFlag + listenHTTP MultiStringFlag + listenHTTPS MultiStringFlag + listenProxy MultiStringFlag + ListenHTTPSProxyv2 MultiStringFlag header MultiStringFlag ) @@ -153,7 +163,7 @@ func setGitLabAPISecretKey(secretFile string, config *appConfig) { } if secretLength != 32 { - log.WithError(fmt.Errorf("Expected 32 bytes GitLab API secret but got %d bytes", secretLength)).Fatal("Failed to decode GitLab API secret") + log.WithError(fmt.Errorf("expected 32 bytes GitLab API secret but got %d bytes", secretLength)).Fatal("Failed to decode GitLab API secret") } config.GitLabAPISecretKey = decoded @@ -208,6 +218,11 @@ func configFromFlags() appConfig { config.SentryDSN = *sentryDSN config.SentryEnvironment = *sentryEnvironment + config.ZipCacheExpiry = *zipCacheExpiration + config.ZipCacheCleanup = *zipCacheCleanup + config.ZipCacheRefresh = *zipCacheRefresh + config.ZipeOpenTimeout = *zipOpenTimeout + checkAuthenticationConfig(config) return config @@ -274,6 +289,7 @@ func loadConfig() appConfig { "listen-http": strings.Join(listenHTTP, ","), "listen-https": strings.Join(listenHTTPS, ","), "listen-proxy": strings.Join(listenProxy, ","), + "listen-https-proxyv2": strings.Join(ListenHTTPSProxyv2, ","), "log-format": *logFormat, "metrics-address": *metricsAddress, "pages-domain": *pagesDomain, @@ -291,6 +307,10 @@ func loadConfig() appConfig { "api-secret-key": *gitLabAPISecretKey, "domain-config-source": config.DomainConfigurationSource, "auth-redirect-uri": config.RedirectURI, + "zip-cache-expiration": config.ZipCacheExpiry, + "zip-cache-cleanup": config.ZipCacheCleanup, + "zip-cache-refresh": config.ZipCacheRefresh, + "zip-open-timeout": config.ZipeOpenTimeout, }).Debug("Start daemon with configuration") return config @@ -301,7 +321,9 @@ func appMain() { // read from -config=/path/to/gitlab-pages-config flag.String(flag.DefaultConfigFlagname, "", "path to config file") + flag.Parse() + if err := tlsconfig.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion); err != nil { fatal(err, "invalid TLS version") } @@ -389,6 +411,17 @@ func createAppListeners(config *appConfig) []io.Closer { config.ListenProxy = append(config.ListenProxy, f.Fd()) } + for _, addr := range ListenHTTPSProxyv2.Split() { + l, f := createSocket(addr) + closers = append(closers, l, f) + + log.WithFields(log.Fields{ + "listener": addr, + }).Debug("Set up https proxyv2 listener") + + config.ListenHTTPSProxyv2 = append(config.ListenHTTPSProxyv2, f.Fd()) + } + return closers } diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/server.go b/server.go index 64f8f5f97..678367a3e 100644 --- a/server.go +++ b/server.go @@ -9,10 +9,10 @@ import ( "time" "github.com/gorilla/context" + proxyproto "github.com/pires/go-proxyproto" "golang.org/x/net/http2" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" - "gitlab.com/gitlab-org/gitlab-pages/internal/tlsconfig" ) type keepAliveListener struct { @@ -37,7 +37,7 @@ func (ln *keepAliveListener) Accept() (net.Conn, error) { return conn, nil } -func listenAndServe(fd uintptr, handler http.Handler, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter) error { +func listenAndServe(fd uintptr, handler http.Handler, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter, proxyv2 bool) error { // create server server := &http.Server{Handler: context.ClearHandler(handler), TLSConfig: tlsConfig} @@ -57,18 +57,20 @@ func listenAndServe(fd uintptr, handler http.Handler, useHTTP2 bool, tlsConfig * l = netutil.SharedLimitListener(l, limiter) } - if tlsConfig != nil { - tlsListener := tls.NewListener(&keepAliveListener{l}, server.TLSConfig) - return server.Serve(tlsListener) + l = &keepAliveListener{l} + + if proxyv2 { + l = &proxyproto.Listener{ + Listener: l, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + return proxyproto.REQUIRE, nil + }, + } } - return server.Serve(&keepAliveListener{l}) -} -func listenAndServeTLS(fd uintptr, cert, key []byte, handler http.Handler, getCertificate tlsconfig.GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16, useHTTP2 bool, limiter *netutil.Limiter) error { - tlsConfig, err := tlsconfig.Create(cert, key, getCertificate, insecureCiphers, tlsMinVersion, tlsMaxVersion) - if err != nil { - return err + if tlsConfig != nil { + l = tls.NewListener(l, server.TLSConfig) } - return listenAndServe(fd, handler, useHTTP2, tlsConfig, limiter) + return server.Serve(l) } diff --git a/shared/lookups/zip-malformed.gitlab.io.json b/shared/lookups/zip-malformed.gitlab.io.json new file mode 100644 index 000000000..8c0185dac --- /dev/null +++ b/shared/lookups/zip-malformed.gitlab.io.json @@ -0,0 +1,16 @@ +{ + "certificate": "", + "key": "", + "lookup_paths": [ + { + "access_control": false, + "https_only": false, + "prefix": "/", + "project_id": 123, + "source": { + "path": "http://127.0.0.1:38001/malformed.zip", + "type": "zip" + } + } + ] +} diff --git a/shared/lookups/zip-not-found.gitlab.io.json b/shared/lookups/zip-not-found.gitlab.io.json new file mode 100644 index 000000000..514b8ff2b --- /dev/null +++ b/shared/lookups/zip-not-found.gitlab.io.json @@ -0,0 +1,16 @@ +{ + "certificate": "", + "key": "", + "lookup_paths": [ + { + "access_control": false, + "https_only": false, + "prefix": "/", + "project_id": 123, + "source": { + "path": "http://127.0.0.1:38001/not-found.zip", + "type": "zip" + } + } + ] +} diff --git a/shared/lookups/zip.gitlab.io.json b/shared/lookups/zip.gitlab.io.json index cf755a582..0549adc82 100644 --- a/shared/lookups/zip.gitlab.io.json +++ b/shared/lookups/zip.gitlab.io.json @@ -8,7 +8,7 @@ "prefix": "/", "project_id": 123, "source": { - "path": "http://127.0.0.1:37003/public.zip", + "path": "http://127.0.0.1:38001/public.zip", "type": "zip" } } diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go new file mode 100644 index 000000000..ba6528c10 --- /dev/null +++ b/test/acceptance/acceptance_test.go @@ -0,0 +1,81 @@ +package acceptance_test + +import ( + "flag" + "fmt" + "log" + "os" + "testing" + + "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" +) + +const ( + objectStorageMockServer = "127.0.0.1:38001" +) + +var ( + pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output + // the actual port (and type of listener) for us to read in place of the + // hardcoded values below. + listeners = []ListenSpec{ + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, + } + + httpListener = listeners[0] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] +) + +func TestMain(m *testing.M) { + flag.Parse() + + if testing.Short() { + log.Println("Acceptance tests disabled") + os.Exit(0) + } + + if _, err := os.Stat(*pagesBinary); os.IsNotExist(err) { + log.Fatalf("Couldn't find gitlab-pages binary at %s\n", *pagesBinary) + } + + if ok := TestCertPool.AppendCertsFromPEM([]byte(fixture.Certificate)); !ok { + fmt.Println("Failed to load cert!") + } + + os.Exit(m.Run()) +} + +func skipUnlessEnabled(t *testing.T, conditions ...string) { + t.Helper() + + for _, condition := range conditions { + switch condition { + case "not-inplace-chroot": + if os.Getenv("TEST_DAEMONIZE") == "inplace" { + t.Log("Not supported with -daemon-inplace-chroot") + t.SkipNow() + } + default: + t.Error("Unknown condition:", condition) + t.FailNow() + } + } +} diff --git a/test/acceptance/acme_test.go b/test/acceptance/acme_test.go new file mode 100644 index 000000000..a0425b7d7 --- /dev/null +++ b/test/acceptance/acme_test.go @@ -0,0 +1,73 @@ +package acceptance_test + +import ( + "io/ioutil" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAcmeChallengesWhenItIsNotConfigured(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "") + defer teardown() + + t.Run("When domain folder contains requested acme challenge it responds with it", func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", + existingAcmeTokenPath) + + defer rsp.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + body, _ := ioutil.ReadAll(rsp.Body) + require.Equal(t, "this is token\n", string(body)) + }) + + t.Run("When domain folder doesn't contains requested acme challenge it returns 404", + func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", + notExistingAcmeTokenPath) + + defer rsp.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, rsp.StatusCode) + }, + ) +} + +func TestAcmeChallengesWhenItIsConfigured(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-gitlab-server=https://gitlab-acme.com") + defer teardown() + + t.Run("When domain folder contains requested acme challenge it responds with it", func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", + existingAcmeTokenPath) + + defer rsp.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + body, _ := ioutil.ReadAll(rsp.Body) + require.Equal(t, "this is token\n", string(body)) + }) + + t.Run("When domain folder doesn't contains requested acme challenge it redirects to GitLab", + func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com", + notExistingAcmeTokenPath) + + defer rsp.Body.Close() + require.NoError(t, err) + require.Equal(t, http.StatusTemporaryRedirect, rsp.StatusCode) + + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + require.Equal(t, url.String(), "https://gitlab-acme.com/-/acme-challenge?domain=withacmechallenge.domain.com&token=notexistingtoken") + }, + ) +} diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go new file mode 100644 index 000000000..57c7a02a9 --- /dev/null +++ b/test/acceptance/artifacts_test.go @@ -0,0 +1,299 @@ +package acceptance_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestArtifactProxyRequest(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + transport := (TestHTTPSClient.Transport).(*http.Transport) + defer func(t time.Duration) { + transport.ResponseHeaderTimeout = t + }(transport.ResponseHeaderTimeout) + transport.ResponseHeaderTimeout = 5 * time.Second + + content := "Title of the document" + contentLength := int64(len(content)) + testServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.RawPath { + case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/delayed_200.html": + time.Sleep(2 * time.Second) + fallthrough + case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/200.html", + "/api/v4/projects/group%2Fsubgroup%2Fproject/jobs/1/artifacts/200.html": + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, content) + case "/api/v4/projects/group%2Fproject/jobs/1/artifacts/500.html": + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, content) + default: + t.Logf("Unexpected r.URL.RawPath: %q", r.URL.RawPath) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, content) + } + })) + + keyFile, certFile := CreateHTTPSFixtureFiles(t) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + defer os.Remove(keyFile) + defer os.Remove(certFile) + + testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + testServer.StartTLS() + defer testServer.Close() + + tests := []struct { + name string + host string + path string + status int + binaryOption string + content string + length int64 + cacheControl string + contentType string + }{ + { + name: "basic proxied request", + host: "group.gitlab-example.com", + path: "/-/project/-/jobs/1/artifacts/200.html", + status: http.StatusOK, + binaryOption: "", + content: content, + length: contentLength, + cacheControl: "max-age=3600", + contentType: "text/html; charset=utf-8", + }, + { + name: "basic proxied request for subgroup", + host: "group.gitlab-example.com", + path: "/-/subgroup/project/-/jobs/1/artifacts/200.html", + status: http.StatusOK, + binaryOption: "", + content: content, + length: contentLength, + cacheControl: "max-age=3600", + contentType: "text/html; charset=utf-8", + }, + { + name: "502 error while attempting to proxy", + host: "group.gitlab-example.com", + path: "/-/project/-/jobs/1/artifacts/delayed_200.html", + status: http.StatusBadGateway, + binaryOption: "-artifacts-server-timeout=1", + content: "", + length: 0, + cacheControl: "", + contentType: "text/html; charset=utf-8", + }, + { + name: "Proxying 404 from server", + host: "group.gitlab-example.com", + path: "/-/project/-/jobs/1/artifacts/404.html", + status: http.StatusNotFound, + binaryOption: "", + content: "", + length: 0, + cacheControl: "", + contentType: "text/html; charset=utf-8", + }, + { + name: "Proxying 500 from server", + host: "group.gitlab-example.com", + path: "/-/project/-/jobs/1/artifacts/500.html", + status: http.StatusInternalServerError, + binaryOption: "", + content: "", + length: 0, + cacheControl: "", + contentType: "text/html; charset=utf-8", + }, + } + + // Ensure the IP address is used in the URL, as we're relying on IP SANs to + // validate + artifactServerURL := testServer.URL + "/api/v4" + t.Log("Artifact server URL", artifactServerURL) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + teardown := RunPagesProcessWithSSLCertFile( + t, + *pagesBinary, + listeners, + "", + certFile, + "-artifacts-server="+artifactServerURL, + tt.binaryOption, + ) + defer teardown() + + resp, err := GetPageFromListener(t, httpListener, tt.host, tt.path) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, tt.status, resp.StatusCode) + require.Equal(t, tt.contentType, resp.Header.Get("Content-Type")) + + if !((tt.status == http.StatusBadGateway) || (tt.status == http.StatusNotFound) || (tt.status == http.StatusInternalServerError)) { + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, tt.content, string(body)) + require.Equal(t, tt.length, resp.ContentLength) + require.Equal(t, tt.cacheControl, resp.Header.Get("Cache-Control")) + } + }) + } +} + +func TestPrivateArtifactProxyRequest(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + setupTransport(t) + + testServer := makeGitLabPagesAccessStub(t) + + keyFile, certFile := CreateHTTPSFixtureFiles(t) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + defer os.Remove(keyFile) + defer os.Remove(certFile) + + testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + testServer.StartTLS() + defer testServer.Close() + + tests := []struct { + name string + host string + path string + status int + binaryOption string + }{ + { + name: "basic proxied request for private project", + host: "group.gitlab-example.com", + path: "/-/private/-/jobs/1/artifacts/200.html", + status: http.StatusOK, + binaryOption: "", + }, + { + name: "basic proxied request for subgroup", + host: "group.gitlab-example.com", + path: "/-/subgroup/private/-/jobs/1/artifacts/200.html", + status: http.StatusOK, + binaryOption: "", + }, + { + name: "502 error while attempting to proxy", + host: "group.gitlab-example.com", + path: "/-/private/-/jobs/1/artifacts/delayed_200.html", + status: http.StatusBadGateway, + binaryOption: "artifacts-server-timeout=1", + }, + { + name: "Proxying 404 from server", + host: "group.gitlab-example.com", + path: "/-/private/-/jobs/1/artifacts/404.html", + status: http.StatusNotFound, + binaryOption: "", + }, + { + name: "Proxying 500 from server", + host: "group.gitlab-example.com", + path: "/-/private/-/jobs/1/artifacts/500.html", + status: http.StatusInternalServerError, + binaryOption: "", + }, + } + + // Ensure the IP address is used in the URL, as we're relying on IP SANs to + // validate + artifactServerURL := testServer.URL + "/api/v4" + t.Log("Artifact server URL", artifactServerURL) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configFile, cleanup := defaultConfigFileWith(t, + "artifacts-server="+artifactServerURL, + "auth-server="+testServer.URL, + "auth-redirect-uri=https://projects.gitlab-example.com/auth", + tt.binaryOption) + defer cleanup() + + teardown := RunPagesProcessWithSSLCertFile( + t, + *pagesBinary, + listeners, + "", + certFile, + "-config="+configFile, + ) + defer teardown() + + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusFound, resp.StatusCode) + + cookie := resp.Header.Get("Set-Cookie") + + // Redirects to the projects under gitlab pages domain for authentication flow + url, err := url.Parse(resp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, "projects.gitlab-example.com", url.Host) + require.Equal(t, "/auth", url.Path) + state := url.Query().Get("state") + + resp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) + + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusFound, resp.StatusCode) + pagesDomainCookie := resp.Header.Get("Set-Cookie") + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetRedirectPageWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ + state, pagesDomainCookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will redirect auth callback to correct host + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, tt.host, url.Host) + require.Equal(t, "/auth", url.Path) + + // Request auth callback in project domain + authrsp, err = GetRedirectPageWithCookie(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery, cookie) + require.NoError(t, err) + + // server returns the ticket, user will be redirected to the project page + require.Equal(t, http.StatusFound, authrsp.StatusCode) + cookie = authrsp.Header.Get("Set-Cookie") + resp, err = GetRedirectPageWithCookie(t, httpsListener, tt.host, tt.path, cookie) + + require.Equal(t, tt.status, resp.StatusCode) + + require.NoError(t, err) + defer resp.Body.Close() + }) + } +} diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go new file mode 100644 index 000000000..fa2d768d8 --- /dev/null +++ b/test/acceptance/auth_test.go @@ -0,0 +1,730 @@ +package acceptance_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "regexp" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWhenAuthIsDisabledPrivateIsNotAccessible(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.auth.gitlab-example.com", "private.project/") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusInternalServerError, rsp.StatusCode) +} + +func TestWhenAuthIsEnabledPrivateWillRedirectToAuthorize(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") + + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusFound, rsp.StatusCode) + require.Equal(t, 1, len(rsp.Header["Location"])) + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + rsp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) + require.NoError(t, err) + + require.Equal(t, http.StatusFound, rsp.StatusCode) + require.Equal(t, 1, len(rsp.Header["Location"])) + + url, err = url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + require.Equal(t, "https", url.Scheme) + require.Equal(t, "gitlab-auth.com", url.Host) + require.Equal(t, "/oauth/authorize", url.Path) + require.Equal(t, "clientID", url.Query().Get("client_id")) + require.Equal(t, "https://projects.gitlab-example.com/auth", url.Query().Get("redirect_uri")) + require.NotEqual(t, "", url.Query().Get("state")) +} + +func TestWhenAuthDeniedWillCauseUnauthorized(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?error=access_denied") + + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusUnauthorized, rsp.StatusCode) +} +func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") + + require.NoError(t, err) + defer rsp.Body.Close() + + // Go to auth page with wrong state will cause failure + authrsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?code=0&state=0") + + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) +} + +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/") + + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetPageFromListenerWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ + url.Query().Get("state"), cookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) +} + +func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { + authorization := r.Header.Get("Authorization") + + switch { + case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/delayed_200.html`).MatchString(r.URL.Path): + sleepIfAuthorized(t, authorization, w) + return true + case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/404.html`).MatchString(r.URL.Path): + w.WriteHeader(http.StatusNotFound) + return true + case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/500.html`).MatchString(r.URL.Path): + returnIfAuthorized(t, authorization, w, http.StatusInternalServerError) + return true + case regexp.MustCompile(`/api/v4/projects/group/private/jobs/\d+/artifacts/200.html`).MatchString(r.URL.Path): + returnIfAuthorized(t, authorization, w, http.StatusOK) + return true + case regexp.MustCompile(`/api/v4/projects/group/subgroup/private/jobs/\d+/artifacts/200.html`).MatchString(r.URL.Path): + returnIfAuthorized(t, authorization, w, http.StatusOK) + return true + default: + return false + } +} + +func handleAccessControlRequests(t *testing.T, w http.ResponseWriter, r *http.Request) { + allowedProjects := regexp.MustCompile(`/api/v4/projects/1\d{3}/pages_access`) + deniedProjects := regexp.MustCompile(`/api/v4/projects/2\d{3}/pages_access`) + invalidTokenProjects := regexp.MustCompile(`/api/v4/projects/3\d{3}/pages_access`) + + switch { + case allowedProjects.MatchString(r.URL.Path): + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + case deniedProjects.MatchString(r.URL.Path): + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusUnauthorized) + case invalidTokenProjects.MatchString(r.URL.Path): + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, "{\"error\":\"invalid_token\"}") + default: + t.Logf("Unexpected r.URL.RawPath: %q", r.URL.Path) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + } +} + +func returnIfAuthorized(t *testing.T, authorization string, w http.ResponseWriter, status int) { + if authorization != "" { + require.Equal(t, "Bearer abc", authorization) + w.WriteHeader(status) + } else { + w.WriteHeader(http.StatusNotFound) + } +} + +func sleepIfAuthorized(t *testing.T, authorization string, w http.ResponseWriter) { + if authorization != "" { + require.Equal(t, "Bearer abc", authorization) + time.Sleep(2 * time.Second) + } else { + w.WriteHeader(http.StatusNotFound) + } +} + +func TestAccessControlUnderCustomDomain(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + tests := map[string]struct { + domain string + path string + }{ + "private_domain": { + domain: "private.domain.com", + path: "", + }, + "private_domain_with_query": { + domain: "private.domain.com", + path: "?q=test", + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := url.Query().Get("state") + require.Equal(t, "http://"+tt.domain, url.Query().Get("domain")) + + pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) + require.NoError(t, err) + defer pagesrsp.Body.Close() + + pagescookie := pagesrsp.Header.Get("Set-Cookie") + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + state, pagescookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain + require.Equal(t, tt.domain, url.Host) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) + + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ + state, cookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will redirect to the page + cookie = authrsp.Header.Get("Set-Cookie") + require.Equal(t, http.StatusFound, authrsp.StatusCode) + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain + require.Equal(t, "http://"+tt.domain+"/"+tt.path, url.String()) + + // Fetch page in custom domain + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, cookie) + require.NoError(t, err) + require.Equal(t, http.StatusOK, authrsp.StatusCode) + }) + } +} + +func TestCustomErrorPageWithAuth(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + tests := []struct { + name string + domain string + path string + expectedErrorPage string + }{ + { + name: "private_project_authorized", + domain: "group.404.gitlab-example.com", + path: "/private_project/unknown", + expectedErrorPage: "Private custom 404 error page", + }, + { + name: "public_namespace_with_private_unauthorized_project", + domain: "group.404.gitlab-example.com", + // /private_unauthorized/config.json resolves project ID to 2000 which will cause a 401 from the mock GitLab testServer + path: "/private_unauthorized/unknown", + expectedErrorPage: "Custom 404 group page", + }, + { + name: "private_namespace_authorized", + domain: "group.auth.gitlab-example.com", + path: "/unknown", + expectedErrorPage: "group.auth.gitlab-example.com namespace custom 404", + }, + { + name: "private_namespace_with_private_project_auth_failed", + domain: "group.auth.gitlab-example.com", + // project ID is 2000 + path: "/private.project.1/unknown", + expectedErrorPage: "The page you're looking for could not be found.", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := url.Query().Get("state") + require.Equal(t, "http://"+tt.domain, url.Query().Get("domain")) + + pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) + require.NoError(t, err) + defer pagesrsp.Body.Close() + + pagescookie := pagesrsp.Header.Get("Set-Cookie") + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetRedirectPageWithCookie(t, httpListener, "projects.gitlab-example.com", "/auth?code=1&state="+ + state, pagescookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain + require.Equal(t, tt.domain, url.Host) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) + require.Equal(t, state, url.Query().Get("state")) + + // Run auth callback in custom domain + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ + state, cookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will redirect to the page + groupCookie := authrsp.Header.Get("Set-Cookie") + require.Equal(t, http.StatusFound, authrsp.StatusCode) + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain error page + require.Equal(t, "http://"+tt.domain+tt.path, url.String()) + + // Fetch page in custom domain + anotherResp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, groupCookie) + require.NoError(t, err) + + require.Equal(t, http.StatusNotFound, anotherResp.StatusCode) + + page, err := ioutil.ReadAll(anotherResp.Body) + require.NoError(t, err) + require.Contains(t, string(page), tt.expectedErrorPage) + }) + } +} + +func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := url.Query().Get("state") + require.Equal(t, url.Query().Get("domain"), "https://private.domain.com") + pagesrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, url.Host, url.Path+"?"+url.RawQuery, "", true) + require.NoError(t, err) + defer pagesrsp.Body.Close() + + pagescookie := pagesrsp.Header.Get("Set-Cookie") + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+state, + pagescookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain + require.Equal(t, "private.domain.com", url.Host) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) + require.Equal(t, state, url.Query().Get("state")) + + // Run auth callback in custom domain + authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", + "/auth?code="+code+"&state="+state, cookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will redirect to the page + cookie = authrsp.Header.Get("Set-Cookie") + require.Equal(t, http.StatusFound, authrsp.StatusCode) + + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + + // Will redirect to custom domain + require.Equal(t, "https://private.domain.com/", url.String()) + // Fetch page in custom domain + authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/", + cookie, true) + require.NoError(t, err) + require.Equal(t, http.StatusOK, authrsp.StatusCode) +} + +func TestAccessControlGroupDomain404RedirectsAuth(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/nonexistent/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusFound, rsp.StatusCode) + // Redirects to the projects under gitlab pages domain for authentication flow + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, "projects.gitlab-example.com", url.Host) + require.Equal(t, "/auth", url.Path) +} +func TestAccessControlProject404DoesNotRedirect(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/project/nonexistent/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusNotFound, rsp.StatusCode) +} + +func setupTransport(t *testing.T) { + transport := (TestHTTPSClient.Transport).(*http.Transport) + defer func(t time.Duration) { + transport.ResponseHeaderTimeout = t + }(transport.ResponseHeaderTimeout) + transport.ResponseHeaderTimeout = 5 * time.Second +} + +type runPagesFunc func(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() + +func testAccessControl(t *testing.T, runPages runPagesFunc) { + skipUnlessEnabled(t, "not-inplace-chroot") + + setupTransport(t) + + keyFile, certFile := CreateHTTPSFixtureFiles(t) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + defer os.Remove(keyFile) + defer os.Remove(certFile) + + testServer := makeGitLabPagesAccessStub(t) + testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} + testServer.StartTLS() + defer testServer.Close() + + tests := []struct { + host string + path string + status int + redirectBack bool + name string + }{ + { + name: "project with access", + host: "group.auth.gitlab-example.com", + path: "/private.project/", + status: http.StatusOK, + redirectBack: false, + }, + { + name: "project without access", + host: "group.auth.gitlab-example.com", + path: "/private.project.1/", + status: http.StatusNotFound, // Do not expose project existed + redirectBack: false, + }, + { + name: "invalid token test should redirect back", + host: "group.auth.gitlab-example.com", + path: "/private.project.2/", + status: http.StatusFound, + redirectBack: true, + }, + { + name: "no project should redirect to login and then return 404", + host: "group.auth.gitlab-example.com", + path: "/nonexistent/", + status: http.StatusNotFound, + redirectBack: false, + }, + { + name: "no project should redirect to login and then return 404", + host: "nonexistent.gitlab-example.com", + path: "/nonexistent/", + status: http.StatusNotFound, + redirectBack: false, + }, // subgroups + { + name: "[subgroup] project with access", + host: "group.auth.gitlab-example.com", + path: "/subgroup/private.project/", + status: http.StatusOK, + redirectBack: false, + }, + { + name: "[subgroup] project without access", + host: "group.auth.gitlab-example.com", + path: "/subgroup/private.project.1/", + status: http.StatusNotFound, // Do not expose project existed + redirectBack: false, + }, + { + name: "[subgroup] invalid token test should redirect back", + host: "group.auth.gitlab-example.com", + path: "/subgroup/private.project.2/", + status: http.StatusFound, + redirectBack: true, + }, + { + name: "[subgroup] no project should redirect to login and then return 404", + host: "group.auth.gitlab-example.com", + path: "/subgroup/nonexistent/", + status: http.StatusNotFound, + redirectBack: false, + }, + { + name: "[subgroup] no project should redirect to login and then return 404", + host: "nonexistent.gitlab-example.com", + path: "/subgroup/nonexistent/", + status: http.StatusNotFound, + redirectBack: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + teardown := runPages(t, *pagesBinary, listeners, "", certFile, testServer.URL) + defer teardown() + + rsp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) + + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusFound, rsp.StatusCode) + cookie := rsp.Header.Get("Set-Cookie") + + // Redirects to the projects under gitlab pages domain for authentication flow + url, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, "projects.gitlab-example.com", url.Host) + require.Equal(t, "/auth", url.Path) + state := url.Query().Get("state") + + rsp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery) + + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusFound, rsp.StatusCode) + pagesDomainCookie := rsp.Header.Get("Set-Cookie") + + // Go to auth page with correct state will cause fetching the token + authrsp, err := GetRedirectPageWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ + state, pagesDomainCookie) + + require.NoError(t, err) + defer authrsp.Body.Close() + + // Will redirect auth callback to correct host + url, err = url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, tt.host, url.Host) + require.Equal(t, "/auth", url.Path) + + // Request auth callback in project domain + authrsp, err = GetRedirectPageWithCookie(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery, cookie) + require.NoError(t, err) + + // server returns the ticket, user will be redirected to the project page + require.Equal(t, http.StatusFound, authrsp.StatusCode) + cookie = authrsp.Header.Get("Set-Cookie") + rsp, err = GetRedirectPageWithCookie(t, httpsListener, tt.host, tt.path, cookie) + + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, tt.status, rsp.StatusCode) + require.Equal(t, "", rsp.Header.Get("Cache-Control")) + + if tt.redirectBack { + url, err = url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + require.Equal(t, "https", url.Scheme) + require.Equal(t, tt.host, url.Host) + require.Equal(t, tt.path, url.Path) + } + }) + } +} + +func TestAccessControlWithSSLCertFile(t *testing.T) { + testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertFile) +} + +func TestAccessControlWithSSLCertDir(t *testing.T) { + testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) +} + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/config_test.go b/test/acceptance/config_test.go new file mode 100644 index 000000000..93e9aa22e --- /dev/null +++ b/test/acceptance/config_test.go @@ -0,0 +1,66 @@ +package acceptance_test + +import ( + "fmt" + "net" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEnvironmentVariablesConfig(t *testing.T) { + skipUnlessEnabled(t) + os.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) + defer func() { os.Unsetenv("LISTEN_HTTP") }() + + teardown := RunPagesProcessWithoutWait(t, *pagesBinary, []ListenSpec{}, "") + defer teardown() + require.NoError(t, httpListener.WaitUntilRequestSucceeds(nil)) + + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com:", "project/") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestMixedConfigSources(t *testing.T) { + skipUnlessEnabled(t) + os.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) + defer func() { os.Unsetenv("LISTEN_HTTP") }() + + teardown := RunPagesProcessWithoutWait(t, *pagesBinary, []ListenSpec{httpsListener}, "") + defer teardown() + + for _, listener := range []ListenSpec{httpListener, httpsListener} { + require.NoError(t, listener.WaitUntilRequestSucceeds(nil)) + rsp, err := GetPageFromListener(t, listener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + rsp.Body.Close() + + require.Equal(t, http.StatusOK, rsp.StatusCode) + } +} + +func TestMultiFlagEnvironmentVariables(t *testing.T) { + skipUnlessEnabled(t) + listenSpecs := []ListenSpec{{"http", "127.0.0.1", "37001"}, {"http", "127.0.0.1", "37002"}} + envVarValue := fmt.Sprintf("%s,%s", net.JoinHostPort("127.0.0.1", "37001"), net.JoinHostPort("127.0.0.1", "37002")) + + os.Setenv("LISTEN_HTTP", envVarValue) + defer func() { os.Unsetenv("LISTEN_HTTP") }() + + teardown := RunPagesProcess(t, *pagesBinary, []ListenSpec{}, "") + defer teardown() + + for _, listener := range listenSpecs { + require.NoError(t, listener.WaitUntilRequestSucceeds(nil)) + rsp, err := GetPageFromListener(t, listener, "group.gitlab-example.com", "project/") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) + } +} diff --git a/test/acceptance/encodings_test.go b/test/acceptance/encodings_test.go new file mode 100644 index 000000000..9b8742053 --- /dev/null +++ b/test/acceptance/encodings_test.go @@ -0,0 +1,78 @@ +package acceptance_test + +import ( + "mime" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMIMETypes(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "") + defer teardown() + + require.NoError(t, httpListener.WaitUntilRequestSucceeds(nil)) + + tests := map[string]struct { + file string + expectedContentType string + }{ + "manifest_json": { + file: "file.webmanifest", + expectedContentType: "application/manifest+json", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/"+tt.file) + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusOK, rsp.StatusCode) + mt, _, err := mime.ParseMediaType(rsp.Header.Get("Content-Type")) + require.NoError(t, err) + require.Equal(t, tt.expectedContentType, mt) + }) + } +} + +func TestCompressedEncoding(t *testing.T) { + skipUnlessEnabled(t) + + tests := []struct { + name string + host string + path string + encoding string + }{ + { + "gzip encoding", + "group.gitlab-example.com", + "index.html", + "gzip", + }, + { + "brotli encoding", + "group.gitlab-example.com", + "index.html", + "br", + }, + } + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rsp, err := GetCompressedPageFromListener(t, httpListener, "group.gitlab-example.com", "index.html", tt.encoding) + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.Equal(t, tt.encoding, rsp.Header.Get("Content-Encoding")) + }) + } +} diff --git a/helpers_test.go b/test/acceptance/helpers_test.go similarity index 63% rename from helpers_test.go rename to test/acceptance/helpers_test.go index eec3c94be..d228f787b 100644 --- a/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -1,7 +1,8 @@ -package main +package acceptance_test import ( "bytes" + "context" "crypto/tls" "crypto/x509" "fmt" @@ -14,28 +15,21 @@ import ( "os/exec" "path" "strings" + "sync" "testing" "time" + proxyproto "github.com/pires/go-proxyproto" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) -type tWriter struct { - t *testing.T -} - -func (t *tWriter) Write(b []byte) (int, error) { - t.t.Log(string(bytes.TrimRight(b, "\r\n"))) - - return len(b), nil -} - // The HTTPS certificate isn't signed by anyone. This http client is set up // so it can talk to servers using it. var ( + // The HTTPS certificate isn't signed by anyone. This http client is set up + // so it can talk to servers using it. TestHTTPSClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{RootCAs: TestCertPool}, @@ -51,40 +45,91 @@ var ( }, } + // Proxyv2 client + TestProxyv2Client = &http.Client{ + Transport: &http.Transport{ + DialContext: Proxyv2DialContext, + TLSClientConfig: &tls.Config{RootCAs: TestCertPool}, + }, + } + + QuickTimeoutProxyv2Client = &http.Client{ + Transport: &http.Transport{ + DialContext: Proxyv2DialContext, + TLSClientConfig: &tls.Config{RootCAs: TestCertPool}, + ResponseHeaderTimeout: 100 * time.Millisecond, + }, + } + TestCertPool = x509.NewCertPool() -) -func init() { - if ok := TestCertPool.AppendCertsFromPEM([]byte(fixture.Certificate)); !ok { - fmt.Println("Failed to load cert!") + // Proxyv2 will create a dummy request with src 10.1.1.1:1000 + // and dst 20.2.2.2:2000 + Proxyv2DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddress: net.ParseIP("10.1.1.1"), + SourcePort: 1000, + DestinationAddress: net.ParseIP("20.2.2.2"), + DestinationPort: 2000, + } + + _, err = header.WriteTo(conn) + + return conn, err } + + existingAcmeTokenPath = "/.well-known/acme-challenge/existingtoken" + notExistingAcmeTokenPath = "/.well-known/acme-challenge/notexistingtoken" +) + +type tWriter struct { + t *testing.T } -func CreateHTTPSFixtureFiles(t *testing.T) (key string, cert string) { - keyfile, err := ioutil.TempFile("", "https-fixture") - require.NoError(t, err) - key = keyfile.Name() - keyfile.Close() +func (t *tWriter) Write(b []byte) (int, error) { + t.t.Log(string(bytes.TrimRight(b, "\r\n"))) - certfile, err := ioutil.TempFile("", "https-fixture") - require.NoError(t, err) - cert = certfile.Name() - certfile.Close() + return len(b), nil +} - require.NoError(t, ioutil.WriteFile(key, []byte(fixture.Key), 0644)) - require.NoError(t, ioutil.WriteFile(cert, []byte(fixture.Certificate), 0644)) +type LogCaptureBuffer struct { + b bytes.Buffer + m sync.Mutex +} - return keyfile.Name(), certfile.Name() +func (b *LogCaptureBuffer) Read(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + + return b.b.Read(p) } +func (b *LogCaptureBuffer) Write(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() -func CreateGitLabAPISecretKeyFixtureFile(t *testing.T) (filepath string) { - secretfile, err := ioutil.TempFile("", "gitlab-api-secret") - require.NoError(t, err) - secretfile.Close() + return b.b.Write(p) +} +func (b *LogCaptureBuffer) String() string { + b.m.Lock() + defer b.m.Unlock() - require.NoError(t, ioutil.WriteFile(secretfile.Name(), []byte(fixture.GitLabAPISecretKey), 0644)) + return b.b.String() +} +func (b *LogCaptureBuffer) Reset() { + b.m.Lock() + defer b.m.Unlock() - return secretfile.Name() + b.b.Reset() } // ListenSpec is used to point at a gitlab-pages http server, preserving the @@ -97,7 +142,7 @@ type ListenSpec struct { func (l ListenSpec) URL(suffix string) string { scheme := request.SchemeHTTP - if l.Type == request.SchemeHTTPS { + if l.Type == request.SchemeHTTPS || l.Type == "https-proxyv2" { scheme = request.SchemeHTTPS } @@ -121,7 +166,12 @@ func (l ListenSpec) WaitUntilRequestSucceeds(done chan struct{}) error { return err } - response, err := QuickTimeoutHTTPSClient.Transport.RoundTrip(req) + client := QuickTimeoutHTTPSClient + if l.Type == "https-proxyv2" { + client = QuickTimeoutProxyv2Client + } + + response, err := client.Transport.RoundTrip(req) if err != nil { time.Sleep(100 * time.Millisecond) continue @@ -147,30 +197,38 @@ func (l ListenSpec) JoinHostPort() string { // GetPageFromProcess to do a HTTP GET against a listener. // // If run as root via sudo, the gitlab-pages process will drop privileges -func RunPagesProcess(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, extraArgs ...string) (teardown func()) { - return runPagesProcess(t, true, pagesPath, listeners, promPort, nil, extraArgs...) +func RunPagesProcess(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, extraArgs ...string) (teardown func()) { + _, cleanup := runPagesProcess(t, true, pagesBinary, listeners, promPort, nil, extraArgs...) + return cleanup +} + +func RunPagesProcessWithoutWait(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, extraArgs ...string) (teardown func()) { + _, cleanup := runPagesProcess(t, false, pagesBinary, listeners, promPort, nil, extraArgs...) + return cleanup } -func RunPagesProcessWithoutWait(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, extraArgs ...string) (teardown func()) { - return runPagesProcess(t, false, pagesPath, listeners, promPort, nil, extraArgs...) +func RunPagesProcessWithSSLCertFile(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, sslCertFile string, extraArgs ...string) (teardown func()) { + _, cleanup := runPagesProcess(t, true, pagesBinary, listeners, promPort, []string{"SSL_CERT_FILE=" + sslCertFile}, extraArgs...) + return cleanup } -func RunPagesProcessWithSSLCertFile(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, sslCertFile string, extraArgs ...string) (teardown func()) { - return runPagesProcess(t, true, pagesPath, listeners, promPort, []string{"SSL_CERT_FILE=" + sslCertFile}, extraArgs...) +func RunPagesProcessWithEnvs(t *testing.T, wait bool, pagesBinary string, listeners []ListenSpec, promPort string, envs []string, extraArgs ...string) (teardown func()) { + _, cleanup := runPagesProcess(t, wait, pagesBinary, listeners, promPort, envs, extraArgs...) + return cleanup } -func RunPagesProcessWithEnvs(t *testing.T, wait bool, pagesPath string, listeners []ListenSpec, promPort string, envs []string, extraArgs ...string) (teardown func()) { - return runPagesProcess(t, wait, pagesPath, listeners, promPort, envs, extraArgs...) +func RunPagesProcessWithOutput(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, extraArgs ...string) (out *LogCaptureBuffer, teardown func()) { + return runPagesProcess(t, true, pagesBinary, listeners, promPort, nil, extraArgs...) } -func RunPagesProcessWithStubGitLabServer(t *testing.T, wait bool, pagesPath string, listeners []ListenSpec, promPort string, envs []string, extraArgs ...string) (teardown func()) { +func RunPagesProcessWithStubGitLabServer(t *testing.T, wait bool, pagesBinary string, listeners []ListenSpec, promPort string, envs []string, extraArgs ...string) (teardown func()) { var apiCalled bool - source := NewGitlabDomainsSourceStub(t, &apiCalled) + source := NewGitlabDomainsSourceStub(t, &apiCalled, 0) gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) pagesArgs := append([]string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", "gitlab"}, extraArgs...) - cleanup := runPagesProcess(t, wait, pagesPath, listeners, promPort, envs, pagesArgs...) + _, cleanup := runPagesProcess(t, wait, pagesBinary, listeners, promPort, envs, pagesArgs...) return func() { source.Close() @@ -178,27 +236,28 @@ func RunPagesProcessWithStubGitLabServer(t *testing.T, wait bool, pagesPath stri } } -func RunPagesProcessWithAuth(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string) func() { +func RunPagesProcessWithAuth(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string) func() { configFile, cleanup := defaultConfigFileWith(t, "auth-server=https://gitlab-auth.com", "auth-redirect-uri=https://projects.gitlab-example.com/auth") defer cleanup() - return runPagesProcess(t, true, pagesPath, listeners, promPort, nil, + _, cleanup2 := runPagesProcess(t, true, pagesBinary, listeners, promPort, nil, "-config="+configFile, ) + return cleanup2 } -func RunPagesProcessWithAuthServer(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, authServer string) func() { - return runPagesProcessWithAuthServer(t, pagesPath, listeners, promPort, nil, authServer) +func RunPagesProcessWithAuthServer(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, authServer string) func() { + return runPagesProcessWithAuthServer(t, pagesBinary, listeners, promPort, nil, authServer) } -func RunPagesProcessWithAuthServerWithSSLCertFile(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() { - return runPagesProcessWithAuthServer(t, pagesPath, listeners, promPort, +func RunPagesProcessWithAuthServerWithSSLCertFile(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() { + return runPagesProcessWithAuthServer(t, pagesBinary, listeners, promPort, []string{"SSL_CERT_FILE=" + sslCertFile}, authServer) } -func RunPagesProcessWithAuthServerWithSSLCertDir(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() { +func RunPagesProcessWithAuthServerWithSSLCertDir(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, sslCertFile string, authServer string) func() { // Create temporary cert dir sslCertDir, err := ioutil.TempDir("", "pages-test-SSL_CERT_DIR") require.NoError(t, err) @@ -207,7 +266,7 @@ func RunPagesProcessWithAuthServerWithSSLCertDir(t *testing.T, pagesPath string, err = copyFile(sslCertDir+"/"+path.Base(sslCertFile), sslCertFile) require.NoError(t, err) - innerCleanup := runPagesProcessWithAuthServer(t, pagesPath, listeners, promPort, + innerCleanup := runPagesProcessWithAuthServer(t, pagesBinary, listeners, promPort, []string{"SSL_CERT_DIR=" + sslCertDir}, authServer) return func() { @@ -216,29 +275,33 @@ func RunPagesProcessWithAuthServerWithSSLCertDir(t *testing.T, pagesPath string, } } -func runPagesProcessWithAuthServer(t *testing.T, pagesPath string, listeners []ListenSpec, promPort string, extraEnv []string, authServer string) func() { +func runPagesProcessWithAuthServer(t *testing.T, pagesBinary string, listeners []ListenSpec, promPort string, extraEnv []string, authServer string) func() { configFile, cleanup := defaultConfigFileWith(t, "auth-server="+authServer, "auth-redirect-uri=https://projects.gitlab-example.com/auth") defer cleanup() - return runPagesProcess(t, true, pagesPath, listeners, promPort, extraEnv, + _, cleanup2 := runPagesProcess(t, true, pagesBinary, listeners, promPort, extraEnv, "-config="+configFile) + return cleanup2 } -func runPagesProcess(t *testing.T, wait bool, pagesPath string, listeners []ListenSpec, promPort string, extraEnv []string, extraArgs ...string) (teardown func()) { +func runPagesProcess(t *testing.T, wait bool, pagesBinary string, listeners []ListenSpec, promPort string, extraEnv []string, extraArgs ...string) (*LogCaptureBuffer, func()) { t.Helper() - _, err := os.Stat(pagesPath) + _, err := os.Stat(pagesBinary) require.NoError(t, err) + logBuf := &LogCaptureBuffer{} + out := io.MultiWriter(&tWriter{t}, logBuf) + args, tempfiles := getPagesArgs(t, listeners, promPort, extraArgs) - cmd := exec.Command(pagesPath, args...) + cmd := exec.Command(pagesBinary, args...) cmd.Env = append(os.Environ(), extraEnv...) - cmd.Stdout = &tWriter{t} - cmd.Stderr = &tWriter{t} + cmd.Stdout = out + cmd.Stderr = out require.NoError(t, cmd.Start()) - t.Logf("Running %s %v", pagesPath, args) + t.Logf("Running %s %v", pagesBinary, args) waitCh := make(chan struct{}) go func() { @@ -263,7 +326,7 @@ func runPagesProcess(t *testing.T, wait bool, pagesPath string, listeners []List } } - return cleanup + return logBuf, cleanup } func getPagesArgs(t *testing.T, listeners []ListenSpec, promPort string, extraArgs []string) (args, tempfiles []string) { @@ -285,6 +348,10 @@ func getPagesArgs(t *testing.T, listeners []ListenSpec, promPort string, extraAr args = append(args, "-root-key", key, "-root-cert", cert) } + if !contains(args, "pages-root") { + args = append(args, "-pages-root", "../../shared/pages") + } + if promPort != "" { args = append(args, "-metrics-address", promPort) } @@ -295,6 +362,15 @@ func getPagesArgs(t *testing.T, listeners []ListenSpec, promPort string, extraAr return } +func contains(slice []string, s string) bool { + for _, e := range slice { + if e == s { + return true + } + } + return false +} + func getPagesDaemonArgs(t *testing.T) []string { mode := os.Getenv("TEST_DAEMONIZE") if mode == "" { @@ -347,7 +423,7 @@ func GetPageFromListenerWithCookie(t *testing.T, spec ListenSpec, host, urlsuffi req.Host = host - return DoPagesRequest(t, req) + return DoPagesRequest(t, spec, req) } func GetCompressedPageFromListener(t *testing.T, spec ListenSpec, host, urlsuffix string, encoding string) (*http.Response, error) { @@ -359,7 +435,7 @@ func GetCompressedPageFromListener(t *testing.T, spec ListenSpec, host, urlsuffi req.Host = host req.Header.Set("Accept-Encoding", encoding) - return DoPagesRequest(t, req) + return DoPagesRequest(t, spec, req) } func GetProxiedPageFromListener(t *testing.T, spec ListenSpec, host, xForwardedHost, urlsuffix string) (*http.Response, error) { @@ -372,12 +448,16 @@ func GetProxiedPageFromListener(t *testing.T, spec ListenSpec, host, xForwardedH req.Host = host req.Header.Set("X-Forwarded-Host", xForwardedHost) - return DoPagesRequest(t, req) + return DoPagesRequest(t, spec, req) } -func DoPagesRequest(t *testing.T, req *http.Request) (*http.Response, error) { +func DoPagesRequest(t *testing.T, spec ListenSpec, req *http.Request) (*http.Response, error) { t.Logf("curl -X %s -H'Host: %s' %s", req.Method, req.Host, req.URL) + if spec.Type == "https-proxyv2" { + return TestProxyv2Client.Do(req) + } + return TestHTTPSClient.Do(req) } @@ -413,6 +493,10 @@ func GetRedirectPageWithHeaders(t *testing.T, spec ListenSpec, host, urlsuffix s req.Host = host + if spec.Type == "https-proxyv2" { + return TestProxyv2Client.Transport.RoundTrip(req) + } + return TestHTTPSClient.Transport.RoundTrip(req) } @@ -434,7 +518,12 @@ func waitForRoundtrips(t *testing.T, listeners []ListenSpec, timeout time.Durati t.Fatal(err) } - if response, err := QuickTimeoutHTTPSClient.Transport.RoundTrip(req); err == nil { + client := QuickTimeoutHTTPSClient + if spec.Type == "https-proxyv2" { + client = QuickTimeoutProxyv2Client + } + + if response, err := client.Transport.RoundTrip(req); err == nil { nListening++ response.Body.Close() break @@ -447,17 +536,23 @@ func waitForRoundtrips(t *testing.T, listeners []ListenSpec, timeout time.Durati require.Equal(t, len(listeners), nListening, "all listeners must be accepting TCP connections") } -func NewGitlabDomainsSourceStub(t *testing.T, apiCalled *bool) *httptest.Server { +func NewGitlabDomainsSourceStub(t *testing.T, apiCalled *bool, readyCount int) *httptest.Server { *apiCalled = false + currentStatusCount := 0 + mux := http.NewServeMux() mux.HandleFunc("/api/v4/internal/pages/status", func(w http.ResponseWriter, r *http.Request) { + if currentStatusCount < readyCount { + w.WriteHeader(http.StatusBadGateway) + } + w.WriteHeader(http.StatusNoContent) }) handler := func(w http.ResponseWriter, r *http.Request) { *apiCalled = true domain := r.URL.Query().Get("host") - path := "shared/lookups/" + domain + ".json" + path := "../../shared/lookups/" + domain + ".json" fixture, err := os.Open(path) if os.IsNotExist(err) { @@ -534,31 +629,3 @@ func copyFile(dest, src string) error { _, err = io.Copy(destFile, srcFile) return err } - -func newZipFileServerURL(t *testing.T, zipFilePath string) (string, func()) { - t.Helper() - - m := http.NewServeMux() - m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, zipFilePath) - })) - - // create a listener with the desired port. - l, err := net.Listen("tcp", objectStorageMockServer) - require.NoError(t, err) - - testServer := httptest.NewUnstartedServer(m) - - // NewUnstartedServer creates a listener. Close that listener and replace - // with the one we created. - testServer.Listener.Close() - testServer.Listener = l - - // Start the server. - testServer.Start() - - return testServer.URL, func() { - // Cleanup. - testServer.Close() - } -} diff --git a/test/acceptance/metrics_test.go b/test/acceptance/metrics_test.go new file mode 100644 index 000000000..64cfb60ac --- /dev/null +++ b/test/acceptance/metrics_test.go @@ -0,0 +1,62 @@ +package acceptance_test + +import ( + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrometheusMetricsCanBeScraped(t *testing.T) { + skipUnlessEnabled(t) + + _, cleanup := newZipFileServerURL(t, "../../shared/pages/group/zip.gitlab.io/public.zip") + defer cleanup() + + teardown := RunPagesProcessWithStubGitLabServer(t, true, *pagesBinary, listeners, ":42345", []string{}) + defer teardown() + + // need to call an actual resource to populate certain metrics e.g. gitlab_pages_domains_source_api_requests_total + res, err := GetPageFromListener(t, httpListener, "zip.gitlab.io", + "/symlink.html") + require.NoError(t, err) + require.Equal(t, http.StatusOK, res.StatusCode) + + resp, err := http.Get("http://localhost:42345/metrics") + require.NoError(t, err) + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.Contains(t, string(body), "gitlab_pages_http_in_flight_requests 0") + // TODO: remove metrics for disk source https://gitlab.com/gitlab-org/gitlab-pages/-/issues/382 + require.Contains(t, string(body), "gitlab_pages_served_domains 0") + require.Contains(t, string(body), "gitlab_pages_domains_failed_total 0") + require.Contains(t, string(body), "gitlab_pages_domains_updated_total 0") + require.Contains(t, string(body), "gitlab_pages_last_domain_update_seconds gauge") + require.Contains(t, string(body), "gitlab_pages_domains_configuration_update_duration gauge") + // end TODO + require.Contains(t, string(body), "gitlab_pages_domains_source_cache_hit") + require.Contains(t, string(body), "gitlab_pages_domains_source_cache_miss") + require.Contains(t, string(body), "gitlab_pages_domains_source_failures_total") + require.Contains(t, string(body), "gitlab_pages_serverless_requests 0") + require.Contains(t, string(body), "gitlab_pages_serverless_latency_sum 0") + require.Contains(t, string(body), "gitlab_pages_disk_serving_file_size_bytes_sum") + require.Contains(t, string(body), "gitlab_pages_serving_time_seconds_sum") + require.Contains(t, string(body), `gitlab_pages_domains_source_api_requests_total{status_code="200"}`) + require.Contains(t, string(body), `gitlab_pages_domains_source_api_call_duration_bucket`) + require.Contains(t, string(body), `gitlab_pages_domains_source_api_trace_duration`) + // httprange + require.Contains(t, string(body), `gitlab_pages_httprange_requests_total{status_code="206"}`) + require.Contains(t, string(body), "gitlab_pages_httprange_requests_duration_bucket") + require.Contains(t, string(body), "gitlab_pages_httprange_trace_duration") + require.Contains(t, string(body), "gitlab_pages_httprange_open_requests") + // zip archives + require.Contains(t, string(body), "gitlab_pages_zip_opened") + require.Contains(t, string(body), "gitlab_pages_zip_cache_requests") + require.Contains(t, string(body), "gitlab_pages_zip_cached_entries") + require.Contains(t, string(body), "gitlab_pages_zip_archive_entries_cached") + require.Contains(t, string(body), "gitlab_pages_zip_opened_entries_count") +} diff --git a/test/acceptance/proxyv2_test.go b/test/acceptance/proxyv2_test.go new file mode 100644 index 000000000..c407ea194 --- /dev/null +++ b/test/acceptance/proxyv2_test.go @@ -0,0 +1,52 @@ +package acceptance_test + +import ( + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestProxyv2(t *testing.T) { + skipUnlessEnabled(t) + + logBuf, teardown := RunPagesProcessWithOutput(t, *pagesBinary, listeners, "") + defer teardown() + + // the dummy client IP 10.1.1.1 is set by TestProxyv2Client + tests := map[string]struct { + host string + urlSuffix string + expectedStatusCode int + expectedContent string + expectedLog string + }{ + "basic_proxyv2_request": { + host: "group.gitlab-example.com", + urlSuffix: "project/", + expectedStatusCode: http.StatusOK, + expectedContent: "project-subdir\n", + expectedLog: "group.gitlab-example.com 10.1.1.1", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + logBuf.Reset() + + response, err := GetPageFromListener(t, httpsProxyv2Listener, tt.host, tt.urlSuffix) + require.NoError(t, err) + defer response.Body.Close() + + require.Equal(t, tt.expectedStatusCode, response.StatusCode) + + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + require.Contains(t, string(body), tt.expectedContent, "content mismatch") + + require.Contains(t, logBuf.String(), tt.expectedLog, "log mismatch") + }) + } +} diff --git a/test/acceptance/redirects_test.go b/test/acceptance/redirects_test.go new file mode 100644 index 000000000..6c564ce69 --- /dev/null +++ b/test/acceptance/redirects_test.go @@ -0,0 +1,116 @@ +package acceptance_test + +import ( + "fmt" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDisabledRedirects(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{"FF_ENABLE_REDIRECTS=false"}) + defer teardown() + + // Test that redirects status page is forbidden + rsp, err := GetPageFromListener(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/_redirects") + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusForbidden, rsp.StatusCode) + + // Test that redirects are disabled + rsp, err = GetRedirectPage(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/redirect-portal.html") + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusNotFound, rsp.StatusCode) +} + +func TestRedirectStatusPage(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/_redirects") + require.NoError(t, err) + + body, err := ioutil.ReadAll(rsp.Body) + require.NoError(t, err) + defer rsp.Body.Close() + + require.Contains(t, string(body), "11 rules") + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestRedirect(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + // Test that serving a file still works with redirects enabled + rsp, err := GetRedirectPage(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/index.html") + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusOK, rsp.StatusCode) + + tests := []struct { + host string + path string + expectedStatus int + expectedLocation string + }{ + // Project domain + { + host: "group.redirects.gitlab-example.com", + path: "/project-redirects/redirect-portal.html", + expectedStatus: http.StatusFound, + expectedLocation: "/project-redirects/magic-land.html", + }, + // Make sure invalid rule does not redirect + { + host: "group.redirects.gitlab-example.com", + path: "/project-redirects/goto-domain.html", + expectedStatus: http.StatusNotFound, + expectedLocation: "", + }, + // Actual file on disk should override any redirects that match + { + host: "group.redirects.gitlab-example.com", + path: "/project-redirects/file-override.html", + expectedStatus: http.StatusOK, + expectedLocation: "", + }, + // Group-level domain + { + host: "group.redirects.gitlab-example.com", + path: "/redirect-portal.html", + expectedStatus: http.StatusFound, + expectedLocation: "/magic-land.html", + }, + // Custom domain + { + host: "redirects.custom-domain.com", + path: "/redirect-portal.html", + expectedStatus: http.StatusFound, + expectedLocation: "/magic-land.html", + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s%s -> %s (%d)", tt.host, tt.path, tt.expectedLocation, tt.expectedStatus), func(t *testing.T) { + rsp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, tt.expectedLocation, rsp.Header.Get("Location")) + require.Equal(t, tt.expectedStatus, rsp.StatusCode) + }) + } +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go new file mode 100644 index 000000000..becd6b8cd --- /dev/null +++ b/test/acceptance/serving_test.go @@ -0,0 +1,574 @@ +package acceptance_test + +import ( + "fmt" + "io/ioutil" + "net/http" + "os" + "path" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHostReturnsNotFound(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, "invalid.invalid", "") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusNotFound, rsp.StatusCode) + } +} + +func TestUnknownProjectReturnsNotFound(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/nonexistent/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusNotFound, rsp.StatusCode) +} + +func TestGroupDomainReturns200(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestKnownHostReturns200(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + tests := []struct { + name string + host string + path string + }{ + { + name: "lower case", + host: "group.gitlab-example.com", + path: "project/", + }, + { + name: "capital project", + host: "group.gitlab-example.com", + path: "CapitalProject/", + }, + { + name: "capital group", + host: "CapitalGroup.gitlab-example.com", + path: "project/", + }, + { + name: "capital group and project", + host: "CapitalGroup.gitlab-example.com", + path: "CapitalProject/", + }, + { + name: "subgroup", + host: "group.gitlab-example.com", + path: "subgroup/project/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, tt.host, tt.path) + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) + } + }) + } +} + +func TestNestedSubgroups(t *testing.T) { + skipUnlessEnabled(t) + + maxNestedSubgroup := 21 + + pagesRoot, err := ioutil.TempDir("", "pages-root") + require.NoError(t, err) + defer os.RemoveAll(pagesRoot) + + makeProjectIndex := func(subGroupPath string) { + projectPath := path.Join(pagesRoot, "nested", subGroupPath, "project", "public") + require.NoError(t, os.MkdirAll(projectPath, 0755)) + + projectIndex := path.Join(projectPath, "index.html") + require.NoError(t, ioutil.WriteFile(projectIndex, []byte("index"), 0644)) + } + makeProjectIndex("") + + paths := []string{""} + for i := 1; i < maxNestedSubgroup*2; i++ { + subGroupPath := fmt.Sprintf("%ssub%d/", paths[i-1], i) + paths = append(paths, subGroupPath) + + makeProjectIndex(subGroupPath) + } + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-pages-root", pagesRoot) + defer teardown() + + for nestingLevel, path := range paths { + t.Run(fmt.Sprintf("nested level %d", nestingLevel), func(t *testing.T) { + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, "nested.gitlab-example.com", path+"project/") + + require.NoError(t, err) + rsp.Body.Close() + if nestingLevel <= maxNestedSubgroup { + require.Equal(t, http.StatusOK, rsp.StatusCode) + } else { + require.Equal(t, http.StatusNotFound, rsp.StatusCode) + } + } + }) + } +} + +func TestCustom404(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + tests := []struct { + host string + path string + content string + }{ + { + host: "group.404.gitlab-example.com", + path: "project.404/not/existing-file", + content: "Custom 404 project page", + }, + { + host: "group.404.gitlab-example.com", + path: "project.404/", + content: "Custom 404 project page", + }, + { + host: "group.404.gitlab-example.com", + path: "not/existing-file", + content: "Custom 404 group page", + }, + { + host: "group.404.gitlab-example.com", + path: "not-existing-file", + content: "Custom 404 group page", + }, + { + host: "group.404.gitlab-example.com", + content: "Custom 404 group page", + }, + { + host: "domain.404.com", + content: "Custom domain.404 page", + }, + { + host: "group.404.gitlab-example.com", + path: "project.no.404/not/existing-file", + content: "The page you're looking for could not be found.", + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s/%s", test.host, test.path), func(t *testing.T) { + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, test.host, test.path) + + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusNotFound, rsp.StatusCode) + + page, err := ioutil.ReadAll(rsp.Body) + require.NoError(t, err) + require.Contains(t, string(page), test.content) + } + }) + } +} + +func TestCORSWhenDisabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-disable-cross-origin-requests") + defer teardown() + + for _, spec := range listeners { + for _, method := range []string{"GET", "OPTIONS"} { + rsp := doCrossOriginRequest(t, spec, method, method, spec.URL("project/")) + + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) + } + } +} + +func TestCORSAllowsGET(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, spec := range listeners { + for _, method := range []string{"GET", "OPTIONS"} { + rsp := doCrossOriginRequest(t, spec, method, method, spec.URL("project/")) + + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.Equal(t, "*", rsp.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) + } + } +} + +func TestCORSForbidsPOST(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, spec := range listeners { + rsp := doCrossOriginRequest(t, spec, "OPTIONS", "POST", spec.URL("project/")) + + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Credentials")) + } +} + +func TestCustomHeaders(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-header", "X-Test1:Testing1", "-header", "X-Test2:Testing2") + defer teardown() + + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:", "project/") + require.NoError(t, err) + require.Equal(t, http.StatusOK, rsp.StatusCode) + require.Equal(t, "Testing1", rsp.Header.Get("X-Test1")) + require.Equal(t, "Testing2", rsp.Header.Get("X-Test2")) + } +} + +func TestKnownHostWithPortReturns200(t *testing.T) { + skipUnlessEnabled(t) + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:"+spec.Port, "project/") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) + } +} + +func TestHttpToHttpsRedirectDisabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) + + rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestHttpToHttpsRedirectEnabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-redirect-http=true") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusTemporaryRedirect, rsp.StatusCode) + require.Equal(t, 1, len(rsp.Header["Location"])) + require.Equal(t, "https://group.gitlab-example.com/project/", rsp.Header.Get("Location")) + + rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestHttpsOnlyGroupEnabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "group.https-only.gitlab-example.com", "project1/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusMovedPermanently, rsp.StatusCode) +} + +func TestHttpsOnlyGroupDisabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.https-only.gitlab-example.com", "project2/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestHttpsOnlyProjectEnabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpListener, "test.my-domain.com", "/index.html") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusMovedPermanently, rsp.StatusCode) +} + +func TestHttpsOnlyProjectDisabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "test2.my-domain.com", "/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestHttpsOnlyDomainDisabled(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "no.cert.com", "/") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestDomainsSource(t *testing.T) { + skipUnlessEnabled(t) + + type args struct { + configSource string + domain string + urlSuffix string + readyCount int + } + type want struct { + statusCode int + content string + apiCalled bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "gitlab_source_domain_exists", + args: args{ + configSource: "gitlab", + domain: "new-source-test.gitlab.io", + urlSuffix: "/my/pages/project/", + }, + want: want{ + statusCode: http.StatusOK, + content: "New Pages GitLab Source TEST OK\n", + apiCalled: true, + }, + }, + { + name: "gitlab_source_domain_does_not_exist", + args: args{ + configSource: "gitlab", + domain: "non-existent-domain.gitlab.io", + }, + want: want{ + statusCode: http.StatusNotFound, + apiCalled: true, + }, + }, + { + name: "disk_source_domain_exists", + args: args{ + configSource: "disk", + // test.domain.com sourced from disk configuration + domain: "test.domain.com", + urlSuffix: "/", + }, + want: want{ + statusCode: http.StatusOK, + content: "main-dir\n", + apiCalled: false, + }, + }, + { + name: "disk_source_domain_does_not_exist", + args: args{ + configSource: "disk", + domain: "non-existent-domain.gitlab.io", + }, + want: want{ + statusCode: http.StatusNotFound, + apiCalled: false, + }, + }, + { + name: "disk_source_domain_should_not_exist_under_hashed_dir", + args: args{ + configSource: "disk", + domain: "hashed.com", + }, + want: want{ + statusCode: http.StatusNotFound, + apiCalled: false, + }, + }, + { + name: "auto_source_gitlab_is_not_ready", + args: args{ + configSource: "auto", + domain: "test.domain.com", + urlSuffix: "/", + readyCount: 100, // big number to ensure the API is in bad state for a while + }, + want: want{ + statusCode: http.StatusOK, + content: "main-dir\n", + apiCalled: false, + }, + }, + { + name: "auto_source_gitlab_is_ready", + args: args{ + configSource: "auto", + domain: "new-source-test.gitlab.io", + urlSuffix: "/my/pages/project/", + readyCount: 0, + }, + want: want{ + statusCode: http.StatusOK, + content: "New Pages GitLab Source TEST OK\n", + apiCalled: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var apiCalled bool + source := NewGitlabDomainsSourceStub(t, &apiCalled, tt.args.readyCount) + defer source.Close() + + gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) + + pagesArgs := []string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", tt.args.configSource} + teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{}, pagesArgs...) + defer teardown() + + response, err := GetPageFromListener(t, httpListener, tt.args.domain, tt.args.urlSuffix) + require.NoError(t, err) + + require.Equal(t, tt.want.statusCode, response.StatusCode) + if tt.want.statusCode == http.StatusOK { + defer response.Body.Close() + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + require.Equal(t, tt.want.content, string(body), "content mismatch") + } + + require.Equal(t, tt.want.apiCalled, apiCalled, "api called mismatch") + }) + } +} + +func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { + skipUnlessEnabled(t) + + var listeners = []ListenSpec{ + {"proxy", "127.0.0.1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, + } + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + for _, spec := range listeners { + rsp, err := GetProxiedPageFromListener(t, spec, "localhost", "group.gitlab-example.com", "project/") + + require.NoError(t, err) + rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) + } +} + +func doCrossOriginRequest(t *testing.T, spec ListenSpec, method, reqMethod, url string) *http.Response { + req, err := http.NewRequest(method, url, nil) + require.NoError(t, err) + + req.Host = "group.gitlab-example.com" + req.Header.Add("Origin", "example.com") + req.Header.Add("Access-Control-Request-Method", reqMethod) + + var rsp *http.Response + err = fmt.Errorf("no request was made") + for start := time.Now(); time.Since(start) < 1*time.Second; { + rsp, err = DoPagesRequest(t, spec, req) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + require.NoError(t, err) + + rsp.Body.Close() + return rsp +} + +func TestQueryStringPersistedInSlashRewrite(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + rsp, err := GetRedirectPage(t, httpsListener, "group.gitlab-example.com", "project?q=test") + require.NoError(t, err) + defer rsp.Body.Close() + + require.Equal(t, http.StatusFound, rsp.StatusCode) + require.Equal(t, 1, len(rsp.Header["Location"])) + require.Equal(t, "//group.gitlab-example.com/project/?q=test", rsp.Header.Get("Location")) + + rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/?q=test") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} diff --git a/test/acceptance/status_test.go b/test/acceptance/status_test.go new file mode 100644 index 000000000..8e227ed80 --- /dev/null +++ b/test/acceptance/status_test.go @@ -0,0 +1,44 @@ +package acceptance_test + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStatusPage(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-pages-status=/@statuscheck") + defer teardown() + + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusOK, rsp.StatusCode) +} + +func TestStatusNotYetReady(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-status=/@statuscheck", "-pages-root=../../shared/invalid-pages") + defer teardown() + + waitForRoundtrips(t, listeners, 5*time.Second) + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) +} + +func TestPageNotAvailableIfNotLoaded(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-root=../../shared/invalid-pages") + defer teardown() + waitForRoundtrips(t, listeners, 5*time.Second) + + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "index.html") + require.NoError(t, err) + defer rsp.Body.Close() + require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) +} diff --git a/test/acceptance/stub_test.go b/test/acceptance/stub_test.go new file mode 100644 index 000000000..8f52ec37a --- /dev/null +++ b/test/acceptance/stub_test.go @@ -0,0 +1,72 @@ +package acceptance_test + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" +) + +// makeGitLabPagesAccessStub provides a stub *httptest.Server to check pages_access API call. +// the result is based on the project id. +// +// Project IDs must be 4 digit long and the following rules applies: +// 1000-1999: Ok +// 2000-2999: Unauthorized +// 3000-3999: Invalid token +func makeGitLabPagesAccessStub(t *testing.T) *httptest.Server { + t.Helper() + + return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + require.Equal(t, "POST", r.Method) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "{\"access_token\":\"abc\"}") + case "/api/v4/user": + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + default: + if handleAccessControlArtifactRequests(t, w, r) { + return + } + handleAccessControlRequests(t, w, r) + } + })) +} + +func CreateHTTPSFixtureFiles(t *testing.T) (key string, cert string) { + t.Helper() + + keyfile, err := ioutil.TempFile("", "https-fixture") + require.NoError(t, err) + key = keyfile.Name() + keyfile.Close() + + certfile, err := ioutil.TempFile("", "https-fixture") + require.NoError(t, err) + cert = certfile.Name() + certfile.Close() + + require.NoError(t, ioutil.WriteFile(key, []byte(fixture.Key), 0644)) + require.NoError(t, ioutil.WriteFile(cert, []byte(fixture.Certificate), 0644)) + + return keyfile.Name(), certfile.Name() +} + +func CreateGitLabAPISecretKeyFixtureFile(t *testing.T) (filepath string) { + t.Helper() + + secretfile, err := ioutil.TempFile("", "gitlab-api-secret") + require.NoError(t, err) + secretfile.Close() + + require.NoError(t, ioutil.WriteFile(secretfile.Name(), []byte(fixture.GitLabAPISecretKey), 0644)) + + return secretfile.Name() +} diff --git a/test/acceptance/tls_test.go b/test/acceptance/tls_test.go new file mode 100644 index 000000000..3445c6c38 --- /dev/null +++ b/test/acceptance/tls_test.go @@ -0,0 +1,130 @@ +package acceptance_test + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAcceptsSupportedCiphers(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + tlsConfig := &tls.Config{ + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + }, + } + client, cleanup := ClientWithConfig(tlsConfig) + defer cleanup() + + rsp, err := client.Get(httpsListener.URL("/")) + + if rsp != nil { + rsp.Body.Close() + } + + require.NoError(t, err) +} + +func tlsConfigWithInsecureCiphersOnly() *tls.Config { + return &tls.Config{ + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + }, + MaxVersion: tls.VersionTLS12, // ciphers for TLS1.3 are not configurable and will work if enabled + } +} + +func TestRejectsUnsupportedCiphers(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) + defer cleanup() + + rsp, err := client.Get(httpsListener.URL("/")) + + if rsp != nil { + rsp.Body.Close() + } + + require.Error(t, err) + require.Nil(t, rsp) +} + +func TestEnableInsecureCiphers(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-insecure-ciphers") + defer teardown() + + client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) + defer cleanup() + + rsp, err := client.Get(httpsListener.URL("/")) + + if rsp != nil { + rsp.Body.Close() + } + + require.NoError(t, err) +} + +func TestTLSVersions(t *testing.T) { + skipUnlessEnabled(t) + + tests := map[string]struct { + tlsMin string + tlsMax string + tlsClient uint16 + expectError bool + }{ + "client version not supported": {tlsMin: "tls1.1", tlsMax: "tls1.2", tlsClient: tls.VersionTLS10, expectError: true}, + "client version supported": {tlsMin: "tls1.1", tlsMax: "tls1.2", tlsClient: tls.VersionTLS12, expectError: false}, + "client and server using default settings": {tlsMin: "", tlsMax: "", tlsClient: 0, expectError: false}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + args := []string{} + if tc.tlsMin != "" { + args = append(args, "-tls-min-version", tc.tlsMin) + } + if tc.tlsMax != "" { + args = append(args, "-tls-max-version", tc.tlsMax) + } + + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", args...) + defer teardown() + + tlsConfig := &tls.Config{} + if tc.tlsClient != 0 { + tlsConfig.MinVersion = tc.tlsClient + tlsConfig.MaxVersion = tc.tlsClient + } + client, cleanup := ClientWithConfig(tlsConfig) + defer cleanup() + + rsp, err := client.Get(httpsListener.URL("/")) + + if rsp != nil { + rsp.Body.Close() + } + + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} diff --git a/test/acceptance/zip_test.go b/test/acceptance/zip_test.go new file mode 100644 index 000000000..5d3037c81 --- /dev/null +++ b/test/acceptance/zip_test.go @@ -0,0 +1,161 @@ +package acceptance_test + +import ( + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestZipServing(t *testing.T) { + skipUnlessEnabled(t) + + var apiCalled bool + source := NewGitlabDomainsSourceStub(t, &apiCalled, 0) + defer source.Close() + + gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) + + pagesArgs := []string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", "gitlab"} + teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{}, pagesArgs...) + defer teardown() + + _, cleanup := newZipFileServerURL(t, "../../shared/pages/group/zip.gitlab.io/public.zip") + defer cleanup() + + tests := map[string]struct { + host string + urlSuffix string + expectedStatusCode int + expectedContent string + }{ + "base_domain_no_suffix": { + host: "zip.gitlab.io", + urlSuffix: "/", + expectedStatusCode: http.StatusOK, + expectedContent: "zip.gitlab.io/project/index.html\n", + }, + "file_exists": { + host: "zip.gitlab.io", + urlSuffix: "/index.html", + expectedStatusCode: http.StatusOK, + expectedContent: "zip.gitlab.io/project/index.html\n", + }, + "file_exists_in_subdir": { + host: "zip.gitlab.io", + urlSuffix: "/subdir/hello.html", + expectedStatusCode: http.StatusOK, + expectedContent: "zip.gitlab.io/project/subdir/hello.html\n", + }, + "file_exists_symlink": { + host: "zip.gitlab.io", + urlSuffix: "/symlink.html", + expectedStatusCode: http.StatusOK, + expectedContent: "symlink.html->subdir/linked.html\n", + }, + "dir": { + host: "zip.gitlab.io", + urlSuffix: "/subdir/", + expectedStatusCode: http.StatusNotFound, + expectedContent: "zip.gitlab.io/project/404.html\n", + }, + "file_does_not_exist": { + host: "zip.gitlab.io", + urlSuffix: "/unknown.html", + expectedStatusCode: http.StatusNotFound, + expectedContent: "zip.gitlab.io/project/404.html\n", + }, + "bad_symlink": { + host: "zip.gitlab.io", + urlSuffix: "/bad-symlink.html", + expectedStatusCode: http.StatusNotFound, + expectedContent: "zip.gitlab.io/project/404.html\n", + }, + "with_not_found_zip": { + host: "zip-not-found.gitlab.io", + urlSuffix: "/", + expectedStatusCode: http.StatusNotFound, + expectedContent: "The page you're looking for could not be found", + }, + "with_malformed_zip": { + host: "zip-malformed.gitlab.io", + urlSuffix: "/", + expectedStatusCode: http.StatusInternalServerError, + expectedContent: "Something went wrong (500)", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + response, err := GetPageFromListener(t, httpListener, tt.host, tt.urlSuffix) + require.NoError(t, err) + defer response.Body.Close() + + require.Equal(t, tt.expectedStatusCode, response.StatusCode) + + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + require.Contains(t, string(body), tt.expectedContent, "content mismatch") + }) + } +} + +func TestZipServingConfigShortTimeout(t *testing.T) { + skipUnlessEnabled(t) + + var apiCalled bool + source := NewGitlabDomainsSourceStub(t, &apiCalled, 0) + defer source.Close() + + gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) + + pagesArgs := []string{"-gitlab-server", source.URL, "-api-secret-key", gitLabAPISecretKey, "-domain-config-source", "gitlab", + "-zip-open-timeout=1ns"} // <- test purpose + + teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, listeners, "", []string{}, pagesArgs...) + defer teardown() + + _, cleanup := newZipFileServerURL(t, "../../shared/pages/group/zip.gitlab.io/public.zip") + defer cleanup() + + response, err := GetPageFromListener(t, httpListener, "zip.gitlab.io", "/") + require.NoError(t, err) + defer response.Body.Close() + + require.Equal(t, http.StatusInternalServerError, response.StatusCode, "should fail to serve") +} + +func newZipFileServerURL(t *testing.T, zipFilePath string) (string, func()) { + t.Helper() + + m := http.NewServeMux() + m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, zipFilePath) + })) + m.HandleFunc("/malformed.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + // create a listener with the desired port. + l, err := net.Listen("tcp", objectStorageMockServer) + require.NoError(t, err) + + testServer := httptest.NewUnstartedServer(m) + + // NewUnstartedServer creates a listener. Close that listener and replace + // with the one we created. + testServer.Listener.Close() + testServer.Listener = l + + // Start the server. + testServer.Start() + + return testServer.URL, func() { + // Cleanup. + testServer.Close() + } +} diff --git a/tools.go b/tools.go index 902fac800..38b719476 100644 --- a/tools.go +++ b/tools.go @@ -4,6 +4,7 @@ package main import ( _ "github.com/fzipp/gocyclo" + _ "github.com/jstemmer/go-junit-report" _ "github.com/wadey/gocovmerge" _ "golang.org/x/lint/golint" _ "golang.org/x/tools/cmd/goimports" -- GitLab From 5e48f9a2cacb9b74aca8810509790791680241c3 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 15:16:53 +0000 Subject: [PATCH 14/17] fix failing ci pipeline --- internal/source/gitlab/cache/retriever.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index 43ef2e523..4341c7901 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -1,11 +1,12 @@ package cache import ( - "time" "context" "errors" "sync" - + + "time" + log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api" @@ -63,7 +64,7 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup - Retry: + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) @@ -114,4 +115,3 @@ func (t *timer) hasStopped() bool { return t.stopped } - \ No newline at end of file -- GitLab From 5f07c1f16e644430b432d076184bfb58dccda76f Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 15:19:44 +0000 Subject: [PATCH 15/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- go.mod | 2 +- go.sum | 4 +- internal/auth/auth.go | 117 +++++++--- internal/auth/auth_code.go | 147 +++++++++++++ internal/auth/auth_code_test.go | 99 +++++++++ internal/auth/auth_test.go | 207 ++++++++++-------- internal/logging/logging.go | 1 + internal/rejectmethods/middleware.go | 31 +++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 49 ++++- .../source/gitlab/cache/retriever_test.go | 27 +++ internal/vfs/zip/archive.go | 24 +- internal/vfs/zip/archive_test.go | 53 +++++ internal/vfs/zip/deflate_reader.go | 43 +++- metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 ++- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 109 ++++++++- test/acceptance/proxyv2_test.go | 7 +- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 ++ 24 files changed, 898 insertions(+), 183 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/go.mod b/go.mod index 76d45a9c9..f06ea125b 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce github.com/wadey/gocovmerge v0.0.0-20160331181800-b5bfa59ec0ad gitlab.com/gitlab-org/go-mimedb v1.45.0 - gitlab.com/gitlab-org/labkit v1.0.0 + gitlab.com/gitlab-org/labkit v1.3.0 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f golang.org/x/net v0.0.0-20200226121028-0de0cce0169b diff --git a/go.sum b/go.sum index 945b05ce7..672bda11e 100644 --- a/go.sum +++ b/go.sum @@ -346,8 +346,8 @@ github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= gitlab.com/gitlab-org/go-mimedb v1.45.0 h1:PO8dx6HEWzPYU6MQTYnCbpQEJzhJLW/Bh43+2VUHTgc= gitlab.com/gitlab-org/go-mimedb v1.45.0/go.mod h1:wa9y/zOSFKmTXLyBs4clz2FNVhZQmmEQM9TxslPAjZ0= -gitlab.com/gitlab-org/labkit v1.0.0 h1:t2Wr8ygtvHfXAMlCkoEdk5pdb5Gy1IYdr41H7t4kAYw= -gitlab.com/gitlab-org/labkit v1.0.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U= +gitlab.com/gitlab-org/labkit v1.3.0 h1:PDP4id5YEvw6juWrGE88LcTtEridtRAOyvNvUOtcc9o= +gitlab.com/gitlab-org/labkit v1.3.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 6643e169e..9a3629ffe 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -88,6 +88,7 @@ func BasicAccessLogger(handler http.Handler, format string, extraFields log.Extr return log.AccessLogger(handler, log.WithExtraFields(extraFields), log.WithAccessLogger(accessLogger), + log.WithXFFAllowed(func(sip string) bool { return false }), ), nil } diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..4341c7901 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -3,6 +3,8 @@ package cache import ( "context" "errors" + "sync" + "time" log "github.com/sirupsen/logrus" @@ -13,15 +15,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +64,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +93,25 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..774e9779e --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/internal/vfs/zip/archive.go b/internal/vfs/zip/archive.go index 1137f0041..9826cdd63 100644 --- a/internal/vfs/zip/archive.go +++ b/internal/vfs/zip/archive.go @@ -7,7 +7,7 @@ import ( "fmt" "io" "os" - "path/filepath" + "path" "strconv" "strings" "sync" @@ -161,34 +161,30 @@ func (a *zipArchive) readArchive(url string) { } // addPathDirectory adds a directory for a given path -func (a *zipArchive) addPathDirectory(path string) { +func (a *zipArchive) addPathDirectory(pathname string) { // Split dir and file from `path` - path, _ = filepath.Split(path) - if path == "" { + pathname, _ = path.Split(pathname) + if pathname == "" { return } - if a.directories[path] != nil { + if a.directories[pathname] != nil { return } - a.directories[path] = &zip.FileHeader{ - Name: path, + a.directories[pathname] = &zip.FileHeader{ + Name: pathname, } } func (a *zipArchive) findFile(name string) *zip.File { - name = filepath.Join(dirPrefix, name) + name = path.Clean(dirPrefix + name) - if file := a.files[name]; file != nil { - return file - } - - return nil + return a.files[name] } func (a *zipArchive) findDirectory(name string) *zip.FileHeader { - name = filepath.Join(dirPrefix, name) + name = path.Clean(dirPrefix + name) return a.directories[name+"/"] } diff --git a/internal/vfs/zip/archive_test.go b/internal/vfs/zip/archive_test.go index da778e620..58b7c74ab 100644 --- a/internal/vfs/zip/archive_test.go +++ b/internal/vfs/zip/archive_test.go @@ -1,11 +1,16 @@ package zip import ( + "archive/zip" + "bytes" "context" + "crypto/rand" + "io" "io/ioutil" "net/http" "net/http/httptest" "os" + "strconv" "sync/atomic" "testing" "time" @@ -419,3 +424,51 @@ func newZipFileServerURL(t *testing.T, zipFilePath string, requests *int64) (str testServer.Close() } } + +func benchmarkArchiveRead(b *testing.B, size int64) { + zbuf := new(bytes.Buffer) + + // create zip file of specified size + zw := zip.NewWriter(zbuf) + w, err := zw.Create("public/file.txt") + require.NoError(b, err) + _, err = io.CopyN(w, rand.Reader, size) + require.NoError(b, err) + require.NoError(b, zw.Close()) + + modtime := time.Now().Add(-time.Hour) + + m := http.NewServeMux() + m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, "public.zip", modtime, bytes.NewReader(zbuf.Bytes())) + })) + + ts := httptest.NewServer(m) + defer ts.Close() + + fs := New(zipCfg).(*zipVFS) + + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + z := newArchive(fs, time.Second) + err := z.openArchive(context.Background(), ts.URL+"/public.zip") + require.NoError(b, err) + + f, err := z.Open(context.Background(), "file.txt") + require.NoError(b, err) + + _, err = io.Copy(ioutil.Discard, f) + require.NoError(b, err) + + require.NoError(b, f.Close()) + } +} + +func BenchmarkArchiveRead(b *testing.B) { + for _, size := range []int{32 * 1024, 64 * 1024, 1024 * 1024} { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchmarkArchiveRead(b, int64(size)) + }) + } +} diff --git a/internal/vfs/zip/deflate_reader.go b/internal/vfs/zip/deflate_reader.go index 16a2d72e0..87a7da0c6 100644 --- a/internal/vfs/zip/deflate_reader.go +++ b/internal/vfs/zip/deflate_reader.go @@ -1,31 +1,66 @@ package zip import ( + "bufio" "compress/flate" + "errors" "io" + "sync" ) +var ErrClosedReader = errors.New("deflatereader: reader is closed") + +var deflateReaderPool sync.Pool + // deflateReader wrapper to support reading compressed files. // Implements the io.ReadCloser interface. type deflateReader struct { - reader io.ReadCloser + reader *bufio.Reader + closer io.Closer flateReader io.ReadCloser } // Read from flateReader func (r *deflateReader) Read(p []byte) (n int, err error) { + if r.closer == nil { + return 0, ErrClosedReader + } + return r.flateReader.Read(p) } // Close all readers func (r *deflateReader) Close() error { - r.reader.Close() + if r.closer == nil { + return ErrClosedReader + } + + defer func() { + r.closer.Close() + r.closer = nil + deflateReaderPool.Put(r) + }() + return r.flateReader.Close() } +func (r *deflateReader) reset(rc io.ReadCloser) { + r.reader.Reset(rc) + r.closer = rc + r.flateReader.(flate.Resetter).Reset(r.reader, nil) +} + func newDeflateReader(r io.ReadCloser) *deflateReader { + if dr, ok := deflateReaderPool.Get().(*deflateReader); ok { + dr.reset(r) + return dr + } + + br := bufio.NewReader(r) + return &deflateReader{ - reader: r, - flateReader: flate.NewReader(r), + reader: br, + closer: r, + flateReader: flate.NewReader(br), } } diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -234,11 +234,10 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -336,11 +335,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +410,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +644,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/proxyv2_test.go b/test/acceptance/proxyv2_test.go index c407ea194..2a42f0f1c 100644 --- a/test/acceptance/proxyv2_test.go +++ b/test/acceptance/proxyv2_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -46,7 +47,11 @@ func TestProxyv2(t *testing.T) { require.Contains(t, string(body), tt.expectedContent, "content mismatch") - require.Contains(t, logBuf.String(), tt.expectedLog, "log mismatch") + // give the process enough time to write the log message + require.Eventually(t, func() bool { + require.Contains(t, logBuf.String(), tt.expectedLog, "log mismatch") + return true + }, time.Second, time.Millisecond) }) } } diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From 5725ae6fe4c2c96b897b051d6dbf2d957f9db0a3 Mon Sep 17 00:00:00 2001 From: Dishon Date: Mon, 18 Jan 2021 16:02:08 +0000 Subject: [PATCH 16/17] Replace time.Sleep with a cancelable timer inside the cache retriever --- CHANGELOG | 27 +++ VERSION | 2 +- app.go | 25 ++- go.mod | 2 +- go.sum | 4 +- internal/auth/auth.go | 117 +++++++--- internal/auth/auth_code.go | 147 +++++++++++++ internal/auth/auth_code_test.go | 99 +++++++++ internal/auth/auth_test.go | 207 ++++++++++-------- internal/logging/logging.go | 1 + internal/rejectmethods/middleware.go | 31 +++ internal/rejectmethods/middleware_test.go | 43 ++++ internal/source/gitlab/cache/retriever.go | 49 ++++- .../source/gitlab/cache/retriever_test.go | 27 +++ internal/vfs/zip/archive.go | 24 +- internal/vfs/zip/archive_test.go | 53 +++++ internal/vfs/zip/deflate_reader.go | 43 +++- metrics/metrics.go | 7 + test/acceptance/acceptance_test.go | 29 ++- test/acceptance/artifacts_test.go | 2 +- test/acceptance/auth_test.go | 109 ++++++++- test/acceptance/proxyv2_test.go | 7 +- test/acceptance/serving_test.go | 3 +- test/acceptance/unknown_http_method_test.go | 23 ++ 24 files changed, 898 insertions(+), 183 deletions(-) create mode 100644 internal/auth/auth_code.go create mode 100644 internal/auth/auth_code_test.go create mode 100644 internal/rejectmethods/middleware.go create mode 100644 internal/rejectmethods/middleware_test.go create mode 100644 internal/source/gitlab/cache/retriever_test.go create mode 100644 test/acceptance/unknown_http_method_test.go diff --git a/CHANGELOG b/CHANGELOG index 9970bff88..e315ddcc0 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,12 @@ +v 1.34.0 + +- Allow DELETE HTTP method + +v 1.33.0 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.32.0 - Try to automatically use gitlab API as a source for domain information !402 @@ -10,6 +19,15 @@ v 1.31.0 - Add zip serving configuration flags !392 - Disable deprecated serverless serving and proxy !400 +v 1.30.2 + +- Allow DELETE HTTP method + +v 1.30.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.30.0 - Allow to refresh an existing cached archive when accessed !375 @@ -23,6 +41,15 @@ v 1.29.0 - Improve httprange timeouts !382 - Fix caching for errored ZIP VFS archives !384 +v 1.28.2 + +- Allow DELETE HTTP method + +v 1.28.1 + +- Reject requests with unknown HTTP methods +- Encrypt OAuth code during auth flow + v 1.28.0 - Implement basic redirects via _redirects text file !367 diff --git a/VERSION b/VERSION index 359c41089..2b17ffd50 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.32.0 +1.34.0 diff --git a/app.go b/app.go index ed06893e4..1352b630b 100644 --- a/app.go +++ b/app.go @@ -28,6 +28,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -337,6 +338,12 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { // Custom response headers handler = a.customHeadersMiddleware(handler) + // This MUST be the last handler! + // This handler blocks unknown HTTP methods, + // being the last means it will be evaluated first + // preventing any operation on bogus requests. + handler = rejectmethods.NewMiddleware(handler) + return handler, nil } @@ -483,10 +490,7 @@ func runApp(config appConfig) { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } - if config.ClientID != "" { - a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) - } + a.setAuth(config) a.Handlers = handlers.New(a.Auth, a.Artifact) @@ -524,6 +528,19 @@ func runApp(config appConfig) { a.Run() } +func (a *theApp) setAuth(config appConfig) { + if config.ClientID == "" { + return + } + + var err error + a.Auth, err = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, + config.RedirectURI, config.GitLabServer) + if err != nil { + log.WithError(err).Fatal("could not initialize auth package") + } +} + // fatal will log a fatal error and exit. func fatal(err error, message string) { log.WithError(err).Fatal(message) diff --git a/go.mod b/go.mod index 76d45a9c9..f06ea125b 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce github.com/wadey/gocovmerge v0.0.0-20160331181800-b5bfa59ec0ad gitlab.com/gitlab-org/go-mimedb v1.45.0 - gitlab.com/gitlab-org/labkit v1.0.0 + gitlab.com/gitlab-org/labkit v1.3.0 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f golang.org/x/net v0.0.0-20200226121028-0de0cce0169b diff --git a/go.sum b/go.sum index 945b05ce7..672bda11e 100644 --- a/go.sum +++ b/go.sum @@ -346,8 +346,8 @@ github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= gitlab.com/gitlab-org/go-mimedb v1.45.0 h1:PO8dx6HEWzPYU6MQTYnCbpQEJzhJLW/Bh43+2VUHTgc= gitlab.com/gitlab-org/go-mimedb v1.45.0/go.mod h1:wa9y/zOSFKmTXLyBs4clz2FNVhZQmmEQM9TxslPAjZ0= -gitlab.com/gitlab-org/labkit v1.0.0 h1:t2Wr8ygtvHfXAMlCkoEdk5pdb5Gy1IYdr41H7t4kAYw= -gitlab.com/gitlab-org/labkit v1.0.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U= +gitlab.com/gitlab-org/labkit v1.3.0 h1:PDP4id5YEvw6juWrGE88LcTtEridtRAOyvNvUOtcc9o= +gitlab.com/gitlab-org/labkit v1.3.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index eaf3c25dd..252954a62 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,14 +16,14 @@ import ( "github.com/gorilla/securecookie" "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/hkdf" + "gitlab.com/gitlab-org/labkit/errortracking" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" - - "golang.org/x/crypto/hkdf" ) // nolint: gosec @@ -47,17 +47,23 @@ var ( errFailAuth = errors.New("Failed to authenticate request") errAuthNotConfigured = errors.New("Authentication is not configured") errQueryParameter = errors.New("Failed to parse domain query parameter") + + errGenerateKeys = errors.New("could not generate auth keys") ) // Auth handles authenticating users with GitLab API type Auth struct { - pagesDomain string - clientID string - clientSecret string - redirectURI string - gitLabServer string - apiClient *http.Client - store sessions.Store + pagesDomain string + clientID string + clientSecret string + redirectURI string + gitLabServer string + authSecret string + jwtSigningKey []byte + jwtExpiry time.Duration + apiClient *http.Client + store sessions.Store + now func() time.Time // allows to stub time.Now() easily in tests } type tokenResponse struct { @@ -111,7 +117,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { return false @@ -166,11 +172,18 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res return } - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + if err != nil { + logRequest(r).WithError(err).Error("failed to decrypt secure code") + errortracking.Capture(err, errortracking.WithRequest(r)) + httperrors.Serve500(w) + return + } - // Fetching token not OK + // Fetch access token with authorization code + token, err := a.fetchAccessToken(decryptedCode) if err != nil { + // Fetching token not OK logRequest(r).WithError(err).WithField( "redirect_uri", redirectURI, ).Error(errFetchAccessToken) @@ -216,8 +229,8 @@ func (a *Auth) domainAllowed(name string, domains source.Source) bool { } func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { - // If request is for authenticating via custom domain - if shouldProxyAuth(r) { + // handle auth callback e.g. https://gitlab.io/auth?domain&domain&state=state + if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") @@ -266,6 +279,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } // If auth request callback should be proxied to custom domain + // redirect to originating domain set in the cookie as proxy_auth_domain if shouldProxyCallbackToCustomDomain(r, session) { // Get domain started auth process proxyDomain := session.Values["proxy_auth_domain"].(string) @@ -283,9 +297,30 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit return true } - // Redirect pages under custom domain - http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) + query := r.URL.Query() + + // prevent https://tools.ietf.org/html/rfc6749#section-10.6 and + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 by encrypting + // and signing the OAuth code + signedCode, err := a.EncryptAndSignCode(proxyDomain, query.Get("code")) + if err != nil { + logRequest(r).WithError(err).Error(errSaveSession) + errortracking.Capture(err, errortracking.WithRequest(r)) + + httperrors.Serve503(w) + return true + } + + // prevent forwarding access token, more context on the security issue + // https://gitlab.com/gitlab-org/gitlab/-/issues/285244#note_451266051 + query.Del("token") + + // replace code with signed code + query.Set("code", signedCode) + // Redirect pages to originating domain with code and state to finish + // authentication process + http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), 302) return true } @@ -306,7 +341,7 @@ func getRequestDomain(r *http.Request) string { return "http://" + r.Host } -func shouldProxyAuth(r *http.Request) bool { +func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } @@ -376,6 +411,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return nil } + // redirect to /auth?domain=%s&state=%s if a.checkTokenExists(session, w, r) { return nil } @@ -586,28 +622,37 @@ func logRequest(r *http.Request) *log.Entry { }) } -// generateKeyPair returns key pair for secure cookie: signing and encryption key -func generateKeyPair(storeSecret string) ([]byte, []byte) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(storeSecret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) - var keys [][]byte - for i := 0; i < 2; i++ { +// generateKeys derives count hkdf keys from a secret, ensuring the key is +// the same for the same secret used across multiple instances +func generateKeys(secret string, count int) ([][]byte, error) { + keys := make([][]byte, count) + hkdfReader := hkdf.New(sha256.New, []byte(secret), []byte{}, []byte("PAGES_SIGNING_AND_ENCRYPTION_KEY")) + + for i := 0; i < count; i++ { key := make([]byte, 32) - if _, err := io.ReadFull(hkdf, key); err != nil { - log.WithError(err).Fatal("Can't generate key pair for secure cookies") + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err } - keys = append(keys, key) + + keys[i] = key + } + + if len(keys) < count { + return nil, errGenerateKeys } - return keys[0], keys[1] -} -func createCookieStore(storeSecret string) sessions.Store { - return sessions.NewCookieStore(generateKeyPair(storeSecret)) + return keys, nil } // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string) (*Auth, error) { + // generate 3 keys, 2 for the cookie store and 1 for JWT signing + keys, err := generateKeys(storeSecret, 3) + if err != nil { + return nil, err + } + return &Auth{ pagesDomain: pagesDomain, clientID: clientID, @@ -618,6 +663,10 @@ func New(pagesDomain string, storeSecret string, clientID string, clientSecret s Timeout: 5 * time.Second, Transport: httptransport.InternalTransport, }, - store: createCookieStore(storeSecret), - } + store: sessions.NewCookieStore(keys[0], keys[1]), + authSecret: storeSecret, + jwtSigningKey: keys[2], + jwtExpiry: time.Minute, + now: time.Now, + }, nil } diff --git a/internal/auth/auth_code.go b/internal/auth/auth_code.go new file mode 100644 index 000000000..d2fea5a95 --- /dev/null +++ b/internal/auth/auth_code.go @@ -0,0 +1,147 @@ +package auth + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/dgrijalva/jwt-go" + "github.com/gorilla/securecookie" + "golang.org/x/crypto/hkdf" +) + +var ( + errInvalidToken = errors.New("invalid token") + errEmptyDomainOrCode = errors.New("empty domain or code") + errInvalidNonce = errors.New("invalid nonce") + errInvalidCode = errors.New("invalid code") +) + +// EncryptAndSignCode encrypts the OAuth code deriving the key from the domain. +// It adds the code and domain as JWT token claims and signs it using signingKey derived from +// the Auth secret. +func (a *Auth) EncryptAndSignCode(domain, code string) (string, error) { + if domain == "" || code == "" { + return "", errEmptyDomainOrCode + } + + nonce := base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + // encrypt code with a randomly generated nonce + encryptedCode := aesGcm.Seal(nil, []byte(nonce), []byte(code), nil) + + // generate JWT token claims with encrypted code + claims := jwt.MapClaims{ + // standard claims + "iss": "gitlab-pages", + "iat": a.now().Unix(), + "exp": a.now().Add(a.jwtExpiry).Unix(), + // custom claims + "domain": domain, // pass the domain so we can validate the signed domain matches the requested domain + "code": hex.EncodeToString(encryptedCode), + "nonce": nonce, + } + + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(a.jwtSigningKey) +} + +// DecryptCode decodes the secureCode as a JWT token and validates its signature. +// It then decrypts the code from the token claims and returns it. +func (a *Auth) DecryptCode(jwt, domain string) (string, error) { + claims, err := a.parseJWTClaims(jwt) + if err != nil { + return "", err + } + + // get nonce and encryptedCode from the JWT claims + nonce, ok := claims["nonce"].(string) + if !ok { + return "", errInvalidNonce + } + + encryptedCode, ok := claims["code"].(string) + if !ok { + return "", errInvalidCode + } + + cipherText, err := hex.DecodeString(encryptedCode) + if err != nil { + return "", err + } + + aesGcm, err := a.newAesGcmCipher(domain, nonce) + if err != nil { + return "", err + } + + decryptedCode, err := aesGcm.Open(nil, []byte(nonce), cipherText, nil) + if err != nil { + return "", err + } + + return string(decryptedCode), nil +} + +func (a *Auth) codeKey(domain string) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, []byte(a.authSecret), []byte(domain), []byte("PAGES_AUTH_CODE_ENCRYPTION_KEY")) + + key := make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return nil, err + } + + return key, nil +} + +func (a *Auth) parseJWTClaims(secureCode string) (jwt.MapClaims, error) { + token, err := jwt.Parse(secureCode, a.getSigningKey) + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errInvalidToken + } + + return claims, nil +} + +func (a *Auth) getSigningKey(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return a.jwtSigningKey, nil +} + +func (a *Auth) newAesGcmCipher(domain, nonce string) (cipher.AEAD, error) { + // get the same key for a domain + key, err := a.codeKey(domain) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aesGcm, err := cipher.NewGCMWithNonceSize(block, len(nonce)) + if err != nil { + return nil, err + } + + return aesGcm, nil +} diff --git a/internal/auth/auth_code_test.go b/internal/auth/auth_code_test.go new file mode 100644 index 000000000..d54fcc7ea --- /dev/null +++ b/internal/auth/auth_code_test.go @@ -0,0 +1,99 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestEncryptAndDecryptSignedCode(t *testing.T) { + auth := createTestAuth(t, "") + + tests := map[string]struct { + auth *Auth + encDomain string + code string + expectedEncErrMsg string + decDomain string + expectedDecErrMsg string + }{ + "happy_path": { + auth: auth, + encDomain: "domain", + decDomain: "domain", + code: "code", + }, + "empty_domain": { + auth: auth, + encDomain: "", + code: "code", + expectedEncErrMsg: "empty domain or code", + }, + "empty_code": { + auth: auth, + encDomain: "domain", + code: "", + expectedEncErrMsg: "empty domain or code", + }, + "different_dec_domain": { + auth: auth, + encDomain: "domain", + decDomain: "another", + code: "code", + expectedDecErrMsg: "cipher: message authentication failed", + }, + "expired_token": { + auth: func() *Auth { + newAuth := *auth + newAuth.jwtExpiry = time.Nanosecond + newAuth.now = func() time.Time { + return time.Time{} + } + + return &newAuth + }(), + encDomain: "domain", + code: "code", + decDomain: "domain", + expectedDecErrMsg: "Token is expired", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + encCode, err := test.auth.EncryptAndSignCode(test.encDomain, test.code) + if test.expectedEncErrMsg != "" { + require.EqualError(t, err, test.expectedEncErrMsg) + require.Empty(t, encCode) + return + } + + require.NoError(t, err) + require.NotEmpty(t, encCode) + + decCode, err := test.auth.DecryptCode(encCode, test.decDomain) + if test.expectedDecErrMsg != "" { + require.EqualError(t, err, test.expectedDecErrMsg) + require.Empty(t, decCode) + return + } + + require.NoError(t, err) + require.Equal(t, test.code, decCode) + }) + } +} + +func TestDecryptCodeWithInvalidJWT(t *testing.T) { + auth1 := createTestAuth(t, "") + auth2 := createTestAuth(t, "") + auth2.jwtSigningKey = []byte("another signing key") + + encCode, err := auth1.EncryptAndSignCode("domain", "code") + require.NoError(t, err) + + decCode, err := auth2.DecryptCode(encCode, "domain") + require.EqualError(t, err, "signature is invalid") + require.Empty(t, decCode) +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 39a533b35..ce7d83207 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/sessions" @@ -16,17 +17,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) -func createAuth(t *testing.T) *Auth { - return New("pages.gitlab-example.com", +func createTestAuth(t *testing.T, url string) *Auth { + t.Helper() + + a, err := New("pages.gitlab-example.com", "something-very-secret", "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") -} + url) + + require.NoError(t, err) -func defaultCookieStore() sessions.Store { - return createCookieStore("something-very-secret") + return a } type domainMock struct { @@ -48,10 +51,13 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(r *http.Request, values map[interface{}]interface{}) { - tmpRequest, _ := http.NewRequest("GET", "/", nil) +func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { + t.Helper() + + tmpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + result := httptest.NewRecorder() - store := defaultCookieStore() session, _ := store.Get(tmpRequest, "gitlab-pages") session.Values = values @@ -63,7 +69,7 @@ func setSessionValues(r *http.Request, values map[interface{}]interface{}) { } func TestTryAuthenticate(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") @@ -75,11 +81,12 @@ func TestTryAuthenticate(t *testing.T) { } func TestTryAuthenticateWithError(t *testing.T) { - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) + reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} @@ -88,8 +95,7 @@ func TestTryAuthenticateWithError(t *testing.T) { } func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := createAuth(t) + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") @@ -97,7 +103,9 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["state"] = "state" session.Save(r, result) @@ -105,7 +113,36 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { require.Equal(t, 401, result.Code) } +func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { + auth := createTestAuth(t, "") + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") + require.NoError(t, err) + + require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + + session.Values["state"] = "state" + session.Values["proxy_auth_domain"] = "https://domain.com" + session.Save(r, result) + + require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, http.StatusFound, result.Code) + + redirect, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + + require.Empty(t, redirect.Query().Get("token"), "token is gone after redirecting") +} + func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { + t.Helper() + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/oauth/token": @@ -125,14 +162,17 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { apiServer.Start() defer apiServer.Close() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) + + domain := apiServer.URL + if https { + domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + } - r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + code, err := auth.EncryptAndSignCode(domain, "1") + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -140,14 +180,16 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - setSessionValues(r, map[interface{}]interface{}{ + r.Host = strings.TrimPrefix(apiServer.URL, "http://") + + setSessionValues(t, r, auth.store, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) result := httptest.NewRecorder() require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) - require.Equal(t, 302, result.Code) + require.Equal(t, http.StatusFound, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) require.Equal(t, https, result.Result().Cookies()[0].Secure) @@ -177,13 +219,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -191,7 +227,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) @@ -217,13 +255,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) w := httptest.NewRecorder() @@ -232,7 +264,9 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, w) @@ -265,22 +299,19 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" - session.Save(r, result) + err = session.Save(r, result) + require.NoError(t, err) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -303,13 +334,7 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -317,7 +342,9 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -343,19 +370,16 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { apiServer.Start() defer apiServer.Close() - store := defaultCookieStore() - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - apiServer.URL) + auth := createTestAuth(t, apiServer.URL) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -364,28 +388,31 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { require.Equal(t, 302, result.Code) } -func TestGenerateKeyPair(t *testing.T) { - signingSecret, encryptionSecret := generateKeyPair("something-very-secret") - require.NotEqual(t, fmt.Sprint(signingSecret), fmt.Sprint(encryptionSecret)) - require.Equal(t, len(signingSecret), 32) - require.Equal(t, len(encryptionSecret), 32) +func TestGenerateKeys(t *testing.T) { + keys, err := generateKeys("something-very-secret", 3) + require.NoError(t, err) + require.Len(t, keys, 3) + + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[1])) + require.NotEqual(t, fmt.Sprint(keys[0]), fmt.Sprint(keys[2])) + require.NotEqual(t, fmt.Sprint(keys[1]), fmt.Sprint(keys[2])) + + require.Equal(t, len(keys[0]), 32) + require.Equal(t, len(keys[1]), 32) + require.Equal(t, len(keys[2]), 32) } func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Values["access_token"] = "abc" session.Save(r, result) @@ -395,20 +422,16 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { } func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { - store := sessions.NewCookieStore([]byte("something-very-secret")) - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, _ := store.Get(r, "gitlab-pages") + session, err := auth.store.Get(r, "gitlab-pages") + require.NoError(t, err) + session.Save(r, result) token, err := auth.GetTokenIfExists(result, r) @@ -417,12 +440,7 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -437,12 +455,7 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { } func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { - auth := New("pages.gitlab-example.com", - "something-very-secret", - "id", - "secret", - "http://pages.gitlab-example.com/auth", - "") + auth := createTestAuth(t, "") result := httptest.NewRecorder() reqURL, err := url.Parse("/something") diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 6643e169e..9a3629ffe 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -88,6 +88,7 @@ func BasicAccessLogger(handler http.Handler, format string, extraFields log.Extr return log.AccessLogger(handler, log.WithExtraFields(extraFields), log.WithAccessLogger(accessLogger), + log.WithXFFAllowed(func(sip string) bool { return false }), ), nil } diff --git a/internal/rejectmethods/middleware.go b/internal/rejectmethods/middleware.go new file mode 100644 index 000000000..e78a0ce59 --- /dev/null +++ b/internal/rejectmethods/middleware.go @@ -0,0 +1,31 @@ +package rejectmethods + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var acceptedMethods = map[string]bool{ + http.MethodGet: true, + http.MethodHead: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, + http.MethodConnect: true, + http.MethodOptions: true, + http.MethodTrace: true, +} + +// NewMiddleware returns middleware which rejects all unknown http methods +func NewMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if acceptedMethods[r.Method] { + handler.ServeHTTP(w, r) + } else { + metrics.RejectedRequestsCount.Inc() + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + } + }) +} diff --git a/internal/rejectmethods/middleware_test.go b/internal/rejectmethods/middleware_test.go new file mode 100644 index 000000000..2921975ae --- /dev/null +++ b/internal/rejectmethods/middleware_test.go @@ -0,0 +1,43 @@ +package rejectmethods + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "OK\n") + }) + + middleware := NewMiddleware(handler) + + acceptedMethods := []string{"GET", "HEAD", "POST", "PUT", "PATCH", "CONNECT", "OPTIONS", "TRACE"} + for _, method := range acceptedMethods { + t.Run(method, func(t *testing.T) { + tmpRequest, _ := http.NewRequest(method, "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + }) + } + + t.Run("UNKNOWN", func(t *testing.T) { + tmpRequest, _ := http.NewRequest("UNKNOWN", "/", nil) + recorder := httptest.NewRecorder() + + middleware.ServeHTTP(recorder, tmpRequest) + + result := recorder.Result() + + require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) + }) +} diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231a..4341c7901 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -3,6 +3,8 @@ package cache import ( "context" "errors" + "sync" + "time" log "github.com/sirupsen/logrus" @@ -13,15 +15,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +64,24 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // timer.Stop() + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +93,25 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 000000000..774e9779e --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +} diff --git a/internal/vfs/zip/archive.go b/internal/vfs/zip/archive.go index 1137f0041..9826cdd63 100644 --- a/internal/vfs/zip/archive.go +++ b/internal/vfs/zip/archive.go @@ -7,7 +7,7 @@ import ( "fmt" "io" "os" - "path/filepath" + "path" "strconv" "strings" "sync" @@ -161,34 +161,30 @@ func (a *zipArchive) readArchive(url string) { } // addPathDirectory adds a directory for a given path -func (a *zipArchive) addPathDirectory(path string) { +func (a *zipArchive) addPathDirectory(pathname string) { // Split dir and file from `path` - path, _ = filepath.Split(path) - if path == "" { + pathname, _ = path.Split(pathname) + if pathname == "" { return } - if a.directories[path] != nil { + if a.directories[pathname] != nil { return } - a.directories[path] = &zip.FileHeader{ - Name: path, + a.directories[pathname] = &zip.FileHeader{ + Name: pathname, } } func (a *zipArchive) findFile(name string) *zip.File { - name = filepath.Join(dirPrefix, name) + name = path.Clean(dirPrefix + name) - if file := a.files[name]; file != nil { - return file - } - - return nil + return a.files[name] } func (a *zipArchive) findDirectory(name string) *zip.FileHeader { - name = filepath.Join(dirPrefix, name) + name = path.Clean(dirPrefix + name) return a.directories[name+"/"] } diff --git a/internal/vfs/zip/archive_test.go b/internal/vfs/zip/archive_test.go index da778e620..58b7c74ab 100644 --- a/internal/vfs/zip/archive_test.go +++ b/internal/vfs/zip/archive_test.go @@ -1,11 +1,16 @@ package zip import ( + "archive/zip" + "bytes" "context" + "crypto/rand" + "io" "io/ioutil" "net/http" "net/http/httptest" "os" + "strconv" "sync/atomic" "testing" "time" @@ -419,3 +424,51 @@ func newZipFileServerURL(t *testing.T, zipFilePath string, requests *int64) (str testServer.Close() } } + +func benchmarkArchiveRead(b *testing.B, size int64) { + zbuf := new(bytes.Buffer) + + // create zip file of specified size + zw := zip.NewWriter(zbuf) + w, err := zw.Create("public/file.txt") + require.NoError(b, err) + _, err = io.CopyN(w, rand.Reader, size) + require.NoError(b, err) + require.NoError(b, zw.Close()) + + modtime := time.Now().Add(-time.Hour) + + m := http.NewServeMux() + m.HandleFunc("/public.zip", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, "public.zip", modtime, bytes.NewReader(zbuf.Bytes())) + })) + + ts := httptest.NewServer(m) + defer ts.Close() + + fs := New(zipCfg).(*zipVFS) + + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + z := newArchive(fs, time.Second) + err := z.openArchive(context.Background(), ts.URL+"/public.zip") + require.NoError(b, err) + + f, err := z.Open(context.Background(), "file.txt") + require.NoError(b, err) + + _, err = io.Copy(ioutil.Discard, f) + require.NoError(b, err) + + require.NoError(b, f.Close()) + } +} + +func BenchmarkArchiveRead(b *testing.B) { + for _, size := range []int{32 * 1024, 64 * 1024, 1024 * 1024} { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchmarkArchiveRead(b, int64(size)) + }) + } +} diff --git a/internal/vfs/zip/deflate_reader.go b/internal/vfs/zip/deflate_reader.go index 16a2d72e0..87a7da0c6 100644 --- a/internal/vfs/zip/deflate_reader.go +++ b/internal/vfs/zip/deflate_reader.go @@ -1,31 +1,66 @@ package zip import ( + "bufio" "compress/flate" + "errors" "io" + "sync" ) +var ErrClosedReader = errors.New("deflatereader: reader is closed") + +var deflateReaderPool sync.Pool + // deflateReader wrapper to support reading compressed files. // Implements the io.ReadCloser interface. type deflateReader struct { - reader io.ReadCloser + reader *bufio.Reader + closer io.Closer flateReader io.ReadCloser } // Read from flateReader func (r *deflateReader) Read(p []byte) (n int, err error) { + if r.closer == nil { + return 0, ErrClosedReader + } + return r.flateReader.Read(p) } // Close all readers func (r *deflateReader) Close() error { - r.reader.Close() + if r.closer == nil { + return ErrClosedReader + } + + defer func() { + r.closer.Close() + r.closer = nil + deflateReaderPool.Put(r) + }() + return r.flateReader.Close() } +func (r *deflateReader) reset(rc io.ReadCloser) { + r.reader.Reset(rc) + r.closer = rc + r.flateReader.(flate.Resetter).Reset(r.reader, nil) +} + func newDeflateReader(r io.ReadCloser) *deflateReader { + if dr, ok := deflateReaderPool.Get().(*deflateReader); ok { + dr.reset(r) + return dr + } + + br := bufio.NewReader(r) + return &deflateReader{ - reader: r, - flateReader: flate.NewReader(r), + reader: br, + closer: r, + flateReader: flate.NewReader(br), } } diff --git a/metrics/metrics.go b/metrics/metrics.go index db7cae9a8..045ff26e0 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -199,6 +199,13 @@ var ( Help: "The number of files per zip archive total count over time", }, ) + + RejectedRequestsCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_pages_unknown_method_rejected_requests", + Help: "The number of requests with unknown HTTP method which were rejected", + }, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/acceptance_test.go b/test/acceptance/acceptance_test.go index 9921076ea..ba6528c10 100644 --- a/test/acceptance/acceptance_test.go +++ b/test/acceptance/acceptance_test.go @@ -17,24 +17,31 @@ const ( var ( pagesBinary = flag.String("gitlab-pages-binary", "../../gitlab-pages", "Path to the gitlab-pages binary") + httpPort = "36000" + httpsPort = "37000" + httpProxyPort = "38000" + httpProxyV2Port = "39000" + // TODO: Use TCP port 0 everywhere to avoid conflicts. The binary could output // the actual port (and type of listener) for us to read in place of the // hardcoded values below. listeners = []ListenSpec{ - {"http", "127.0.0.1", "37000"}, - {"http", "::1", "37000"}, - {"https", "127.0.0.1", "37001"}, - {"https", "::1", "37001"}, - {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, - {"https-proxyv2", "127.0.0.1", "37003"}, - {"https-proxyv2", "::1", "37003"}, + {"http", "127.0.0.1", httpPort}, + {"https", "127.0.0.1", httpsPort}, + {"proxy", "127.0.0.1", httpProxyPort}, + {"https-proxyv2", "127.0.0.1", httpProxyV2Port}, + // TODO: re-enable IPv6 listeners once https://gitlab.com/gitlab-com/gl-infra/infrastructure/-/issues/12258 is resolved + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"http", "::1", httpPort}, + // {"https", "::1", httpsPort}, + // {"proxy", "::1", httpProxyPort}, + // {"https-proxyv2", "::1", httpProxyV2Port}, } httpListener = listeners[0] - httpsListener = listeners[2] - proxyListener = listeners[4] - httpsProxyv2Listener = listeners[6] + httpsListener = listeners[1] + proxyListener = listeners[2] + httpsProxyv2Listener = listeners[3] ) func TestMain(m *testing.M) { diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index 3440ef34f..57c7a02a9 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -245,7 +245,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { ) defer teardown() - resp, err := GetRedirectPage(t, httpListener, tt.host, tt.path) + resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path) require.NoError(t, err) defer resp.Body.Close() diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index b2233591f..fa2d768d8 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -88,7 +88,7 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) { require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode) } -func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { +func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") defer teardown() @@ -110,8 +110,8 @@ func TestWhenLoginCallbackWithCorrectStateWithoutEndpoint(t *testing.T) { require.NoError(t, err) defer authrsp.Body.Close() - // Will cause 503 because token endpoint is not available - require.Equal(t, http.StatusServiceUnavailable, authrsp.StatusCode) + // Will cause 500 because the code is not encrypted + require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode) } func handleAccessControlArtifactRequests(t *testing.T, w http.ResponseWriter, r *http.Request) bool { @@ -234,11 +234,10 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) - require.Equal(t, state, url.Query().Get("state")) + code := url.Query().Get("code") + require.NotEqual(t, "1", code) - // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -336,11 +335,13 @@ func TestCustomErrorPageWithAuth(t *testing.T) { // Will redirect to custom domain require.Equal(t, tt.domain, url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ + authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ state, cookie) require.NoError(t, err) @@ -409,12 +410,14 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) { // Will redirect to custom domain require.Equal(t, "private.domain.com", url.Host) - require.Equal(t, "1", url.Query().Get("code")) + // code must have changed since it's encrypted + code := url.Query().Get("code") + require.NotEqual(t, "1", code) require.Equal(t, state, url.Query().Get("state")) // Run auth callback in custom domain authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", - "/auth?code=1&state="+state, cookie, true) + "/auth?code="+code+"&state="+state, cookie, true) require.NoError(t, err) defer authrsp.Body.Close() @@ -641,3 +644,87 @@ func TestAccessControlWithSSLCertFile(t *testing.T) { func TestAccessControlWithSSLCertDir(t *testing.T) { testAccessControl(t, RunPagesProcessWithAuthServerWithSSLCertDir) } + +// This proves the fix for https://gitlab.com/gitlab-org/gitlab-pages/-/issues/262 +// Read the issue description if any changes to internal/auth/ break this test. +// Related to https://tools.ietf.org/html/rfc6749#section-10.6. +func TestHijackedCode(t *testing.T) { + skipUnlessEnabled(t, "not-inplace-chroot") + + testServer := makeGitLabPagesAccessStub(t) + testServer.Start() + defer testServer.Close() + + teardown := RunPagesProcessWithAuthServer(t, *pagesBinary, listeners, "", testServer.URL) + defer teardown() + + /****ATTACKER******/ + // get valid cookie for a different private project + targetDomain := "private.domain.com" + attackersDomain := "group.auth.gitlab-example.com" + attackerCookie, attackerState := getValidCookieAndState(t, targetDomain) + + /****TARGET******/ + // fool target to click on modified URL with attacker's domain for redirect with a valid state + hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant") + maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true) + require.NoError(t, err) + defer maliciousResp.Body.Close() + + pagesCookie := maliciousResp.Header.Get("Set-Cookie") + + /* + OAuth flow happens here... + */ + maliciousRespURL, err := url.Parse(maliciousResp.Header.Get("Location")) + require.NoError(t, err) + maliciousState := maliciousRespURL.Query().Get("state") + + // Go to auth page with correct state and code "obtained" from GitLab + authrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, + "projects.gitlab-example.com", "/auth?code=1&state="+maliciousState, + pagesCookie, true) + + require.NoError(t, err) + defer authrsp.Body.Close() + + /****ATTACKER******/ + // Target is redirected to attacker's domain and attacker receives the proper code + require.Equal(t, http.StatusFound, authrsp.StatusCode, "should redirect to attacker's domain") + authrspURL, err := url.Parse(authrsp.Header.Get("Location")) + require.NoError(t, err) + require.Contains(t, authrspURL.String(), attackersDomain) + + // attacker's got the code + hijackedCode := authrspURL.Query().Get("code") + require.NotEmpty(t, hijackedCode) + + // attacker tries to access private pages content + impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain, + "/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true) + require.NoError(t, err) + defer authrsp.Body.Close() + + require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code") +} + +func getValidCookieAndState(t *testing.T, domain string) (string, string) { + t.Helper() + + // follow flow to get a valid cookie + // visit https:/// + rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true) + require.NoError(t, err) + defer rsp.Body.Close() + + cookie := rsp.Header.Get("Set-Cookie") + require.NotEmpty(t, cookie) + + redirectURL, err := url.Parse(rsp.Header.Get("Location")) + require.NoError(t, err) + + state := redirectURL.Query().Get("state") + require.NotEmpty(t, state) + + return cookie, state +} diff --git a/test/acceptance/proxyv2_test.go b/test/acceptance/proxyv2_test.go index c407ea194..2a42f0f1c 100644 --- a/test/acceptance/proxyv2_test.go +++ b/test/acceptance/proxyv2_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -46,7 +47,11 @@ func TestProxyv2(t *testing.T) { require.Contains(t, string(body), tt.expectedContent, "content mismatch") - require.Contains(t, logBuf.String(), tt.expectedLog, "log mismatch") + // give the process enough time to write the log message + require.Eventually(t, func() bool { + require.Contains(t, logBuf.String(), tt.expectedLog, "log mismatch") + return true + }, time.Second, time.Millisecond) }) } } diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 66b5fa477..becd6b8cd 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -515,7 +515,8 @@ func TestKnownHostInReverseProxySetupReturns200(t *testing.T) { var listeners = []ListenSpec{ {"proxy", "127.0.0.1", "37002"}, - {"proxy", "::1", "37002"}, + // TODO: re-enable https://gitlab.com/gitlab-org/gitlab-pages/-/issues/528 + // {"proxy", "::1", "37002"}, } teardown := RunPagesProcess(t, *pagesBinary, listeners, "") diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go new file mode 100644 index 000000000..f6c5ffee5 --- /dev/null +++ b/test/acceptance/unknown_http_method_test.go @@ -0,0 +1,23 @@ +package acceptance_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnknownHTTPMethod(t *testing.T) { + skipUnlessEnabled(t) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") + defer teardown() + + req, err := http.NewRequest("UNKNOWN", listeners[0].URL(""), nil) + require.NoError(t, err) + req.Host = "" + + resp, err := DoPagesRequest(t, httpListener, req) + require.NoError(t, err) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} -- GitLab From c02c5b3dd7f3f4b162d62a9ee5520cdbce0a0a59 Mon Sep 17 00:00:00 2001 From: Dishon Date: Tue, 19 Jan 2021 06:55:44 +0000 Subject: [PATCH 17/17] Apply 1 suggestion(s) to 1 file(s) --- internal/source/gitlab/cache/retriever.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index 4341c7901..50140cfc4 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -76,7 +76,6 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha continue Retry case <-ctx.Done(): // when the retrieval context is done we stop the timer - // timer.Stop() // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") // when the retrieval context is done we stop the timerFunc r.timer.stop() -- GitLab