diff --git a/consumer.go b/consumer.go index 05ae65c..5e2c93c 100644 --- a/consumer.go +++ b/consumer.go @@ -276,10 +276,12 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } func isRetriableError(err error) bool { - switch err.(type) { - case *types.ExpiredIteratorException: + var expiredIteratorException *types.ExpiredIteratorException + var provisionedThroughputExceededException *types.ProvisionedThroughputExceededException + switch { + case errors.As(err, &expiredIteratorException): return true - case *types.ProvisionedThroughputExceededException: + case errors.As(err, &provisionedThroughputExceededException): return true } return false