#11 checks in vendor files

This commit is contained in:
Alex Senger 2024-04-10 16:50:33 +02:00
parent b71c1964bc
commit 262dce8ba6
No known key found for this signature in database
GPG key ID: 0B4A96F8AF6934CF
969 changed files with 257038 additions and 3 deletions

3
.gitignore vendored
View file

@ -30,9 +30,6 @@ _testmain.go
.DS_Store
tags*
# Vendored files
vendor/**
# Benchmark files
prof.cpu
prof.mem

27
vendor/filippo.io/edwards25519/LICENSE generated vendored Normal file
View file

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

14
vendor/filippo.io/edwards25519/README.md generated vendored Normal file
View file

@ -0,0 +1,14 @@
# filippo.io/edwards25519
```
import "filippo.io/edwards25519"
```
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.

20
vendor/filippo.io/edwards25519/doc.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

427
vendor/filippo.io/edwards25519/edwards25519.go generated vendored Normal file
View file

@ -0,0 +1,427 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

349
vendor/filippo.io/edwards25519/extra.go generated vendored Normal file
View file

@ -0,0 +1,349 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.
import (
"errors"
"filippo.io/edwards25519/field"
)
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap. Don't change the style without making
// sure it doesn't increase the inliner cost.
var e [4]field.Element
X, Y, Z, T = v.extendedCoordinates(&e)
return
}
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
checkInitialized(v)
X = e[0].Set(&v.x)
Y = e[1].Set(&v.y)
Z = e[2].Set(&v.z)
T = e[3].Set(&v.t)
return
}
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
if !isOnCurve(X, Y, Z, T) {
return nil, errors.New("edwards25519: invalid point coordinates")
}
v.x.Set(X)
v.y.Set(Y)
v.z.Set(Z)
v.t.Set(T)
return v, nil
}
func isOnCurve(X, Y, Z, T *field.Element) bool {
var lhs, rhs field.Element
XX := new(field.Element).Square(X)
YY := new(field.Element).Square(Y)
ZZ := new(field.Element).Square(Z)
TT := new(field.Element).Square(T)
// -x² + y² = 1 + dx²y²
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
// -X² + Y² = Z² + dT²
lhs.Subtract(YY, XX)
rhs.Multiply(d, TT).Add(&rhs, ZZ)
if lhs.Equal(&rhs) != 1 {
return false
}
// xy = T/Z
// XY/Z² = T/Z
// XY = TZ
lhs.Multiply(X, Y)
rhs.Multiply(T, Z)
return lhs.Equal(&rhs) == 1
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func (s *Scalar) Invert(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

420
vendor/filippo.io/edwards25519/field/fe.go generated vendored Normal file
View file

@ -0,0 +1,420 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

16
vendor/filippo.io/edwards25519/field/fe_amd64.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)

379
vendor/filippo.io/edwards25519/field/fe_amd64.s generated vendored Normal file
View file

@ -0,0 +1,379 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

12
vendor/filippo.io/edwards25519/field/fe_amd64_noasm.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

16
vendor/filippo.io/edwards25519/field/fe_arm64.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

42
vendor/filippo.io/edwards25519/field/fe_arm64.s generated vendored Normal file
View file

@ -0,0 +1,42 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

12
vendor/filippo.io/edwards25519/field/fe_arm64_noasm.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

50
vendor/filippo.io/edwards25519/field/fe_extra.go generated vendored Normal file
View file

@ -0,0 +1,50 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "errors"
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/edwards25519/field package.
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
// is reduced modulo the field order. If x is not of the right length,
// SetWideBytes returns nil and an error, and the receiver is unchanged.
//
// SetWideBytes is not necessary to select a uniformly distributed value, and is
// only provided for compatibility: SetBytes can be used instead as the chance
// of bias is less than 2⁻²⁵⁰.
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
}
// Split the 64 bytes into two elements, and extract the most significant
// bit of each, which is ignored by SetBytes.
lo, _ := new(Element).SetBytes(x[:32])
loMSB := uint64(x[31] >> 7)
hi, _ := new(Element).SetBytes(x[32:])
hiMSB := uint64(x[63] >> 7)
// The output we want is
//
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
//
// which applying the reduction identity comes out to
//
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
//
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
// (hiMSB * 2 * 19²), so it fits in a uint64.
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
v.l1 = lo.l1 + hi.l1*2*19
v.l2 = lo.l2 + hi.l2*2*19
v.l3 = lo.l3 + hi.l3*2*19
v.l4 = lo.l4 + hi.l4*2*19
return v.carryPropagate(), nil
}

266
vendor/filippo.io/edwards25519/field/fe_generic.go generated vendored Normal file
View file

@ -0,0 +1,266 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

343
vendor/filippo.io/edwards25519/scalar.go generated vendored Normal file
View file

@ -0,0 +1,343 @@
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}

1147
vendor/filippo.io/edwards25519/scalar_fiat.go generated vendored Normal file

File diff suppressed because it is too large Load diff

214
vendor/filippo.io/edwards25519/scalarmult.go generated vendored Normal file
View file

@ -0,0 +1,214 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

129
vendor/filippo.io/edwards25519/tables.go generated vendored Normal file
View file

@ -0,0 +1,129 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}

4
vendor/github.com/DATA-DOG/go-sqlmock/.gitignore generated vendored Normal file
View file

@ -0,0 +1,4 @@
/examples/blog/blog
/examples/orders/orders
/examples/basic/basic
.idea/

29
vendor/github.com/DATA-DOG/go-sqlmock/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,29 @@
language: go
go_import_path: github.com/DATA-DOG/go-sqlmock
go:
- 1.2.x
- 1.3.x
- 1.4 # has no cover tool for latest releases
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
- 1.14.x
- 1.15.x
- 1.16.x
- 1.17.x
script:
- go vet
- test -z "$(go fmt ./...)" # fail if not formatted properly
- go test -race -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)

28
vendor/github.com/DATA-DOG/go-sqlmock/LICENSE generated vendored Normal file
View file

@ -0,0 +1,28 @@
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013-2019, DATA-DOG team
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* The name DataDog.lt may not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

265
vendor/github.com/DATA-DOG/go-sqlmock/README.md generated vendored Normal file
View file

@ -0,0 +1,265 @@
[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock)
[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock)
[![Go Report Card](https://goreportcard.com/badge/github.com/DATA-DOG/go-sqlmock)](https://goreportcard.com/report/github.com/DATA-DOG/go-sqlmock)
[![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock)
# Sql driver mock for Golang
**sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only
purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to
maintain correct **TDD** workflow.
- this library is now complete and stable. (you may not find new changes for this reason)
- supports concurrency and multiple connections.
- supports **go1.8** Context related feature mocking and Named sql parameters.
- does not require any modifications to your source code.
- the driver allows to mock any sql driver method behavior.
- has strict by default expectation order matching.
- has no third party dependencies.
**NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that
interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows**
interface, which was not required or useful for mocking and was removed. Hope it will not cause issues.
## Looking for maintainers
I do not have much spare time for this library and willing to transfer the repository ownership
to person or an organization motivated to maintain it. Open up a conversation if you are interested. See #230.
## Install
go get github.com/DATA-DOG/go-sqlmock
## Documentation and Examples
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
See **.travis.yml** for supported **go** versions.
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
all database related actions are isolated within a single transaction so the database can remain in the same state.
See implementation examples:
- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog)
- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders)
### Something you may want to test, assuming you use the [go-mysql-driver](https://github.com/go-sql-driver/mysql)
``` go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
func recordStats(db *sql.DB, userID, productID int64) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
switch err {
case nil:
err = tx.Commit()
default:
tx.Rollback()
}
}()
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
return
}
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
return
}
return
}
func main() {
// @NOTE: the real connection is not required for tests
db, err := sql.Open("mysql", "root@/blog")
if err != nil {
panic(err)
}
defer db.Close()
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
panic(err)
}
}
```
### Tests with sqlmock
``` go
package main
import (
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
// a successful case
func TestShouldUpdateStats(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
// now we execute our method
if err = recordStats(db, 2, 3); err != nil {
t.Errorf("error was not expected while updating stats: %s", err)
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
// a failing test case
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT INTO product_viewers").
WithArgs(2, 3).
WillReturnError(fmt.Errorf("some error"))
mock.ExpectRollback()
// now we execute our method
if err = recordStats(db, 2, 3); err == nil {
t.Errorf("was expecting an error, but there was none")
}
// we make sure that all expectations were met
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
```
## Customize SQL query matching
There were plenty of requests from users regarding SQL query string validation or different matching option.
We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling
`sqlmock.New` or `sqlmock.NewWithDSN`.
This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST.
And create a custom QueryMatcher in order to validate SQL in sophisticated ways.
By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp`
which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher:
`QueryMatcherEqual` which will do a full case sensitive match.
In order to customize the QueryMatcher, use the following:
``` go
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
```
The query matcher can be fully customized based on user needs. **sqlmock** will not
provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard.
## Matching arguments like time.Time
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
**sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which
can be used in more sophisticated matching. Here is a simple example of time argument matching:
``` go
type AnyTime struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyTime) Match(v driver.Value) bool {
_, ok := v.(time.Time)
return ok
}
func TestAnyTimeArgument(t *testing.T) {
t.Parallel()
db, mock, err := sqlmock.New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyTime{}).
WillReturnResult(sqlmock.NewResult(1, 1))
_, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now())
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
```
It only asserts that argument is of `time.Time` type.
## Run tests
go test -race
## Change Log
- **2019-04-06** - added functionality to mock a sql MetaData request
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
using **ExpectedPrepare.WillBeClosed**.
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
accept multiple row sets.
- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL
query. It should still be validated even if Exec or Query is not
executed on that prepared statement.
- **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind
of argument matcher.
- **2016-02-23** - convert expected arguments to driver.Value as natural
driver does, the change may affect time.Time comparison and will be
stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31).
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)
- **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface
- **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method
calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4).
This way on Close it will surely assert if all expectations are met, even if database was not triggered at all.
The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close.
- **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString.
It has changed to allow more ways to construct rows and to easily extend this API in future.
See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1)
**RowsFromCSVString** is deprecated and will be removed in future
## Contributions
Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) -
please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are
and will be treated cautiously
## License
The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses)

24
vendor/github.com/DATA-DOG/go-sqlmock/argument.go generated vendored Normal file
View file

@ -0,0 +1,24 @@
package sqlmock
import "database/sql/driver"
// Argument interface allows to match
// any argument in specific way when used with
// ExpectedQuery and ExpectedExec expectations.
type Argument interface {
Match(driver.Value) bool
}
// AnyArg will return an Argument which can
// match any kind of arguments.
//
// Useful for time.Time or similar kinds of arguments.
func AnyArg() Argument {
return anyArgument{}
}
type anyArgument struct{}
func (a anyArgument) Match(_ driver.Value) bool {
return true
}

77
vendor/github.com/DATA-DOG/go-sqlmock/column.go generated vendored Normal file
View file

@ -0,0 +1,77 @@
package sqlmock
import "reflect"
// Column is a mocked column Metadata for rows.ColumnTypes()
type Column struct {
name string
dbType string
nullable bool
nullableOk bool
length int64
lengthOk bool
precision int64
scale int64
psOk bool
scanType reflect.Type
}
func (c *Column) Name() string {
return c.name
}
func (c *Column) DbType() string {
return c.dbType
}
func (c *Column) IsNullable() (bool, bool) {
return c.nullable, c.nullableOk
}
func (c *Column) Length() (int64, bool) {
return c.length, c.lengthOk
}
func (c *Column) PrecisionScale() (int64, int64, bool) {
return c.precision, c.scale, c.psOk
}
func (c *Column) ScanType() reflect.Type {
return c.scanType
}
// NewColumn returns a Column with specified name
func NewColumn(name string) *Column {
return &Column{
name: name,
}
}
// Nullable returns the column with nullable metadata set
func (c *Column) Nullable(nullable bool) *Column {
c.nullable = nullable
c.nullableOk = true
return c
}
// OfType returns the column with type metadata set
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
c.dbType = dbType
c.scanType = reflect.TypeOf(sampleValue)
return c
}
// WithLength returns the column with length metadata set.
func (c *Column) WithLength(length int64) *Column {
c.length = length
c.lengthOk = true
return c
}
// WithPrecisionAndScale returns the column with precision and scale metadata set.
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
c.precision = precision
c.scale = scale
c.psOk = true
return c
}

81
vendor/github.com/DATA-DOG/go-sqlmock/driver.go generated vendored Normal file
View file

@ -0,0 +1,81 @@
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"sync"
)
var pool *mockDriver
func init() {
pool = &mockDriver{
conns: make(map[string]*sqlmock),
}
sql.Register("sqlmock", pool)
}
type mockDriver struct {
sync.Mutex
counter int
conns map[string]*sqlmock
}
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
c, ok := d.conns[dsn]
if !ok {
return c, fmt.Errorf("expected a connection to be available, but it is not")
}
c.opened++
return c, nil
}
// New creates sqlmock database connection and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be
// asserted.
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
pool.conns[dsn] = smock
pool.Unlock()
return smock.open(options)
}
// NewWithDSN creates sqlmock database connection with a specific DSN
// and a mock to manage expectations.
// Accepts options, like ValueConverterOption, to use a ValueConverter from
// a specific driver.
// Pings db so that all expectations could be asserted.
//
// This method is introduced because of sql abstraction
// libraries, which do not provide a way to initialize
// with sql.DB instance. For example GORM library.
//
// Note, it will error if attempted to create with an
// already used dsn
//
// It is not recommended to use this method, unless you
// really need it and there is no other way around.
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
pool.Lock()
if _, ok := pool.conns[dsn]; ok {
pool.Unlock()
return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn)
}
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
pool.conns[dsn] = smock
pool.Unlock()
return smock.open(options)
}

403
vendor/github.com/DATA-DOG/go-sqlmock/expectations.go generated vendored Normal file
View file

@ -0,0 +1,403 @@
package sqlmock
import (
"database/sql/driver"
"fmt"
"strings"
"sync"
"time"
)
// an expectation interface
type expectation interface {
fulfilled() bool
Lock()
Unlock()
String() string
}
// common expectation struct
// satisfies the expectation interface
type commonExpectation struct {
sync.Mutex
triggered bool
err error
}
func (e *commonExpectation) fulfilled() bool {
return e.triggered
}
// ExpectedClose is used to manage *sql.DB.Close expectation
// returned by *Sqlmock.ExpectClose.
type ExpectedClose struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.DB.Close action
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedClose) String() string {
msg := "ExpectedClose => expecting database Close"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedBegin is used to manage *sql.DB.Begin expectation
// returned by *Sqlmock.ExpectBegin.
type ExpectedBegin struct {
commonExpectation
delay time.Duration
}
// WillReturnError allows to set an error for *sql.DB.Begin action
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedBegin) String() string {
msg := "ExpectedBegin => expecting database transaction Begin"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
e.delay = duration
return e
}
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit.
type ExpectedCommit struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.Tx.Close action
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedCommit) String() string {
msg := "ExpectedCommit => expecting transaction Commit"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
// returned by *Sqlmock.ExpectRollback.
type ExpectedRollback struct {
commonExpectation
}
// WillReturnError allows to set an error for *sql.Tx.Rollback action
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedRollback) String() string {
msg := "ExpectedRollback => expecting transaction Rollback"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}
// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query,
// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations.
// Returned by *Sqlmock.ExpectQuery.
type ExpectedQuery struct {
queryBasedExpectation
rows driver.Rows
delay time.Duration
rowsMustBeClosed bool
rowsWereClosed bool
}
// WithArgs will match given expected args to actual database query arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
// Must not be used together with WithoutArgs()
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
if e.noArgs {
panic("WithArgs() and WithoutArgs() must not be used together")
}
e.args = args
return e
}
// WithoutArgs will ensure that no arguments are passed for this query.
// if at least one argument is passed, it will return an error. This allows
// for stricter validation of the query arguments.
// Must no be used together with WithArgs()
func (e *ExpectedQuery) WithoutArgs() *ExpectedQuery {
if len(e.args) > 0 {
panic("WithoutArgs() and WithArgs() must not be used together")
}
e.noArgs = true
return e
}
// RowsWillBeClosed expects this query rows to be closed.
func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
e.rowsMustBeClosed = true
return e
}
// WillReturnError allows to set an error for expected database query
func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
e.err = err
return e
}
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
e.delay = duration
return e
}
// String returns string representation
func (e *ExpectedQuery) String() string {
msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
msg += "\n - matches sql: '" + e.expectSQL + "'"
if len(e.args) == 0 {
msg += "\n - is without arguments"
} else {
msg += "\n - is with arguments:\n"
for i, arg := range e.args {
msg += fmt.Sprintf(" %d - %+v\n", i, arg)
}
msg = strings.TrimSpace(msg)
}
if e.rows != nil {
msg += fmt.Sprintf("\n - %s", e.rows)
}
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
return msg
}
// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
// Returned by *Sqlmock.ExpectExec.
type ExpectedExec struct {
queryBasedExpectation
result driver.Result
delay time.Duration
}
// WithArgs will match given expected args to actual database exec operation arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
// Must not be used together with WithoutArgs()
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
if len(e.args) > 0 {
panic("WithArgs() and WithoutArgs() must not be used together")
}
e.args = args
return e
}
// WithoutArgs will ensure that no args are passed for this expected database exec action.
// if at least one argument is passed, it will return an error. This allows for stricter
// validation of the query arguments.
// Must not be used together with WithArgs()
func (e *ExpectedExec) WithoutArgs() *ExpectedExec {
if len(e.args) > 0 {
panic("WithoutArgs() and WithArgs() must not be used together")
}
e.noArgs = true
return e
}
// WillReturnError allows to set an error for expected database exec action
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
e.err = err
return e
}
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
e.delay = duration
return e
}
// String returns string representation
func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting Exec or ExecContext which:"
msg += "\n - matches sql: '" + e.expectSQL + "'"
if len(e.args) == 0 {
msg += "\n - is without arguments"
} else {
msg += "\n - is with arguments:\n"
var margs []string
for i, arg := range e.args {
margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
}
msg += strings.Join(margs, "\n")
}
if e.result != nil {
if res, ok := e.result.(*result); ok {
msg += "\n - should return Result having:"
msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
if res.err != nil {
msg += fmt.Sprintf("\n Error: %s", res.err)
}
}
}
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
return msg
}
// WillReturnResult arranges for an expected Exec() to return a particular
// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method
// to build a corresponding result. Or if actions needs to be tested against errors
// sqlmock.NewErrorResult(err error) to return a given error.
func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
e.result = result
return e
}
// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
// Returned by *Sqlmock.ExpectPrepare.
type ExpectedPrepare struct {
commonExpectation
mock *sqlmock
expectSQL string
statement driver.Stmt
closeErr error
mustBeClosed bool
wasClosed bool
delay time.Duration
}
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
e.err = err
return e
}
// WillReturnCloseError allows to set an error for this prepared statement Close action
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
e.closeErr = err
return e
}
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare {
e.delay = duration
return e
}
// WillBeClosed expects this prepared statement to
// be closed.
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
e.mustBeClosed = true
return e
}
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// This method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
eq := &ExpectedQuery{}
eq.expectSQL = e.expectSQL
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
// ExpectExec allows to expect Exec() on this prepared statement.
// This method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
eq := &ExpectedExec{}
eq.expectSQL = e.expectSQL
eq.converter = e.mock.converter
e.mock.expected = append(e.mock.expected, eq)
return eq
}
// String returns string representation
func (e *ExpectedPrepare) String() string {
msg := "ExpectedPrepare => expecting Prepare statement which:"
msg += "\n - matches sql: '" + e.expectSQL + "'"
if e.err != nil {
msg += fmt.Sprintf("\n - should return error: %s", e.err)
}
if e.closeErr != nil {
msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
}
return msg
}
// query based expectation
// adds a query matching logic
type queryBasedExpectation struct {
commonExpectation
expectSQL string
converter driver.ValueConverter
args []driver.Value
noArgs bool // ensure no args are passed
}
// ExpectedPing is used to manage *sql.DB.Ping expectations.
// Returned by *Sqlmock.ExpectPing.
type ExpectedPing struct {
commonExpectation
delay time.Duration
}
// WillDelayFor allows to specify duration for which it will delay result. May
// be used together with Context.
func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing {
e.delay = duration
return e
}
// WillReturnError allows to set an error for expected database ping
func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing {
e.err = err
return e
}
// String returns string representation
func (e *ExpectedPing) String() string {
msg := "ExpectedPing => expecting database Ping"
if e.err != nil {
msg += fmt.Sprintf(", which should return error: %s", e.err)
}
return msg
}

View file

@ -0,0 +1,71 @@
//go:build !go1.8
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
e.rows = &rowSets{sets: []*Rows{rows}, ex: e}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
if e.noArgs && len(args) > 0 {
return fmt.Errorf("expected 0, but got %d arguments", len(args))
}
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
// @TODO: does it make sense to pass value instead of named value?
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
// convert to driver converter
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) {
// catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
err = e.argsMatches(args)
return
}

View file

@ -0,0 +1,89 @@
//go:build go1.8
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
defs := 0
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
if r.def != nil {
defs++
}
}
if defs > 0 && defs == len(sets) {
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
} else {
e.rows = &rowSets{sets: sets, ex: e}
}
return e
}
func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error {
if nil == e.args {
if e.noArgs && len(args) > 0 {
return fmt.Errorf("expected 0, but got %d arguments", len(args))
}
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
// @TODO should we assert either all args are named or ordinal?
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
if named, isNamed := dval.(sql.NamedArg); isNamed {
dval = named.Value
if v.Name != named.Name {
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
}
} else if k+1 != v.Ordinal {
return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal)
}
// convert to driver converter
darg, err := e.converter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
func (e *queryBasedExpectation) attemptArgMatch(args []driver.NamedValue) (err error) {
// catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
err = e.argsMatches(args)
return
}

38
vendor/github.com/DATA-DOG/go-sqlmock/options.go generated vendored Normal file
View file

@ -0,0 +1,38 @@
package sqlmock
import "database/sql/driver"
// ValueConverterOption allows to create a sqlmock connection
// with a custom ValueConverter to support drivers with special data types.
func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error {
return func(s *sqlmock) error {
s.converter = converter
return nil
}
}
// QueryMatcherOption allows to customize SQL query matcher
// and match SQL query strings in more sophisticated ways.
// The default QueryMatcher is QueryMatcherRegexp.
func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error {
return func(s *sqlmock) error {
s.queryMatcher = queryMatcher
return nil
}
}
// MonitorPingsOption determines whether calls to Ping on the driver should be
// observed and mocked.
//
// If true is passed, we will check these calls were expected. Expectations can
// be registered using the ExpectPing() method on the mock.
//
// If false is passed or this option is omitted, calls to Ping will not be
// considered when determining expectations and calls to ExpectPing will have
// no effect.
func MonitorPingsOption(monitorPings bool) func(*sqlmock) error {
return func(s *sqlmock) error {
s.monitorPings = monitorPings
return nil
}
}

68
vendor/github.com/DATA-DOG/go-sqlmock/query.go generated vendored Normal file
View file

@ -0,0 +1,68 @@
package sqlmock
import (
"fmt"
"regexp"
"strings"
)
var re = regexp.MustCompile("\\s+")
// strip out new lines and trim spaces
func stripQuery(q string) (s string) {
return strings.TrimSpace(re.ReplaceAllString(q, " "))
}
// QueryMatcher is an SQL query string matcher interface,
// which can be used to customize validation of SQL query strings.
// As an example, external library could be used to build
// and validate SQL ast, columns selected.
//
// sqlmock can be customized to implement a different QueryMatcher
// configured through an option when sqlmock.New or sqlmock.NewWithDSN
// is called, default QueryMatcher is QueryMatcherRegexp.
type QueryMatcher interface {
// Match expected SQL query string without whitespace to
// actual SQL.
Match(expectedSQL, actualSQL string) error
}
// QueryMatcherFunc type is an adapter to allow the use of
// ordinary functions as QueryMatcher. If f is a function
// with the appropriate signature, QueryMatcherFunc(f) is a
// QueryMatcher that calls f.
type QueryMatcherFunc func(expectedSQL, actualSQL string) error
// Match implements the QueryMatcher
func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error {
return f(expectedSQL, actualSQL)
}
// QueryMatcherRegexp is the default SQL query matcher
// used by sqlmock. It parses expectedSQL to a regular
// expression and attempts to match actualSQL.
var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
re, err := regexp.Compile(expect)
if err != nil {
return err
}
if !re.MatchString(actual) {
return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String())
}
return nil
})
// QueryMatcherEqual is the SQL query matcher
// which simply tries a case sensitive match of
// expected and actual SQL strings without whitespace.
var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
expect := stripQuery(expectedSQL)
actual := stripQuery(actualSQL)
if actual != expect {
return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect)
}
return nil
})

39
vendor/github.com/DATA-DOG/go-sqlmock/result.go generated vendored Normal file
View file

@ -0,0 +1,39 @@
package sqlmock
import (
"database/sql/driver"
)
// Result satisfies sql driver Result, which
// holds last insert id and rows affected
// by Exec queries
type result struct {
insertID int64
rowsAffected int64
err error
}
// NewResult creates a new sql driver Result
// for Exec based query mocks.
func NewResult(lastInsertID int64, rowsAffected int64) driver.Result {
return &result{
insertID: lastInsertID,
rowsAffected: rowsAffected,
}
}
// NewErrorResult creates a new sql driver Result
// which returns an error given for both interface methods
func NewErrorResult(err error) driver.Result {
return &result{
err: err,
}
}
func (r *result) LastInsertId() (int64, error) {
return r.insertID, r.err
}
func (r *result) RowsAffected() (int64, error) {
return r.rowsAffected, r.err
}

226
vendor/github.com/DATA-DOG/go-sqlmock/rows.go generated vendored Normal file
View file

@ -0,0 +1,226 @@
package sqlmock
import (
"bytes"
"database/sql/driver"
"encoding/csv"
"errors"
"fmt"
"io"
"strings"
)
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. Currently
// transforms NULL to nil
var CSVColumnParser = func(s string) interface{} {
switch {
case strings.ToLower(s) == "null":
return nil
}
return []byte(s)
}
type rowSets struct {
sets []*Rows
pos int
ex *ExpectedQuery
raw [][]byte
}
func (rs *rowSets) Columns() []string {
return rs.sets[rs.pos].cols
}
func (rs *rowSets) Close() error {
rs.invalidateRaw()
rs.ex.rowsWereClosed = true
return rs.sets[rs.pos].closeErr
}
// advances to next row
func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++
rs.invalidateRaw()
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
if b, ok := rawBytes(col); ok {
rs.raw = append(rs.raw, b)
dest[i] = b
continue
}
dest[i] = col
}
return r.nextErr[r.pos-1]
}
// transforms to debuggable printable string
func (rs *rowSets) String() string {
if rs.empty() {
return "with empty rows"
}
msg := "should return rows:\n"
if len(rs.sets) == 1 {
for n, row := range rs.sets[0].rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
return strings.TrimSpace(msg)
}
for i, set := range rs.sets {
msg += fmt.Sprintf(" result set: %d\n", i)
for n, row := range set.rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
}
return strings.TrimSpace(msg)
}
func (rs *rowSets) empty() bool {
for _, set := range rs.sets {
if len(set.rows) > 0 {
return false
}
}
return true
}
func rawBytes(col driver.Value) (_ []byte, ok bool) {
val, ok := col.([]byte)
if !ok || len(val) == 0 {
return nil, false
}
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
b := make([]byte, len(val))
copy(b, val)
return b, true
}
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
func (rs *rowSets) invalidateRaw() {
// Replace the content of slices previously returned
b := []byte(invalidate)
for _, r := range rs.raw {
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
}
// Start with new slices for the next scan
rs.raw = nil
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
converter driver.ValueConverter
cols []string
def []*Column
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows {
return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
func (r *Rows) CloseError(err error) *Rows {
r.closeErr = err
return r
}
// RowError allows to set an error
// which will be returned when a given
// row number is read
func (r *Rows) RowError(row int, err error) *Rows {
r.nextErr[row] = err
return r
}
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) AddRow(values ...driver.Value) *Rows {
if len(values) != len(r.cols) {
panic(fmt.Sprintf("Expected number of values to match number of columns: expected %d, actual %d", len(values), len(r.cols)))
}
row := make([]driver.Value, len(r.cols))
for i, v := range values {
// Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64)
var err error
v, err = r.converter.ConvertValue(v)
if err != nil {
panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s",
len(r.rows)+1, i, r.cols[i], values[i], err,
))
}
row[i] = v
}
r.rows = append(r.rows, row)
return r
}
// AddRows adds multiple rows composed from database driver.Value slice and
// returns the same instance to perform subsequent actions.
func (r *Rows) AddRows(values ...[]driver.Value) *Rows {
for _, value := range values {
r.AddRow(value...)
}
return r
}
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) FromCSVString(s string) *Rows {
res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res)
for {
res, err := csvReader.Read()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
panic(fmt.Sprintf("Parsing CSV string failed: %s", err.Error()))
}
row := make([]driver.Value, len(r.cols))
for i, v := range res {
row[i] = CSVColumnParser(strings.TrimSpace(v))
}
r.rows = append(r.rows, row)
}
return r
}

74
vendor/github.com/DATA-DOG/go-sqlmock/rows_go18.go generated vendored Normal file
View file

@ -0,0 +1,74 @@
// +build go1.8
package sqlmock
import (
"database/sql/driver"
"io"
"reflect"
)
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool {
return rs.pos+1 < len(rs.sets)
}
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) NextResultSet() error {
if !rs.HasNextResultSet() {
return io.EOF
}
rs.pos++
return nil
}
// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
type rowSetsWithDefinition struct {
*rowSets
}
// Implement the "RowsColumnTypeDatabaseTypeName" interface
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
return rs.getDefinition(index).DbType()
}
// Implement the "RowsColumnTypeLength" interface
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
return rs.getDefinition(index).Length()
}
// Implement the "RowsColumnTypeNullable" interface
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
return rs.getDefinition(index).IsNullable()
}
// Implement the "RowsColumnTypePrecisionScale" interface
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return rs.getDefinition(index).PrecisionScale()
}
// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
return rs.getDefinition(index).ScanType()
}
// return column definition from current set metadata
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
return rs.sets[rs.pos].def[index]
}
// NewRowsWithColumnDefinition return rows with columns metadata
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
cols := make([]string, len(columns))
for i, column := range columns {
cols[i] = column.Name()
}
return &Rows{
cols: cols,
def: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}

439
vendor/github.com/DATA-DOG/go-sqlmock/sqlmock.go generated vendored Normal file
View file

@ -0,0 +1,439 @@
/*
Package sqlmock is a mock library implementing sql driver. Which has one and only
purpose - to simulate any sql driver behavior in tests, without needing a real
database connection. It helps to maintain correct **TDD** workflow.
It does not require any modifications to your source code in order to test
and mock database operations. Supports concurrency and multiple database mocking.
The driver allows to mock any sql driver method behavior.
*/
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"time"
)
// Sqlmock interface serves to create expectations
// for any kind of database action in order to mock
// and test real database behavior.
type SqlmockCommon interface {
// ExpectClose queues an expectation for this database
// action to be triggered. the *ExpectedClose allows
// to mock database response
ExpectClose() *ExpectedClose
// ExpectationsWereMet checks whether all queued expectations
// were met in order. If any of them was not met - an error is returned.
ExpectationsWereMet() error
// ExpectPrepare expects Prepare() to be called with expectedSQL query.
// the *ExpectedPrepare allows to mock database response.
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
// statement to prevent repeating expectedSQL
ExpectPrepare(expectedSQL string) *ExpectedPrepare
// ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
// the *ExpectedQuery allows to mock database response.
ExpectQuery(expectedSQL string) *ExpectedQuery
// ExpectExec expects Exec() to be called with expectedSQL query.
// the *ExpectedExec allows to mock database response
ExpectExec(expectedSQL string) *ExpectedExec
// ExpectBegin expects *sql.DB.Begin to be called.
// the *ExpectedBegin allows to mock database response
ExpectBegin() *ExpectedBegin
// ExpectCommit expects *sql.Tx.Commit to be called.
// the *ExpectedCommit allows to mock database response
ExpectCommit() *ExpectedCommit
// ExpectRollback expects *sql.Tx.Rollback to be called.
// the *ExpectedRollback allows to mock database response
ExpectRollback() *ExpectedRollback
// ExpectPing expected *sql.DB.Ping to be called.
// the *ExpectedPing allows to mock database response
//
// Ping support only exists in the SQL library in Go 1.8 and above.
// ExpectPing in Go <=1.7 will return an ExpectedPing but not register
// any expectations.
//
// You must enable pings using MonitorPingsOption for this to register
// any expectations.
ExpectPing() *ExpectedPing
// MatchExpectationsInOrder gives an option whether to match all
// expectations in the order they were set or not.
//
// By default it is set to - true. But if you use goroutines
// to parallelize your query executation, that option may
// be handy.
//
// This option may be turned on anytime during tests. As soon
// as it is switched to false, expectations will be matched
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool)
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
NewRows(columns []string) *Rows
}
type sqlmock struct {
ordered bool
dsn string
opened int
drv *mockDriver
converter driver.ValueConverter
queryMatcher QueryMatcher
monitorPings bool
expected []expectation
}
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
db, err := sql.Open("sqlmock", c.dsn)
if err != nil {
return db, c, err
}
for _, option := range options {
err := option(c)
if err != nil {
return db, c, err
}
}
if c.converter == nil {
c.converter = driver.DefaultParameterConverter
}
if c.queryMatcher == nil {
c.queryMatcher = QueryMatcherRegexp
}
if c.monitorPings {
// We call Ping on the driver shortly to verify startup assertions by
// driving internal behaviour of the sql standard library. We don't
// want this call to ping to be monitored for expectation purposes so
// temporarily disable.
c.monitorPings = false
defer func() { c.monitorPings = true }()
}
return db, c, db.Ping()
}
func (c *sqlmock) ExpectClose() *ExpectedClose {
e := &ExpectedClose{}
c.expected = append(c.expected, e)
return e
}
func (c *sqlmock) MatchExpectationsInOrder(b bool) {
c.ordered = b
}
// Close a mock database driver connection. It may or may not
// be called depending on the circumstances, but if it is called
// there must be an *ExpectedClose expectation satisfied.
// meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Close() error {
c.drv.Lock()
defer c.drv.Unlock()
c.opened--
if c.opened == 0 {
delete(c.drv.conns, c.dsn)
}
var expected *ExpectedClose
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedClose); ok {
break
}
next.Unlock()
if c.ordered {
return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to database Close was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
}
func (c *sqlmock) ExpectationsWereMet() error {
for _, e := range c.expected {
e.Lock()
fulfilled := e.fulfilled()
e.Unlock()
if !fulfilled {
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
}
// for expected prepared statement check whether it was closed if expected
if prep, ok := e.(*ExpectedPrepare); ok {
if prep.mustBeClosed && !prep.wasClosed {
return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
}
}
// must check whether all expected queried rows are closed
if query, ok := e.(*ExpectedQuery); ok {
if query.rowsMustBeClosed && !query.rowsWereClosed {
return fmt.Errorf("expected query rows to be closed, but it was not: %s", query)
}
}
}
return nil
}
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.begin()
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return c, nil
}
func (c *sqlmock) begin() (*ExpectedBegin, error) {
var expected *ExpectedBegin
var ok bool
var fulfilled int
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedBegin); ok {
break
}
next.Unlock()
if c.ordered {
return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to database transaction Begin was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected, expected.err
}
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
e := &ExpectedBegin{}
c.expected = append(c.expected, e)
return e
}
func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{}
e.expectSQL = expectedSQL
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return &statement{c, ex, query}, nil
}
func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
var expected *ExpectedPrepare
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.ordered {
if expected, ok = next.(*ExpectedPrepare); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next)
}
if pr, ok := next.(*ExpectedPrepare); ok {
if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
expected = pr
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to Prepare '%s' query was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query)
}
defer expected.Unlock()
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("Prepare: %v", err)
}
expected.triggered = true
return expected, expected.err
}
func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
c.expected = append(c.expected, e)
return e
}
func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
e := &ExpectedQuery{}
e.expectSQL = expectedSQL
e.converter = c.converter
c.expected = append(c.expected, e)
return e
}
func (c *sqlmock) ExpectCommit() *ExpectedCommit {
e := &ExpectedCommit{}
c.expected = append(c.expected, e)
return e
}
func (c *sqlmock) ExpectRollback() *ExpectedRollback {
e := &ExpectedRollback{}
c.expected = append(c.expected, e)
return e
}
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *sqlmock) Commit() error {
var expected *ExpectedCommit
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedCommit); ok {
break
}
next.Unlock()
if c.ordered {
return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to Commit transaction was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
}
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
func (c *sqlmock) Rollback() error {
var expected *ExpectedRollback
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedRollback); ok {
break
}
next.Unlock()
if c.ordered {
return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to Rollback transaction was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected.err
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
func (c *sqlmock) NewRows(columns []string) *Rows {
r := NewRows(columns)
r.converter = c.converter
return r
}

View file

@ -0,0 +1,191 @@
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
"fmt"
"log"
"time"
)
// Sqlmock interface for Go up to 1.7
type Sqlmock interface {
// Embed common methods
SqlmockCommon
}
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
func (c *sqlmock) ExpectPing() *ExpectedPing {
log.Println("ExpectPing has no effect on Go 1.7 or below")
return &ExpectedPing{}
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.query(query, namedArgs)
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return ex.rows, nil
}
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
var expected *ExpectedQuery
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.ordered {
if expected, ok = next.(*ExpectedQuery); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if qr, ok := next.(*ExpectedQuery); ok {
if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := qr.attemptArgMatch(args); err == nil {
expected = qr
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to Query '%s' with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("Query: %v", err)
}
if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
if expected.err != nil {
return expected, expected.err // mocked to return error
}
if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected, nil
}
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.exec(query, namedArgs)
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return ex.result, nil
}
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
var expected *ExpectedExec
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.ordered {
if expected, ok = next.(*ExpectedExec); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if exec, ok := next.(*ExpectedExec); ok {
if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := exec.attemptArgMatch(args); err == nil {
expected = exec
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to ExecQuery '%s' with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("ExecQuery: %v", err)
}
if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
if expected.err != nil {
return expected, expected.err // mocked to return error
}
if expected.result == nil {
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected, nil
}

356
vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go18.go generated vendored Normal file
View file

@ -0,0 +1,356 @@
// +build go1.8
package sqlmock
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"log"
"time"
)
// Sqlmock interface for Go 1.8+
type Sqlmock interface {
// Embed common methods
SqlmockCommon
// NewRowsWithColumnDefinition allows Rows to be created from a
// sql driver.Value slice with a definition of sql metadata
NewRowsWithColumnDefinition(columns ...*Column) *Rows
// New Column allows to create a Column
NewColumn(name string) *Column
}
// ErrCancelled defines an error value, which can be expected in case of
// such cancellation error.
var ErrCancelled = errors.New("canceling query due to user request")
// Implement the "QueryerContext" interface
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
ex, err := c.query(query, args)
if ex != nil {
select {
case <-time.After(ex.delay):
if err != nil {
return nil, err
}
return ex.rows, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
return nil, err
}
// Implement the "ExecerContext" interface
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
ex, err := c.exec(query, args)
if ex != nil {
select {
case <-time.After(ex.delay):
if err != nil {
return nil, err
}
return ex.result, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
return nil, err
}
// Implement the "ConnBeginTx" interface
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
ex, err := c.begin()
if ex != nil {
select {
case <-time.After(ex.delay):
if err != nil {
return nil, err
}
return c, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
return nil, err
}
// Implement the "ConnPrepareContext" interface
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if ex != nil {
select {
case <-time.After(ex.delay):
if err != nil {
return nil, err
}
return &statement{c, ex, query}, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
return nil, err
}
// Implement the "Pinger" interface - the explicit DB driver ping was only added to database/sql in Go 1.8
func (c *sqlmock) Ping(ctx context.Context) error {
if !c.monitorPings {
return nil
}
ex, err := c.ping()
if ex != nil {
select {
case <-ctx.Done():
return ErrCancelled
case <-time.After(ex.delay):
}
}
return err
}
func (c *sqlmock) ping() (*ExpectedPing, error) {
var expected *ExpectedPing
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if expected, ok = next.(*ExpectedPing); ok {
break
}
next.Unlock()
if c.ordered {
return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next)
}
}
if expected == nil {
msg := "call to database Ping was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg)
}
expected.triggered = true
expected.Unlock()
return expected, expected.err
}
// Implement the "StmtExecContext" interface
func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return stmt.conn.ExecContext(ctx, stmt.query, args)
}
// Implement the "StmtQueryContext" interface
func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return stmt.conn.QueryContext(ctx, stmt.query, args)
}
func (c *sqlmock) ExpectPing() *ExpectedPing {
if !c.monitorPings {
log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.")
return nil
}
e := &ExpectedPing{}
c.expected = append(c.expected, e)
return e
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
// Deprecated: Drivers should implement QueryerContext instead.
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
namedArgs := make([]driver.NamedValue, len(args))
for i, v := range args {
namedArgs[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.query(query, namedArgs)
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return ex.rows, nil
}
func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, error) {
var expected *ExpectedQuery
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.ordered {
if expected, ok = next.(*ExpectedQuery); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if qr, ok := next.(*ExpectedQuery); ok {
if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := qr.attemptArgMatch(args); err == nil {
expected = qr
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to Query '%s' with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("Query: %v", err)
}
if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
if expected.err != nil {
return expected, expected.err // mocked to return error
}
if expected.rows == nil {
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected, nil
}
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
// Deprecated: Drivers should implement ExecerContext instead.
func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
namedArgs := make([]driver.NamedValue, len(args))
for i, v := range args {
namedArgs[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.exec(query, namedArgs)
if ex != nil {
time.Sleep(ex.delay)
}
if err != nil {
return nil, err
}
return ex.result, nil
}
func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, error) {
var expected *ExpectedExec
var fulfilled int
var ok bool
for _, next := range c.expected {
next.Lock()
if next.fulfilled() {
next.Unlock()
fulfilled++
continue
}
if c.ordered {
if expected, ok = next.(*ExpectedExec); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
}
if exec, ok := next.(*ExpectedExec); ok {
if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
next.Unlock()
continue
}
if err := exec.attemptArgMatch(args); err == nil {
expected = exec
break
}
}
next.Unlock()
}
if expected == nil {
msg := "call to ExecQuery '%s' with args %+v was not expected"
if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg
}
return nil, fmt.Errorf(msg, query, args)
}
defer expected.Unlock()
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
return nil, fmt.Errorf("ExecQuery: %v", err)
}
if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
}
expected.triggered = true
if expected.err != nil {
return expected, expected.err // mocked to return error
}
if expected.result == nil {
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
}
return expected, nil
}
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
// NewRowsWithColumnDefinition allows Rows to be created from a
// sql driver.Value slice with a definition of sql metadata
func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows {
r := NewRowsWithColumnDefinition(columns...)
r.converter = c.converter
return r
}
// NewColumn allows to create a Column that can be enhanced with metadata
// using OfType/Nullable/WithLength/WithPrecisionAndScale methods.
func (c *sqlmock) NewColumn(name string) *Column {
return NewColumn(name)
}

View file

@ -0,0 +1,11 @@
// +build go1.8,!go1.9
package sqlmock
import "database/sql/driver"
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = c.converter.ConvertValue(nv.Value)
return err
}

