diff --git a/internal/featureflag/ff_unauthenticated_concurrency.go b/internal/featureflag/ff_unauthenticated_concurrency.go new file mode 100644 index 0000000000000000000000000000000000000000..d68a5c13d78ec831048f98f6ac8fdc1c1e4d66ec --- /dev/null +++ b/internal/featureflag/ff_unauthenticated_concurrency.go @@ -0,0 +1,10 @@ +package featureflag + +// LimitUnauthenticated allows the concurrency limiter to limit unauthenticated +// requests separately from authenticated requests. +var LimitUnauthenticated = NewFeatureFlag( + "limit_unauthenticated", + "v18.6.0", + "https://gitlab.com/gitlab-org/gitaly/-/issues/6955", + true, +) diff --git a/internal/gitaly/config/config.go b/internal/gitaly/config/config.go index 3b0c90ccc05bb96b6b5f5f7b4ad8c710935c5087..1e53e05f508d830957539d0e9befa94d069c70bf 100644 --- a/internal/gitaly/config/config.go +++ b/internal/gitaly/config/config.go @@ -525,8 +525,24 @@ type Logging struct { // Requests that come in after the maximum number of concurrent requests are in progress will wait // in a queue that is bounded by MaxQueueSize. type Concurrency struct { + ConcurrencyLimits // RPC is the name of the RPC to set concurrency limits for RPC string `json:"rpc" toml:"rpc"` + // Unauthenticated sets the limits for unauthenticated requests + Unauthenticated ConcurrencyLimits `json:"unauthenticated" toml:"unauthenticated"` +} + +// ConcurrencyLimits sets the limits for adaptive limiting +type ConcurrencyLimits struct { + // MaxPerRepo is the maximum number of concurrent calls for a given repository. This config is used only + // if Adaptive is false. + MaxPerRepo int `json:"max_per_repo" toml:"max_per_repo"` + // MaxQueueSize is the maximum number of requests in the queue waiting to be picked up + // after which subsequent requests will return with an error. + MaxQueueSize int `json:"max_queue_size" toml:"max_queue_size"` + // MaxQueueWait is the maximum time a request can remain in the concurrency queue + // waiting to be picked up by Gitaly + MaxQueueWait duration.Duration `json:"max_queue_wait" toml:"max_queue_wait"` // Adaptive determines the behavior of the concurrency limit. If set to true, the concurrency limit is dynamic // and starts at InitialLimit, then adjusts within the range [MinLimit, MaxLimit] based on current resource // usage. If set to false, the concurrency limit is static and is set to MaxPerRepo. @@ -537,15 +553,6 @@ type Concurrency struct { MaxLimit int `json:"max_limit,omitempty" toml:"max_limit,omitempty"` // MinLimit is the mini adaptive concurrency limit. MinLimit int `json:"min_limit,omitempty" toml:"min_limit,omitempty"` - // MaxPerRepo is the maximum number of concurrent calls for a given repository. This config is used only - // if Adaptive is false. - MaxPerRepo int `json:"max_per_repo" toml:"max_per_repo"` - // MaxQueueSize is the maximum number of requests in the queue waiting to be picked up - // after which subsequent requests will return with an error. - MaxQueueSize int `json:"max_queue_size" toml:"max_queue_size"` - // MaxQueueWait is the maximum time a request can remain in the concurrency queue - // waiting to be picked up by Gitaly - MaxQueueWait duration.Duration `json:"max_queue_wait" toml:"max_queue_wait"` } // Validate runs validation on all fields and compose all found errors. diff --git a/internal/gitaly/config/config_test.go b/internal/gitaly/config/config_test.go index ae61e876564209a2fe21cba350af3f1802c3f13f..2232d06382d120b947a1820436d092916d5d2b60 100644 --- a/internal/gitaly/config/config_test.go +++ b/internal/gitaly/config/config_test.go @@ -2056,9 +2056,11 @@ func TestConcurrency(t *testing.T) { max_per_repo = 20 `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 500, + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 500, + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2070,10 +2072,12 @@ func TestConcurrency(t *testing.T) { max_queue_wait = "10s" `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(10 * time.Second), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(10 * time.Second), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2085,10 +2089,12 @@ func TestConcurrency(t *testing.T) { max_queue_wait = "1m" `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(1 * time.Minute), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(1 * time.Minute), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2106,15 +2112,19 @@ func TestConcurrency(t *testing.T) { `, expectedCfg: []Concurrency{ { - RPC: "/gitaly.CommitService/ListCommits", - MaxPerRepo: 20, - MaxQueueSize: 20, + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 20, + }, + RPC: "/gitaly.CommitService/ListCommits", }, { - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 30, - MaxQueueSize: 500, - MaxQueueWait: duration.Duration(10 * time.Second), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 30, + MaxQueueSize: 500, + MaxQueueWait: duration.Duration(10 * time.Second), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }, }, }, @@ -2130,13 +2140,15 @@ func TestConcurrency(t *testing.T) { initial_limit = 40 `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.SmartHTTPService/PostUploadPack", - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(1 * time.Minute), - Adaptive: true, - MinLimit: 10, - MaxLimit: 60, - InitialLimit: 40, + ConcurrencyLimits: ConcurrencyLimits{ + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(1 * time.Minute), + Adaptive: true, + MinLimit: 10, + MaxLimit: 60, + InitialLimit: 40, + }, + RPC: "/gitaly.SmartHTTPService/PostUploadPack", }}, }, } @@ -2160,9 +2172,9 @@ func TestConcurrency(t *testing.T) { func TestConcurrency_Validate(t *testing.T) { t.Parallel() - require.NoError(t, Concurrency{MaxPerRepo: 0, MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxPerRepo: 1, MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxPerRepo: 100, MaxQueueSize: 100}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 0, MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 1, MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 100, MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2171,12 +2183,12 @@ func TestConcurrency_Validate(t *testing.T) { "max_per_repo", ), }, - Concurrency{MaxPerRepo: -1, MaxQueueSize: 1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: -1, MaxQueueSize: 1}}.Validate(), ) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 100, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 100, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2185,7 +2197,7 @@ func TestConcurrency_Validate(t *testing.T) { "min_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2195,7 +2207,7 @@ func TestConcurrency_Validate(t *testing.T) { "initial_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: -1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: -1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2205,7 +2217,7 @@ func TestConcurrency_Validate(t *testing.T) { "initial_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2215,7 +2227,7 @@ func TestConcurrency_Validate(t *testing.T) { "max_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2225,7 +2237,7 @@ func TestConcurrency_Validate(t *testing.T) { "min_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2235,11 +2247,11 @@ func TestConcurrency_Validate(t *testing.T) { "max_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueSize: 100}}.Validate(), ) - require.NoError(t, Concurrency{MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxQueueSize: 100}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2248,7 +2260,7 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_size", ), }, - Concurrency{MaxQueueSize: 0}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 0}}.Validate(), ) require.Equal( t, @@ -2258,10 +2270,10 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_size", ), }, - Concurrency{MaxQueueSize: -1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: -1}}.Validate(), ) - require.NoError(t, Concurrency{MaxQueueWait: duration.Duration(1), MaxQueueSize: 1}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(1), MaxQueueSize: 1}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2270,7 +2282,7 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_wait", ), }, - Concurrency{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueSize: 1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueSize: 1}}.Validate(), ) } diff --git a/internal/gitaly/server/auth/auth.go b/internal/gitaly/server/auth/auth.go index b6db8a3a39aa743cb0e8ccf9fa4093d39322e334..93e81ee01f9a59aa8b7b20b66ea3336f38698a45 100644 --- a/internal/gitaly/server/auth/auth.go +++ b/internal/gitaly/server/auth/auth.go @@ -23,6 +23,20 @@ var authCount = promauto.NewCounterVec( []string{"enforced", "status"}, ) +type authenticatedKey struct{} + +// IsAuthenticated returns true if the request has been validated by the auth interceptor. +// This is different from just having an auth token in the metadata - this confirms the token +// was cryptographically validated. +func IsAuthenticated(ctx context.Context) bool { + authenticated, ok := ctx.Value(authenticatedKey{}).(bool) + return ok && authenticated +} + +func setAuthenticated(ctx context.Context) context.Context { + return context.WithValue(ctx, authenticatedKey{}, true) +} + // UnauthenticatedHealthService wraps the health server and disables authentication for all of its methods. type UnauthenticatedHealthService struct{ grpc_health_v1.HealthServer } @@ -52,6 +66,8 @@ func checkFunc(conf gitalycfgauth.Config) func(ctx context.Context) (context.Con switch status.Code(err) { case codes.OK: countStatus(okLabel(conf.Transitioning), conf.Transitioning).Inc() + // Mark the context as authenticated only when validation succeeds + ctx = setAuthenticated(ctx) case codes.Unauthenticated: countStatus("unauthenticated", conf.Transitioning).Inc() case codes.PermissionDenied: diff --git a/internal/gitaly/server/auth_test.go b/internal/gitaly/server/auth_test.go index ab87f67a3c3b6a1406eaf03e0df3bb54c1c2ad6c..5db3457aaaa62374ffe3247e72847d623fbe1006 100644 --- a/internal/gitaly/server/auth_test.go +++ b/internal/gitaly/server/auth_test.go @@ -360,8 +360,10 @@ func TestAuthBeforeLimit(t *testing.T) { cfg := testcfg.Build(t, testcfg.WithBase(config.Cfg{ Auth: auth.Config{Token: "abc123"}, Concurrency: []config.Concurrency{{ - RPC: "/gitaly.OperationService/UserCreateTag", - MaxPerRepo: 1, + RPC: "/gitaly.OperationService/UserCreateTag", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }}, }, )) diff --git a/internal/grpc/middleware/limithandler/middleware.go b/internal/grpc/middleware/limithandler/middleware.go index 0af1946d4297f67d5052ce764f31a1f2d7448826..959a8ac1a604f70cd36c6dbeff06086d5ece85f0 100644 --- a/internal/grpc/middleware/limithandler/middleware.go +++ b/internal/grpc/middleware/limithandler/middleware.go @@ -5,7 +5,9 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/gitaly/v18/internal/featureflag" "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/config" + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/middleware/requestinfohandler" "gitlab.com/gitlab-org/gitaly/v18/internal/limiter" "google.golang.org/grpc" @@ -25,10 +27,11 @@ func LimitConcurrencyByRepo(ctx context.Context) string { // LimiterMiddleware contains rate limiter state type LimiterMiddleware struct { - methodLimiters map[string]limiter.Limiter - getLockKey GetLockKey - requestsDroppedMetric *prometheus.CounterVec - collect func(metrics chan<- prometheus.Metric) + methodLimiters map[string]limiter.Limiter + methodLimitersUnauthenticated map[string]limiter.Limiter + getLockKey GetLockKey + requestsDroppedMetric *prometheus.CounterVec + collect func(metrics chan<- prometheus.Metric) } // New creates a new middleware that limits requests. SetupFunc sets up the @@ -76,7 +79,19 @@ func (c *LimiterMiddleware) UnaryInterceptor() grpc.UnaryServerInterceptor { return handler(ctx, req) } + // Check if request is authenticated limiter := c.methodLimiters[info.FullMethod] + + if featureflag.LimitUnauthenticated.IsEnabled(ctx) { + unauthLimiter, ok := c.methodLimitersUnauthenticated[info.FullMethod] + // Use auth.IsAuthenticated to check if the token was cryptographically validated, + // not just whether a token was present in metadata. This prevents spoofed tokens + // from bypassing unauthenticated rate limits. + if !auth.IsAuthenticated(ctx) && ok { + limiter = unauthLimiter + } + } + if limiter == nil { // No concurrency limiting return handler(ctx, req) @@ -125,7 +140,20 @@ func (w *wrappedStream) RecvMsg(m interface{}) error { return nil } + // Check if request is authenticated limiter := w.limiterMiddleware.methodLimiters[w.info.FullMethod] + + if featureflag.LimitUnauthenticated.IsEnabled(ctx) { + unauthLimiter, ok := w.limiterMiddleware.methodLimitersUnauthenticated[w.info.FullMethod] + // Use auth.IsAuthenticated to check if the token was cryptographically validated, + // not just whether a token was present in metadata. This prevents spoofed tokens + // from bypassing unauthenticated rate limits. + if !auth.IsAuthenticated(ctx) && ok { + // Unauthenticated request + limiter = unauthLimiter + } + } + if limiter == nil { // No concurrency limiting return nil @@ -158,7 +186,10 @@ func (w *wrappedStream) RecvMsg(m interface{}) error { // requests based on RPC and repository func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, SetupFunc) { perRPCLimits := map[string]*limiter.AdaptiveLimit{} + perRPCLimitsUnauthenticated := map[string]*limiter.AdaptiveLimit{} + for _, concurrency := range cfg.Concurrency { + // Create authenticated limiter limitName := fmt.Sprintf("perRPC%s", concurrency.RPC) if concurrency.Adaptive { perRPCLimits[concurrency.RPC] = limiter.NewAdaptiveLimit(limitName, limiter.AdaptiveSetting{ @@ -172,6 +203,25 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, Initial: concurrency.MaxPerRepo, }) } + + // Create unauthenticated limiter if configured + unauthLimits := concurrency.Unauthenticated + if unauthLimits.Adaptive || unauthLimits.MaxPerRepo > 0 || + unauthLimits.InitialLimit > 0 || unauthLimits.MaxLimit > 0 || unauthLimits.MinLimit > 0 { + limitNameUnauth := fmt.Sprintf("perRPC%s-unauthenticated", concurrency.RPC) + if unauthLimits.Adaptive { + perRPCLimitsUnauthenticated[concurrency.RPC] = limiter.NewAdaptiveLimit(limitNameUnauth, limiter.AdaptiveSetting{ + Initial: unauthLimits.InitialLimit, + Max: unauthLimits.MaxLimit, + Min: unauthLimits.MinLimit, + BackoffFactor: limiter.DefaultBackoffFactor, + }) + } else if unauthLimits.MaxPerRepo > 0 { + perRPCLimitsUnauthenticated[concurrency.RPC] = limiter.NewAdaptiveLimit(limitNameUnauth, limiter.AdaptiveSetting{ + Initial: unauthLimits.MaxPerRepo, + }) + } + } } return perRPCLimits, func(cfg config.Cfg, middleware *LimiterMiddleware) { acquiringSecondsMetric := prometheus.NewHistogramVec( @@ -210,7 +260,10 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, } result := make(map[string]limiter.Limiter) + resultUnauthenticated := make(map[string]limiter.Limiter) + for _, concurrency := range cfg.Concurrency { + // Create authenticated limiter result[concurrency.RPC] = limiter.NewConcurrencyLimiter( perRPCLimits[concurrency.RPC], concurrency.MaxQueueSize, @@ -220,6 +273,20 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric, ), ) + + // Create unauthenticated limiter if configured + if adaptiveLimit, ok := perRPCLimitsUnauthenticated[concurrency.RPC]; ok { + unauthLimits := concurrency.Unauthenticated + resultUnauthenticated[concurrency.RPC] = limiter.NewConcurrencyLimiter( + adaptiveLimit, + unauthLimits.MaxQueueSize, + unauthLimits.MaxQueueWait.Duration(), + limiter.NewPerRPCPromMonitor( + "gitaly", concurrency.RPC+"-unauthenticated", + queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric, + ), + ) + } } // Set default for ReplicateRepository. @@ -237,5 +304,6 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, } middleware.methodLimiters = result + middleware.methodLimitersUnauthenticated = resultUnauthenticated } } diff --git a/internal/grpc/middleware/limithandler/middleware_test.go b/internal/grpc/middleware/limithandler/middleware_test.go index 4bf5c1dec43e369bb601112512cf3f687badca9a..dbcb7805381043b1468e29a53968f88cfb562b1c 100644 --- a/internal/grpc/middleware/limithandler/middleware_test.go +++ b/internal/grpc/middleware/limithandler/middleware_test.go @@ -11,9 +11,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gitalyauth "gitlab.com/gitlab-org/gitaly/v18/auth" "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/config" + gitalycfgauth "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/config/auth" + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/client" "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/middleware/limithandler" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/middleware/requestinfohandler" "gitlab.com/gitlab-org/gitaly/v18/internal/helper/duration" "gitlab.com/gitlab-org/gitaly/v18/internal/limiter" "gitlab.com/gitlab-org/gitaly/v18/internal/structerr" @@ -35,19 +39,25 @@ func TestWithConcurrencyLimiters(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }, { - RPC: "/grpc.testing.TestService/FullDuplexCall", - MaxPerRepo: 99, + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 99, + }, }, { - RPC: "/grpc.testing.TestService/AnotherUnaryCall", - Adaptive: true, - MinLimit: 5, - InitialLimit: 10, - MaxLimit: 15, + RPC: "/grpc.testing.TestService/AnotherUnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + Adaptive: true, + MinLimit: 5, + InitialLimit: 10, + MaxLimit: 15, + }, }, }, } @@ -81,7 +91,12 @@ func TestUnaryLimitHandler(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: "/grpc.testing.TestService/UnaryCall", MaxPerRepo: 2}, + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, + }, + }, }, } @@ -141,23 +156,25 @@ func TestUnaryLimitHandler_queueing(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, - MaxQueueSize: 1, - // This test setups two requests: - // - The first one is eligible. It enters the handler and blocks the queue. - // - The second request is blocked until timeout. - // Both of them shares this timeout. Internally, the limiter creates a context - // deadline to reject timed out requests. If it's set too low, there's a tiny - // possibility that the context reaches the deadline when the limiter checks the - // request. Thus, setting a reasonable timeout here and adding some retry - // attempts below make the test stable. - // Another approach is to implement a hooking mechanism that allows us to - // override context deadline setup. However, that approach exposes the internal - // implementation of the limiter. It also adds unnecessarily logics. - // Congiuring the timeout is more straight-forward and close to the expected - // behavior. - MaxQueueWait: duration.Duration(100 * time.Millisecond), + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + // This test setups two requests: + // - The first one is eligible. It enters the handler and blocks the queue. + // - The second request is blocked until timeout. + // Both of them shares this timeout. Internally, the limiter creates a context + // deadline to reject timed out requests. If it's set too low, there's a tiny + // possibility that the context reaches the deadline when the limiter checks the + // request. Thus, setting a reasonable timeout here and adding some retry + // attempts below make the test stable. + // Another approach is to implement a hooking mechanism that allows us to + // override context deadline setup. However, that approach exposes the internal + // implementation of the limiter. It also adds unnecessarily logics. + // Congiuring the timeout is more straight-forward and close to the expected + // behavior. + MaxQueueWait: duration.Duration(100 * time.Millisecond), + }, }, }, } @@ -221,13 +238,17 @@ func TestUnaryLimitHandler_queueing(t *testing.T) { // that has no wait limit. We of course expect that the actual // config should not have any maximum queueing time. { - RPC: "dummy", - MaxPerRepo: 1, - MaxQueueWait: duration.Duration(1 * time.Nanosecond), + RPC: "dummy", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueWait: duration.Duration(1 * time.Nanosecond), + }, }, { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }, }, } @@ -487,9 +508,11 @@ func TestStreamLimitHandler(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: tc.fullname, - MaxPerRepo: tc.maxConcurrency, - MaxQueueSize: maxQueueSize, + RPC: tc.fullname, + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: tc.maxConcurrency, + MaxQueueSize: maxQueueSize, + }, }, }, } @@ -540,7 +563,13 @@ func TestStreamLimitHandler_error(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: "/grpc.testing.TestService/FullDuplexCall", MaxPerRepo: 1, MaxQueueSize: 1}, + { + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + }, + }, }, } @@ -660,7 +689,13 @@ func TestConcurrencyLimitHandlerMetrics(t *testing.T) { methodName := "/grpc.testing.TestService/UnaryCall" cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: methodName, MaxPerRepo: 1, MaxQueueSize: 1}, + { + RPC: methodName, + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + }, + }, }, } @@ -737,6 +772,237 @@ func TestConcurrencyLimitHandlerMetrics(t *testing.T) { <-respCh } +func TestAuthenticatedVsUnauthenticatedLimiting(t *testing.T) { + t.Parallel() + + t.Run("unary: authenticated and unauthenticated requests use separate limiters", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Authenticated: 2 concurrent + MaxQueueSize: 10, + }, + Unauthenticated: config.ConcurrencyLimits{ + MaxPerRepo: 1, // Unauthenticated: 1 concurrent + MaxQueueSize: 10, + }, + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, lh.UnaryInterceptor(), nil) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + // First, send 2 authenticated requests - both should be accepted (limit is 2) + var wg sync.WaitGroup + wg.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + _, err := authClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + } + + // Wait for both authenticated requests to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Now send an unauthenticated request - it should also be accepted + // because it uses a separate limiter + wg.Add(1) + go func() { + defer wg.Done() + _, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + + // Wait for the unauthenticated request to arrive + <-s.reqArrivedCh + + // Verify no more requests can get through (both limiters saturated) + select { + case <-s.reqArrivedCh: + require.FailNow(t, "received unexpected fourth request") + case <-time.After(100 * time.Millisecond): + } + + // Unblock all requests + close(s.blockCh) + wg.Wait() + }) + + t.Run("unary: unauthenticated falls back to authenticated limiter when not configured", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Only authenticated limiter configured + MaxQueueSize: 10, + }, + // No unauthenticated limiter configured + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, lh.UnaryInterceptor(), nil) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + var wg sync.WaitGroup + + // Send 1 authenticated and 1 unauthenticated request + // Both should be accepted (they share the same limiter with limit 2) + wg.Add(2) + go func() { + defer wg.Done() + _, err := authClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + go func() { + defer wg.Done() + _, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + + // Wait for both requests to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Verify no more requests can get through (shared limiter saturated) + select { + case <-s.reqArrivedCh: + require.FailNow(t, "received unexpected third request") + case <-time.After(100 * time.Millisecond): + } + + // Unblock all requests + close(s.blockCh) + wg.Wait() + }) + + t.Run("stream: authenticated and unauthenticated requests use separate limiters", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Authenticated: 2 concurrent + MaxQueueSize: 10, + }, + Unauthenticated: config.ConcurrencyLimits{ + MaxPerRepo: 1, // Unauthenticated: 1 concurrent + MaxQueueSize: 10, + }, + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, nil, lh.StreamInterceptor()) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + respChan := make(chan *grpc_testing.StreamingOutputCallResponse) + + // Send 2 authenticated streams + for i := 0; i < 2; i++ { + go func() { + stream, err := authClient.FullDuplexCall(ctx) + require.NoError(t, err) + require.NoError(t, stream.Send(&grpc_testing.StreamingOutputCallRequest{})) + require.NoError(t, stream.CloseSend()) + resp, err := stream.Recv() + require.NoError(t, err) + respChan <- resp + }() + } + + // Wait for both authenticated streams to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Send 1 unauthenticated stream - should be accepted with separate limiter + go func() { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + require.NoError(t, stream.Send(&grpc_testing.StreamingOutputCallRequest{})) + require.NoError(t, stream.CloseSend()) + resp, err := stream.Recv() + require.NoError(t, err) + respChan <- resp + }() + + // Wait for the unauthenticated stream to arrive + <-s.reqArrivedCh + + // Unblock all streams + close(s.blockCh) + + // Collect all responses + for i := 0; i < 3; i++ { + <-respChan + } + }) +} + func runServer(tb testing.TB, s grpc_testing.TestServiceServer, opt ...grpc.ServerOption) (*grpc.Server, string) { serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(tb) grpcServer := grpc.NewServer(opt...) @@ -750,6 +1016,47 @@ func runServer(tb testing.TB, s grpc_testing.TestServiceServer, opt ...grpc.Serv return grpcServer, "unix://" + serverSocketPath } +func runServerWithAuth(tb testing.TB, s grpc_testing.TestServiceServer, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) (*grpc.Server, string) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(tb) + + var unaryInterceptors []grpc.UnaryServerInterceptor + var streamInterceptors []grpc.StreamServerInterceptor + + // Add requestinfohandler first to extract authentication info + unaryInterceptors = append(unaryInterceptors, requestinfohandler.UnaryInterceptor) + streamInterceptors = append(streamInterceptors, requestinfohandler.StreamInterceptor) + + // Add auth interceptor to validate tokens and set authenticated flag + // Use transitioning mode so invalid tokens don't block requests (for testing unauthenticated flow) + authCfg := gitalycfgauth.Config{ + Token: "test-secret", + Transitioning: true, + } + unaryInterceptors = append(unaryInterceptors, auth.UnaryServerInterceptor(authCfg)) + streamInterceptors = append(streamInterceptors, auth.StreamServerInterceptor(authCfg)) + + // Then add the limiter interceptor + if unaryInt != nil { + unaryInterceptors = append(unaryInterceptors, unaryInt) + } + if streamInt != nil { + streamInterceptors = append(streamInterceptors, streamInt) + } + + grpcServer := grpc.NewServer( + grpc.ChainUnaryInterceptor(unaryInterceptors...), + grpc.ChainStreamInterceptor(streamInterceptors...), + ) + grpc_testing.RegisterTestServiceServer(grpcServer, s) + + lis, err := net.Listen("unix", serverSocketPath) + require.NoError(tb, err) + + go testhelper.MustServe(tb, grpcServer, lis) + + return grpcServer, "unix://" + serverSocketPath +} + func newClient(tb testing.TB, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { conn, err := client.New(testhelper.Context(tb), serverSocketPath) if err != nil { @@ -758,3 +1065,18 @@ func newClient(tb testing.TB, serverSocketPath string) (grpc_testing.TestService return grpc_testing.NewTestServiceClient(conn), conn } + +func newAuthenticatedClient(tb testing.TB, serverSocketPath, secret string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { + conn, err := client.New( + testhelper.Context(tb), + serverSocketPath, + client.WithGrpcOptions([]grpc.DialOption{ + grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(secret)), + }), + ) + if err != nil { + tb.Fatal(err) + } + + return grpc_testing.NewTestServiceClient(conn), conn +} diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index 976ba9de2ae39b0d13a21558f5a43a5f92403b80..508884a16b3d473b6c64902f2464debb72915554 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -337,6 +337,9 @@ func ContextWithoutCancel(opts ...ContextOpt) context.Context { // Enable trace2 logs for receive pack ctx = featureflag.ContextWithFeatureFlag(ctx, featureflag.ReceivePackTrace2Hook, true) + // Enable unauthenticated limiter + ctx = featureflag.ContextWithFeatureFlag(ctx, featureflag.LimitUnauthenticated, true) + for _, opt := range opts { ctx = opt(ctx) }