diff --git a/consumer.go b/consumer.go index e28fe67..4777356 100644 --- a/consumer.go +++ b/consumer.go @@ -299,10 +299,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } func isRetriableError(err error) bool { - if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) { + switch err.(type) { + case *types.ExpiredIteratorException: return true - } - if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) { + case *types.ProvisionedThroughputExceededException: return true } return false