19
vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go19.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
// +build go1.9
package sqlmock
import (
"database/sql"
"database/sql/driver"
)
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
switch nv.Value.(type) {
case sql.Out:
return nil
default:
nv.Value, err = c.converter.ConvertValue(nv.Value)
return err
}
}

16
vendor/github.com/DATA-DOG/go-sqlmock/statement.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
package sqlmock
type statement struct {
conn *sqlmock
ex *ExpectedPrepare
query string
}
func (stmt *statement) Close() error {
stmt.ex.wasClosed = true
return stmt.ex.closeErr
}
func (stmt *statement) NumInput() int {
return -1
}

View file

@ -0,0 +1,17 @@
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
)
// Deprecated: Drivers should implement ExecerContext instead.
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
return stmt.conn.Exec(stmt.query, args)
}
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
return stmt.conn.Query(stmt.query, args)
}

View file

@ -0,0 +1,26 @@
// +build go1.8
package sqlmock
import (
"context"
"database/sql/driver"
)
// Deprecated: Drivers should implement ExecerContext instead.
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
return stmt.conn.ExecContext(context.Background(), stmt.query, convertValueToNamedValue(args))
}
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
return stmt.conn.QueryContext(context.Background(), stmt.query, convertValueToNamedValue(args))
}
func convertValueToNamedValue(args []driver.Value) []driver.NamedValue {
namedArgs := make([]driver.NamedValue, len(args))
for i, v := range args {
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v}
}
return namedArgs
}

24
vendor/github.com/alicebob/gopher-json/LICENSE generated vendored Normal file
View file

@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>

7
vendor/github.com/alicebob/gopher-json/README.md generated vendored Normal file
View file

@ -0,0 +1,7 @@
# gopher-json [![GoDoc](https://godoc.org/layeh.com/gopher-json?status.svg)](https://godoc.org/layeh.com/gopher-json)
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
## License
Public domain

33
vendor/github.com/alicebob/gopher-json/doc.go generated vendored Normal file
View file

@ -0,0 +1,33 @@
// Package json is a simple JSON encoder/decoder for gopher-lua.
//
// Documentation
//
// The following functions are exposed by the library:
// decode(string): Decodes a JSON string. Returns nil and an error string if
// the string could not be decoded.
// encode(value): Encodes a value into a JSON string. Returns nil and an error
// string if the value could not be encoded.
//
// The following types are supported:
//
// Lua | JSON
// ---------+-----
// nil | null
// number | number
// string | string
// table | object: when table is non-empty and has only string keys
// | array: when table is empty, or has only sequential numeric keys
// | starting from 1
//
// Attempting to encode any other Lua type will result in an error.
//
// Example
//
// Below is an example usage of the library:
// import (
// luajson "layeh.com/gopher-json"
// )
//
// L := lua.NewState()
// luajson.Preload(s)
package json

189
vendor/github.com/alicebob/gopher-json/json.go generated vendored Normal file
View file

@ -0,0 +1,189 @@
package json
import (
"encoding/json"
"errors"
"github.com/yuin/gopher-lua"
)
// Preload adds json to the given Lua state's package.preload table. After it
// has been preloaded, it can be loaded using require:
//
// local json = require("json")
func Preload(L *lua.LState) {
L.PreloadModule("json", Loader)
}
// Loader is the module loader function.
func Loader(L *lua.LState) int {
t := L.NewTable()
L.SetFuncs(t, api)
L.Push(t)
return 1
}
var api = map[string]lua.LGFunction{
"decode": apiDecode,
"encode": apiEncode,
}
func apiDecode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to decode"), 1)
return 0
}
str := L.CheckString(1)
value, err := Decode(L, []byte(str))
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(value)
return 1
}
func apiEncode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to encode"), 1)
return 0
}
value := L.CheckAny(1)
data, err := Encode(value)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(string(data)))
return 1
}
var (
errNested = errors.New("cannot encode recursively nested tables to JSON")
errSparseArray = errors.New("cannot encode sparse array")
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
)
type invalidTypeError lua.LValueType
func (i invalidTypeError) Error() string {
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
}
// Encode returns the JSON encoding of value.
func Encode(value lua.LValue) ([]byte, error) {
return json.Marshal(jsonValue{
LValue: value,
visited: make(map[*lua.LTable]bool),
})
}
type jsonValue struct {
lua.LValue
visited map[*lua.LTable]bool
}
func (j jsonValue) MarshalJSON() (data []byte, err error) {
switch converted := j.LValue.(type) {
case lua.LBool:
data, err = json.Marshal(bool(converted))
case lua.LNumber:
data, err = json.Marshal(float64(converted))
case *lua.LNilType:
data = []byte(`null`)
case lua.LString:
data, err = json.Marshal(string(converted))
case *lua.LTable:
if j.visited[converted] {
return nil, errNested
}
j.visited[converted] = true
key, value := converted.Next(lua.LNil)
switch key.Type() {
case lua.LTNil: // empty table
data = []byte(`[]`)
case lua.LTNumber:
arr := make([]jsonValue, 0, converted.Len())
expectedKey := lua.LNumber(1)
for key != lua.LNil {
if key.Type() != lua.LTNumber {
err = errInvalidKeys
return
}
if expectedKey != key {
err = errSparseArray
return
}
arr = append(arr, jsonValue{value, j.visited})
expectedKey++
key, value = converted.Next(key)
}
data, err = json.Marshal(arr)
case lua.LTString:
obj := make(map[string]jsonValue)
for key != lua.LNil {
if key.Type() != lua.LTString {
err = errInvalidKeys
return
}
obj[key.String()] = jsonValue{value, j.visited}
key, value = converted.Next(key)
}
data, err = json.Marshal(obj)
default:
err = errInvalidKeys
}
default:
err = invalidTypeError(j.LValue.Type())
}
return
}
// Decode converts the JSON encoded data to Lua values.
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
var value interface{}
err := json.Unmarshal(data, &value)
if err != nil {
return nil, err
}
return DecodeValue(L, value), nil
}
// DecodeValue converts the value to a Lua value.
//
// This function only converts values that the encoding/json package decodes to.
// All other values will return lua.LNil.
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
switch converted := value.(type) {
case bool:
return lua.LBool(converted)
case float64:
return lua.LNumber(converted)
case string:
return lua.LString(converted)
case json.Number:
return lua.LString(converted)
case []interface{}:
arr := L.CreateTable(len(converted), 0)
for _, item := range converted {
arr.Append(DecodeValue(L, item))
}
return arr
case map[string]interface{}:
tbl := L.CreateTable(0, len(converted))
for key, item := range converted {
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
}
return tbl
case nil:
return lua.LNil
}
return lua.LNil
}

1
vendor/github.com/alicebob/miniredis/.gitignore generated vendored Normal file
View file

@ -0,0 +1 @@
/integration/redis_src/

13
vendor/github.com/alicebob/miniredis/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,13 @@
language: go
before_script:
- (cd ./integration && ./get_redis.sh)
install: go get -t
script: make test testrace int
sudo: false
go:
- 1.11

21
vendor/github.com/alicebob/miniredis/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Harmen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

12
vendor/github.com/alicebob/miniredis/Makefile generated vendored Normal file
View file

