diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go index de37c231afdb8a28a8b639da371dc5eab7d8ce35..50140cfc40f4d71ffc3b7db9ca67bfd9ae0f0dc3 100644 --- a/internal/source/gitlab/cache/retriever.go +++ b/internal/source/gitlab/cache/retriever.go @@ -3,6 +3,8 @@ package cache import ( "context" "errors" + "sync" + "time" log "github.com/sirupsen/logrus" @@ -13,15 +15,25 @@ import ( // Retriever is an utility type that performs an HTTP request with backoff in // case of errors type Retriever struct { + timer timer client api.Client retrievalTimeout time.Duration maxRetrievalInterval time.Duration maxRetrievalRetries int } +type timer struct { + mu *sync.Mutex + stopped bool + timer *time.Timer +} + // NewRetriever creates a Retriever with a client func NewRetriever(client api.Client, retrievalTimeout, maxRetrievalInterval time.Duration, maxRetrievalRetries int) *Retriever { return &Retriever{ + timer: timer{ + mu: &sync.Mutex{}, + }, client: client, retrievalTimeout: retrievalTimeout, maxRetrievalInterval: maxRetrievalInterval, @@ -52,11 +64,23 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha go func() { var lookup api.Lookup + Retry: for i := 1; i <= r.maxRetrievalRetries; i++ { lookup = r.client.GetLookup(ctx, domain) if lookup.Error != nil { - time.Sleep(r.maxRetrievalInterval) + r.timer.start(r.maxRetrievalInterval) + select { + case <-r.timer.timer.C: + // retry to GetLookup + continue Retry + case <-ctx.Done(): + // when the retrieval context is done we stop the timer + // log.WithError(ctx.Err()).Debug("domain retrieval backoff canceled by context") + // when the retrieval context is done we stop the timerFunc + r.timer.stop() + break Retry + } } else { break } @@ -68,3 +92,25 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domain string) <-cha return response } + +func (t *timer) start(d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = false + t.timer = time.NewTimer(d) +} + +func (t *timer) stop() { + t.mu.Lock() + defer t.mu.Unlock() + + t.stopped = t.timer.Stop() +} + +func (t *timer) hasStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return t.stopped +} diff --git a/internal/source/gitlab/cache/retriever_test.go b/internal/source/gitlab/cache/retriever_test.go new file mode 100644 index 0000000000000000000000000000000000000000..774e9779e80215e0b06348b978e145ed1a2fea5b --- /dev/null +++ b/internal/source/gitlab/cache/retriever_test.go @@ -0,0 +1,27 @@ +package cache + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRetrieveTimerStopsWhenContextIsDone(t *testing.T) { + retrievalTimeout := time.Millisecond // quick timeout that cancels inner context + maxRetrievalInterval := time.Minute // long sleep inside resolveWithBackoff + + resolver := &client{ + domain: make(chan string), + lookups: make(chan uint64, 1), + failure: errors.New("500 error"), + } + + retriever := NewRetriever(resolver, retrievalTimeout, maxRetrievalInterval, 3) + require.False(t, retriever.timer.hasStopped(), "timer has not been stopped yet") + + lookup := retriever.Retrieve("my.gitlab.com") + require.Empty(t, lookup.Name) + require.Eventually(t, retriever.timer.hasStopped, time.Second, time.Millisecond, "timer must have been stopped") +}