diff --git a/consumer.go b/consumer.go index 4777356..e28fe67 100644 --- a/consumer.go +++ b/consumer.go @@ -299,10 +299,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se } func isRetriableError(err error) bool { - switch err.(type) { - case *types.ExpiredIteratorException: + if oe := (*types.ExpiredIteratorException)(nil); errors.As(err, &oe) { return true - case *types.ProvisionedThroughputExceededException: + } + if oe := (*types.ProvisionedThroughputExceededException)(nil); errors.As(err, &oe) { return true } return false