@ -0,0 +1,12 @@
.PHONY: all test testrace int
all: test
test:
go test ./...
testrace:
go test -race ./...
int:
${MAKE} -C integration all

336
vendor/github.com/alicebob/miniredis/README.md generated vendored Normal file
View file

@ -0,0 +1,336 @@
# Miniredis
Pure Go Redis test server, used in Go unittests.
##
Sometimes you want to test code which uses Redis, without making it a full-blown
integration test.
Miniredis implements (parts of) the Redis server, to be used in unittests. It
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
It saves you from using mock code, and since the redis server lives in the
test process you can query for values directly, without going through the server
stack.
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
## Changelog
### 2.5.0
Added ZPopMin and ZPopMax
### v2.4.6
support for TIME (thanks @leon-barrett and @lirao)
support for ZREVRANGEBYLEX
fix for SINTER (thanks @robstein)
updates for latest redis
### 2.4.4
Fixed nil Lua return value (#43)
### 2.4.3
Fixed using Lua with authenticated redis.
### 2.4.2
Changed redigo import path.
### 2.4
Minor cleanups. Miniredis now requires Go >= 1.9 (only for the tests. If you don't run the tests you can use an older Go version).
### 2.3.1
Lua changes: added `cjson` library, and `redis.sha1hex()`.
### 2.3
Added the `EVAL`, `EVALSHA`, and `SCRIPT` commands. Uses a pure Go Lua interpreter. Please open an issue if there are problems with any Lua code.
### 2.2
Introduced `StartAddr()`.
### 2.1
Internal cleanups. No changes in functionality.
### 2.0
2.0.0 improves TTLs to be `time.Duration` values. `.Expire()` is removed and
replaced by `.TTL()`, which returns the TTL as a `time.Duration`.
This should be the change needed to upgrade:
1.0:
m.Expire() == 4
2.0:
m.TTL() == 4 * time.Second
Furthermore, `.SetTime()` is added to help with `EXPIREAT` commands, and `.FastForward()` is introduced to test keys expiration.
## Commands
Implemented commands:
- Connection (complete)
- AUTH -- see RequireAuth()
- ECHO
- PING
- SELECT
- QUIT
- Key
- DEL
- EXISTS
- EXPIRE
- EXPIREAT
- KEYS
- MOVE
- PERSIST
- PEXPIRE
- PEXPIREAT
- PTTL
- RENAME
- RENAMENX
- RANDOMKEY -- call math.rand.Seed(...) once before using.
- TTL
- TYPE
- SCAN
- Transactions (complete)
- DISCARD
- EXEC
- MULTI
- UNWATCH
- WATCH
- Server
- DBSIZE
- FLUSHALL
- FLUSHDB
- TIME -- returns time.Now() or value set by SetTime()
- String keys (complete)
- APPEND
- BITCOUNT
- BITOP
- BITPOS
- DECR
- DECRBY
- GET
- GETBIT
- GETRANGE
- GETSET
- INCR
- INCRBY
- INCRBYFLOAT
- MGET
- MSET
- MSETNX
- PSETEX
- SET
- SETBIT
- SETEX
- SETNX
- SETRANGE
- STRLEN
- Hash keys (complete)
- HDEL
- HEXISTS
- HGET
- HGETALL
- HINCRBY
- HINCRBYFLOAT
- HKEYS
- HLEN
- HMGET
- HMSET
- HSET
- HSETNX
- HVALS
- HSCAN
- List keys (complete)
- BLPOP
- BRPOP
- BRPOPLPUSH
- LINDEX
- LINSERT
- LLEN
- LPOP
- LPUSH
- LPUSHX
- LRANGE
- LREM
- LSET
- LTRIM
- RPOP
- RPOPLPUSH
- RPUSH
- RPUSHX
- Set keys (complete)
- SADD
- SCARD
- SDIFF
- SDIFFSTORE
- SINTER
- SINTERSTORE
- SISMEMBER
- SMEMBERS
- SMOVE
- SPOP -- call math.rand.Seed(...) once before using.
- SRANDMEMBER -- call math.rand.Seed(...) once before using.
- SREM
- SUNION
- SUNIONSTORE
- SSCAN
- Sorted Set keys (complete)
- ZADD
- ZCARD
- ZCOUNT
- ZINCRBY
- ZINTERSTORE
- ZLEXCOUNT
- ZPOPMIN
- ZPOPMAX
- ZRANGE
- ZRANGEBYLEX
- ZRANGEBYSCORE
- ZRANK
- ZREM
- ZREMRANGEBYLEX
- ZREMRANGEBYRANK
- ZREMRANGEBYSCORE
- ZREVRANGE
- ZREVRANGEBYLEX
- ZREVRANGEBYSCORE
- ZREVRANK
- ZSCORE
- ZUNIONSTORE
- ZSCAN
- Scripting
- EVAL
- EVALSHA
- SCRIPT LOAD
- SCRIPT EXISTS
- SCRIPT FLUSH
## TTLs, key expiration, and time
Since miniredis is intended to be used in unittests TTLs don't decrease
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
key. It will return 0 when no TTL is set.
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
0 will be removed.
EXPIREAT and PEXPIREAT values will be
converted to a duration. For that you can either set m.SetTime(t) to use that
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
which case time.Now() will be used.
SetTime() also sets the value returned by TIME, which defaults to time.Now().
It is not updated by FastForward, only by SetTime.
## Example
``` Go
func TestSomething(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
panic(err)
}
defer s.Close()
// Optionally set some keys your code expects:
s.Set("foo", "bar")
s.HSet("some", "other", "key")
// Run your code and see if it behaves.
// An example using the redigo library from "github.com/gomodule/redigo/redis":
c, err := redis.Dial("tcp", s.Addr())
_, err = c.Do("SET", "foo", "bar")
// Optionally check values in redis...
if got, err := s.Get("foo"); err != nil || got != "bar" {
t.Error("'foo' has the wrong value")
}
// ... or use a helper for that:
s.CheckGet(t, "foo", "bar")
// TTL and expiration:
s.Set("foo", "bar")
s.SetTTL("foo", 10*time.Second)
s.FastForward(11 * time.Second)
if s.Exists("foo") {
t.Fatal("'foo' should not have existed anymore")
}
}
```
## Not supported
Commands which will probably not be implemented:
- CLUSTER (all)
- ~~CLUSTER *~~
- ~~READONLY~~
- ~~READWRITE~~
- GEO (all) -- unless someone needs these
- ~~GEOADD~~
- ~~GEODIST~~
- ~~GEOHASH~~
- ~~GEOPOS~~
- ~~GEORADIUS~~
- ~~GEORADIUSBYMEMBER~~
- HyperLogLog (all) -- unless someone needs these
- ~~PFADD~~
- ~~PFCOUNT~~
- ~~PFMERGE~~
- Key
- ~~DUMP~~
- ~~MIGRATE~~
- ~~OBJECT~~
- ~~RESTORE~~
- ~~WAIT~~
- Pub/Sub (all)
- ~~PSUBSCRIBE~~
- ~~PUBLISH~~
- ~~PUBSUB~~
- ~~PUNSUBSCRIBE~~
- ~~SUBSCRIBE~~
- ~~UNSUBSCRIBE~~
- Scripting
- ~~SCRIPT DEBUG~~
- ~~SCRIPT KILL~~
- Server
- ~~BGSAVE~~
- ~~BGWRITEAOF~~
- ~~CLIENT *~~
- ~~COMMAND *~~
- ~~CONFIG *~~
- ~~DEBUG *~~
- ~~INFO~~
- ~~LASTSAVE~~
- ~~MONITOR~~
- ~~ROLE~~
- ~~SAVE~~
- ~~SHUTDOWN~~
- ~~SLAVEOF~~
- ~~SLOWLOG~~
- ~~SYNC~~
## &c.
Tests are run against Redis 5.0.3. The [./integration](./integration/) subdir
compares miniredis against a real redis instance.
[![Build Status](https://travis-ci.org/alicebob/miniredis.svg?branch=master)](https://travis-ci.org/alicebob/miniredis)
[![GoDoc](https://godoc.org/github.com/alicebob/miniredis?status.svg)](https://godoc.org/github.com/alicebob/miniredis)

68
vendor/github.com/alicebob/miniredis/check.go generated vendored Normal file
View file

@ -0,0 +1,68 @@
package miniredis
// 'Fail' methods.
import (
"fmt"
"path/filepath"
"reflect"
"runtime"
"sort"
)
// T is implemented by Testing.T
type T interface {
Fail()
}
// CheckGet does not call Errorf() iff there is a string key with the
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
func (m *Miniredis) CheckGet(t T, key, expected string) {
found, err := m.Get(key)
if err != nil {
lError(t, "GET error, key %#v: %v", key, err)
return
}
if found != expected {
lError(t, "GET error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckList does not call Errorf() iff there is a list key with the
// expected values.
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
found, err := m.List(key)
if err != nil {
lError(t, "List error, key %#v: %v", key, err)
return
}
if !reflect.DeepEqual(expected, found) {
lError(t, "List error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckSet does not call Errorf() iff there is a set key with the
// expected values.
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
found, err := m.Members(key)
if err != nil {
lError(t, "Set error, key %#v: %v", key, err)
return
}
sort.Strings(expected)
if !reflect.DeepEqual(expected, found) {
lError(t, "Set error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
func lError(t T, format string, args ...interface{}) {
_, file, line, _ := runtime.Caller(2)
prefix := fmt.Sprintf("%s:%d: ", filepath.Base(file), line)
fmt.Printf(prefix+format+"\n", args...)
t.Fail()
}

96
vendor/github.com/alicebob/miniredis/cmd_connection.go generated vendored Normal file
View file

@ -0,0 +1,96 @@
// Commands from https://redis.io/commands#connection
package miniredis
import (
"strconv"
"github.com/alicebob/miniredis/server"
)
func commandsConnection(m *Miniredis) {
m.srv.Register("AUTH", m.cmdAuth)
m.srv.Register("ECHO", m.cmdEcho)
m.srv.Register("PING", m.cmdPing)
m.srv.Register("SELECT", m.cmdSelect)
m.srv.Register("QUIT", m.cmdQuit)
}
// PING
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
c.WriteInline("PONG")
}
// AUTH
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
pw := args[0]
m.Lock()
defer m.Unlock()
if m.password == "" {
c.WriteError("ERR Client sent AUTH, but no password is set")
return
}
if m.password != pw {
c.WriteError("ERR invalid password")
return
}
setAuthenticated(c)
c.WriteOK()
}
// ECHO
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
msg := args[0]
c.WriteBulk(msg)
}
// SELECT
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
id, err := strconv.Atoi(args[0])
if err != nil {
id = 0
}
m.Lock()
defer m.Unlock()
ctx := getCtx(c)
ctx.selectedDB = id
c.WriteOK()
}
// QUIT
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
// QUIT isn't transactionfied and accepts any arguments.
c.WriteOK()
c.Close()
}

479
vendor/github.com/alicebob/miniredis/cmd_generic.go generated vendored Normal file
View file

@ -0,0 +1,479 @@
// Commands from https://redis.io/commands#generic
package miniredis
import (
"math/rand"
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/server"
)
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
func commandsGeneric(m *Miniredis) {
m.srv.Register("DEL", m.cmdDel)
// DUMP
m.srv.Register("EXISTS", m.cmdExists)
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
m.srv.Register("KEYS", m.cmdKeys)
// MIGRATE
m.srv.Register("MOVE", m.cmdMove)
// OBJECT
m.srv.Register("PERSIST", m.cmdPersist)
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
m.srv.Register("PTTL", m.cmdPTTL)
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
m.srv.Register("RENAME", m.cmdRename)
m.srv.Register("RENAMENX", m.cmdRenamenx)
// RESTORE
// SORT
m.srv.Register("TTL", m.cmdTTL)
m.srv.Register("TYPE", m.cmdType)
m.srv.Register("SCAN", m.cmdScan)
}
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
// converted to a duration.
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
return func(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
value := args[1]
i, err := strconv.Atoi(value)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// Key must be present.
if _, ok := db.keys[key]; !ok {
c.WriteInt(0)
return
}
if unix {
var ts time.Time
switch d {
case time.Millisecond:
ts = time.Unix(int64(i/1000), 1000000*int64(i%1000))
case time.Second:
ts = time.Unix(int64(i), 0)
default:
panic("invalid time unit (d). Fixme!")
}
now := m.now
if now.IsZero() {
now = time.Now().UTC()
}
db.ttl[key] = ts.Sub(now)
} else {
db.ttl[key] = time.Duration(i) * d
}
db.keyVersion[key]++
db.checkTTL(key)
c.WriteInt(1)
})
}
}
// TTL
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// No such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Seconds()))
})
}
// PTTL
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Nanoseconds() / 1000000))
})
}
// PERSIST
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(0)
return
}
if _, ok := db.ttl[key]; !ok {
// no expire value
c.WriteInt(0)
return
}
delete(db.ttl, key)
db.keyVersion[key]++
c.WriteInt(1)
})
}
// DEL
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count := 0
for _, key := range args {
if db.exists(key) {
count++
}
db.del(key, true) // delete expire
}
c.WriteInt(count)
})
}
// TYPE
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError("usage error")
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInline("none")
return
}
c.WriteInline(t)
})
}
// EXISTS
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
found := 0
for _, k := range args {
if db.exists(k) {
found++
}
}
c.WriteInt(found)
})
}
// MOVE
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
targetDB, err := strconv.Atoi(args[1])
if err != nil {
targetDB = 0
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if ctx.selectedDB == targetDB {
c.WriteError("ERR source and destination objects are the same")
return
}
db := m.db(ctx.selectedDB)
targetDB := m.db(targetDB)
if !db.move(key, targetDB) {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// KEYS
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
keys := matchKeys(db.allKeys(), key)
c.WriteLen(len(keys))
for _, s := range keys {
c.WriteBulk(s)
}
})
}
// RANDOMKEY
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if len(db.keys) == 0 {
c.WriteNull()
return
}
nr := rand.Intn(len(db.keys))
for k := range db.keys {
if nr == 0 {
c.WriteBulk(k)
return
}
nr--
}
})
}
// RENAME
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
from, to := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(from) {
c.WriteError(msgKeyNotFound)
return
}
db.rename(from, to)
c.WriteOK()
})
}
// RENAMENX
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
from, to := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(from) {
c.WriteError(msgKeyNotFound)
return
}
if db.exists(to) {
c.WriteInt(0)
return
}
db.rename(from, to)
c.WriteInt(1)
})
}
// SCAN
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
cursor, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidCursor)
return
}
args = args[1:]
// MATCH and COUNT options
var withMatch bool
var match string
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
// we do nothing with count
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if _, err := strconv.Atoi(args[1]); err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withMatch = true
match, args = args[1], args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// We return _all_ (matched) keys every time.
if cursor != 0 {
// Invalid cursor.
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
keys := db.allKeys()
if withMatch {
keys = matchKeys(keys, match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(len(keys))
for _, k := range keys {
c.WriteBulk(k)
}
})
}

571
vendor/github.com/alicebob/miniredis/cmd_hash.go generated vendored Normal file
View file

@ -0,0 +1,571 @@
// Commands from https://redis.io/commands#hash
package miniredis
import (
"strconv"
"strings"
"github.com/alicebob/miniredis/server"
)
// commandsHash handles all hash value operations.
func commandsHash(m *Miniredis) {
m.srv.Register("HDEL", m.cmdHdel)
m.srv.Register("HEXISTS", m.cmdHexists)
m.srv.Register("HGET", m.cmdHget)
m.srv.Register("HGETALL", m.cmdHgetall)
m.srv.Register("HINCRBY", m.cmdHincrby)
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
m.srv.Register("HKEYS", m.cmdHkeys)
m.srv.Register("HLEN", m.cmdHlen)
m.srv.Register("HMGET", m.cmdHmget)
m.srv.Register("HMSET", m.cmdHmset)
m.srv.Register("HSET", m.cmdHset)
m.srv.Register("HSETNX", m.cmdHsetnx)
m.srv.Register("HVALS", m.cmdHvals)
m.srv.Register("HSCAN", m.cmdHscan)
}
// HSET
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field, value := args[0], args[1], args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
if db.hashSet(key, field, value) {
c.WriteInt(0)
} else {
c.WriteInt(1)
}
})
}
// HSETNX
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field, value := args[0], args[1], args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[key]; !ok {
db.hashKeys[key] = map[string]string{}
db.keys[key] = "hash"
}
_, ok := db.hashKeys[key][field]
if ok {
c.WriteInt(0)
return
}
db.hashKeys[key][field] = value
db.keyVersion[key]++
c.WriteInt(1)
})
}
// HMSET
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, args := args[0], args[1:]
if len(args)%2 != 0 {
setDirty(c)
// non-default error message
c.WriteError("ERR wrong number of arguments for HMSET")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
for len(args) > 0 {
field, value := args[0], args[1]
args = args[2:]
db.hashSet(key, field, value)
}
c.WriteOK()
})
}
// HGET
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteNull()
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
value, ok := db.hashKeys[key][field]
if !ok {
c.WriteNull()
return
}
c.WriteBulk(value)
})
}
// HDEL
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, fields := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No key is zero deleted
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
deleted := 0
for _, f := range fields {
_, ok := db.hashKeys[key][f]
if !ok {
continue
}
delete(db.hashKeys[key], f)
deleted++
}
c.WriteInt(deleted)
// Nothing left. Remove the whole key.
if len(db.hashKeys[key]) == 0 {
db.del(key, true)
}
})
}
// HEXISTS
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[key][field]; !ok {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// HGETALL
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteLen(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
c.WriteLen(len(db.hashKeys[key]) * 2)
for _, k := range db.hashFields(key) {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(key, k))
}
})
}
// HKEYS
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteLen(0)
return
}
if db.t(key) != "hash" {
c.WriteError(msgWrongType)
return
}
fields := db.hashFields(key)
c.WriteLen(len(fields))
for _, f := range fields {
c.WriteBulk(f)
}
})
}
// HVALS
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteLen(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
c.WriteLen(len(db.hashKeys[key]))
for _, v := range db.hashKeys[key] {
c.WriteBulk(v)
}
})
}
// HLEN
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
c.WriteInt(len(db.hashKeys[key]))
})
}
// HMGET
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
f, ok := db.hashKeys[key]
if !ok {
f = map[string]string{}
}
c.WriteLen(len(args) - 1)
for _, k := range args[1:] {
v, ok := f[k]
if !ok {
c.WriteNull()
continue
}
c.WriteBulk(v)
}
})
}
// HINCRBY
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field, deltas := args[0], args[1], args[2]
delta, err := strconv.Atoi(deltas)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncr(key, field, delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(v)
})
}
// HINCRBYFLOAT
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, field, deltas := args[0], args[1], args[2]
delta, err := strconv.ParseFloat(deltas, 64)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidFloat)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncrfloat(key, field, delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteBulk(formatFloat(v))
})
}
// HSCAN
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
cursor, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidCursor)
return
}
args = args[2:]
// MATCH and COUNT options
var withMatch bool
var match string
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
// we do nothing with count
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
_, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withMatch = true
match, args = args[1], args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// return _all_ (matched) keys every time
if cursor != 0 {
// Invalid cursor.
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
if db.exists(key) && db.t(key) != "hash" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.hashFields(key)
if withMatch {
members = matchKeys(members, match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
// HSCAN gives key, values.
c.WriteLen(len(members) * 2)
for _, k := range members {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(key, k))
}
})
}

687
vendor/github.com/alicebob/miniredis/cmd_list.go generated vendored Normal file
View file

@ -0,0 +1,687 @@
// Commands from https://redis.io/commands#list
package miniredis
import (
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/server"
)
type leftright int
const (
left leftright = iota
right
)
// commandsList handles list commands (mostly L*)
func commandsList(m *Miniredis) {
m.srv.Register("BLPOP", m.cmdBlpop)
m.srv.Register("BRPOP", m.cmdBrpop)
m.srv.Register("BRPOPLPUSH", m.cmdBrpoplpush)
m.srv.Register("LINDEX", m.cmdLindex)
m.srv.Register("LINSERT", m.cmdLinsert)
m.srv.Register("LLEN", m.cmdLlen)
m.srv.Register("LPOP", m.cmdLpop)
m.srv.Register("LPUSH", m.cmdLpush)
m.srv.Register("LPUSHX", m.cmdLpushx)
m.srv.Register("LRANGE", m.cmdLrange)
m.srv.Register("LREM", m.cmdLrem)
m.srv.Register("LSET", m.cmdLset)
m.srv.Register("LTRIM", m.cmdLtrim)
m.srv.Register("RPOP", m.cmdRpop)
m.srv.Register("RPOPLPUSH", m.cmdRpoplpush)
m.srv.Register("RPUSH", m.cmdRpush)
m.srv.Register("RPUSHX", m.cmdRpushx)
}
// BLPOP
func (m *Miniredis) cmdBlpop(c *server.Peer, cmd string, args []string) {
m.cmdBXpop(c, cmd, args, left)
}
// BRPOP
func (m *Miniredis) cmdBrpop(c *server.Peer, cmd string, args []string) {
m.cmdBXpop(c, cmd, args, right)
}
func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
timeoutS := args[len(args)-1]
keys := args[:len(args)-1]
timeout, err := strconv.Atoi(timeoutS)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidTimeout)
return
}
if timeout < 0 {
setDirty(c)
c.WriteError(msgNegTimeout)
return
}
blocking(
m,
c,
time.Duration(timeout)*time.Second,
func(c *server.Peer, ctx *connCtx) bool {
db := m.db(ctx.selectedDB)
for _, key := range keys {
if !db.exists(key) {
continue
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return true
}
if len(db.listKeys[key]) == 0 {
continue
}
c.WriteLen(2)
c.WriteBulk(key)
var v string
switch lr {
case left:
v = db.listLpop(key)
case right:
v = db.listPop(key)
}
c.WriteBulk(v)
return true
}
return false
},
func(c *server.Peer) {
// timeout
c.WriteNull()
},
)
}
// LINDEX
func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, offsets := args[0], args[1]
offset, err := strconv.Atoi(offsets)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key
c.WriteNull()
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
if offset < 0 {
offset = len(l) + offset
}
if offset < 0 || offset > len(l)-1 {
c.WriteNull()
return
}
c.WriteBulk(l[offset])
})
}
// LINSERT
func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) {
if len(args) != 4 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
where := 0
switch strings.ToLower(args[1]) {
case "before":
where = -1
case "after":
where = +1
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
pivot := args[2]
value := args[3]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key
c.WriteInt(0)
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
for i, el := range l {
if el != pivot {
continue
}
if where < 0 {
l = append(l[:i], append(listKey{value}, l[i:]...)...)
} else {
if i == len(l)-1 {
l = append(l, value)
} else {
l = append(l[:i+1], append(listKey{value}, l[i+1:]...)...)
}
}
db.listKeys[key] = l
db.keyVersion[key]++
c.WriteInt(len(l))
return
}
c.WriteInt(-1)
})
}
// LLEN
func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key. That's zero length.
c.WriteInt(0)
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
c.WriteInt(len(db.listKeys[key]))
})
}
// LPOP
func (m *Miniredis) cmdLpop(c *server.Peer, cmd string, args []string) {
m.cmdXpop(c, cmd, args, left)
}
// RPOP
func (m *Miniredis) cmdRpop(c *server.Peer, cmd string, args []string) {
m.cmdXpop(c, cmd, args, right)
}
func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
// non-existing key is fine
c.WriteNull()
return
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
var elem string
switch lr {
case left:
elem = db.listLpop(key)
case right:
elem = db.listPop(key)
}
c.WriteBulk(elem)
})
}
// LPUSH
func (m *Miniredis) cmdLpush(c *server.Peer, cmd string, args []string) {
m.cmdXpush(c, cmd, args, left)
}
// RPUSH
func (m *Miniredis) cmdRpush(c *server.Peer, cmd string, args []string) {
m.cmdXpush(c, cmd, args, right)
}
func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
var newLen int
for _, value := range args {
switch lr {
case left:
newLen = db.listLpush(key, value)
case right:
newLen = db.listPush(key, value)
}
}
c.WriteInt(newLen)
})
}
// LPUSHX
func (m *Miniredis) cmdLpushx(c *server.Peer, cmd string, args []string) {
m.cmdXpushx(c, cmd, args, left)
}
// RPUSHX
func (m *Miniredis) cmdRpushx(c *server.Peer, cmd string, args []string) {
m.cmdXpushx(c, cmd, args, right)
}
func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
var newLen int
for _, value := range args {
switch lr {
case left:
newLen = db.listLpush(key, value)
case right:
newLen = db.listPush(key, value)
}
}
c.WriteInt(newLen)
})
}
// LRANGE
func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
start, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
end, err := strconv.Atoi(args[2])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
if len(l) == 0 {
c.WriteLen(0)
return
}
rs, re := redisRange(len(l), start, end, false)
c.WriteLen(re - rs)
for _, el := range l[rs:re] {
c.WriteBulk(el)
}
})
}
// LREM
func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
count, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
value := args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
if count < 0 {
reverseSlice(l)
}
deleted := 0
newL := []string{}
toDelete := len(l)
if count < 0 {
toDelete = -count
}
if count > 0 {
toDelete = count
}
for _, el := range l {
if el == value {
if toDelete > 0 {
deleted++
toDelete--
continue
}
}
newL = append(newL, el)
}
if count < 0 {
reverseSlice(newL)
}
if len(newL) == 0 {
db.del(key, true)
} else {
db.listKeys[key] = newL
db.keyVersion[key]++
}
c.WriteInt(deleted)
})
}
// LSET
func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
index, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
value := args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteError(msgKeyNotFound)
return
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
if index < 0 {
index = len(l) + index
}
if index < 0 || index > len(l)-1 {
c.WriteError(msgOutOfRange)
return
}
l[index] = value
db.keyVersion[key]++
c.WriteOK()
})
}
// LTRIM
func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
start, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
end, err := strconv.Atoi(args[2])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteOK()
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
rs, re := redisRange(len(l), start, end, false)
l = l[rs:re]
if len(l) == 0 {
db.del(key, true)
} else {
db.listKeys[key] = l
db.keyVersion[key]++
}
c.WriteOK()
})
}
// RPOPLPUSH
func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
src, dst := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(src) {
c.WriteNull()
return
}
if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") {
c.WriteError(msgWrongType)
return
}
elem := db.listPop(src)
db.listLpush(dst, elem)
c.WriteBulk(elem)
})
}
// BRPOPLPUSH
func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
src := args[0]
dst := args[1]
timeout, err := strconv.Atoi(args[2])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidTimeout)
return
}
if timeout < 0 {
setDirty(c)
c.WriteError(msgNegTimeout)
return
}
blocking(
m,
c,
time.Duration(timeout)*time.Second,
func(c *server.Peer, ctx *connCtx) bool {
db := m.db(ctx.selectedDB)
if !db.exists(src) {
return false
}
if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") {
c.WriteError(msgWrongType)
return true
}
if len(db.listKeys[src]) == 0 {
return false
}
elem := db.listPop(src)
db.listLpush(dst, elem)
c.WriteBulk(elem)
return true
},
func(c *server.Peer) {
// timeout
c.WriteNull()
},
)
}

