From db9ca80b69978e33154dc0c22b5631fe62ac8348 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Wed, 16 Nov 2022 11:16:07 -0500 Subject: [PATCH] Fix race condition making it possible for batches to be >batchSize --- util/batching_queue.go | 13 ++++++++----- util/batching_queue_test.go | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/util/batching_queue.go b/util/batching_queue.go index 86901bcd..85ba9be9 100644 --- a/util/batching_queue.go +++ b/util/batching_queue.go @@ -48,10 +48,13 @@ func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueu func (q *BatchingQueue[T]) Enqueue(element T) { q.mu.Lock() q.in = append(q.in, element) - limitReached := len(q.in) == q.batchSize + var elements []T + if len(q.in) == q.batchSize { + elements = q.dequeueAll() + } q.mu.Unlock() - if limitReached { - q.out <- q.dequeueAll() + if len(elements) > 0 { + q.out <- elements } } @@ -61,8 +64,6 @@ func (q *BatchingQueue[T]) Dequeue() <-chan []T { } func (q *BatchingQueue[T]) dequeueAll() []T { - q.mu.Lock() - defer q.mu.Unlock() elements := make([]T, len(q.in)) copy(elements, q.in) q.in = q.in[:0] @@ -75,7 +76,9 @@ func (q *BatchingQueue[T]) timeoutTicker() { } ticker := time.NewTicker(q.timeout) for range ticker.C { + q.mu.Lock() elements := q.dequeueAll() + q.mu.Unlock() if len(elements) > 0 { q.out <- elements } diff --git a/util/batching_queue_test.go b/util/batching_queue_test.go index f7dccfab..b3c41a4c 100644 --- a/util/batching_queue_test.go +++ b/util/batching_queue_test.go @@ -24,7 +24,7 @@ func TestBatchingQueue_InfTimeout(t *testing.T) { for i := 0; i < 101; i++ { go q.Enqueue(i) } - time.Sleep(500 * time.Millisecond) + time.Sleep(time.Second) mu.Lock() require.Equal(t, 100, total) // One is missing, stuck in the last batch! require.Equal(t, 4, len(batches))