diff --git a/consumer.go b/consumer.go index e28fe67..7243beb 100644 --- a/consumer.go +++ b/consumer.go @@ -118,18 +118,18 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { wg := new(sync.WaitGroup) // process each of the shards - shardsInProcess := make(map[string]struct{}) + s := newShardsInProcess() for shard := range shardC { shardId := aws.ToString(shard.ShardId) - if _, ok := shardsInProcess[shardId]; ok { + if s.doesShardExist(shardId) { // safetynet: if shard already in process by another goroutine, just skipping the request continue } wg.Add(1) go func(shardID string) { - shardsInProcess[shardID] = struct{}{} + s.addShard(shardID) defer func() { - delete(shardsInProcess, shardID) + s.deleteShard(shardID) }() defer wg.Done() var err error @@ -311,3 +311,34 @@ func isRetriableError(err error) bool { func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } + +type shards struct { + *sync.RWMutex + shardsInProcess map[string]struct{} +} + +func newShardsInProcess() *shards { + return &shards{ + RWMutex: &sync.RWMutex{}, + shardsInProcess: make(map[string]struct{}), + } +} + +func (s *shards) addShard(shardId string) { + s.Lock() + defer s.Unlock() + s.shardsInProcess[shardId] = struct{}{} +} + +func (s *shards) doesShardExist(shardId string) bool { + s.RLock() + defer s.RUnlock() + _, ok := s.shardsInProcess[shardId] + return ok +} + +func (s *shards) deleteShard(shardId string) { + s.Lock() + defer s.Unlock() + delete(s.shardsInProcess, shardId) +}