220
vendor/github.com/alicebob/miniredis/cmd_scripting.go generated vendored Normal file
View file

@ -0,0 +1,220 @@
package miniredis
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
luajson "github.com/alicebob/gopher-json"
"github.com/yuin/gopher-lua"
"github.com/yuin/gopher-lua/parse"
"github.com/alicebob/miniredis/server"
)
func commandsScripting(m *Miniredis) {
m.srv.Register("EVAL", m.cmdEval)
m.srv.Register("EVALSHA", m.cmdEvalsha)
m.srv.Register("SCRIPT", m.cmdScript)
}
// Execute lua. Needs to run m.Lock()ed, from within withTx().
func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
l := lua.NewState(lua.Options{SkipOpenLibs: true})
defer l.Close()
// Taken from the go-lua manual
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage},
{lua.BaseLibName, lua.OpenBase},
{lua.CoroutineLibName, lua.OpenCoroutine},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
} {
if err := l.CallByParam(lua.P{
Fn: l.NewFunction(pair.f),
NRet: 0,
Protect: true,
}, lua.LString(pair.n)); err != nil {
panic(err)
}
}
luajson.Preload(l)
requireGlobal(l, "cjson", "json")
m.Unlock()
conn := m.redigo()
m.Lock()
defer conn.Close()
// set global variable KEYS
keysTable := l.NewTable()
keysS, args := args[0], args[1:]
keysLen, err := strconv.Atoi(keysS)
if err != nil {
c.WriteError(msgInvalidInt)
return
}
if keysLen < 0 {
c.WriteError(msgNegativeKeysNumber)
return
}
if keysLen > len(args) {
c.WriteError(msgInvalidKeysNumber)
return
}
keys, args := args[:keysLen], args[keysLen:]
for i, k := range keys {
l.RawSet(keysTable, lua.LNumber(i+1), lua.LString(k))
}
l.SetGlobal("KEYS", keysTable)
argvTable := l.NewTable()
for i, a := range args {
l.RawSet(argvTable, lua.LNumber(i+1), lua.LString(a))
}
l.SetGlobal("ARGV", argvTable)
redisFuncs := mkLuaFuncs(conn)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
l.Push(mod)
return 1
}))
l.Push(lua.LString("redis"))
l.Call(1, 0)
m.Unlock() // This runs in a transaction, but can access our db recursively
defer m.Lock()
if err := l.DoString(script); err != nil {
c.WriteError(errLuaParseError(err))
return
}
luaToRedis(l, c, l.Get(1))
}
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
script, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
m.runLuaScript(c, script, args)
})
}
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
sha, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
script, ok := m.scripts[sha]
if !ok {
c.WriteError(msgNoScriptFound)
return
}
m.runLuaScript(c, script, args)
})
}
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
subcmd, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
switch strings.ToLower(subcmd) {
case "load":
if len(args) != 1 {
c.WriteError(fmt.Sprintf(msgFScriptUsage, "LOAD"))
return
}
script := args[0]
if _, err := parse.Parse(strings.NewReader(script), "user_script"); err != nil {
c.WriteError(errLuaParseError(err))
return
}
sha := sha1Hex(script)
m.scripts[sha] = script
c.WriteBulk(sha)
case "exists":
c.WriteLen(len(args))
for _, arg := range args {
if _, ok := m.scripts[arg]; ok {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
}
case "flush":
if len(args) != 0 {
c.WriteError(fmt.Sprintf(msgFScriptUsage, "FLUSH"))
return
}
m.scripts = map[string]string{}
c.WriteOK()
default:
c.WriteError(fmt.Sprintf(msgFScriptUsage, strings.ToUpper(subcmd)))
}
})
}
func sha1Hex(s string) string {
h := sha1.New()
io.WriteString(h, s)
return hex.EncodeToString(h.Sum(nil))
}
// requireGlobal imports module modName into the global namespace with the
// identifier id. panics if an error results from the function execution
func requireGlobal(l *lua.LState, id, modName string) {
if err := l.CallByParam(lua.P{
Fn: l.GetGlobal("require"),
NRet: 1,
Protect: true,
}, lua.LString(modName)); err != nil {
panic(err)
}
mod := l.Get(-1)
l.Pop(1)
l.SetGlobal(id, mod)
}

104
vendor/github.com/alicebob/miniredis/cmd_server.go generated vendored Normal file
View file

@ -0,0 +1,104 @@
// Commands from https://redis.io/commands#server
package miniredis
import (
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/server"
)
func commandsServer(m *Miniredis) {
m.srv.Register("DBSIZE", m.cmdDbsize)
m.srv.Register("FLUSHALL", m.cmdFlushall)
m.srv.Register("FLUSHDB", m.cmdFlushdb)
m.srv.Register("TIME", m.cmdTime)
}
// DBSIZE
func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) {
if len(args) > 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
c.WriteInt(len(db.keys))
})
}
// FLUSHALL
func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) {
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
m.flushAll()
c.WriteOK()
})
}
// FLUSHDB
func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) {
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
m.db(ctx.selectedDB).flush()
c.WriteOK()
})
}
// TIME: time values are returned in string format instead of int
func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) {
if len(args) > 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
now := m.now
if now.IsZero() {
now = time.Now()
}
nanos := now.UnixNano()
seconds := nanos / 1000000000
microseconds := (nanos / 1000) % 1000000
c.WriteLen(2)
c.WriteBulk(strconv.FormatInt(seconds, 10))
c.WriteBulk(strconv.FormatInt(microseconds, 10))
})
}

639
vendor/github.com/alicebob/miniredis/cmd_set.go generated vendored Normal file
View file

@ -0,0 +1,639 @@
// Commands from https://redis.io/commands#set
package miniredis
import (
"math/rand"
"strconv"
"strings"
"github.com/alicebob/miniredis/server"
)
// commandsSet handles all set value operations.
func commandsSet(m *Miniredis) {
m.srv.Register("SADD", m.cmdSadd)
m.srv.Register("SCARD", m.cmdScard)
m.srv.Register("SDIFF", m.cmdSdiff)
m.srv.Register("SDIFFSTORE", m.cmdSdiffstore)
m.srv.Register("SINTER", m.cmdSinter)
m.srv.Register("SINTERSTORE", m.cmdSinterstore)
m.srv.Register("SISMEMBER", m.cmdSismember)
m.srv.Register("SMEMBERS", m.cmdSmembers)
m.srv.Register("SMOVE", m.cmdSmove)
m.srv.Register("SPOP", m.cmdSpop)
m.srv.Register("SRANDMEMBER", m.cmdSrandmember)
m.srv.Register("SREM", m.cmdSrem)
m.srv.Register("SUNION", m.cmdSunion)
m.srv.Register("SUNIONSTORE", m.cmdSunionstore)
m.srv.Register("SSCAN", m.cmdSscan)
}
// SADD
func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, elems := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
added := db.setAdd(key, elems...)
c.WriteInt(added)
})
}
// SCARD
func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
c.WriteInt(len(members))
})
}
// SDIFF
func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setDiff(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SDIFFSTORE
func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setDiff(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SINTER
func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setInter(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SINTERSTORE
func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setInter(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SISMEMBER
func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, value := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if db.setIsMember(key, value) {
c.WriteInt(1)
return
}
c.WriteInt(0)
})
}
// SMEMBERS
func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteLen(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
c.WriteLen(len(members))
for _, elem := range members {
c.WriteBulk(elem)
}
})
}
// SMOVE
func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
src, dst, member := args[0], args[1], args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(src) {
c.WriteInt(0)
return
}
if db.t(src) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if db.exists(dst) && db.t(dst) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if !db.setIsMember(src, member) {
c.WriteInt(0)
return
}
db.setRem(src, member)
db.setAdd(dst, member)
c.WriteInt(1)
})
}
// SPOP
func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
withCount := false
count := 1
if len(args) > 0 {
v, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
count = v
withCount = true
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if !db.exists(key) {
if !withCount {
c.WriteNull()
return
}
c.WriteLen(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
var deleted []string
for i := 0; i < count; i++ {
members := db.setMembers(key)
if len(members) == 0 {
break
}
member := members[rand.Intn(len(members))]
db.setRem(key, member)
deleted = append(deleted, member)
}
// without `count` return a single value...
if !withCount {
if len(deleted) == 0 {
c.WriteNull()
return
}
c.WriteBulk(deleted[0])
return
}
// ... with `count` return a list
c.WriteLen(len(deleted))
for _, v := range deleted {
c.WriteBulk(v)
}
})
}
// SRANDMEMBER
func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if len(args) > 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
count := 0
withCount := false
if len(args) == 2 {
var err error
count, err = strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withCount = true
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteNull()
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
if count < 0 {
// Non-unique elements is allowed with negative count.
c.WriteLen(-count)
for count != 0 {
member := members[rand.Intn(len(members))]
c.WriteBulk(member)
count++
}
return
}
// Must be unique elements.
shuffle(members)
if count > len(members) {
count = len(members)
}
if !withCount {
c.WriteBulk(members[0])
return
}
c.WriteLen(count)
for i := range make([]struct{}, count) {
c.WriteBulk(members[i])
}
})
}
// SREM
func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key, fields := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
c.WriteInt(db.setRem(key, fields...))
})
}
// SUNION
func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setUnion(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SUNIONSTORE
func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setUnion(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SSCAN
func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
key := args[0]
cursor, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidCursor)
return
}
args = args[2:]
// MATCH and COUNT options
var withMatch bool
var match string
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
_, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
// We do nothing with count.
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withMatch = true
match = args[1]
args = args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// return _all_ (matched) keys every time
if cursor != 0 {
// invalid cursor
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
if db.exists(key) && db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
if withMatch {
members = matchKeys(members, match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(len(members))
for _, k := range members {
c.WriteBulk(k)
}
})
}
// shuffle shuffles a string. Kinda.
func shuffle(m []string) {
for _ = range m {
i := rand.Intn(len(m))
j := rand.Intn(len(m))
m[i], m[j] = m[j], m[i]
}
}

1399
vendor/github.com/alicebob/miniredis/cmd_sorted_set.go generated vendored Normal file

File diff suppressed because it is too large Load diff

1075
vendor/github.com/alicebob/miniredis/cmd_string.go generated vendored Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,152 @@
// Commands from https://redis.io/commands#transactions
package miniredis
import (
"github.com/alicebob/miniredis/server"
)
// commandsTransaction handles MULTI &c.
func commandsTransaction(m *Miniredis) {
m.srv.Register("DISCARD", m.cmdDiscard)
m.srv.Register("EXEC", m.cmdExec)
m.srv.Register("MULTI", m.cmdMulti)
m.srv.Register("UNWATCH", m.cmdUnwatch)
m.srv.Register("WATCH", m.cmdWatch)
}
// MULTI
func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if inTx(ctx) {
c.WriteError("ERR MULTI calls can not be nested")
return
}
startTx(ctx)
c.WriteOK()
}
// EXEC
func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if !inTx(ctx) {
c.WriteError("ERR EXEC without MULTI")
return
}
if ctx.dirtyTransaction {
c.WriteError("EXECABORT Transaction discarded because of previous errors.")
return
}
m.Lock()
defer m.Unlock()
// Check WATCHed keys.
for t, version := range ctx.watch {
if m.db(t.db).keyVersion[t.key] > version {
// Abort! Abort!
stopTx(ctx)
c.WriteLen(0)
return
}
}
c.WriteLen(len(ctx.transaction))
for _, cb := range ctx.transaction {
cb(c, ctx)
}
// wake up anyone who waits on anything.
m.signal.Broadcast()
stopTx(ctx)
}
// DISCARD
func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if !inTx(ctx) {
c.WriteError("ERR DISCARD without MULTI")
return
}
stopTx(ctx)
c.WriteOK()
}
// WATCH
func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if inTx(ctx) {
c.WriteError("ERR WATCH in MULTI")
return
}
m.Lock()
defer m.Unlock()
db := m.db(ctx.selectedDB)
for _, key := range args {
watch(db, ctx, key)
}
c.WriteOK()
}
// UNWATCH
func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
// Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me.
unwatch(getCtx(c))
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
// Do nothing if it's called in a transaction.
c.WriteOK()
})
}

551
vendor/github.com/alicebob/miniredis/db.go generated vendored Normal file
View file

@ -0,0 +1,551 @@
package miniredis
import (
"sort"
"strconv"
"time"
)
func (db *RedisDB) exists(k string) bool {
_, ok := db.keys[k]
return ok
}
// t gives the type of a key, or ""
func (db *RedisDB) t(k string) string {
return db.keys[k]
}
// allKeys returns all keys. Sorted.
func (db *RedisDB) allKeys() []string {
res := make([]string, 0, len(db.keys))
for k := range db.keys {
res = append(res, k)
}
sort.Strings(res) // To make things deterministic.
return res
}
// flush removes all keys and values.
func (db *RedisDB) flush() {
db.keys = map[string]string{}
db.stringKeys = map[string]string{}
db.hashKeys = map[string]hashKey{}
db.listKeys = map[string]listKey{}
db.setKeys = map[string]setKey{}
db.sortedsetKeys = map[string]sortedSet{}
db.ttl = map[string]time.Duration{}
}
// move something to another db. Will return ok. Or not.
func (db *RedisDB) move(key string, to *RedisDB) bool {
if _, ok := to.keys[key]; ok {
return false
}
t, ok := db.keys[key]
if !ok {
return false
}
to.keys[key] = db.keys[key]
switch t {
case "string":
to.stringKeys[key] = db.stringKeys[key]
case "hash":
to.hashKeys[key] = db.hashKeys[key]
case "list":
to.listKeys[key] = db.listKeys[key]
case "set":
to.setKeys[key] = db.setKeys[key]
case "zset":
to.sortedsetKeys[key] = db.sortedsetKeys[key]
default:
panic("unhandled key type")
}
to.keyVersion[key]++
if v, ok := db.ttl[key]; ok {
to.ttl[key] = v
}
db.del(key, true)
return true
}
func (db *RedisDB) rename(from, to string) {
db.del(to, true)
switch db.t(from) {
case "string":
db.stringKeys[to] = db.stringKeys[from]
case "hash":
db.hashKeys[to] = db.hashKeys[from]
case "list":
db.listKeys[to] = db.listKeys[from]
case "set":
db.setKeys[to] = db.setKeys[from]
case "zset":
db.sortedsetKeys[to] = db.sortedsetKeys[from]
default:
panic("missing case")
}
db.keys[to] = db.keys[from]
db.keyVersion[to]++
db.ttl[to] = db.ttl[from]
db.del(from, true)
}
func (db *RedisDB) del(k string, delTTL bool) {
if !db.exists(k) {
return
}
t := db.t(k)
delete(db.keys, k)
db.keyVersion[k]++
if delTTL {
delete(db.ttl, k)
}
switch t {
case "string":
delete(db.stringKeys, k)
case "hash":
delete(db.hashKeys, k)
case "list":
delete(db.listKeys, k)
case "set":
delete(db.setKeys, k)
case "zset":
delete(db.sortedsetKeys, k)
default:
panic("Unknown key type: " + t)
}
}
// stringGet returns the string key or "" on error/nonexists.
func (db *RedisDB) stringGet(k string) string {
if t, ok := db.keys[k]; !ok || t != "string" {
return ""
}
return db.stringKeys[k]
}
// stringSet force set()s a key. Does not touch expire.
func (db *RedisDB) stringSet(k, v string) {
db.del(k, false)
db.keys[k] = "string"
db.stringKeys[k] = v
db.keyVersion[k]++
}
// change int key value
func (db *RedisDB) stringIncr(k string, delta int) (int, error) {
v := 0
if sv, ok := db.stringKeys[k]; ok {
var err error
v, err = strconv.Atoi(sv)
if err != nil {
return 0, ErrIntValueError
}
}
v += delta
db.stringSet(k, strconv.Itoa(v))
return v, nil
}
// change float key value
func (db *RedisDB) stringIncrfloat(k string, delta float64) (float64, error) {
v := 0.0
if sv, ok := db.stringKeys[k]; ok {
var err error
v, err = strconv.ParseFloat(sv, 64)
if err != nil {
return 0, ErrFloatValueError
}
}
v += delta
db.stringSet(k, formatFloat(v))
return v, nil
}
// listLpush is 'left push', aka unshift. Returns the new length.
func (db *RedisDB) listLpush(k, v string) int {
l, ok := db.listKeys[k]
if !ok {
db.keys[k] = "list"
}
l = append([]string{v}, l...)
db.listKeys[k] = l
db.keyVersion[k]++
return len(l)
}
// 'left pop', aka shift.
func (db *RedisDB) listLpop(k string) string {
l := db.listKeys[k]
el := l[0]
l = l[1:]
if len(l) == 0 {
db.del(k, true)
} else {
db.listKeys[k] = l
}
db.keyVersion[k]++
return el
}
func (db *RedisDB) listPush(k string, v ...string) int {
l, ok := db.listKeys[k]
if !ok {
db.keys[k] = "list"
}
l = append(l, v...)
db.listKeys[k] = l
db.keyVersion[k]++
return len(l)
}
func (db *RedisDB) listPop(k string) string {
l := db.listKeys[k]
el := l[len(l)-1]
l = l[:len(l)-1]
if len(l) == 0 {
db.del(k, true)
} else {
db.listKeys[k] = l
db.keyVersion[k]++
}
return el
}
// setset replaces a whole set.
func (db *RedisDB) setSet(k string, set setKey) {
db.keys[k] = "set"
db.setKeys[k] = set
db.keyVersion[k]++
}
// setadd adds members to a set. Returns nr of new keys.
func (db *RedisDB) setAdd(k string, elems ...string) int {
s, ok := db.setKeys[k]
if !ok {
s = setKey{}
db.keys[k] = "set"
}
added := 0
for _, e := range elems {
if _, ok := s[e]; !ok {
added++
}
s[e] = struct{}{}
}
db.setKeys[k] = s
db.keyVersion[k]++
return added
}
// setrem removes members from a set. Returns nr of deleted keys.
func (db *RedisDB) setRem(k string, fields ...string) int {
s, ok := db.setKeys[k]
if !ok {
return 0
}
removed := 0
for _, f := range fields {
if _, ok := s[f]; ok {
removed++
delete(s, f)
}
}
if len(s) == 0 {
db.del(k, true)
} else {
db.setKeys[k] = s
}
db.keyVersion[k]++
return removed
}
// All members of a set.
func (db *RedisDB) setMembers(k string) []string {
set := db.setKeys[k]
members := make([]string, 0, len(set))
for k := range set {
members = append(members, k)
}
sort.Strings(members)
return members
}
// Is a SET value present?
func (db *RedisDB) setIsMember(k, v string) bool {
set, ok := db.setKeys[k]
if !ok {
return false
}
_, ok = set[v]
return ok
}
// hashFields returns all (sorted) keys ('fields') for a hash key.
func (db *RedisDB) hashFields(k string) []string {
v := db.hashKeys[k]
r := make([]string, 0, len(v))
for k := range v {
r = append(r, k)
}
sort.Strings(r)
return r
}
// hashGet a value
func (db *RedisDB) hashGet(key, field string) string {
return db.hashKeys[key][field]
}
// hashSet returns whether the key already existed
func (db *RedisDB) hashSet(k, f, v string) bool {
if t, ok := db.keys[k]; ok && t != "hash" {
db.del(k, true)
}
db.keys[k] = "hash"
if _, ok := db.hashKeys[k]; !ok {
db.hashKeys[k] = map[string]string{}
}
_, ok := db.hashKeys[k][f]
db.hashKeys[k][f] = v
db.keyVersion[k]++
return ok
}
// hashIncr changes int key value
func (db *RedisDB) hashIncr(key, field string, delta int) (int, error) {
v := 0
if h, ok := db.hashKeys[key]; ok {
if f, ok := h[field]; ok {
var err error
v, err = strconv.Atoi(f)
if err != nil {
return 0, ErrIntValueError
}
}
}
v += delta
db.hashSet(key, field, strconv.Itoa(v))
return v, nil
}
// hashIncrfloat changes float key value
func (db *RedisDB) hashIncrfloat(key, field string, delta float64) (float64, error) {
v := 0.0
if h, ok := db.hashKeys[key]; ok {
if f, ok := h[field]; ok {
var err error
v, err = strconv.ParseFloat(f, 64)
if err != nil {
return 0, ErrFloatValueError
}
}
}
v += delta
db.hashSet(key, field, formatFloat(v))
return v, nil
}
// sortedSet set returns a sortedSet as map
func (db *RedisDB) sortedSet(key string) map[string]float64 {
ss := db.sortedsetKeys[key]
return map[string]float64(ss)
}
// ssetSet sets a complete sorted set.
func (db *RedisDB) ssetSet(key string, sset sortedSet) {
db.keys[key] = "zset"
db.keyVersion[key]++
db.sortedsetKeys[key] = sset
}
// ssetAdd adds member to a sorted set. Returns whether this was a new member.
func (db *RedisDB) ssetAdd(key string, score float64, member string) bool {
ss, ok := db.sortedsetKeys[key]
if !ok {
ss = newSortedSet()
db.keys[key] = "zset"
}
_, ok = ss[member]
ss[member] = score
db.sortedsetKeys[key] = ss
db.keyVersion[key]++
return !ok
}
// All members from a sorted set, ordered by score.
func (db *RedisDB) ssetMembers(key string) []string {
ss, ok := db.sortedsetKeys[key]
if !ok {
return nil
}
elems := ss.byScore(asc)
members := make([]string, 0, len(elems))
for _, e := range elems {
members = append(members, e.member)
}
return members
}
// All members+scores from a sorted set, ordered by score.
func (db *RedisDB) ssetElements(key string) ssElems {
ss, ok := db.sortedsetKeys[key]
if !ok {
return nil
}
return ss.byScore(asc)
}
// ssetCard is the sorted set cardinality.
func (db *RedisDB) ssetCard(key string) int {
ss := db.sortedsetKeys[key]
return ss.card()
}
// ssetRank is the sorted set rank.
func (db *RedisDB) ssetRank(key, member string, d direction) (int, bool) {
ss := db.sortedsetKeys[key]
return ss.rankByScore(member, d)
}
// ssetScore is sorted set score.
func (db *RedisDB) ssetScore(key, member string) float64 {
ss := db.sortedsetKeys[key]
return ss[member]
}
// ssetRem is sorted set key delete.
func (db *RedisDB) ssetRem(key, member string) bool {
ss := db.sortedsetKeys[key]
_, ok := ss[member]
delete(ss, member)
if len(ss) == 0 {
// Delete key on removal of last member
db.del(key, true)
}
return ok
}
// ssetExists tells if a member exists in a sorted set.
func (db *RedisDB) ssetExists(key, member string) bool {
ss := db.sortedsetKeys[key]
_, ok := ss[member]
return ok
}
// ssetIncrby changes float sorted set score.
func (db *RedisDB) ssetIncrby(k, m string, delta float64) float64 {
ss, ok := db.sortedsetKeys[k]
if !ok {
ss = newSortedSet()
db.keys[k] = "zset"
db.sortedsetKeys[k] = ss
}
v, _ := ss.get(m)
v += delta
ss.set(v, m)
db.keyVersion[k]++
return v
}
// setDiff implements the logic behind SDIFF*
func (db *RedisDB) setDiff(keys []string) (setKey, error) {
key := keys[0]
keys = keys[1:]
if db.exists(key) && db.t(key) != "set" {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
continue
}
if db.t(sk) != "set" {
return nil, ErrWrongType
}
for e := range db.setKeys[sk] {
delete(s, e)
}
}
return s, nil
}
// setInter implements the logic behind SINTER*
func (db *RedisDB) setInter(keys []string) (setKey, error) {
key := keys[0]
keys = keys[1:]
if !db.exists(key) {
return setKey{}, nil
}
if db.t(key) != "set" {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
return setKey{}, nil
}
if db.t(sk) != "set" {
return nil, ErrWrongType
}
other := db.setKeys[sk]
for e := range s {
if _, ok := other[e]; ok {
continue
}
delete(s, e)
}
}
return s, nil
}
// setUnion implements the logic behind SUNION*
func (db *RedisDB) setUnion(keys []string) (setKey, error) {
key := keys[0]
keys = keys[1:]
if db.exists(key) && db.t(key) != "set" {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
continue
}
if db.t(sk) != "set" {
return nil, ErrWrongType
}
for e := range db.setKeys[sk] {
s[e] = struct{}{}
}
}
return s, nil
}
// fastForward proceeds the current timestamp with duration, works as a time machine
func (db *RedisDB) fastForward(duration time.Duration) {
for _, key := range db.allKeys() {
if value, ok := db.ttl[key]; ok {
db.ttl[key] = value - duration
db.checkTTL(key)
}
}
}
func (db *RedisDB) checkTTL(key string) {
if v, ok := db.ttl[key]; ok && v <= 0 {
db.del(key, true)
}
}

549
vendor/github.com/alicebob/miniredis/direct.go generated vendored Normal file
View file

@ -0,0 +1,549 @@
package miniredis
// Commands to modify and query our databases directly.
import (
"errors"
"time"
)
var (
// ErrKeyNotFound is returned when a key doesn't exist.
ErrKeyNotFound = errors.New(msgKeyNotFound)
// ErrWrongType when a key is not the right type.
ErrWrongType = errors.New(msgWrongType)
// ErrIntValueError can returned by INCRBY
ErrIntValueError = errors.New(msgInvalidInt)
// ErrFloatValueError can returned by INCRBYFLOAT
ErrFloatValueError = errors.New(msgInvalidFloat)
)
// Select sets the DB id for all direct commands.
func (m *Miniredis) Select(i int) {
m.Lock()
defer m.Unlock()
m.selectedDB = i
}
// Keys returns all keys from the selected database, sorted.
func (m *Miniredis) Keys() []string {
return m.DB(m.selectedDB).Keys()
}
// Keys returns all keys, sorted.
func (db *RedisDB) Keys() []string {
db.master.Lock()
defer db.master.Unlock()
return db.allKeys()
}
// FlushAll removes all keys from all databases.
func (m *Miniredis) FlushAll() {
m.Lock()
defer m.Unlock()
m.flushAll()
}
func (m *Miniredis) flushAll() {
for _, db := range m.dbs {
db.flush()
}
}
// FlushDB removes all keys from the selected database.
func (m *Miniredis) FlushDB() {
m.DB(m.selectedDB).FlushDB()
}
// FlushDB removes all keys.
func (db *RedisDB) FlushDB() {
db.master.Lock()
defer db.master.Unlock()
db.flush()
}
// Get returns string keys added with SET.
func (m *Miniredis) Get(k string) (string, error) {
return m.DB(m.selectedDB).Get(k)
}
// Get returns a string key.
func (db *RedisDB) Get(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != "string" {
return "", ErrWrongType
}
return db.stringGet(k), nil
}
// Set sets a string key. Removes expire.
func (m *Miniredis) Set(k, v string) error {
return m.DB(m.selectedDB).Set(k, v)
}
// Set sets a string key. Removes expire.
// Unlike redis the key can't be an existing non-string key.
func (db *RedisDB) Set(k, v string) error {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "string" {
return ErrWrongType
}
db.del(k, true) // Remove expire
db.stringSet(k, v)
return nil
}
// Incr changes a int string value by delta.
func (m *Miniredis) Incr(k string, delta int) (int, error) {
return m.DB(m.selectedDB).Incr(k, delta)
}
// Incr changes a int string value by delta.
func (db *RedisDB) Incr(k string, delta int) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "string" {
return 0, ErrWrongType
}
return db.stringIncr(k, delta)
}
// Incrfloat changes a float string value by delta.
func (m *Miniredis) Incrfloat(k string, delta float64) (float64, error) {
return m.DB(m.selectedDB).Incrfloat(k, delta)
}
// Incrfloat changes a float string value by delta.
func (db *RedisDB) Incrfloat(k string, delta float64) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "string" {
return 0, ErrWrongType
}
return db.stringIncrfloat(k, delta)
}
// List returns the list k, or an error if it's not there or something else.
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
// range-ing.
func (m *Miniredis) List(k string) ([]string, error) {
return m.DB(m.selectedDB).List(k)
}
// List returns the list k, or an error if it's not there or something else.
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
// range-ing.
func (db *RedisDB) List(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != "list" {
return nil, ErrWrongType
}
return db.listKeys[k], nil
}
// Lpush is an unshift. Returns the new length.
func (m *Miniredis) Lpush(k, v string) (int, error) {
return m.DB(m.selectedDB).Lpush(k, v)
}
// Lpush is an unshift. Returns the new length.
func (db *RedisDB) Lpush(k, v string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "list" {
return 0, ErrWrongType
}
return db.listLpush(k, v), nil
}
// Lpop is a shift. Returns the popped element.
func (m *Miniredis) Lpop(k string) (string, error) {
return m.DB(m.selectedDB).Lpop(k)
}
// Lpop is a shift. Returns the popped element.
func (db *RedisDB) Lpop(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != "list" {
return "", ErrWrongType
}
return db.listLpop(k), nil
}
// Push add element at the end. Is called RPUSH in redis. Returns the new length.
func (m *Miniredis) Push(k string, v ...string) (int, error) {
return m.DB(m.selectedDB).Push(k, v...)
}
// Push add element at the end. Is called RPUSH in redis. Returns the new length.
func (db *RedisDB) Push(k string, v ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "list" {
return 0, ErrWrongType
}
return db.listPush(k, v...), nil
}
// Pop removes and returns the last element. Is called RPOP in Redis.
func (m *Miniredis) Pop(k string) (string, error) {
return m.DB(m.selectedDB).Pop(k)
}
// Pop removes and returns the last element. Is called RPOP in Redis.
func (db *RedisDB) Pop(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != "list" {
return "", ErrWrongType
}
return db.listPop(k), nil
}
// SetAdd adds keys to a set. Returns the number of new keys.
func (m *Miniredis) SetAdd(k string, elems ...string) (int, error) {
return m.DB(m.selectedDB).SetAdd(k, elems...)
}
// SetAdd adds keys to a set. Returns the number of new keys.
func (db *RedisDB) SetAdd(k string, elems ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "set" {
return 0, ErrWrongType
}
return db.setAdd(k, elems...), nil
}
// Members gives all set keys. Sorted.
func (m *Miniredis) Members(k string) ([]string, error) {
return m.DB(m.selectedDB).Members(k)
}
// Members gives all set keys. Sorted.
func (db *RedisDB) Members(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != "set" {
return nil, ErrWrongType
}
return db.setMembers(k), nil
}
// IsMember tells if value is in the set.
func (m *Miniredis) IsMember(k, v string) (bool, error) {
return m.DB(m.selectedDB).IsMember(k, v)
}
// IsMember tells if value is in the set.
func (db *RedisDB) IsMember(k, v string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return false, ErrKeyNotFound
}
if db.t(k) != "set" {
return false, ErrWrongType
}
return db.setIsMember(k, v), nil
}
// HKeys returns all (sorted) keys ('fields') for a hash key.
func (m *Miniredis) HKeys(k string) ([]string, error) {
return m.DB(m.selectedDB).HKeys(k)
}
// HKeys returns all (sorted) keys ('fields') for a hash key.
func (db *RedisDB) HKeys(key string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(key) {
return nil, ErrKeyNotFound
}
if db.t(key) != "hash" {
return nil, ErrWrongType
}
return db.hashFields(key), nil
}
// Del deletes a key and any expiration value. Returns whether there was a key.
func (m *Miniredis) Del(k string) bool {
return m.DB(m.selectedDB).Del(k)
}
// Del deletes a key and any expiration value. Returns whether there was a key.
func (db *RedisDB) Del(k string) bool {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return false
}
db.del(k, true)
return true
}
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
// PEXPIREAT.
// 0 if not set.
func (m *Miniredis) TTL(k string) time.Duration {
return m.DB(m.selectedDB).TTL(k)
}
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
// PEXPIREAT.
// 0 if not set.
func (db *RedisDB) TTL(k string) time.Duration {
db.master.Lock()
defer db.master.Unlock()
return db.ttl[k]
}
// SetTTL sets the TTL of a key.
func (m *Miniredis) SetTTL(k string, ttl time.Duration) {
m.DB(m.selectedDB).SetTTL(k, ttl)
}
// SetTTL sets the time to live of a key.
func (db *RedisDB) SetTTL(k string, ttl time.Duration) {
db.master.Lock()
defer db.master.Unlock()
db.ttl[k] = ttl
db.keyVersion[k]++
}
// Type gives the type of a key, or ""
func (m *Miniredis) Type(k string) string {
return m.DB(m.selectedDB).Type(k)
}
// Type gives the type of a key, or ""
func (db *RedisDB) Type(k string) string {
db.master.Lock()
defer db.master.Unlock()
return db.t(k)
}
// Exists tells whether a key exists.
func (m *Miniredis) Exists(k string) bool {
return m.DB(m.selectedDB).Exists(k)
}
// Exists tells whether a key exists.
func (db *RedisDB) Exists(k string) bool {
db.master.Lock()
defer db.master.Unlock()
return db.exists(k)
}
// HGet returns hash keys added with HSET.
// This will return an empty string if the key is not set. Redis would return
// a nil.
// Returns empty string when the key is of a different type.
func (m *Miniredis) HGet(k, f string) string {
return m.DB(m.selectedDB).HGet(k, f)
}
// HGet returns hash keys added with HSET.
// Returns empty string when the key is of a different type.
func (db *RedisDB) HGet(k, f string) string {
db.master.Lock()
defer db.master.Unlock()
h, ok := db.hashKeys[k]
if !ok {
return ""
}
return h[f]
}
// HSet sets a hash key.
// If there is another key by the same name it will be gone.
func (m *Miniredis) HSet(k, f, v string) {
m.DB(m.selectedDB).HSet(k, f, v)
}
// HSet sets a hash key.
// If there is another key by the same name it will be gone.
func (db *RedisDB) HSet(k, f, v string) {
db.master.Lock()
defer db.master.Unlock()
db.hashSet(k, f, v)
}
// HDel deletes a hash key.
func (m *Miniredis) HDel(k, f string) {
m.DB(m.selectedDB).HDel(k, f)
}
// HDel deletes a hash key.
func (db *RedisDB) HDel(k, f string) {
db.master.Lock()
defer db.master.Unlock()
db.hdel(k, f)
}
func (db *RedisDB) hdel(k, f string) {
if _, ok := db.hashKeys[k]; !ok {
return
}
delete(db.hashKeys[k], f)
db.keyVersion[k]++
}
// HIncr increases a key/field by delta (int).
func (m *Miniredis) HIncr(k, f string, delta int) (int, error) {
return m.DB(m.selectedDB).HIncr(k, f, delta)
}
// HIncr increases a key/field by delta (int).
func (db *RedisDB) HIncr(k, f string, delta int) (int, error) {
db.master.Lock()
defer db.master.Unlock()
return db.hashIncr(k, f, delta)
}
// HIncrfloat increases a key/field by delta (float).
func (m *Miniredis) HIncrfloat(k, f string, delta float64) (float64, error) {
return m.DB(m.selectedDB).HIncrfloat(k, f, delta)
}
// HIncrfloat increases a key/field by delta (float).
func (db *RedisDB) HIncrfloat(k, f string, delta float64) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
return db.hashIncrfloat(k, f, delta)
}
// SRem removes fields from a set. Returns number of deleted fields.
func (m *Miniredis) SRem(k string, fields ...string) (int, error) {
return m.DB(m.selectedDB).SRem(k, fields...)
}
// SRem removes fields from a set. Returns number of deleted fields.
func (db *RedisDB) SRem(k string, fields ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return 0, ErrKeyNotFound
}
if db.t(k) != "set" {
return 0, ErrWrongType
}
return db.setRem(k, fields...), nil
}
// ZAdd adds a score,member to a sorted set.
func (m *Miniredis) ZAdd(k string, score float64, member string) (bool, error) {
return m.DB(m.selectedDB).ZAdd(k, score, member)
}
// ZAdd adds a score,member to a sorted set.
func (db *RedisDB) ZAdd(k string, score float64, member string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != "zset" {
return false, ErrWrongType
}
return db.ssetAdd(k, score, member), nil
}
// ZMembers returns all members by score
func (m *Miniredis) ZMembers(k string) ([]string, error) {
return m.DB(m.selectedDB).ZMembers(k)
}
// ZMembers returns all members by score
func (db *RedisDB) ZMembers(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != "zset" {
return nil, ErrWrongType
}
return db.ssetMembers(k), nil
}
// SortedSet returns a raw string->float64 map.
func (m *Miniredis) SortedSet(k string) (map[string]float64, error) {
return m.DB(m.selectedDB).SortedSet(k)
}
// SortedSet returns a raw string->float64 map.
func (db *RedisDB) SortedSet(k string) (map[string]float64, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != "zset" {
return nil, ErrWrongType
}
return db.sortedSet(k), nil
}
// ZRem deletes a member. Returns whether the was a key.
func (m *Miniredis) ZRem(k, member string) (bool, error) {
return m.DB(m.selectedDB).ZRem(k, member)
}
// ZRem deletes a member. Returns whether the was a key.
func (db *RedisDB) ZRem(k, member string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return false, ErrKeyNotFound
}
if db.t(k) != "zset" {
return false, ErrWrongType
}
return db.ssetRem(k, member), nil
}
// ZScore gives the score of a sorted set member.
func (m *Miniredis) ZScore(k, member string) (float64, error) {
return m.DB(m.selectedDB).ZScore(k, member)
}
// ZScore gives the score of a sorted set member.
func (db *RedisDB) ZScore(k, member string) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return 0, ErrKeyNotFound
}
if db.t(k) != "zset" {
return 0, ErrWrongType
}
return db.ssetScore(k, member), nil
}

65
vendor/github.com/alicebob/miniredis/keys.go generated vendored Normal file
View file

@ -0,0 +1,65 @@
package miniredis
// Translate the 'KEYS' argument ('foo*', 'f??', &c.) into a regexp.
import (
"bytes"
"regexp"
)
// patternRE compiles a KEYS argument to a regexp. Returns nil if the given
// pattern will never match anything.
// The general strategy is to sandwich all non-meta characters between \Q...\E.
func patternRE(k string) *regexp.Regexp {
re := bytes.Buffer{}
re.WriteString(`^\Q`)
for i := 0; i < len(k); i++ {
p := k[i]
switch p {
case '*':
re.WriteString(`\E.*\Q`)
case '?':
re.WriteString(`\E.\Q`)
case '[':
charClass := bytes.Buffer{}
i++
for ; i < len(k); i++ {
if k[i] == ']' {
break
}
if k[i] == '\\' {
if i == len(k)-1 {
// Ends with a '\'. U-huh.
return nil
}
charClass.WriteByte(k[i])
i++
charClass.WriteByte(k[i])
continue
}
charClass.WriteByte(k[i])
}
if charClass.Len() == 0 {
// '[]' is valid in Redis, but matches nothing.
return nil
}
re.WriteString(`\E[`)
re.Write(charClass.Bytes())
re.WriteString(`]\Q`)
case '\\':
if i == len(k)-1 {
// Ends with a '\'. U-huh.
return nil
}
// Forget the \, keep the next char.
i++
re.WriteByte(k[i])
continue
default:
re.WriteByte(p)
}
}
re.WriteString(`\E$`)
return regexp.MustCompile(re.String())
}

189
vendor/github.com/alicebob/miniredis/lua.go generated vendored Normal file
View file

@ -0,0 +1,189 @@
package miniredis
import (
redigo "github.com/gomodule/redigo/redis"
"github.com/yuin/gopher-lua"
"github.com/alicebob/miniredis/server"
)
func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
mkCall := func(failFast bool) func(l *lua.LState) int {
return func(l *lua.LState) int {
top := l.GetTop()
if top == 0 {
l.Error(lua.LString("Please specify at least one argument for redis.call()"), 1)
return 0
}
var args []interface{}
for i := 1; i <= top; i++ {
switch a := l.Get(i).(type) {
// case lua.LBool:
// args[i-2] = a
case lua.LNumber:
// value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64)
args = append(args, float64(a))
case lua.LString:
args = append(args, string(a))
default:
l.Error(lua.LString("Lua redis() command arguments must be strings or integers"), 1)
return 0
}
}
cmd, ok := args[0].(string)
if !ok {
l.Error(lua.LString("Unknown Redis command called from Lua script"), 1)
return 0
}
res, err := conn.Do(cmd, args[1:]...)
if err != nil {
if failFast {
// call() mode
l.Error(lua.LString(err.Error()), 1)
return 0
}
// pcall() mode
l.Push(lua.LNil)
return 1
}
if res == nil {
l.Push(lua.LFalse)
} else {
switch r := res.(type) {
case int64:
l.Push(lua.LNumber(r))
case []uint8:
l.Push(lua.LString(string(r)))
case []interface{}:
l.Push(redisToLua(l, r))
case string:
l.Push(lua.LString(r))
default:
panic("type not handled")
}
}
return 1
}
}
return map[string]lua.LGFunction{
"call": mkCall(true),
"pcall": mkCall(false),
"error_reply": func(l *lua.LState) int {
msg := l.CheckString(1)
res := &lua.LTable{}
res.RawSetString("err", lua.LString(msg))
l.Push(res)
return 1
},
"status_reply": func(l *lua.LState) int {
msg := l.CheckString(1)
res := &lua.LTable{}
res.RawSetString("ok", lua.LString(msg))
l.Push(res)
return 1
},
"sha1hex": func(l *lua.LState) int {
top := l.GetTop()
if top != 1 {
l.Error(lua.LString("wrong number of arguments"), 1)
return 0
}
msg := lua.LVAsString(l.Get(1))
l.Push(lua.LString(sha1Hex(msg)))
return 1
},
"replicate_commands": func(l *lua.LState) int {
// ignored
return 1
},
}
}
func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) {
if value == nil {
c.WriteNull()
return
}
switch t := value.(type) {
case *lua.LNilType:
c.WriteNull()
case lua.LBool:
if lua.LVAsBool(value) {
c.WriteInt(1)
} else {
c.WriteNull()
}
case lua.LNumber:
c.WriteInt(int(lua.LVAsNumber(value)))
case lua.LString:
s := lua.LVAsString(value)
if s == "OK" {
c.WriteInline(s)
} else {
c.WriteBulk(s)
}
case *lua.LTable:
// special case for tables with an 'err' or 'ok' field
// note: according to the docs this only counts when 'err' or 'ok' is
// the only field.
if s := t.RawGetString("err"); s.Type() != lua.LTNil {
c.WriteError(s.String())
return
}
if s := t.RawGetString("ok"); s.Type() != lua.LTNil {
c.WriteInline(s.String())
return
}
result := []lua.LValue{}
for j := 1; true; j++ {
val := l.GetTable(value, lua.LNumber(j))
if val == nil {
result = append(result, val)
continue
}
if val.Type() == lua.LTNil {
break
}
result = append(result, val)
}
c.WriteLen(len(result))
for _, r := range result {
luaToRedis(l, c, r)
}
default:
panic("....")
}
}
func redisToLua(l *lua.LState, res []interface{}) *lua.LTable {
rettb := l.NewTable()
for _, e := range res {
var v lua.LValue
if e == nil {
v = lua.LFalse
} else {
switch et := e.(type) {
case int64:
v = lua.LNumber(et)
case []uint8:
v = lua.LString(string(et))
case []interface{}:
v = redisToLua(l, et)
case string:
v = lua.LString(et)
default:
// TODO: oops?
v = lua.LString(e.(string))
}
}
l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v)
}
return rettb
}

373
vendor/github.com/alicebob/miniredis/miniredis.go generated vendored Normal file
View file

@ -0,0 +1,373 @@
// Package miniredis is a pure Go Redis test server, for use in Go unittests.
// There are no dependencies on system binaries, and every server you start
// will be empty.
//
// Start a server with `s, err := miniredis.Run()`.
// Stop it with `defer s.Close()`.
//
// Point your Redis client to `s.Addr()` or `s.Host(), s.Port()`.
//
// Set keys directly via s.Set(...) and similar commands, or use a Redis client.
//
// For direct use you can select a Redis database with either `s.Select(12);
// s.Get("foo")` or `s.DB(12).Get("foo")`.
//
package miniredis
import (
"fmt"
"net"
"strconv"
"sync"
"time"
redigo "github.com/gomodule/redigo/redis"
"github.com/alicebob/miniredis/server"
)
type hashKey map[string]string
type listKey []string
type setKey map[string]struct{}
// RedisDB holds a single (numbered) Redis database.
type RedisDB struct {
master *sync.Mutex // pointer to the lock in Miniredis
id int // db id
keys map[string]string // Master map of keys with their type
stringKeys map[string]string // GET/SET &c. keys
hashKeys map[string]hashKey // MGET/MSET &c. keys
listKeys map[string]listKey // LPUSH &c. keys
setKeys map[string]setKey // SADD &c. keys
sortedsetKeys map[string]sortedSet // ZADD &c. keys
ttl map[string]time.Duration // effective TTL values
keyVersion map[string]uint // used to watch values
}
// Miniredis is a Redis server implementation.
type Miniredis struct {
sync.Mutex
srv *server.Server
port int
password string
dbs map[int]*RedisDB
selectedDB int // DB id used in the direct Get(), Set() &c.
scripts map[string]string // sha1 -> lua src
signal *sync.Cond
now time.Time // used to make a duration from EXPIREAT. time.Now() if not set.
}
type txCmd func(*server.Peer, *connCtx)
// database id + key combo
type dbKey struct {
db int
key string
}
// connCtx has all state for a single connection.
type connCtx struct {
selectedDB int // selected DB
authenticated bool // auth enabled and a valid AUTH seen
transaction []txCmd // transaction callbacks. Or nil.
dirtyTransaction bool // any error during QUEUEing.
watch map[dbKey]uint // WATCHed keys.
}
// NewMiniRedis makes a new, non-started, Miniredis object.
func NewMiniRedis() *Miniredis {
m := Miniredis{
dbs: map[int]*RedisDB{},
scripts: map[string]string{},
}
m.signal = sync.NewCond(&m)
return &m
}
func newRedisDB(id int, l *sync.Mutex) RedisDB {
return RedisDB{
id: id,
master: l,
keys: map[string]string{},
stringKeys: map[string]string{},
hashKeys: map[string]hashKey{},
listKeys: map[string]listKey{},
setKeys: map[string]setKey{},
sortedsetKeys: map[string]sortedSet{},
ttl: map[string]time.Duration{},
keyVersion: map[string]uint{},
}
}
// Run creates and Start()s a Miniredis.
func Run() (*Miniredis, error) {
m := NewMiniRedis()
return m, m.Start()
}
// Start starts a server. It listens on a random port on localhost. See also
// Addr().
func (m *Miniredis) Start() error {
s, err := server.NewServer(fmt.Sprintf("127.0.0.1:%d", m.port))
if err != nil {
return err
}
return m.start(s)
}
// StartAddr runs miniredis with a given addr. Examples: "127.0.0.1:6379",
// ":6379", or "127.0.0.1:0"
func (m *Miniredis) StartAddr(addr string) error {
s, err := server.NewServer(addr)
if err != nil {
return err
}
return m.start(s)
}
func (m *Miniredis) start(s *server.Server) error {
m.Lock()
defer m.Unlock()
m.srv = s
m.port = s.Addr().Port
commandsConnection(m)
commandsGeneric(m)
commandsServer(m)
commandsString(m)
commandsHash(m)
commandsList(m)
commandsSet(m)
commandsSortedSet(m)
commandsTransaction(m)
commandsScripting(m)
return nil
}
// Restart restarts a Close()d server on the same port. Values will be
// preserved.
func (m *Miniredis) Restart() error {
return m.Start()
}
// Close shuts down a Miniredis.
func (m *Miniredis) Close() {
m.Lock()
defer m.Unlock()
if m.srv == nil {
return
}
m.srv.Close()
m.srv = nil
}
// RequireAuth makes every connection need to AUTH first. Disable again by
// setting an empty string.
func (m *Miniredis) RequireAuth(pw string) {
m.Lock()
defer m.Unlock()
m.password = pw
}
// DB returns a DB by ID.
func (m *Miniredis) DB(i int) *RedisDB {
m.Lock()
defer m.Unlock()
return m.db(i)
}
// get DB. No locks!
func (m *Miniredis) db(i int) *RedisDB {
if db, ok := m.dbs[i]; ok {
return db
}
db := newRedisDB(i, &m.Mutex) // the DB has our lock.
m.dbs[i] = &db
return &db
}
// Addr returns '127.0.0.1:12345'. Can be given to a Dial(). See also Host()
// and Port(), which return the same things.
func (m *Miniredis) Addr() string {
m.Lock()
defer m.Unlock()
return m.srv.Addr().String()
}
// Host returns the host part of Addr().
func (m *Miniredis) Host() string {
m.Lock()
defer m.Unlock()
return m.srv.Addr().IP.String()
}
// Port returns the (random) port part of Addr().
func (m *Miniredis) Port() string {
m.Lock()
defer m.Unlock()
return strconv.Itoa(m.srv.Addr().Port)
}
// CommandCount returns the number of processed commands.
func (m *Miniredis) CommandCount() int {
m.Lock()
defer m.Unlock()
return int(m.srv.TotalCommands())
}
// CurrentConnectionCount returns the number of currently connected clients.
func (m *Miniredis) CurrentConnectionCount() int {
m.Lock()
defer m.Unlock()
return m.srv.ClientsLen()
}
// TotalConnectionCount returns the number of client connections since server start.
func (m *Miniredis) TotalConnectionCount() int {
m.Lock()
defer m.Unlock()
return int(m.srv.TotalConnections())
}
// FastForward decreases all TTLs by the given duration. All TTLs <= 0 will be
// expired.
func (m *Miniredis) FastForward(duration time.Duration) {
m.Lock()
defer m.Unlock()
for _, db := range m.dbs {
db.fastForward(duration)
}
}
// redigo returns a redigo.Conn, connected using net.Pipe
func (m *Miniredis) redigo() redigo.Conn {
c1, c2 := net.Pipe()
m.srv.ServeConn(c1)
c := redigo.NewConn(c2, 0, 0)
if m.password != "" {
if _, err := c.Do("AUTH", m.password); err != nil {
// ?
}
}
return c
}
// Dump returns a text version of the selected DB, usable for debugging.
func (m *Miniredis) Dump() string {
m.Lock()
defer m.Unlock()
var (
maxLen = 60
indent = " "
db = m.db(m.selectedDB)
r = ""
v = func(s string) string {
suffix := ""
if len(s) > maxLen {
suffix = fmt.Sprintf("...(%d)", len(s))
s = s[:maxLen-len(suffix)]
}
return fmt.Sprintf("%q%s", s, suffix)
}
)
for _, k := range db.allKeys() {
r += fmt.Sprintf("- %s\n", k)
t := db.t(k)
switch t {
case "string":
r += fmt.Sprintf("%s%s\n", indent, v(db.stringKeys[k]))
case "hash":
for _, hk := range db.hashFields(k) {
r += fmt.Sprintf("%s%s: %s\n", indent, hk, v(db.hashGet(k, hk)))
}
case "list":
for _, lk := range db.listKeys[k] {
r += fmt.Sprintf("%s%s\n", indent, v(lk))
}
case "set":
for _, mk := range db.setMembers(k) {
r += fmt.Sprintf("%s%s\n", indent, v(mk))
}
case "zset":
for _, el := range db.ssetElements(k) {
r += fmt.Sprintf("%s%f: %s\n", indent, el.score, v(el.member))
}
default:
r += fmt.Sprintf("%s(a %s, fixme!)\n", indent, t)
}
}
return r
}
// SetTime sets the time against which EXPIREAT values are compared. EXPIREAT
// will use time.Now() if this is not set.
func (m *Miniredis) SetTime(t time.Time) {
m.Lock()
defer m.Unlock()
m.now = t
}
// handleAuth returns false if connection has no access. It sends the reply.
func (m *Miniredis) handleAuth(c *server.Peer) bool {
m.Lock()
defer m.Unlock()
if m.password == "" {
return true
}
if !getCtx(c).authenticated {
c.WriteError("NOAUTH Authentication required.")
return false
}
return true
}
func getCtx(c *server.Peer) *connCtx {
if c.Ctx == nil {
c.Ctx = &connCtx{}
}
return c.Ctx.(*connCtx)
}
func startTx(ctx *connCtx) {
ctx.transaction = []txCmd{}
ctx.dirtyTransaction = false
}
func stopTx(ctx *connCtx) {
ctx.transaction = nil
unwatch(ctx)
}
func inTx(ctx *connCtx) bool {
return ctx.transaction != nil
}
func addTxCmd(ctx *connCtx, cb txCmd) {
ctx.transaction = append(ctx.transaction, cb)
}
func watch(db *RedisDB, ctx *connCtx, key string) {
if ctx.watch == nil {
ctx.watch = map[dbKey]uint{}
}
ctx.watch[dbKey{db: db.id, key: key}] = db.keyVersion[key] // Can be 0.
}
func unwatch(ctx *connCtx) {
ctx.watch = nil
}
// setDirty can be called even when not in an tx. Is an no-op then.
func setDirty(c *server.Peer) {
if c.Ctx == nil {
// No transaction. Not relevant.
return
}
getCtx(c).dirtyTransaction = true
}
func setAuthenticated(c *server.Peer) {
getCtx(c).authenticated = true
}

208
vendor/github.com/alicebob/miniredis/redis.go generated vendored Normal file
View file

@ -0,0 +1,208 @@
package miniredis
import (
"fmt"
"math"
"strings"
"sync"
"time"
"github.com/alicebob/miniredis/server"
)
const (
msgWrongType = "WRONGTYPE Operation against a key holding the wrong kind of value"
msgInvalidInt = "ERR value is not an integer or out of range"
msgInvalidFloat = "ERR value is not a valid float"
msgInvalidMinMax = "ERR min or max is not a float"
msgInvalidRangeItem = "ERR min or max not valid string range item"
msgInvalidTimeout = "ERR timeout is not an integer or out of range"
msgSyntaxError = "ERR syntax error"
msgKeyNotFound = "ERR no such key"
msgOutOfRange = "ERR index out of range"
msgInvalidCursor = "ERR invalid cursor"
msgXXandNX = "ERR XX and NX options at the same time are not compatible"
msgNegTimeout = "ERR timeout is negative"
msgInvalidSETime = "ERR invalid expire time in set"
msgInvalidSETEXTime = "ERR invalid expire time in setex"
msgInvalidPSETEXTime = "ERR invalid expire time in psetex"
msgInvalidKeysNumber = "ERR Number of keys can't be greater than number of args"
msgNegativeKeysNumber = "ERR Number of keys can't be negative"
msgFScriptUsage = "ERR Unknown subcommand or wrong number of arguments for '%s'. Try SCRIPT HELP."
msgSingleElementPair = "ERR INCR option supports a single increment-element pair"
msgNoScriptFound = "NOSCRIPT No matching script. Please use EVAL."
)
func errWrongNumber(cmd string) string {
return fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmd))
}
func errLuaParseError(err error) string {
return fmt.Sprintf("ERR Error compiling script (new function): %s", err.Error())
}
// withTx wraps the non-argument-checking part of command handling code in
// transaction logic.
func withTx(
m *Miniredis,
c *server.Peer,
cb txCmd,
) {
ctx := getCtx(c)
if inTx(ctx) {
addTxCmd(ctx, cb)
c.WriteInline("QUEUED")
return
}
m.Lock()
cb(c, ctx)
// done, wake up anyone who waits on anything.
m.signal.Broadcast()
m.Unlock()
}
// blockCmd is executed returns whether it is done
type blockCmd func(*server.Peer, *connCtx) bool
// blocking keeps trying a command until the callback returns true. Calls
// onTimeout after the timeout (or when we call this in a transaction).
func blocking(
m *Miniredis,
c *server.Peer,
timeout time.Duration,
cb blockCmd,
onTimeout func(*server.Peer),
) {
var (
ctx = getCtx(c)
dl *time.Timer
dlc <-chan time.Time
)
if inTx(ctx) {
addTxCmd(ctx, func(c *server.Peer, ctx *connCtx) {
if !cb(c, ctx) {
onTimeout(c)
}
})
c.WriteInline("QUEUED")
return
}
if timeout != 0 {
dl = time.NewTimer(timeout)
defer dl.Stop()
dlc = dl.C
}
m.Lock()
defer m.Unlock()
for {
done := cb(c, ctx)
if done {
return
}
// there is no cond.WaitTimeout(), so hence the the goroutine to wait
// for a timeout
var (
wg sync.WaitGroup
wakeup = make(chan struct{}, 1)
)
wg.Add(1)
go func() {
m.signal.Wait()
wakeup <- struct{}{}
wg.Done()
}()
select {
case <-wakeup:
case <-dlc:
onTimeout(c)
m.signal.Broadcast() // to kill the wakeup go routine
wg.Wait()
return
}
wg.Wait()
}
}
// formatFloat formats a float the way redis does (sort-of)
func formatFloat(v float64) string {
// Format with %f and strip trailing 0s. This is the most like Redis does
// it :(
// .12 is the magic number where most output is the same as Redis.
if math.IsInf(v, +1) {
return "inf"
}
if math.IsInf(v, -1) {
return "-inf"
}
sv := fmt.Sprintf("%.12f", v)
for strings.Contains(sv, ".") {
if sv[len(sv)-1] != '0' {
break
}
// Remove trailing 0s.
sv = sv[:len(sv)-1]
// Ends with a '.'.
if sv[len(sv)-1] == '.' {
sv = sv[:len(sv)-1]
break
}
}
return sv
}
// redisRange gives Go offsets for something l long with start/end in
// Redis semantics. Both start and end can be negative.
// Used for string range and list range things.
// The results can be used as: v[start:end]
// Note that GETRANGE (on a string key) never returns an empty string when end
// is a large negative number.
func redisRange(l, start, end int, stringSymantics bool) (int, int) {
if start < 0 {
start = l + start
if start < 0 {
start = 0
}
}
if start > l {
start = l
}
if end < 0 {
end = l + end
if end < 0 {
end = -1
if stringSymantics {
end = 0
}
}
}
end++ // end argument is inclusive in Redis.
if end > l {
end = l
}
if end < start {
return 0, 0
}
return start, end
}
// matchKeys filters only matching keys.
// Will return an empty list on invalid match expression.
func matchKeys(keys []string, match string) []string {
re := patternRE(match)
if re == nil {
// Special case, the given pattern won't match anything / is
// invalid.
return nil
}
res := []string{}
for _, k := range keys {
if !re.MatchString(k) {
continue
}
res = append(res, k)
}
return res
}

9
vendor/github.com/alicebob/miniredis/server/Makefile generated vendored Normal file
View file

@ -0,0 +1,9 @@
.PHONY: all build test
all: build test
build:
go build
test:
go test

84
vendor/github.com/alicebob/miniredis/server/proto.go generated vendored Normal file
View file

@ -0,0 +1,84 @@
package server
import (
"bufio"
"errors"
"strconv"
)
// ErrProtocol is the general error for unexpected input
var ErrProtocol = errors.New("invalid request")
// client always sends arrays with bulk strings
func readArray(rd *bufio.Reader) ([]string, error) {
line, err := rd.ReadString('\n')
if err != nil {
return nil, err
}
if len(line) < 3 {
return nil, ErrProtocol
}
switch line[0] {
default:
return nil, ErrProtocol
case '*':
l, err := strconv.Atoi(line[1 : len(line)-2])
if err != nil {
return nil, err
}
// l can be -1
var fields []string
for ; l > 0; l-- {
s, err := readString(rd)
if err != nil {
return nil, err
}
fields = append(fields, s)
}
return fields, nil
}
}
func readString(rd *bufio.Reader) (string, error) {
line, err := rd.ReadString('\n')
if err != nil {
return "", err
}
if len(line) < 3 {
return "", ErrProtocol
}
switch line[0] {
default:
return "", ErrProtocol
case '+', '-', ':':
// +: simple string
// -: errors
// :: integer
// Simple line based replies.
return string(line[1 : len(line)-2]), nil
case '$':
// bulk strings are: `$5\r\nhello\r\n`
length, err := strconv.Atoi(line[1 : len(line)-2])
if err != nil {
return "", err
}
if length < 0 {
// -1 is a nil response
return "", nil
}
var (
buf = make([]byte, length+2)
pos = 0
)
for pos < length+2 {
n, err := rd.Read(buf[pos:])
if err != nil {
return "", err
}
pos += n
}
return string(buf[:length]), nil
}
}

242
vendor/github.com/alicebob/miniredis/server/server.go generated vendored Normal file
View file

@ -0,0 +1,242 @@
package server
import (
"bufio"
"fmt"
"net"
"strings"
"sync"
"unicode"
)
func errUnknownCommand(cmd string, args []string) string {
s := fmt.Sprintf("ERR unknown command `%s`, with args beginning with: ", cmd)
if len(args) > 20 {
args = args[:20]
}
for _, a := range args {
s += fmt.Sprintf("`%s`, ", a)
}
return s
}
// Cmd is what Register expects
type Cmd func(c *Peer, cmd string, args []string)
// Server is a simple redis server
type Server struct {
l net.Listener
cmds map[string]Cmd
peers map[net.Conn]struct{}
mu sync.Mutex
wg sync.WaitGroup
infoConns int
infoCmds int
}
// NewServer makes a server listening on addr. Close with .Close().
func NewServer(addr string) (*Server, error) {
s := Server{
cmds: map[string]Cmd{},
peers: map[net.Conn]struct{}{},
}
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
s.l = l
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.serve(l)
}()
return &s, nil
}
func (s *Server) serve(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
return
}
s.ServeConn(conn)
}
}
// ServeConn handles a net.Conn. Nice with net.Pipe()
func (s *Server) ServeConn(conn net.Conn) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer conn.Close()
s.mu.Lock()
s.peers[conn] = struct{}{}
s.infoConns++
s.mu.Unlock()
s.servePeer(conn)
s.mu.Lock()
delete(s.peers, conn)
s.mu.Unlock()
}()
}
// Addr has the net.Addr struct
func (s *Server) Addr() *net.TCPAddr {
s.mu.Lock()
defer s.mu.Unlock()
if s.l == nil {
return nil
}
return s.l.Addr().(*net.TCPAddr)
}
// Close a server started with NewServer. It will wait until all clients are
// closed.
func (s *Server) Close() {
s.mu.Lock()
if s.l != nil {
s.l.Close()
}
s.l = nil
for c := range s.peers {
c.Close()
}
s.mu.Unlock()
s.wg.Wait()
}
// Register a command. It can't have been registered before. Safe to call on a
// running server.
func (s *Server) Register(cmd string, f Cmd) error {
s.mu.Lock()
defer s.mu.Unlock()
cmd = strings.ToUpper(cmd)
if _, ok := s.cmds[cmd]; ok {
return fmt.Errorf("command already registered: %s", cmd)
}
s.cmds[cmd] = f
return nil
}
func (s *Server) servePeer(c net.Conn) {
r := bufio.NewReader(c)
cl := &Peer{
w: bufio.NewWriter(c),
}
for {
args, err := readArray(r)
if err != nil {
return
}
s.dispatch(cl, args)
cl.w.Flush()
if cl.closed {
c.Close()
return
}
}
}
func (s *Server) dispatch(c *Peer, args []string) {
cmd, args := args[0], args[1:]
cmdUp := strings.ToUpper(cmd)
s.mu.Lock()
cb, ok := s.cmds[cmdUp]
s.mu.Unlock()
if !ok {
c.WriteError(errUnknownCommand(cmd, args))
return
}
s.mu.Lock()
s.infoCmds++
s.mu.Unlock()
cb(c, cmdUp, args)
}
// TotalCommands is total (known) commands since this the server started
func (s *Server) TotalCommands() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.infoCmds
}
// ClientsLen gives the number of connected clients right now
func (s *Server) ClientsLen() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.peers)
}
// TotalConnections give the number of clients connected since the server
// started, including the currently connected ones
func (s *Server) TotalConnections() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.infoConns
}
// Peer is a client connected to the server
type Peer struct {
w *bufio.Writer
closed bool
Ctx interface{} // anything goes, server won't touch this
}
// Flush the write buffer. Called automatically after every redis command
func (c *Peer) Flush() {
c.w.Flush()
}
// Close the client connection after the current command is done.
func (c *Peer) Close() {
c.closed = true
}
// WriteError writes a redis 'Error'
func (c *Peer) WriteError(e string) {
fmt.Fprintf(c.w, "-%s\r\n", toInline(e))
}
// WriteInline writes a redis inline string
func (c *Peer) WriteInline(s string) {
fmt.Fprintf(c.w, "+%s\r\n", toInline(s))
}
// WriteOK write the inline string `OK`
func (c *Peer) WriteOK() {
c.WriteInline("OK")
}
// WriteBulk writes a bulk string
func (c *Peer) WriteBulk(s string) {
fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s)
}
// WriteNull writes a redis Null element
func (c *Peer) WriteNull() {
fmt.Fprintf(c.w, "$-1\r\n")
}
// WriteLen starts an array with the given length
func (c *Peer) WriteLen(n int) {
fmt.Fprintf(c.w, "*%d\r\n", n)
}
// WriteInt writes an integer
func (c *Peer) WriteInt(i int) {
fmt.Fprintf(c.w, ":%d\r\n", i)
}
func toInline(s string) string {
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s)
}

97
vendor/github.com/alicebob/miniredis/sorted_set.go generated vendored Normal file
View file

@ -0,0 +1,97 @@
package miniredis
// The most KISS way to implement a sorted set. Luckily we don't care about
// performance that much.
import (
"sort"
)
type direction int
const (
asc direction = iota
desc
)
type sortedSet map[string]float64
type ssElem struct {
score float64
member string
}
type ssElems []ssElem
type byScore ssElems
func (sse byScore) Len() int { return len(sse) }
func (sse byScore) Swap(i, j int) { sse[i], sse[j] = sse[j], sse[i] }
func (sse byScore) Less(i, j int) bool {
if sse[i].score != sse[j].score {
return sse[i].score < sse[j].score
}
return sse[i].member < sse[j].member
}
func newSortedSet() sortedSet {
return sortedSet{}
}
func (ss *sortedSet) card() int {
return len(*ss)
}
func (ss *sortedSet) set(score float64, member string) {
(*ss)[member] = score
}
func (ss *sortedSet) get(member string) (float64, bool) {
v, ok := (*ss)[member]
return v, ok
}
// elems gives the list of ssElem, ready to sort.
func (ss *sortedSet) elems() ssElems {
elems := make(ssElems, 0, len(*ss))
for e, s := range *ss {
elems = append(elems, ssElem{s, e})
}
return elems
}
func (ss *sortedSet) byScore(d direction) ssElems {
elems := ss.elems()
sort.Sort(byScore(elems))
if d == desc {
reverseElems(elems)
}
return ssElems(elems)
}
// rankByScore gives the (0-based) index of member, or returns false.
func (ss *sortedSet) rankByScore(member string, d direction) (int, bool) {
if _, ok := (*ss)[member]; !ok {
return 0, false
}
for i, e := range ss.byScore(d) {
if e.member == member {
return i, true
}
}
// Can't happen
return 0, false
}
func reverseSlice(o []string) {
for i := range make([]struct{}, len(o)/2) {
other := len(o) - 1 - i
o[i], o[other] = o[other], o[i]
}
}
func reverseElems(o ssElems) {
for i := range make([]struct{}, len(o)/2) {
other := len(o) - 1 - i
o[i], o[other] = o[other], o[i]
}
}

1
vendor/github.com/apex/log/.gitignore generated vendored Normal file
View file

@ -0,0 +1 @@
.envrc

75
vendor/github.com/apex/log/History.md generated vendored Normal file
View file

@ -0,0 +1,75 @@
v1.9.0 / 2020-08-18
===================
* add `WithDuration()` method to record a duration as milliseconds
* add: ignore nil errors in `WithError()`
* change trace duration to milliseconds (arguably a breaking change)
v1.8.0 / 2020-08-05
===================
* refactor apexlogs handler to not make the AddEvents() call if there are no events to flush
v1.7.1 / 2020-08-05
===================
* fix potential nil panic in apexlogs handler
v1.7.0 / 2020-08-03
===================
* add FlushSync() to apexlogs handler
v1.6.0 / 2020-07-13
===================
* update apex/logs dep to v1.0.0
* docs: mention that Flush() is non-blocking now, use Close()
v1.5.0 / 2020-07-11
===================
* add buffering to Apex Logs handler
v1.4.0 / 2020-06-16
===================
* add AuthToken to apexlogs handler
v1.3.0 / 2020-05-26
===================
* change FromContext() to always return a logger
v1.2.0 / 2020-05-26
===================
* add log.NewContext() and log.FromContext(). Closes #78
v1.1.4 / 2020-04-22
===================
* add apexlogs HTTPClient support
v1.1.3 / 2020-04-22
===================
* add events len check before flushing to apexlogs handler
v1.1.2 / 2020-01-29
===================
* refactor apexlogs handler to use github.com/apex/logs client
v1.1.1 / 2019-06-24
===================
* add go.mod
* add rough pass at apexlogs handler
v1.1.0 / 2018-10-11
===================
* fix: cli handler to show non-string fields appropriately
* fix: cli using fatih/color to better support windows

22
vendor/github.com/apex/log/LICENSE generated vendored Normal file
View file

@ -0,0 +1,22 @@
(The MIT License)
Copyright (c) 2015 TJ Holowaychuk tj@tjholowaychuk.com
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

2
vendor/github.com/apex/log/Makefile generated vendored Normal file
View file

@ -0,0 +1,2 @@
include github.com/tj/make/golang

61
vendor/github.com/apex/log/Readme.md generated vendored Normal file
View file

@ -0,0 +1,61 @@
![Structured logging for golang](assets/title.png)
Package log implements a simple structured logging API inspired by Logrus, designed with centralization in mind. Read more on [Medium](https://medium.com/@tjholowaychuk/apex-log-e8d9627f4a9a#.rav8yhkud).
## Handlers
- __apexlogs__ handler for [Apex Logs](https://apex.sh/logs/)
- __cli__ human-friendly CLI output
- __discard__ discards all logs
- __es__  Elasticsearch handler
- __graylog__ Graylog handler
- __json__  JSON output handler
- __kinesis__ AWS Kinesis handler
- __level__  level filter handler
- __logfmt__  logfmt plain-text formatter
- __memory__ in-memory handler for tests
- __multi__ fan-out to multiple handlers
- __papertrail__ Papertrail handler
- __text__  human-friendly colored output
- __delta__  outputs the delta between log calls and spinner
## Example
Example using the [Apex Logs](https://apex.sh/logs/) handler.
```go
package main
import (
"errors"
"time"
"github.com/apex/log"
)
func main() {
ctx := log.WithFields(log.Fields{
"file": "something.png",
"type": "image/png",
"user": "tobi",
})
for range time.Tick(time.Millisecond * 200) {
ctx.Info("upload")
ctx.Info("upload complete")
ctx.Warn("upload retry")
ctx.WithError(errors.New("unauthorized")).Error("upload failed")
ctx.Errorf("failed to upload %s", "img.png")
}
}
```
---
[![Build Status](https://semaphoreci.com/api/v1/projects/d8a8b1c0-45b0-4b89-b066-99d788d0b94c/642077/badge.svg)](https://semaphoreci.com/tj/log)
[![GoDoc](https://godoc.org/github.com/apex/log?status.svg)](https://godoc.org/github.com/apex/log)
![](https://img.shields.io/badge/license-MIT-blue.svg)
![](https://img.shields.io/badge/status-stable-green.svg)
<a href="https://apex.sh"><img src="http://tjholowaychuk.com:6000/svg/sponsor"></a>

19
vendor/github.com/apex/log/context.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
package log
import "context"
// logKey is a private context key.
type logKey struct{}
// NewContext returns a new context with logger.
func NewContext(ctx context.Context, v Interface) context.Context {
return context.WithValue(ctx, logKey{}, v)
}
// FromContext returns the logger from context, or log.Log.
func FromContext(ctx context.Context) Interface {
if v, ok := ctx.Value(logKey{}).(Interface); ok {
return v
}
return Log
}

45
vendor/github.com/apex/log/default.go generated vendored Normal file
View file

@ -0,0 +1,45 @@
package log
import (
"bytes"
"fmt"
"log"
"sort"
)
// field used for sorting.
type field struct {
Name string
Value interface{}
}
// by sorts fields by name.
type byName []field
func (a byName) Len() int { return len(a) }
func (a byName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byName) Less(i, j int) bool { return a[i].Name < a[j].Name }
// handleStdLog outpouts to the stlib log.
func handleStdLog(e *Entry) error {
level := levelNames[e.Level]
var fields []field
for k, v := range e.Fields {
fields = append(fields, field{k, v})
}
sort.Sort(byName(fields))
var b bytes.Buffer
fmt.Fprintf(&b, "%5s %-25s", level, e.Message)
for _, f := range fields {
fmt.Fprintf(&b, " %s=%v", f.Name, f.Value)
}
log.Println(b.String())
return nil
}

10
vendor/github.com/apex/log/doc.go generated vendored Normal file
View file

@ -0,0 +1,10 @@
/*
Package log implements a simple structured logging API designed with few assumptions. Designed for
centralized logging solutions such as Kinesis which require encoding and decoding before fanning-out
to handlers.
You may use this package with inline handlers, much like Logrus, however a centralized solution
is recommended so that apps do not need to be re-deployed to add or remove logging service
providers.
*/
package log

182
vendor/github.com/apex/log/entry.go generated vendored Normal file
View file

@ -0,0 +1,182 @@
package log
import (
"fmt"
"os"
"strings"
"time"
)
// assert interface compliance.
var _ Interface = (*Entry)(nil)
// Now returns the current time.
var Now = time.Now
// Entry represents a single log entry.
type Entry struct {
Logger *Logger `json:"-"`
Fields Fields `json:"fields"`
Level Level `json:"level"`
Timestamp time.Time `json:"timestamp"`
Message string `json:"message"`
start time.Time
fields []Fields
}
// NewEntry returns a new entry for `log`.
func NewEntry(log *Logger) *Entry {
return &Entry{
Logger: log,
}
}
// WithFields returns a new entry with `fields` set.
func (e *Entry) WithFields(fields Fielder) *Entry {
f := []Fields{}
f = append(f, e.fields...)
f = append(f, fields.Fields())
return &Entry{
Logger: e.Logger,
fields: f,
}
}
// WithField returns a new entry with the `key` and `value` set.
func (e *Entry) WithField(key string, value interface{}) *Entry {
return e.WithFields(Fields{key: value})
}
// WithDuration returns a new entry with the "duration" field set
// to the given duration in milliseconds.
func (e *Entry) WithDuration(d time.Duration) *Entry {
return e.WithField("duration", d.Milliseconds())
}
// WithError returns a new entry with the "error" set to `err`.
//
// The given error may implement .Fielder, if it does the method
// will add all its `.Fields()` into the returned entry.
func (e *Entry) WithError(err error) *Entry {
if err == nil {
return e
}
ctx := e.WithField("error", err.Error())
if s, ok := err.(stackTracer); ok {
frame := s.StackTrace()[0]
name := fmt.Sprintf("%n", frame)
file := fmt.Sprintf("%+s", frame)
line := fmt.Sprintf("%d", frame)
parts := strings.Split(file, "\n\t")
if len(parts) > 1 {
file = parts[1]
}
ctx = ctx.WithField("source", fmt.Sprintf("%s: %s:%s", name, file, line))
}
if f, ok := err.(Fielder); ok {
ctx = ctx.WithFields(f.Fields())
}
return ctx
}
// Debug level message.
func (e *Entry) Debug(msg string) {
e.Logger.log(DebugLevel, e, msg)
}
// Info level message.
func (e *Entry) Info(msg string) {
e.Logger.log(InfoLevel, e, msg)
}
// Warn level message.
func (e *Entry) Warn(msg string) {
e.Logger.log(WarnLevel, e, msg)
}
// Error level message.
func (e *Entry) Error(msg string) {
e.Logger.log(ErrorLevel, e, msg)
}
// Fatal level message, followed by an exit.
func (e *Entry) Fatal(msg string) {
e.Logger.log(FatalLevel, e, msg)
os.Exit(1)
}
// Debugf level formatted message.
func (e *Entry) Debugf(msg string, v ...interface{}) {
e.Debug(fmt.Sprintf(msg, v...))
}
// Infof level formatted message.
func (e *Entry) Infof(msg string, v ...interface{}) {
e.Info(fmt.Sprintf(msg, v...))
}
// Warnf level formatted message.
func (e *Entry) Warnf(msg string, v ...interface{}) {
e.Warn(fmt.Sprintf(msg, v...))
}
// Errorf level formatted message.
func (e *Entry) Errorf(msg string, v ...interface{}) {
e.Error(fmt.Sprintf(msg, v...))
}
// Fatalf level formatted message, followed by an exit.
func (e *Entry) Fatalf(msg string, v ...interface{}) {
e.Fatal(fmt.Sprintf(msg, v...))
}
// Trace returns a new entry with a Stop method to fire off
// a corresponding completion log, useful with defer.
func (e *Entry) Trace(msg string) *Entry {
e.Info(msg)
v := e.WithFields(e.Fields)
v.Message = msg
v.start = time.Now()
return v
}
// Stop should be used with Trace, to fire off the completion message. When
// an `err` is passed the "error" field is set, and the log level is error.
func (e *Entry) Stop(err *error) {
if err == nil || *err == nil {
e.WithDuration(time.Since(e.start)).Info(e.Message)
} else {
e.WithDuration(time.Since(e.start)).WithError(*err).Error(e.Message)
}
}
// mergedFields returns the fields list collapsed into a single map.
func (e *Entry) mergedFields() Fields {
f := Fields{}
for _, fields := range e.fields {
for k, v := range fields {
f[k] = v
}
}
return f
}
// finalize returns a copy of the Entry with Fields merged.
func (e *Entry) finalize(level Level, msg string) *Entry {
return &Entry{
Logger: e.Logger,
Fields: e.mergedFields(),
Level: level,
Message: msg,
Timestamp: Now(),
}
}

80
vendor/github.com/apex/log/handlers/text/text.go generated vendored Normal file
View file

@ -0,0 +1,80 @@
// Package text implements a development-friendly textual handler.
package text
import (
"fmt"
"io"
"os"
"sync"
"time"
"github.com/apex/log"
)
// Default handler outputting to stderr.
var Default = New(os.Stderr)
// start time.
var start = time.Now()
// colors.
const (
none = 0
red = 31
green = 32
yellow = 33
blue = 34
gray = 37
)
// Colors mapping.
var Colors = [...]int{
log.DebugLevel: gray,
log.InfoLevel: blue,
log.WarnLevel: yellow,
log.ErrorLevel: red,
log.FatalLevel: red,
}
// Strings mapping.
var Strings = [...]string{
log.DebugLevel: "DEBUG",
log.InfoLevel: "INFO",
log.WarnLevel: "WARN",
log.ErrorLevel: "ERROR",
log.FatalLevel: "FATAL",
}
// Handler implementation.
type Handler struct {
mu sync.Mutex
Writer io.Writer
}
// New handler.
func New(w io.Writer) *Handler {
return &Handler{
Writer: w,
}
}
// HandleLog implements log.Handler.
func (h *Handler) HandleLog(e *log.Entry) error {
color := Colors[e.Level]
level := Strings[e.Level]
names := e.Fields.Names()
h.mu.Lock()
defer h.mu.Unlock()
ts := time.Since(start) / time.Second
fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%04d] %-25s", color, level, ts, e.Message)
for _, name := range names {
fmt.Fprintf(h.Writer, " \033[%dm%s\033[0m=%v", color, name, e.Fields.Get(name))
}
fmt.Fprintln(h.Writer)
return nil
}

22
vendor/github.com/apex/log/interface.go generated vendored Normal file
View file

@ -0,0 +1,22 @@
package log
import "time"
// Interface represents the API of both Logger and Entry.
type Interface interface {
WithFields(Fielder) *Entry
WithField(string, interface{}) *Entry
WithDuration(time.Duration) *Entry
WithError(error) *Entry
Debug(string)
Info(string)
Warn(string)
Error(string)
Fatal(string)
Debugf(string, ...interface{})
Infof(string, ...interface{})
Warnf(string, ...interface{})
Errorf(string, ...interface{})
Fatalf(string, ...interface{})
Trace(string) *Entry
}

81
vendor/github.com/apex/log/levels.go generated vendored Normal file
View file

@ -0,0 +1,81 @@
package log
import (
"bytes"
"errors"
"strings"
)
// ErrInvalidLevel is returned if the severity level is invalid.
var ErrInvalidLevel = errors.New("invalid level")
// Level of severity.
type Level int
// Log levels.
const (
InvalidLevel Level = iota - 1
DebugLevel
InfoLevel
WarnLevel
ErrorLevel
FatalLevel
)
var levelNames = [...]string{
DebugLevel: "debug",
InfoLevel: "info",
WarnLevel: "warn",
ErrorLevel: "error",
FatalLevel: "fatal",
}
var levelStrings = map[string]Level{
"debug": DebugLevel,
"info": InfoLevel,
"warn": WarnLevel,
"warning": WarnLevel,
"error": ErrorLevel,
"fatal": FatalLevel,
}
// String implementation.
func (l Level) String() string {
return levelNames[l]
}
// MarshalJSON implementation.
func (l Level) MarshalJSON() ([]byte, error) {
return []byte(`"` + l.String() + `"`), nil
}
// UnmarshalJSON implementation.
func (l *Level) UnmarshalJSON(b []byte) error {
v, err := ParseLevel(string(bytes.Trim(b, `"`)))
if err != nil {
return err
}
*l = v
return nil
}
// ParseLevel parses level string.
func ParseLevel(s string) (Level, error) {
l, ok := levelStrings[strings.ToLower(s)]
if !ok {
return InvalidLevel, ErrInvalidLevel
}
return l, nil
}
// MustParseLevel parses level string or panics.
func MustParseLevel(s string) Level {
l, err := ParseLevel(s)
if err != nil {
panic("invalid log level")
}
return l
}

156
vendor/github.com/apex/log/logger.go generated vendored Normal file
View file

@ -0,0 +1,156 @@
package log
import (
stdlog "log"
"sort"
"time"
)
// assert interface compliance.
var _ Interface = (*Logger)(nil)
// Fielder is an interface for providing fields to custom types.
type Fielder interface {
Fields() Fields
}
// Fields represents a map of entry level data used for structured logging.
type Fields map[string]interface{}
// Fields implements Fielder.
func (f Fields) Fields() Fields {
return f
}
// Get field value by name.
func (f Fields) Get(name string) interface{} {
return f[name]
}
// Names returns field names sorted.
func (f Fields) Names() (v []string) {
for k := range f {
v = append(v, k)
}
sort.Strings(v)
return
}
// The HandlerFunc type is an adapter to allow the use of ordinary functions as
// log handlers. If f is a function with the appropriate signature,
// HandlerFunc(f) is a Handler object that calls f.
type HandlerFunc func(*Entry) error
// HandleLog calls f(e).
func (f HandlerFunc) HandleLog(e *Entry) error {
return f(e)
}
// Handler is used to handle log events, outputting them to
// stdio or sending them to remote services. See the "handlers"
// directory for implementations.
//
// It is left up to Handlers to implement thread-safety.
type Handler interface {
HandleLog(*Entry) error
}
// Logger represents a logger with configurable Level and Handler.
type Logger struct {
Handler Handler
Level Level
}
// WithFields returns a new entry with `fields` set.
func (l *Logger) WithFields(fields Fielder) *Entry {
return NewEntry(l).WithFields(fields.Fields())
}
// WithField returns a new entry with the `key` and `value` set.
//
// Note that the `key` should not have spaces in it - use camel
// case or underscores
func (l *Logger) WithField(key string, value interface{}) *Entry {
return NewEntry(l).WithField(key, value)
}
// WithDuration returns a new entry with the "duration" field set
// to the given duration in milliseconds.
func (l *Logger) WithDuration(d time.Duration) *Entry {
return NewEntry(l).WithDuration(d)
}
// WithError returns a new entry with the "error" set to `err`.
func (l *Logger) WithError(err error) *Entry {
return NewEntry(l).WithError(err)
}
// Debug level message.
func (l *Logger) Debug(msg string) {
NewEntry(l).Debug(msg)
}
// Info level message.
func (l *Logger) Info(msg string) {
NewEntry(l).Info(msg)
}
// Warn level message.
func (l *Logger) Warn(msg string) {
NewEntry(l).Warn(msg)
}
// Error level message.
func (l *Logger) Error(msg string) {
NewEntry(l).Error(msg)
}
// Fatal level message, followed by an exit.
func (l *Logger) Fatal(msg string) {
NewEntry(l).Fatal(msg)
}
// Debugf level formatted message.
func (l *Logger) Debugf(msg string, v ...interface{}) {
NewEntry(l).Debugf(msg, v...)
}
// Infof level formatted message.
func (l *Logger) Infof(msg string, v ...interface{}) {
NewEntry(l).Infof(msg, v...)
}
// Warnf level formatted message.
func (l *Logger) Warnf(msg string, v ...interface{}) {
NewEntry(l).Warnf(msg, v...)
}
// Errorf level formatted message.
func (l *Logger) Errorf(msg string, v ...interface{}) {
NewEntry(l).Errorf(msg, v...)
}
// Fatalf level formatted message, followed by an exit.
func (l *Logger) Fatalf(msg string, v ...interface{}) {
NewEntry(l).Fatalf(msg, v...)
}
// Trace returns a new entry with a Stop method to fire off
// a corresponding completion log, useful with defer.
func (l *Logger) Trace(msg string) *Entry {
return NewEntry(l).Trace(msg)
}
// log the message, invoking the handler. We clone the entry here
// to bypass the overhead in Entry methods when the level is not
// met.
func (l *Logger) log(level Level, e *Entry, msg string) {
if level < l.Level {
return
}
if err := l.Handler.HandleLog(e.finalize(level, msg)); err != nil {
stdlog.Printf("error logging: %s", err)
}
}

108
vendor/github.com/apex/log/pkg.go generated vendored Normal file
View file

@ -0,0 +1,108 @@
package log
import "time"
// singletons ftw?
var Log Interface = &Logger{
Handler: HandlerFunc(handleStdLog),
Level: InfoLevel,
}
// SetHandler sets the handler. This is not thread-safe.
// The default handler outputs to the stdlib log.
func SetHandler(h Handler) {
if logger, ok := Log.(*Logger); ok {
logger.Handler = h
}
}
// SetLevel sets the log level. This is not thread-safe.
func SetLevel(l Level) {
if logger, ok := Log.(*Logger); ok {
logger.Level = l
}
}
// SetLevelFromString sets the log level from a string, panicing when invalid. This is not thread-safe.
func SetLevelFromString(s string) {
if logger, ok := Log.(*Logger); ok {
logger.Level = MustParseLevel(s)
}
}
// WithFields returns a new entry with `fields` set.
func WithFields(fields Fielder) *Entry {
return Log.WithFields(fields)
}
// WithField returns a new entry with the `key` and `value` set.
func WithField(key string, value interface{}) *Entry {
return Log.WithField(key, value)
}
// WithDuration returns a new entry with the "duration" field set
// to the given duration in milliseconds.
func WithDuration(d time.Duration) *Entry {
return Log.WithDuration(d)
}
// WithError returns a new entry with the "error" set to `err`.
func WithError(err error) *Entry {
return Log.WithError(err)
}
// Debug level message.
func Debug(msg string) {
Log.Debug(msg)
}
// Info level message.
func Info(msg string) {
Log.Info(msg)
}
// Warn level message.
func Warn(msg string) {
Log.Warn(msg)
}
// Error level message.
func Error(msg string) {
Log.Error(msg)
}
// Fatal level message, followed by an exit.
func Fatal(msg string) {
Log.Fatal(msg)
}
// Debugf level formatted message.
func Debugf(msg string, v ...interface{}) {
Log.Debugf(msg, v...)
}
// Infof level formatted message.
func Infof(msg string, v ...interface{}) {
Log.Infof(msg, v...)
}
// Warnf level formatted message.
func Warnf(msg string, v ...interface{}) {
Log.Warnf(msg, v...)
}
// Errorf level formatted message.
func Errorf(msg string, v ...interface{}) {
Log.Errorf(msg, v...)
}
// Fatalf level formatted message, followed by an exit.
func Fatalf(msg string, v ...interface{}) {
Log.Fatalf(msg, v...)
}
// Trace returns a new entry with a Stop method to fire off
// a corresponding completion log, useful with defer.
func Trace(msg string) *Entry {
return Log.Trace(msg)
}

8
vendor/github.com/apex/log/stack.go generated vendored Normal file
View file

@ -0,0 +1,8 @@
package log
import "github.com/pkg/errors"
// stackTracer interface.
type stackTracer interface {
StackTrace() errors.StackTrace
}

202
vendor/github.com/aws/aws-sdk-go-v2/LICENSE.txt generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

3
vendor/github.com/aws/aws-sdk-go-v2/NOTICE.txt generated vendored Normal file
View file

@ -0,0 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

208
vendor/github.com/aws/aws-sdk-go-v2/aws/config.go generated vendored Normal file
View file

@ -0,0 +1,208 @@
package aws
import (
"net/http"
smithybearer "github.com/aws/smithy-go/auth/bearer"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
)
// HTTPClient provides the interface to provide custom HTTPClients. Generally
// *http.Client is sufficient for most use cases. The HTTPClient should not
// follow 301 or 302 redirects.
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
// A Config provides service configuration for service clients.
type Config struct {
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for
// information on AWS regions.
Region string
// The credentials object to use when signing requests.
// Use the LoadDefaultConfig to load configuration from all the SDK's supported
// sources, and resolve credentials using the SDK's default credential chain.
Credentials CredentialsProvider
// The Bearer Authentication token provider to use for authenticating API
// operation calls with a Bearer Authentication token. The API clients and
// operation must support Bearer Authentication scheme in order for the
// token provider to be used. API clients created with NewFromConfig will
// automatically be configured with this option, if the API client support
// Bearer Authentication.
//
// The SDK's config.LoadDefaultConfig can automatically populate this
// option for external configuration options such as SSO session.
// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html
BearerAuthTokenProvider smithybearer.TokenProvider
// The HTTP Client the SDK's API clients will use to invoke HTTP requests.
// The SDK defaults to a BuildableClient allowing API clients to create
// copies of the HTTP Client for service specific customizations.
//
// Use a (*http.Client) for custom behavior. Using a custom http.Client
// will prevent the SDK from modifying the HTTP client.
HTTPClient HTTPClient
// An endpoint resolver that can be used to provide or override an endpoint
// for the given service and region.
//
// See the `aws.EndpointResolver` documentation for additional usage
// information.
//
// Deprecated: See Config.EndpointResolverWithOptions
EndpointResolver EndpointResolver
// An endpoint resolver that can be used to provide or override an endpoint
// for the given service and region.
//
// When EndpointResolverWithOptions is specified, it will be used by a
// service client rather than using EndpointResolver if also specified.
//
// See the `aws.EndpointResolverWithOptions` documentation for additional
// usage information.
//
// Deprecated: with the release of endpoint resolution v2 in API clients,
// EndpointResolver and EndpointResolverWithOptions are deprecated.
// Providing a value for this field will likely prevent you from using
// newer endpoint-related service features. See API client options
// EndpointResolverV2 and BaseEndpoint.
EndpointResolverWithOptions EndpointResolverWithOptions
// RetryMaxAttempts specifies the maximum number attempts an API client
// will call an operation that fails with a retryable error.
//
// API Clients will only use this value to construct a retryer if the
// Config.Retryer member is not nil. This value will be ignored if
// Retryer is not nil.
RetryMaxAttempts int
// RetryMode specifies the retry model the API client will be created with.
//
// API Clients will only use this value to construct a retryer if the
// Config.Retryer member is not nil. This value will be ignored if
// Retryer is not nil.
RetryMode RetryMode
// Retryer is a function that provides a Retryer implementation. A Retryer
// guides how HTTP requests should be retried in case of recoverable
// failures. When nil the API client will use a default retryer.
//
// In general, the provider function should return a new instance of a
// Retryer if you are attempting to provide a consistent Retryer
// configuration across all clients. This will ensure that each client will
// be provided a new instance of the Retryer implementation, and will avoid
// issues such as sharing the same retry token bucket across services.
//
// If not nil, RetryMaxAttempts, and RetryMode will be ignored by API
// clients.
Retryer func() Retryer
// ConfigSources are the sources that were used to construct the Config.
// Allows for additional configuration to be loaded by clients.
ConfigSources []interface{}
// APIOptions provides the set of middleware mutations modify how the API
// client requests will be handled. This is useful for adding additional
// tracing data to a request, or changing behavior of the SDK's client.
APIOptions []func(*middleware.Stack) error
// The logger writer interface to write logging messages to. Defaults to
// standard error.
Logger logging.Logger
// Configures the events that will be sent to the configured logger. This
// can be used to configure the logging of signing, retries, request, and
// responses of the SDK clients.
//
// See the ClientLogMode type documentation for the complete set of logging
// modes and available configuration.
ClientLogMode ClientLogMode
// The configured DefaultsMode. If not specified, service clients will
// default to legacy.
//
// Supported modes are: auto, cross-region, in-region, legacy, mobile,
// standard
DefaultsMode DefaultsMode
// The RuntimeEnvironment configuration, only populated if the DefaultsMode
// is set to DefaultsModeAuto and is initialized by
// `config.LoadDefaultConfig`. You should not populate this structure
// programmatically, or rely on the values here within your applications.
RuntimeEnvironment RuntimeEnvironment
// AppId is an optional application specific identifier that can be set.
// When set it will be appended to the User-Agent header of every request
// in the form of App/{AppId}. This variable is sourced from environment
// variable AWS_SDK_UA_APP_ID or the shared config profile attribute sdk_ua_app_id.
// See https://docs.aws.amazon.com/sdkref/latest/guide/settings-reference.html for
// more information on environment variables and shared config settings.
AppID string
// BaseEndpoint is an intermediary transfer location to a service specific
// BaseEndpoint on a service's Options.
BaseEndpoint *string
// DisableRequestCompression toggles if an operation request could be
// compressed or not. Will be set to false by default. This variable is sourced from
// environment variable AWS_DISABLE_REQUEST_COMPRESSION or the shared config profile attribute
// disable_request_compression
DisableRequestCompression bool
// RequestMinCompressSizeBytes sets the inclusive min bytes of a request body that could be
// compressed. Will be set to 10240 by default and must be within 0 and 10485760 bytes inclusively.
// This variable is sourced from environment variable AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES or
// the shared config profile attribute request_min_compression_size_bytes
RequestMinCompressSizeBytes int64
}
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
func NewConfig() *Config {
return &Config{}
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() Config {
cp := c
return cp
}
// EndpointDiscoveryEnableState indicates if endpoint discovery is
// enabled, disabled, auto or unset state.
//
// Default behavior (Auto or Unset) indicates operations that require endpoint
// discovery will use Endpoint Discovery by default. Operations that
// optionally use Endpoint Discovery will not use Endpoint Discovery
// unless EndpointDiscovery is explicitly enabled.
type EndpointDiscoveryEnableState uint
// Enumeration values for EndpointDiscoveryEnableState
const (
// EndpointDiscoveryUnset represents EndpointDiscoveryEnableState is unset.
// Users do not need to use this value explicitly. The behavior for unset
// is the same as for EndpointDiscoveryAuto.
EndpointDiscoveryUnset EndpointDiscoveryEnableState = iota
// EndpointDiscoveryAuto represents an AUTO state that allows endpoint
// discovery only when required by the api. This is the default
// configuration resolved by the client if endpoint discovery is neither
// enabled or disabled.
EndpointDiscoveryAuto // default state
// EndpointDiscoveryDisabled indicates client MUST not perform endpoint
// discovery even when required.
EndpointDiscoveryDisabled
// EndpointDiscoveryEnabled indicates client MUST always perform endpoint
// discovery if supported for the operation.
EndpointDiscoveryEnabled
)

22
vendor/github.com/aws/aws-sdk-go-v2/aws/context.go generated vendored Normal file
View file

@ -0,0 +1,22 @@
package aws
import (
"context"
"time"
)
type suppressedContext struct {
context.Context
}
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (s *suppressedContext) Done() <-chan struct{} {
return nil
}
func (s *suppressedContext) Err() error {
return nil
}

View file

@ -0,0 +1,224 @@
package aws
import (
"context"
"fmt"
"sync/atomic"
"time"
sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)
// CredentialsCacheOptions are the options
type CredentialsCacheOptions struct {
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired. This can cause an
// increased number of requests to refresh the credentials to occur.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// ExpiryWindowJitterFrac provides a mechanism for randomizing the
// expiration of credentials within the configured ExpiryWindow by a random
// percentage. Valid values are between 0.0 and 1.0.
//
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
// is 0.5 then credentials will be set to expire between 30 to 60 seconds
// prior to their actual expiration time.
//
// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
ExpiryWindowJitterFrac float64
}
// CredentialsCache provides caching and concurrency safe credentials retrieval
// via the provider's retrieve method.
//
// CredentialsCache will look for optional interfaces on the Provider to adjust
// how the credential cache handles credentials caching.
//
// - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
// credential refresh failures. This could return an updated Credentials
// value, or attempt another means of retrieving credentials.
//
// - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
// credentials Expires is modified. This could modify how the Credentials
// Expires is adjusted based on the CredentialsCache ExpiryWindow option.
// Such as providing a floor not to reduce the Expires below.
type CredentialsCache struct {
provider CredentialsProvider
options CredentialsCacheOptions
creds atomic.Value
sf singleflight.Group
}
// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
// is expected to not be nil. A variadic list of one or more functions can be
// provided to modify the CredentialsCache configuration. This allows for
// configuration of credential expiry window and jitter.
func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
options := CredentialsCacheOptions{}
for _, fn := range optFns {
fn(&options)
}
if options.ExpiryWindow < 0 {
options.ExpiryWindow = 0
}
if options.ExpiryWindowJitterFrac < 0 {
options.ExpiryWindowJitterFrac = 0
} else if options.ExpiryWindowJitterFrac > 1 {
options.ExpiryWindowJitterFrac = 1
}
return &CredentialsCache{
provider: provider,
options: options,
}
}
// Retrieve returns the credentials. If the credentials have already been
// retrieved, and not expired the cached credentials will be returned. If the
// credentials have not been retrieved yet, or expired the provider's Retrieve
// method will be called.
//
// Returns and error if the provider's retrieve method returns an error.
func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
if creds, ok := p.getCreds(); ok && !creds.Expired() {
return creds, nil
}
resCh := p.sf.DoChan("", func() (interface{}, error) {
return p.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Credentials), res.Err
case <-ctx.Done():
return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
}
}
func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
currCreds, ok := p.getCreds()
if ok && !currCreds.Expired() {
return currCreds, nil
}
newCreds, err := p.provider.Retrieve(ctx)
if err != nil {
handleFailToRefresh := defaultHandleFailToRefresh
if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
handleFailToRefresh = cs.HandleFailToRefresh
}
newCreds, err = handleFailToRefresh(ctx, currCreds, err)
if err != nil {
return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
}
}
if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
adjustExpiresBy := defaultAdjustExpiresBy
if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
adjustExpiresBy = cs.AdjustExpiresBy
}
randFloat64, err := sdkrand.CryptoRandFloat64()
if err != nil {
return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
}
var jitter time.Duration
if p.options.ExpiryWindowJitterFrac > 0 {
jitter = time.Duration(randFloat64 *
p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
}
newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
if err != nil {
return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
}
}
p.creds.Store(&newCreds)
return newCreds, nil
}
// getCreds returns the currently stored credentials and true. Returning false
// if no credentials were stored.
func (p *CredentialsCache) getCreds() (Credentials, bool) {
v := p.creds.Load()
if v == nil {
return Credentials{}, false
}
c := v.(*Credentials)
if c == nil || !c.HasKeys() {
return Credentials{}, false
}
return *c, true
}
// Invalidate will invalidate the cached credentials. The next call to Retrieve
// will cause the provider's Retrieve method to be called.
func (p *CredentialsCache) Invalidate() {
p.creds.Store((*Credentials)(nil))
}
// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
// matches the target provider type.
func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
return IsCredentialsProvider(p.provider, target)
}
// HandleFailRefreshCredentialsCacheStrategy is an interface for
// CredentialsCache to allow CredentialsProvider how failed to refresh
// credentials is handled.
type HandleFailRefreshCredentialsCacheStrategy interface {
// Given the previously cached Credentials, if any, and refresh error, may
// returns new or modified set of Credentials, or error.
//
// Credential caches may use default implementation if nil.
HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
}
// defaultHandleFailToRefresh returns the passed in error.
func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
return Credentials{}, err
}
// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
// to allow CredentialsProvider to intercept adjustments to Credentials expiry
// based on expectations and use cases of CredentialsProvider.
//
// Credential caches may use default implementation if nil.
type AdjustExpiresByCredentialsCacheStrategy interface {
// Given a Credentials as input, applying any mutations and
// returning the potentially updated Credentials, or error.
AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
}
// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
// and returns the updated credentials value. If Credentials value's CanExpire
// is false, the passed in credentials are returned unchanged.
func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
if !creds.CanExpire {
return creds, nil
}
creds.Expires = creds.Expires.Add(dur)
return creds, nil
}

170
vendor/github.com/aws/aws-sdk-go-v2/aws/credentials.go generated vendored Normal file
View file

@ -0,0 +1,170 @@
package aws
import (
"context"
"fmt"
"reflect"
"time"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)
// AnonymousCredentials provides a sentinel CredentialsProvider that should be
// used to instruct the SDK's signing middleware to not sign the request.
//
// Using `nil` credentials when configuring an API client will achieve the same
// result. The AnonymousCredentials type allows you to configure the SDK's
// external config loading to not attempt to source credentials from the shared
// config or environment.
//
// For example you can use this CredentialsProvider with an API client's
// Options to instruct the client not to sign a request for accessing public
// S3 bucket objects.
//
// The following example demonstrates using the AnonymousCredentials to prevent
// SDK's external config loading attempt to resolve credentials.
//
// cfg, err := config.LoadDefaultConfig(context.TODO(),
// config.WithCredentialsProvider(aws.AnonymousCredentials{}),
// )
// if err != nil {
// log.Fatalf("failed to load config, %v", err)
// }
//
// client := s3.NewFromConfig(cfg)
//
// Alternatively you can leave the API client Option's `Credential` member to
// nil. If using the `NewFromConfig` constructor you'll need to explicitly set
// the `Credentials` member to nil, if the external config resolved a
// credential provider.
//
// client := s3.New(s3.Options{
// // Credentials defaults to a nil value.
// })
//
// This can also be configured for specific operations calls too.
//
// cfg, err := config.LoadDefaultConfig(context.TODO())
// if err != nil {
// log.Fatalf("failed to load config, %v", err)
// }
//
// client := s3.NewFromConfig(config)
//
// result, err := client.GetObject(context.TODO(), s3.GetObject{
// Bucket: aws.String("example-bucket"),
// Key: aws.String("example-key"),
// }, func(o *s3.Options) {
// o.Credentials = nil
// // Or
// o.Credentials = aws.AnonymousCredentials{}
// })
type AnonymousCredentials struct{}
// Retrieve implements the CredentialsProvider interface, but will always
// return error, and cannot be used to sign a request. The AnonymousCredentials
// type is used as a sentinel type instructing the AWS request signing
// middleware to not sign a request.
func (AnonymousCredentials) Retrieve(context.Context) (Credentials, error) {
return Credentials{Source: "AnonymousCredentials"},
fmt.Errorf("the AnonymousCredentials is not a valid credential provider, and cannot be used to sign AWS requests with")
}
// A Credentials is the AWS credentials value for individual credential fields.
type Credentials struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
// Source of the credentials
Source string
// States if the credentials can expire or not.
CanExpire bool
// The time the credentials will expire at. Should be ignored if CanExpire
// is false.
Expires time.Time
}
// Expired returns if the credentials have expired.
func (v Credentials) Expired() bool {
if v.CanExpire {
// Calling Round(0) on the current time will truncate the monotonic
// reading only. Ensures credential expiry time is always based on
// reported wall-clock time.
return !v.Expires.After(sdk.NowTime().Round(0))
}
return false
}
// HasKeys returns if the credentials keys are set.
func (v Credentials) HasKeys() bool {
return len(v.AccessKeyID) > 0 && len(v.SecretAccessKey) > 0
}
// A CredentialsProvider is the interface for any component which will provide
// credentials Credentials. A CredentialsProvider is required to manage its own
// Expired state, and what to be expired means.
//
// A credentials provider implementation can be wrapped with a CredentialCache
// to cache the credential value retrieved. Without the cache the SDK will
// attempt to retrieve the credentials for every request.
type CredentialsProvider interface {
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve(ctx context.Context) (Credentials, error)
}
// CredentialsProviderFunc provides a helper wrapping a function value to
// satisfy the CredentialsProvider interface.
type CredentialsProviderFunc func(context.Context) (Credentials, error)
// Retrieve delegates to the function value the CredentialsProviderFunc wraps.
func (fn CredentialsProviderFunc) Retrieve(ctx context.Context) (Credentials, error) {
return fn(ctx)
}
type isCredentialsProvider interface {
IsCredentialsProvider(CredentialsProvider) bool
}
// IsCredentialsProvider returns whether the target CredentialProvider is the same type as provider when comparing the
// implementation type.
//
// If provider has a method IsCredentialsProvider(CredentialsProvider) bool it will be responsible for validating
// whether target matches the credential provider type.
//
// When comparing the CredentialProvider implementations provider and target for equality, the following rules are used:
//
// If provider is of type T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type T and target is of type *V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type *V,true if type *T is the same as type *V, otherwise false
func IsCredentialsProvider(provider, target CredentialsProvider) bool {
if target == nil || provider == nil {
return provider == target
}
if x, ok := provider.(isCredentialsProvider); ok {
return x.IsCredentialsProvider(target)
}
targetType := reflect.TypeOf(target)
if targetType.Kind() != reflect.Ptr {
targetType = reflect.PtrTo(targetType)
}
providerType := reflect.TypeOf(provider)
if providerType.Kind() != reflect.Ptr {
providerType = reflect.PtrTo(providerType)
}
return targetType.AssignableTo(providerType)
}

View file

@ -0,0 +1,38 @@
package defaults
import (
"github.com/aws/aws-sdk-go-v2/aws"
"runtime"
"strings"
)
var getGOOS = func() string {
return runtime.GOOS
}
// ResolveDefaultsModeAuto is used to determine the effective aws.DefaultsMode when the mode
// is set to aws.DefaultsModeAuto.
func ResolveDefaultsModeAuto(region string, environment aws.RuntimeEnvironment) aws.DefaultsMode {
goos := getGOOS()
if goos == "android" || goos == "ios" {
return aws.DefaultsModeMobile
}
var currentRegion string
if len(environment.EnvironmentIdentifier) > 0 {
currentRegion = environment.Region
}
if len(currentRegion) == 0 && len(environment.EC2InstanceMetadataRegion) > 0 {
currentRegion = environment.EC2InstanceMetadataRegion
}
if len(region) > 0 && len(currentRegion) > 0 {
if strings.EqualFold(region, currentRegion) {
return aws.DefaultsModeInRegion
}
return aws.DefaultsModeCrossRegion
}
return aws.DefaultsModeStandard
}

View file

@ -0,0 +1,43 @@
package defaults
import (
"time"
"github.com/aws/aws-sdk-go-v2/aws"
)
// Configuration is the set of SDK configuration options that are determined based
// on the configured DefaultsMode.
type Configuration struct {
// RetryMode is the configuration's default retry mode API clients should
// use for constructing a Retryer.
RetryMode aws.RetryMode
// ConnectTimeout is the maximum amount of time a dial will wait for
// a connect to complete.
//
// See https://pkg.go.dev/net#Dialer.Timeout
ConnectTimeout *time.Duration
// TLSNegotiationTimeout specifies the maximum amount of time waiting to
// wait for a TLS handshake.
//
// See https://pkg.go.dev/net/http#Transport.TLSHandshakeTimeout
TLSNegotiationTimeout *time.Duration
}
// GetConnectTimeout returns the ConnectTimeout value, returns false if the value is not set.
func (c *Configuration) GetConnectTimeout() (time.Duration, bool) {
if c.ConnectTimeout == nil {
return 0, false
}
return *c.ConnectTimeout, true
}
// GetTLSNegotiationTimeout returns the TLSNegotiationTimeout value, returns false if the value is not set.
func (c *Configuration) GetTLSNegotiationTimeout() (time.Duration, bool) {
if c.TLSNegotiationTimeout == nil {
return 0, false
}
return *c.TLSNegotiationTimeout, true
}

View file

@ -0,0 +1,50 @@
// Code generated by github.com/aws/aws-sdk-go-v2/internal/codegen/cmd/defaultsconfig. DO NOT EDIT.
package defaults
import (
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"time"
)
// GetModeConfiguration returns the default Configuration descriptor for the given mode.
//
// Supports the following modes: cross-region, in-region, mobile, standard
func GetModeConfiguration(mode aws.DefaultsMode) (Configuration, error) {
var mv aws.DefaultsMode
mv.SetFromString(string(mode))
switch mv {
case aws.DefaultsModeCrossRegion:
settings := Configuration{
ConnectTimeout: aws.Duration(3100 * time.Millisecond),
RetryMode: aws.RetryMode("standard"),
TLSNegotiationTimeout: aws.Duration(3100 * time.Millisecond),
}
return settings, nil
case aws.DefaultsModeInRegion:
settings := Configuration{
ConnectTimeout: aws.Duration(1100 * time.Millisecond),
RetryMode: aws.RetryMode("standard"),
TLSNegotiationTimeout: aws.Duration(1100 * time.Millisecond),
}
return settings, nil
case aws.DefaultsModeMobile:
settings := Configuration{
ConnectTimeout: aws.Duration(30000 * time.Millisecond),
RetryMode: aws.RetryMode("standard"),
TLSNegotiationTimeout: aws.Duration(30000 * time.Millisecond),
}
return settings, nil
case aws.DefaultsModeStandard:
settings := Configuration{
ConnectTimeout: aws.Duration(3100 * time.Millisecond),
RetryMode: aws.RetryMode("standard"),
TLSNegotiationTimeout: aws.Duration(3100 * time.Millisecond),
}
return settings, nil
default:
return Configuration{}, fmt.Errorf("unsupported defaults mode: %v", mode)
}
}

View file

@ -0,0 +1,2 @@
// Package defaults provides recommended configuration values for AWS SDKs and CLIs.
package defaults

View file

@ -0,0 +1,95 @@
// Code generated by github.com/aws/aws-sdk-go-v2/internal/codegen/cmd/defaultsmode. DO NOT EDIT.
package aws
import (
"strings"
)
// DefaultsMode is the SDK defaults mode setting.
type DefaultsMode string
// The DefaultsMode constants.
const (
// DefaultsModeAuto is an experimental mode that builds on the standard mode.
// The SDK will attempt to discover the execution environment to determine the
// appropriate settings automatically.
//
// Note that the auto detection is heuristics-based and does not guarantee 100%
// accuracy. STANDARD mode will be used if the execution environment cannot
// be determined. The auto detection might query EC2 Instance Metadata service
// (https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html),
// which might introduce latency. Therefore we recommend choosing an explicit
// defaults_mode instead if startup latency is critical to your application
DefaultsModeAuto DefaultsMode = "auto"
// DefaultsModeCrossRegion builds on the standard mode and includes optimization
// tailored for applications which call AWS services in a different region
//
// Note that the default values vended from this mode might change as best practices
// may evolve. As a result, it is encouraged to perform tests when upgrading
// the SDK
DefaultsModeCrossRegion DefaultsMode = "cross-region"
// DefaultsModeInRegion builds on the standard mode and includes optimization
// tailored for applications which call AWS services from within the same AWS
// region
//
// Note that the default values vended from this mode might change as best practices
// may evolve. As a result, it is encouraged to perform tests when upgrading
// the SDK
DefaultsModeInRegion DefaultsMode = "in-region"
// DefaultsModeLegacy provides default settings that vary per SDK and were used
// prior to establishment of defaults_mode
DefaultsModeLegacy DefaultsMode = "legacy"
// DefaultsModeMobile builds on the standard mode and includes optimization
// tailored for mobile applications
//
// Note that the default values vended from this mode might change as best practices
// may evolve. As a result, it is encouraged to perform tests when upgrading
// the SDK
DefaultsModeMobile DefaultsMode = "mobile"
// DefaultsModeStandard provides the latest recommended default values that
// should be safe to run in most scenarios
//
// Note that the default values vended from this mode might change as best practices
// may evolve. As a result, it is encouraged to perform tests when upgrading
// the SDK
DefaultsModeStandard DefaultsMode = "standard"
)
// SetFromString sets the DefaultsMode value to one of the pre-defined constants that matches
// the provided string when compared using EqualFold. If the value does not match a known
// constant it will be set to as-is and the function will return false. As a special case, if the
// provided value is a zero-length string, the mode will be set to LegacyDefaultsMode.
func (d *DefaultsMode) SetFromString(v string) (ok bool) {
switch {
case strings.EqualFold(v, string(DefaultsModeAuto)):
*d = DefaultsModeAuto
ok = true
case strings.EqualFold(v, string(DefaultsModeCrossRegion)):
*d = DefaultsModeCrossRegion
ok = true
case strings.EqualFold(v, string(DefaultsModeInRegion)):
*d = DefaultsModeInRegion
ok = true
case strings.EqualFold(v, string(DefaultsModeLegacy)):
*d = DefaultsModeLegacy
ok = true
case strings.EqualFold(v, string(DefaultsModeMobile)):
*d = DefaultsModeMobile
ok = true
case strings.EqualFold(v, string(DefaultsModeStandard)):
*d = DefaultsModeStandard
ok = true
case len(v) == 0:
*d = DefaultsModeLegacy
ok = true
default:
*d = DefaultsMode(v)
}
return ok
}

62
vendor/github.com/aws/aws-sdk-go-v2/aws/doc.go generated vendored Normal file
View file

@ -0,0 +1,62 @@
// Package aws provides the core SDK's utilities and shared types. Use this package's
// utilities to simplify setting and reading API operations parameters.
//
// # Value and Pointer Conversion Utilities
//
// This package includes a helper conversion utility for each scalar type the SDK's
// API use. These utilities make getting a pointer of the scalar, and dereferencing
// a pointer easier.
//
// Each conversion utility comes in two forms. Value to Pointer and Pointer to Value.
// The Pointer to value will safely dereference the pointer and return its value.
// If the pointer was nil, the scalar's zero value will be returned.
//
// The value to pointer functions will be named after the scalar type. So get a
// *string from a string value use the "String" function. This makes it easy to
// to get pointer of a literal string value, because getting the address of a
// literal requires assigning the value to a variable first.
//
// var strPtr *string
//
// // Without the SDK's conversion functions
// str := "my string"
// strPtr = &str
//
// // With the SDK's conversion functions
// strPtr = aws.String("my string")
//
// // Convert *string to string value
// str = aws.ToString(strPtr)
//
// In addition to scalars the aws package also includes conversion utilities for
// map and slice for commonly types used in API parameters. The map and slice
// conversion functions use similar naming pattern as the scalar conversion
// functions.
//
// var strPtrs []*string
// var strs []string = []string{"Go", "Gophers", "Go"}
//
// // Convert []string to []*string
// strPtrs = aws.StringSlice(strs)
//
// // Convert []*string to []string
// strs = aws.ToStringSlice(strPtrs)
//
// # SDK Default HTTP Client
//
// The SDK will use the http.DefaultClient if a HTTP client is not provided to
// the SDK's Session, or service client constructor. This means that if the
// http.DefaultClient is modified by other components of your application the
// modifications will be picked up by the SDK as well.
//
// In some cases this might be intended, but it is a better practice to create
// a custom HTTP Client to share explicitly through your application. You can
// configure the SDK to use the custom HTTP Client by setting the HTTPClient
// value of the SDK's Config type when creating a Session or service client.
package aws
// generate.go uses a build tag of "ignore", go run doesn't need to specify
// this because go run ignores all build flags when running a go file directly.
//go:generate go run -tags codegen generate.go
//go:generate go run -tags codegen logging_generate.go
//go:generate gofmt -w -s .

229
vendor/github.com/aws/aws-sdk-go-v2/aws/endpoints.go generated vendored Normal file
View file

@ -0,0 +1,229 @@
package aws
import (
"fmt"
)
// DualStackEndpointState is a constant to describe the dual-stack endpoint resolution behavior.
type DualStackEndpointState uint
const (
// DualStackEndpointStateUnset is the default value behavior for dual-stack endpoint resolution.
DualStackEndpointStateUnset DualStackEndpointState = iota
// DualStackEndpointStateEnabled enables dual-stack endpoint resolution for service endpoints.
DualStackEndpointStateEnabled
// DualStackEndpointStateDisabled disables dual-stack endpoint resolution for endpoints.
DualStackEndpointStateDisabled
)
// GetUseDualStackEndpoint takes a service's EndpointResolverOptions and returns the UseDualStackEndpoint value.
// Returns boolean false if the provided options does not have a method to retrieve the DualStackEndpointState.
func GetUseDualStackEndpoint(options ...interface{}) (value DualStackEndpointState, found bool) {
type iface interface {
GetUseDualStackEndpoint() DualStackEndpointState
}
for _, option := range options {
if i, ok := option.(iface); ok {
value = i.GetUseDualStackEndpoint()
found = true
break
}
}
return value, found
}
// FIPSEndpointState is a constant to describe the FIPS endpoint resolution behavior.
type FIPSEndpointState uint
const (
// FIPSEndpointStateUnset is the default value behavior for FIPS endpoint resolution.
FIPSEndpointStateUnset FIPSEndpointState = iota
// FIPSEndpointStateEnabled enables FIPS endpoint resolution for service endpoints.
FIPSEndpointStateEnabled
// FIPSEndpointStateDisabled disables FIPS endpoint resolution for endpoints.
FIPSEndpointStateDisabled
)
// GetUseFIPSEndpoint takes a service's EndpointResolverOptions and returns the UseDualStackEndpoint value.
// Returns boolean false if the provided options does not have a method to retrieve the DualStackEndpointState.
func GetUseFIPSEndpoint(options ...interface{}) (value FIPSEndpointState, found bool) {
type iface interface {
GetUseFIPSEndpoint() FIPSEndpointState
}
for _, option := range options {
if i, ok := option.(iface); ok {
value = i.GetUseFIPSEndpoint()
found = true
break
}
}
return value, found
}
// Endpoint represents the endpoint a service client should make API operation
// calls to.
//
// The SDK will automatically resolve these endpoints per API client using an
// internal endpoint resolvers. If you'd like to provide custom endpoint
// resolving behavior you can implement the EndpointResolver interface.
type Endpoint struct {
// The base URL endpoint the SDK API clients will use to make API calls to.
// The SDK will suffix URI path and query elements to this endpoint.
URL string
// Specifies if the endpoint's hostname can be modified by the SDK's API
// client.
//
// If the hostname is mutable the SDK API clients may modify any part of
// the hostname based on the requirements of the API, (e.g. adding, or
// removing content in the hostname). Such as, Amazon S3 API client
// prefixing "bucketname" to the hostname, or changing the
// hostname service name component from "s3." to "s3-accesspoint.dualstack."
// for the dualstack endpoint of an S3 Accesspoint resource.
//
// Care should be taken when providing a custom endpoint for an API. If the
// endpoint hostname is mutable, and the client cannot modify the endpoint
// correctly, the operation call will most likely fail, or have undefined
// behavior.
//
// If hostname is immutable, the SDK API clients will not modify the
// hostname of the URL. This may cause the API client not to function
// correctly if the API requires the operation specific hostname values
// to be used by the client.
//
// This flag does not modify the API client's behavior if this endpoint
// will be used instead of Endpoint Discovery, or if the endpoint will be
// used to perform Endpoint Discovery. That behavior is configured via the
// API Client's Options.
HostnameImmutable bool
// The AWS partition the endpoint belongs to.
PartitionID string
// The service name that should be used for signing the requests to the
// endpoint.
SigningName string
// The region that should be used for signing the request to the endpoint.
SigningRegion string
// The signing method that should be used for signing the requests to the
// endpoint.
SigningMethod string
// The source of the Endpoint. By default, this will be EndpointSourceServiceMetadata.
// When providing a custom endpoint, you should set the source as EndpointSourceCustom.
// If source is not provided when providing a custom endpoint, the SDK may not
// perform required host mutations correctly. Source should be used along with
// HostnameImmutable property as per the usage requirement.
Source EndpointSource
}
// EndpointSource is the endpoint source type.
type EndpointSource int
const (
// EndpointSourceServiceMetadata denotes service modeled endpoint metadata is used as Endpoint Source.
EndpointSourceServiceMetadata EndpointSource = iota
// EndpointSourceCustom denotes endpoint is a custom endpoint. This source should be used when
// user provides a custom endpoint to be used by the SDK.
EndpointSourceCustom
)
// EndpointNotFoundError is a sentinel error to indicate that the
// EndpointResolver implementation was unable to resolve an endpoint for the
// given service and region. Resolvers should use this to indicate that an API
// client should fallback and attempt to use it's internal default resolver to
// resolve the endpoint.
type EndpointNotFoundError struct {
Err error
}
// Error is the error message.
func (e *EndpointNotFoundError) Error() string {
return fmt.Sprintf("endpoint not found, %v", e.Err)
}
// Unwrap returns the underlying error.
func (e *EndpointNotFoundError) Unwrap() error {
return e.Err
}
// EndpointResolver is an endpoint resolver that can be used to provide or
// override an endpoint for the given service and region. API clients will
// attempt to use the EndpointResolver first to resolve an endpoint if
// available. If the EndpointResolver returns an EndpointNotFoundError error,
// API clients will fallback to attempting to resolve the endpoint using its
// internal default endpoint resolver.
//
// Deprecated: See EndpointResolverWithOptions
type EndpointResolver interface {
ResolveEndpoint(service, region string) (Endpoint, error)
}
// EndpointResolverFunc wraps a function to satisfy the EndpointResolver interface.
//
// Deprecated: See EndpointResolverWithOptionsFunc
type EndpointResolverFunc func(service, region string) (Endpoint, error)
// ResolveEndpoint calls the wrapped function and returns the results.
//
// Deprecated: See EndpointResolverWithOptions.ResolveEndpoint
func (e EndpointResolverFunc) ResolveEndpoint(service, region string) (Endpoint, error) {
return e(service, region)
}
// EndpointResolverWithOptions is an endpoint resolver that can be used to provide or
// override an endpoint for the given service, region, and the service client's EndpointOptions. API clients will
// attempt to use the EndpointResolverWithOptions first to resolve an endpoint if
// available. If the EndpointResolverWithOptions returns an EndpointNotFoundError error,
// API clients will fallback to attempting to resolve the endpoint using its
// internal default endpoint resolver.
type EndpointResolverWithOptions interface {
ResolveEndpoint(service, region string, options ...interface{}) (Endpoint, error)
}
// EndpointResolverWithOptionsFunc wraps a function to satisfy the EndpointResolverWithOptions interface.
type EndpointResolverWithOptionsFunc func(service, region string, options ...interface{}) (Endpoint, error)
// ResolveEndpoint calls the wrapped function and returns the results.
func (e EndpointResolverWithOptionsFunc) ResolveEndpoint(service, region string, options ...interface{}) (Endpoint, error) {
return e(service, region, options...)
}
// GetDisableHTTPS takes a service's EndpointResolverOptions and returns the DisableHTTPS value.
// Returns boolean false if the provided options does not have a method to retrieve the DisableHTTPS.
func GetDisableHTTPS(options ...interface{}) (value bool, found bool) {
type iface interface {
GetDisableHTTPS() bool
}
for _, option := range options {
if i, ok := option.(iface); ok {
value = i.GetDisableHTTPS()
found = true
break
}
}
return value, found
}
// GetResolvedRegion takes a service's EndpointResolverOptions and returns the ResolvedRegion value.
// Returns boolean false if the provided options does not have a method to retrieve the ResolvedRegion.
func GetResolvedRegion(options ...interface{}) (value string, found bool) {
type iface interface {
GetResolvedRegion() string
}
for _, option := range options {
if i, ok := option.(iface); ok {
value = i.GetResolvedRegion()
found = true
break
}
}
return value, found
}

Some files were not shown because too many files have changed in this diff Show more