Bump github.com/go-sql-driver/mysql from 1.8.1 to 1.9.0

Bumps [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) from 1.8.1 to 1.9.0.
- [Release notes](https://github.com/go-sql-driver/mysql/releases)
- [Changelog](https://github.com/go-sql-driver/mysql/blob/master/CHANGELOG.md)
- [Commits](https://github.com/go-sql-driver/mysql/compare/v1.8.1...v1.9.0)

---
updated-dependencies:
- dependency-name: github.com/go-sql-driver/mysql
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot] 2025-02-18 13:01:41 +00:00 committed by GitHub
parent fefb993745
commit 27f4fc3134
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 720 additions and 506 deletions

2
go.mod
View file

@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.4 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.39.4
github.com/aws/aws-sdk-go-v2/service/kinesis v1.32.12 github.com/aws/aws-sdk-go-v2/service/kinesis v1.32.12
github.com/awslabs/kinesis-aggregation/go/v2 v2.0.0-20230808105340-e631fe742486 github.com/awslabs/kinesis-aggregation/go/v2 v2.0.0-20230808105340-e631fe742486
github.com/go-sql-driver/mysql v1.8.1 github.com/go-sql-driver/mysql v1.9.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.20.5 github.com/prometheus/client_golang v1.20.5

4
go.sum
View file

@ -62,8 +62,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=

View file

@ -20,7 +20,10 @@ Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com> Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com> Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il> Ariel Mashraki <ariel at mashraki.co.il>
Artur Melanchyk <artur.melanchyk@gmail.com>
Asta Xie <xiemengjun at gmail.com> Asta Xie <xiemengjun at gmail.com>
B Lamarche <blam413 at gmail.com>
Bes Dollma <bdollma@thousandeyes.com>
Brian Hendriks <brian at dolthub.com> Brian Hendriks <brian at dolthub.com>
Bulat Gaifullin <gaifullinbf at gmail.com> Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu> Caine Jette <jette at alum.mit.edu>
@ -33,6 +36,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com> Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl> Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com> Dave Protasowski <dprotaso at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me> DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com> Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io> Erwan Martin <hello at erwan.io>
@ -50,6 +54,7 @@ ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com> Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com> INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com> Jacek Szwec <szwec.jacek at gmail.com>
Jakub Adamus <kratky at zobak.cz>
James Harr <james.harr at gmail.com> James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net> Janek Vedock <janekvedock at comcast.net>
Jason Ng <oblitorum at gmail.com> Jason Ng <oblitorum at gmail.com>
@ -60,6 +65,7 @@ Jennifer Purevsuren <jennifer at dolthub.com>
Jerome Meyer <jxmeyer at gmail.com> Jerome Meyer <jxmeyer at gmail.com>
Jiajia Zhong <zhong2plus at gmail.com> Jiajia Zhong <zhong2plus at gmail.com>
Jian Zhen <zhenjl at gmail.com> Jian Zhen <zhenjl at gmail.com>
Joe Mann <contact at joemann.co.uk>
Joshua Prunier <joshua.prunier at gmail.com> Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com> Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com> Julien Schmidt <go-sql-driver at julienschmidt.com>
@ -80,6 +86,7 @@ Lunny Xiao <xiaolunwen at gmail.com>
Luke Scott <luke at webconnex.com> Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com> Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com> Michael Woolnough <michael.woolnough at gmail.com>
Nao Yokotsuka <yokotukanao at gmail.com>
Nathanial Murphy <nathanial.murphy at gmail.com> Nathanial Murphy <nathanial.murphy at gmail.com>
Nicola Peduzzi <thenikso at gmail.com> Nicola Peduzzi <thenikso at gmail.com>
Oliver Bone <owbone at github.com> Oliver Bone <owbone at github.com>
@ -89,6 +96,7 @@ Paul Bonser <misterpib at gmail.com>
Paulius Lozys <pauliuslozys at gmail.com> Paulius Lozys <pauliuslozys at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com> Peter Schultz <peter.schultz at classmarkets.com>
Phil Porada <philporada at gmail.com> Phil Porada <philporada at gmail.com>
Minh Quang <minhquang4334 at gmail.com>
Rebecca Chin <rchin at pivotal.io> Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com> Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com> Richard Wilkes <wilkes at me.com>
@ -139,4 +147,5 @@ PingCAP Inc.
Pivotal Inc. Pivotal Inc.
Shattered Silicon Ltd. Shattered Silicon Ltd.
Stripe Inc. Stripe Inc.
ThousandEyes
Zendesk Inc. Zendesk Inc.

View file

@ -1,3 +1,28 @@
# Changelog
## v1.9.0 (2025-02-18)
### Major Changes
- Implement zlib compression. (#1487)
- Supported Go version is updated to Go 1.21+. (#1639)
- Add support for VECTOR type introduced in MySQL 9.0. (#1609)
- Config object can have custom dial function. (#1527)
### Bugfixes
- Fix auth errors when username/password are too long. (#1625)
- Check if MySQL supports CLIENT_CONNECT_ATTRS before sending client attributes. (#1640)
- Fix auth switch request handling. (#1666)
### Other changes
- Add "filename:line" prefix to log in go-mysql. Custom loggers now show it. (#1589)
- Improve error handling. It reduces the "busy buffer" errors. (#1595, #1601, #1641)
- Use `strconv.Atoi` to parse max_allowed_packet. (#1661)
- `rejectReadOnly` option now handles ER_READ_ONLY_MODE (1290) error too. (#1660)
## Version 1.8.1 (2024-03-26) ## Version 1.8.1 (2024-03-26)
Bugfixes: Bugfixes:

View file

@ -38,11 +38,12 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
* Optional `time.Time` parsing * Optional `time.Time` parsing
* Optional placeholder interpolation * Optional placeholder interpolation
* Supports zlib compression.
## Requirements ## Requirements
* Go 1.19 or higher. We aim to support the 3 latest versions of Go. * Go 1.21 or higher. We aim to support the 3 latest versions of Go.
* MySQL (5.7+) and MariaDB (10.3+) are supported. * MySQL (5.7+) and MariaDB (10.5+) are supported.
* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. * [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP.
* Do not ask questions about TiDB in our issue tracker or forum. * Do not ask questions about TiDB in our issue tracker or forum.
* [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang) * [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang)
@ -267,6 +268,16 @@ SELECT u.id FROM users as u
will return `u.id` instead of just `id` if `columnsWithAlias=true`. will return `u.id` instead of just `id` if `columnsWithAlias=true`.
##### `compress`
```
Type: bool
Valid Values: true, false
Default: false
```
Toggles zlib compression. false by default.
##### `interpolateParams` ##### `interpolateParams`
``` ```
@ -519,6 +530,9 @@ This driver supports the [`ColumnType` interface](https://golang.org/pkg/databas
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
> [!IMPORTANT]
> The `QueryContext`, `ExecContext`, etc. variants provided by `database/sql` will cause the connection to be closed if the provided context is cancelled or timed out before the result is received by the driver.
### `LOAD DATA LOCAL INFILE` support ### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`): For this feature you need direct access to the package. Therefore you must change the import path (no `_`):

View file

@ -1,19 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.19
// +build go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
type atomicBool = atomic.Bool

View file

@ -1,47 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !go1.19
// +build !go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
// atomicBool is an implementation of atomic.Bool for older version of Go.
// it is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_ noCopy
value uint32
}
// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}

View file

@ -10,54 +10,42 @@ package mysql
import ( import (
"io" "io"
"net"
"time"
) )
const defaultBufSize = 4096 const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024 const maxCachedBufSize = 256 * 1024
// readerFunc is a function that compatible with io.Reader.
// We use this function type instead of io.Reader because we want to
// just pass mc.readWithTimeout.
type readerFunc func([]byte) (int, error)
// A buffer which is used for both reading and writing. // A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous. // This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection. // In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish // The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case. // Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct { type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal. buf []byte // read buffer.
nc net.Conn cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize.
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
} }
// newBuffer allocates and returns a new buffer. // newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer { func newBuffer() buffer {
fg := make([]byte, defaultBufSize)
return buffer{ return buffer{
buf: fg, cachedBuf: make([]byte, defaultBufSize),
nc: nc,
dbuf: [2][]byte{fg, nil},
} }
} }
// flip replaces the active buffer with the background buffer // busy returns true if the read buffer is not empty.
// this is a delayed flip that simply increases the buffer counter; func (b *buffer) busy() bool {
// the actual flip will be performed the next time we call `buffer.fill` return len(b.buf) > 0
func (b *buffer) flip() {
b.flipcnt += 1
} }
// fill reads into the buffer until at least _need_ bytes are in it // fill reads into the read buffer until at least _need_ bytes are in it.
func (b *buffer) fill(need int) error { func (b *buffer) fill(need int, r readerFunc) error {
n := b.length // we'll move the contents of the current buffer to dest before filling it.
// fill data into its double-buffering target: if we've called dest := b.cachedBuf
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]
// grow buffer if necessary to fit the whole packet. // grow buffer if necessary to fit the whole packet.
if need > len(dest) { if need > len(dest) {
@ -67,64 +55,48 @@ func (b *buffer) fill(need int) error {
// if the allocated buffer is not too large, move it to backing storage // if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads // to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize { if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest b.cachedBuf = dest
} }
} }
// if we're filling the fg buffer, move the existing data to the start of it. // move the existing data to the start of the buffer.
// if we're filling the bg buffer, copy over the data n := len(b.buf)
if n > 0 { copy(dest[:n], b.buf)
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for { for {
if b.timeout > 0 { nn, err := r(dest[n:])
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn n += nn
switch err { if err == nil && n < need {
case nil: continue
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
} }
b.buf = dest[:n]
if err == io.EOF {
if n < need {
err = io.ErrUnexpectedEOF
} else {
err = nil
}
}
return err
} }
} }
// returns next N bytes from buffer. // returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read // The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) { func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) {
if b.length < need { if len(b.buf) < need {
// refill // refill
if err := b.fill(need); err != nil { if err := b.fill(need, r); err != nil {
return nil, err return nil, err
} }
} }
offset := b.idx data := b.buf[:need]
b.idx += need b.buf = b.buf[need:]
b.length -= need return data, nil
return b.buf[offset:b.idx], nil
} }
// takeBuffer returns a buffer with the requested size. // takeBuffer returns a buffer with the requested size.
@ -132,18 +104,18 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// Otherwise a bigger buffer is made. // Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) { func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
// test (cheap) general case first // test (cheap) general case first
if length <= cap(b.buf) { if length <= len(b.cachedBuf) {
return b.buf[:length], nil return b.cachedBuf[:length], nil
} }
if length < maxPacketSize { if length < maxCachedBufSize {
b.buf = make([]byte, length) b.cachedBuf = make([]byte, length)
return b.buf, nil return b.cachedBuf, nil
} }
// buffer is larger than we want to store. // buffer is larger than we want to store.
@ -154,10 +126,10 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) {
// known to be smaller than defaultBufSize. // known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
return b.buf[:length], nil return b.cachedBuf[:length], nil
} }
// takeCompleteBuffer returns the complete existing buffer. // takeCompleteBuffer returns the complete existing buffer.
@ -165,18 +137,15 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
// cap and len of the returned buffer will be equal. // cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) { func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
return b.buf, nil return b.cachedBuf, nil
} }
// store stores buf, an updated buffer, if its suitable to do so. // store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error { func (b *buffer) store(buf []byte) {
if b.length > 0 { if cap(buf) <= maxCachedBufSize && cap(buf) > cap(b.cachedBuf) {
return ErrBusyBuffer b.cachedBuf = buf[:cap(buf)]
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
} }
return nil
} }

View file

@ -8,7 +8,7 @@
package mysql package mysql
const defaultCollation = "utf8mb4_general_ci" const defaultCollationID = 45 // utf8mb4_general_ci
const binaryCollationID = 63 const binaryCollationID = 63
// A list of available collations mapped to the internal ID. // A list of available collations mapped to the internal ID.

214
vendor/github.com/go-sql-driver/mysql/compress.go generated vendored Normal file
View file

@ -0,0 +1,214 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
)
var (
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
)
func init() {
zrPool = &sync.Pool{
New: func() any { return nil },
}
zwPool = &sync.Pool{
New: func() any {
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
if err != nil {
panic(err) // compress/zlib return non-nil error only if level is invalid
}
return zw
},
}
}
func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
br := bytes.NewReader(src)
var zr io.ReadCloser
var err error
if a := zrPool.Get(); a == nil {
if zr, err = zlib.NewReader(br); err != nil {
return 0, err
}
} else {
zr = a.(io.ReadCloser)
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
return 0, err
}
}
n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
err = zr.Close() // zr.Close() may return chuecksum error.
zrPool.Put(zr)
return int(n), err
}
func zCompress(src []byte, dst io.Writer) error {
zw := zwPool.Get().(*zlib.Writer)
zw.Reset(dst)
if _, err := zw.Write(src); err != nil {
return err
}
err := zw.Close()
zwPool.Put(zw)
return err
}
type compIO struct {
mc *mysqlConn
buff bytes.Buffer
}
func newCompIO(mc *mysqlConn) *compIO {
return &compIO{
mc: mc,
}
}
func (c *compIO) reset() {
c.buff.Reset()
}
func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) {
for c.buff.Len() < need {
if err := c.readCompressedPacket(r); err != nil {
return nil, err
}
}
data := c.buff.Next(need)
return data[:need:need], nil // prevent caller writes into c.buff
}
func (c *compIO) readCompressedPacket(r readerFunc) error {
header, err := c.mc.buf.readNext(7, r) // size of compressed header
if err != nil {
return err
}
_ = header[6] // bounds check hint to compiler; guaranteed by readNext
// compressed header structure
comprLength := getUint24(header[0:3])
compressionSequence := uint8(header[3])
uncompressedLength := getUint24(header[4:7])
if debug {
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
}
// Do not return ErrPktSync here.
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
// before receiving all packets from client. In this case, seqnr is younger than expected.
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
if debug && compressionSequence != c.mc.sequence {
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
c.mc.sequence, compressionSequence)
}
c.mc.sequence = compressionSequence + 1
c.mc.compressSequence = c.mc.sequence
comprData, err := c.mc.buf.readNext(comprLength, r)
if err != nil {
return err
}
// if payload is uncompressed, its length will be specified as zero, and its
// true length is contained in comprLength
if uncompressedLength == 0 {
c.buff.Write(comprData)
return nil
}
// use existing capacity in bytesBuf if possible
c.buff.Grow(uncompressedLength)
nread, err := zDecompress(comprData, &c.buff)
if err != nil {
return err
}
if nread != uncompressedLength {
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
uncompressedLength, nread)
}
return nil
}
const minCompressLength = 150
const maxPayloadLen = maxPacketSize - 4
// writePackets sends one or some packets with compression.
// Use this instead of mc.netConn.Write() when mc.compress is true.
func (c *compIO) writePackets(packets []byte) (int, error) {
totalBytes := len(packets)
blankHeader := make([]byte, 7)
buf := &c.buff
for len(packets) > 0 {
payloadLen := min(maxPayloadLen, len(packets))
payload := packets[:payloadLen]
uncompressedLen := payloadLen
buf.Reset()
buf.Write(blankHeader) // Buffer.Write() never returns error
// If payload is less than minCompressLength, don't compress.
if uncompressedLen < minCompressLength {
buf.Write(payload)
uncompressedLen = 0
} else {
err := zCompress(payload, buf)
if debug && err != nil {
fmt.Printf("zCompress error: %v", err)
}
// do not compress if compressed data is larger than uncompressed data
// I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes.
if err != nil || buf.Len() >= uncompressedLen {
buf.Reset()
buf.Write(blankHeader)
buf.Write(payload)
uncompressedLen = 0
}
}
if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
// To allow returning ErrBadConn when sending really 0 bytes, we sum
// up compressed bytes that is returned by underlying Write().
return totalBytes - len(packets) + n, err
}
packets = packets[payloadLen:]
}
return totalBytes, nil
}
// writeCompressedPacket writes a compressed packet with header.
// data should start with 7 size space for header followed by payload.
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
mc := c.mc
comprLength := len(data) - 7
if debug {
fmt.Printf(
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
comprLength, uncompressedLen, mc.compressSequence)
}
// compression header
putUint24(data[0:3], comprLength)
data[3] = mc.compressSequence
putUint24(data[4:7], uncompressedLen)
mc.compressSequence++
return mc.writeWithTimeout(data)
}

View file

@ -13,10 +13,13 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net" "net"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
@ -25,15 +28,17 @@ type mysqlConn struct {
netConn net.Conn netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection. rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket(). result mysqlResult // managed by clearResult() and handleOkPacket().
compIO *compIO
cfg *Config cfg *Config
connector *connector connector *connector
maxAllowedPacket int maxAllowedPacket int
maxWriteSize int maxWriteSize int
writeTimeout time.Duration
flags clientFlag flags clientFlag
status statusFlag status statusFlag
sequence uint8 sequence uint8
compressSequence uint8
parseTime bool parseTime bool
compress bool
// for context support (Go 1.8+) // for context support (Go 1.8+)
watching bool watching bool
@ -41,71 +46,92 @@ type mysqlConn struct {
closech chan struct{} closech chan struct{}
finished chan<- struct{} finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed closed atomic.Bool // set when conn is closed, before closech is closed
} }
// Helper function to call per-connection logger. // Helper function to call per-connection logger.
func (mc *mysqlConn) log(v ...any) { func (mc *mysqlConn) log(v ...any) {
_, filename, lineno, ok := runtime.Caller(1)
if ok {
pos := strings.LastIndexByte(filename, '/')
if pos != -1 {
filename = filename[pos+1:]
}
prefix := fmt.Sprintf("%s:%d ", filename, lineno)
v = append([]any{prefix}, v...)
}
mc.cfg.Logger.Print(v...) mc.cfg.Logger.Print(v...)
} }
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
to := mc.cfg.ReadTimeout
if to > 0 {
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Read(b)
}
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
to := mc.cfg.WriteTimeout
if to > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Write(b)
}
func (mc *mysqlConn) resetSequence() {
mc.sequence = 0
mc.compressSequence = 0
}
// syncSequence must be called when finished writing some packet and before start reading.
func (mc *mysqlConn) syncSequence() {
// Syncs compressionSequence to sequence.
// This is not documented but done in `net_flush()` in MySQL and MariaDB.
// https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171
// https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293
if mc.compress {
mc.sequence = mc.compressSequence
mc.compIO.reset()
}
}
// Handles parameters set in DSN after the connection is established // Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) { func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder var cmdSet strings.Builder
for param, val := range mc.cfg.Params { for param, val := range mc.cfg.Params {
switch param { if cmdSet.Len() == 0 {
// Charset: character_set_connection, character_set_client, character_set_results // Heuristic: 29 chars for each other key=value to reduce reallocations
case "charset": cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
charsets := strings.Split(val, ",") cmdSet.WriteString("SET ")
for _, cs := range charsets { } else {
// ignore errors here - a charset may not exist cmdSet.WriteString(", ")
if mc.cfg.Collation != "" {
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
} else {
err = mc.exec("SET NAMES " + cs)
}
if err == nil {
break
}
}
if err != nil {
return
}
// Other system vars accumulated in a single SET command
default:
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteString(", ")
}
cmdSet.WriteString(param)
cmdSet.WriteString(" = ")
cmdSet.WriteString(val)
} }
cmdSet.WriteString(param)
cmdSet.WriteString(" = ")
cmdSet.WriteString(val)
} }
if cmdSet.Len() > 0 { if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String()) err = mc.exec(cmdSet.String())
if err != nil {
return
}
} }
return return
} }
// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn.
// This function is used to return driver.ErrBadConn only when safe to retry.
func (mc *mysqlConn) markBadConn(err error) error { func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil { if err == errBadConnNoWrite {
return err return driver.ErrBadConn
} }
if err != errBadConnNoWrite { return err
return err
}
return driver.ErrBadConn
} }
func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Begin() (driver.Tx, error) {
@ -114,7 +140,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() { if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
var q string var q string
@ -135,10 +160,14 @@ func (mc *mysqlConn) Close() (err error) {
if !mc.closed.Load() { if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit) err = mc.writeCommandPacket(comQuit)
} }
mc.close()
return
}
// close closes the network connection and clear results without sending COM_QUIT.
func (mc *mysqlConn) close() {
mc.cleanup() mc.cleanup()
mc.clearResult() mc.clearResult()
return
} }
// Closes the network connection and unsets internal variables. Do not call this // Closes the network connection and unsets internal variables. Do not call this
@ -157,7 +186,7 @@ func (mc *mysqlConn) cleanup() {
return return
} }
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
mc.log(err) mc.log("closing connection:", err)
} }
// This function can be called from multiple goroutines. // This function can be called from multiple goroutines.
// So we can not mc.clearResult() here. // So we can not mc.clearResult() here.
@ -176,7 +205,6 @@ func (mc *mysqlConn) error() error {
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() { if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
@ -217,8 +245,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer() buf, err := mc.buf.takeCompleteBuffer()
if err != nil { if err != nil {
// can not take the buffer. Something must be wrong with the connection // can not take the buffer. Something must be wrong with the connection
mc.log(err) mc.cleanup()
return "", ErrInvalidConn // interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
} }
buf = buf[:0] buf = buf[:0]
argPos := 0 argPos := 0
@ -309,7 +339,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() { if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
if len(args) != 0 { if len(args) != 0 {
@ -369,7 +398,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
handleOk := mc.clearResult() handleOk := mc.clearResult()
if mc.closed.Load() { if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
if len(args) != 0 { if len(args) != 0 {
@ -385,31 +413,34 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
} }
// Send command // Send command
err := mc.writeCommandPacketStr(comQuery, query) err := mc.writeCommandPacketStr(comQuery, query)
if err == nil { if err != nil {
// Read Result return nil, mc.markBadConn(err)
var resLen int }
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 { // Read Result
rows.rs.done = true var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
switch err := rows.NextResultSet(); err { rows := new(textRows)
case nil, io.EOF: rows.mc = mc
return rows, nil
default:
return nil, err
}
}
// Columns if resLen == 0 {
rows.rs.columns, err = mc.readColumns(resLen) rows.rs.done = true
return rows, err
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
} }
} }
return nil, mc.markBadConn(err)
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
} }
// Gets the value of the given MySQL System Variable // Gets the value of the given MySQL System Variable
@ -443,7 +474,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
return nil, err return nil, err
} }
// finish is called when the query has canceled. // cancel is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) { func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err) mc.canceled.Set(err)
mc.cleanup() mc.cleanup()
@ -464,7 +495,6 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface // Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) { func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() { if mc.closed.Load() {
mc.log(ErrInvalidConn)
return driver.ErrBadConn return driver.ErrBadConn
} }
@ -650,7 +680,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter. // ResetSession implements driver.SessionResetter.
// (From Go 1.10) // (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error { func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() { if mc.closed.Load() || mc.buf.busy() {
return driver.ErrBadConn return driver.ErrBadConn
} }
@ -684,5 +714,8 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
// IsValid implements driver.Validator interface // IsValid implements driver.Validator interface
// (From Go 1.15) // (From Go 1.15)
func (mc *mysqlConn) IsValid() bool { func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load() return !mc.closed.Load() && !mc.buf.busy()
} }
var _ driver.SessionResetter = &mysqlConn{}
var _ driver.Validator = &mysqlConn{}

View file

@ -11,6 +11,7 @@ package mysql
import ( import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"fmt"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -87,20 +88,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.parseTime = mc.cfg.ParseTime mc.parseTime = mc.cfg.ParseTime
// Connect to Server // Connect to Server
dialsLock.RLock() dctx := ctx
dial, ok := dials[mc.cfg.Net] if mc.cfg.Timeout > 0 {
dialsLock.RUnlock() var cancel context.CancelFunc
if ok { dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
dctx := ctx defer cancel()
if mc.cfg.Timeout > 0 { }
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) if c.cfg.DialFunc != nil {
defer cancel() mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else { } else {
nd := net.Dialer{Timeout: mc.cfg.Timeout} dialsLock.RLock()
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{}
mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr)
}
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -122,11 +128,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
} }
defer mc.finish() defer mc.finish()
mc.buf = newBuffer(mc.netConn) mc.buf = newBuffer()
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet // Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket() authData, plugin, err := mc.readHandshakePacket()
@ -165,6 +167,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err return nil, err
} }
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
mc.compress = true
mc.compIO = newCompIO(mc)
}
if mc.cfg.MaxAllowedPacket > 0 { if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else { } else {
@ -174,12 +180,36 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.Close() mc.Close()
return nil, err return nil, err
} }
mc.maxAllowedPacket = stringToInt(maxap) - 1 n, err := strconv.Atoi(string(maxap))
if err != nil {
mc.Close()
return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err)
}
mc.maxAllowedPacket = n - 1
} }
if mc.maxAllowedPacket < maxPacketSize { if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket mc.maxWriteSize = mc.maxAllowedPacket
} }
// Charset: character_set_connection, character_set_client, character_set_results
if len(mc.cfg.charsets) > 0 {
for _, cs := range mc.cfg.charsets {
// ignore errors here - a charset may not exist
if mc.cfg.Collation != "" {
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
} else {
err = mc.exec("SET NAMES " + cs)
}
if err == nil {
break
}
}
if err != nil {
mc.Close()
return nil, err
}
}
// Handle DSN Params // Handle DSN Params
err = mc.handleParams() err = mc.handleParams()
if err != nil { if err != nil {

View file

@ -11,6 +11,8 @@ package mysql
import "runtime" import "runtime"
const ( const (
debug = false // for debugging. Set true only in development.
defaultAuthPlugin = "mysql_native_password" defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10 minProtocolVersion = 10
@ -125,7 +127,10 @@ const (
fieldTypeBit fieldTypeBit
) )
const ( const (
fieldTypeJSON fieldType = iota + 0xf5 fieldTypeVector fieldType = iota + 0xf2
fieldTypeInvalid
fieldTypeBool
fieldTypeJSON
fieldTypeNewDecimal fieldTypeNewDecimal
fieldTypeEnum fieldTypeEnum
fieldTypeSet fieldTypeSet

View file

@ -44,7 +44,8 @@ type Config struct {
DBName string // Database name DBName string // Database name
Params map[string]string // Connection parameters Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation charsets []string // Connection charset. When set, this will be set in SET NAMES <charset> query
Collation string // Connection collation. When set, this will be set in SET NAMES <charset> COLLATE <collation> query
Loc *time.Location // Location for time.Time values Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name ServerPubKey string // Server public key name
@ -54,6 +55,8 @@ type Config struct {
ReadTimeout time.Duration // I/O read timeout ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger Logger Logger // Logger
// DialFunc specifies the dial function for creating connections
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// boolean fields // boolean fields
@ -70,7 +73,10 @@ type Config struct {
ParseTime bool // Parse time values to time.Time ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections RejectReadOnly bool // Reject read-only connections
// unexported fields. new options should be come here // unexported fields. new options should be come here.
// boolean first. alphabetical order.
compress bool // Enable zlib compression
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key pubKey *rsa.PublicKey // Server public key
@ -90,7 +96,6 @@ func NewConfig() *Config {
AllowNativePasswords: true, AllowNativePasswords: true,
CheckConnLiveness: true, CheckConnLiveness: true,
} }
return cfg return cfg
} }
@ -122,6 +127,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option {
} }
} }
// EnableCompress sets the compression mode.
func EnableCompression(yes bool) Option {
return func(cfg *Config) error {
cfg.compress = yes
return nil
}
}
func (cfg *Config) Clone() *Config { func (cfg *Config) Clone() *Config {
cp := *cfg cp := *cfg
if cp.TLS != nil { if cp.TLS != nil {
@ -282,6 +295,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
} }
if charsets := cfg.charsets; len(charsets) > 0 {
writeDSNParam(&buf, &hasParam, "charset", strings.Join(charsets, ","))
}
if col := cfg.Collation; col != "" { if col := cfg.Collation; col != "" {
writeDSNParam(&buf, &hasParam, "collation", col) writeDSNParam(&buf, &hasParam, "collation", col)
} }
@ -290,6 +307,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
} }
if cfg.compress {
writeDSNParam(&buf, &hasParam, "compress", "true")
}
if cfg.InterpolateParams { if cfg.InterpolateParams {
writeDSNParam(&buf, &hasParam, "interpolateParams", "true") writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
} }
@ -501,6 +522,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value) return errors.New("invalid bool value: " + value)
} }
// charset
case "charset":
cfg.charsets = strings.Split(value, ",")
// Collation // Collation
case "collation": case "collation":
cfg.Collation = value cfg.Collation = value
@ -514,7 +539,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// Compression // Compression
case "compress": case "compress":
return errors.New("compression not implemented yet") var isBool bool
cfg.compress, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Enable client side placeholder substitution // Enable client side placeholder substitution
case "interpolateParams": case "interpolateParams":

View file

@ -32,12 +32,12 @@ var (
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend. // to trigger a resend. Use mc.markBadConn(err) to do this.
// See https://github.com/go-sql-driver/mysql/pull/302 // See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection") errBadConnNoWrite = errors.New("bad connection")
) )
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))
// Logger is used to log critical error messages. // Logger is used to log critical error messages.
type Logger interface { type Logger interface {

View file

@ -112,6 +112,8 @@ func (mf *mysqlField) typeDatabaseName() string {
return "VARCHAR" return "VARCHAR"
case fieldTypeYear: case fieldTypeYear:
return "YEAR" return "YEAR"
case fieldTypeVector:
return "VECTOR"
default: default:
return "" return ""
} }
@ -198,7 +200,7 @@ func (mf *mysqlField) scanType() reflect.Type {
return scanTypeNullFloat return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeVector:
if mf.charSet == binaryCollationID { if mf.charSet == binaryCollationID {
return scanTypeBytes return scanTypeBytes
} }

View file

@ -17,7 +17,7 @@ import (
) )
var ( var (
fileRegister map[string]bool fileRegister map[string]struct{}
fileRegisterLock sync.RWMutex fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex readerRegisterLock sync.RWMutex
@ -37,10 +37,10 @@ func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock() fileRegisterLock.Lock()
// lazy map init // lazy map init
if fileRegister == nil { if fileRegister == nil {
fileRegister = make(map[string]bool) fileRegister = make(map[string]struct{})
} }
fileRegister[strings.Trim(filePath, `"`)] = true fileRegister[strings.Trim(filePath, `"`)] = struct{}{}
fileRegisterLock.Unlock() fileRegisterLock.Unlock()
} }
@ -95,7 +95,6 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a
func (mc *okHandler) handleInFileRequest(name string) (err error) { func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader var rdr io.Reader
var data []byte
packetSize := defaultPacketSize packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize { if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize packetSize = mc.maxWriteSize
@ -124,9 +123,9 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) {
} else { // File } else { // File
name = strings.Trim(name, `"`) name = strings.Trim(name, `"`)
fileRegisterLock.RLock() fileRegisterLock.RLock()
fr := fileRegister[name] _, exists := fileRegister[name]
fileRegisterLock.RUnlock() fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr { if mc.cfg.AllowAllFiles || exists {
var file *os.File var file *os.File
var fi os.FileInfo var fi os.FileInfo
@ -147,9 +146,11 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) {
} }
// send content packets // send content packets
var data []byte
// if packetSize == 0, the Reader contains no data // if packetSize == 0, the Reader contains no data
if err == nil && packetSize > 0 { if err == nil && packetSize > 0 {
data := make([]byte, 4+packetSize) data = make([]byte, 4+packetSize)
var n int var n int
for err == nil { for err == nil {
n, err = rdr.Read(data[4:]) n, err = rdr.Read(data[4:])
@ -171,6 +172,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) {
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr return ioErr
} }
mc.conn().syncSequence()
// read OK packet // read OK packet
if err == nil { if err == nil {

View file

@ -21,36 +21,56 @@ import (
"time" "time"
) )
// Packets documentation: // MySQL client/server protocol documentations.
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html // https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
// https://mariadb.com/kb/en/clientserver-protocol/
// Read packet to buffer 'data' // Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) { func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte var prevData []byte
invalidSequence := false
readNext := mc.buf.readNext
if mc.compress {
readNext = mc.compIO.readNext
}
for { for {
// read packet header // read packet header
data, err := mc.buf.readNext(4) data, err := readNext(4, mc.readWithTimeout)
if err != nil { if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
mc.log(err) mc.log(err)
mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
// packet length [24 bit] // packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) pktLen := getUint24(data[:3])
seq := data[3]
// check packet sync [8 bit] if mc.compress {
if data[3] != mc.sequence { // MySQL and MariaDB doesn't check packet nr in compressed packet.
mc.Close() if debug && seq != mc.compressSequence {
if data[3] > mc.sequence { fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v",
return nil, ErrPktSyncMul mc.compressSequence, seq)
} }
return nil, ErrPktSync mc.compressSequence = seq + 1
} else {
// check packet sync [8 bit]
if seq != mc.sequence {
mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq))
// For large packets, we stop reading as soon as sync error.
if len(prevData) > 0 {
mc.close()
return nil, ErrPktSyncMul
}
invalidSequence = true
}
mc.sequence++
} }
mc.sequence++
// packets with length 0 terminate a previous packet which is a // packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long // multiple of (2^24)-1 bytes long
@ -58,32 +78,38 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// there was no previous packet // there was no previous packet
if prevData == nil { if prevData == nil {
mc.log(ErrMalformPkt) mc.log(ErrMalformPkt)
mc.Close() mc.close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
return prevData, nil return prevData, nil
} }
// read packet body [pktLen bytes] // read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen) data, err = readNext(pktLen, mc.readWithTimeout)
if err != nil { if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
mc.log(err) mc.log(err)
mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
// return data if this was the last packet // return data if this was the last packet
if pktLen < maxPacketSize { if pktLen < maxPacketSize {
// zero allocations for non-split packets // zero allocations for non-split packets
if prevData == nil { if prevData != nil {
return data, nil data = append(prevData, data...)
} }
if invalidSequence {
return append(prevData, data...), nil mc.close()
// return sync error only for regular packet.
// error packets may have wrong sequence number.
if data[0] != iERR {
return nil, ErrPktSync
}
}
return data, nil
} }
prevData = append(prevData, data...) prevData = append(prevData, data...)
@ -93,60 +119,52 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// Write packet buffer 'data' // Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4 pktLen := len(data) - 4
if pktLen > mc.maxAllowedPacket { if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge return ErrPktTooLarge
} }
writeFunc := mc.writeWithTimeout
if mc.compress {
writeFunc = mc.compIO.writePackets
}
for { for {
var size int size := min(maxPacketSize, pktLen)
if pktLen >= maxPacketSize { putUint24(data[:3], size)
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
data[3] = mc.sequence data[3] = mc.sequence
// Write packet // Write packet
if mc.writeTimeout > 0 { if debug {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
return err
}
} }
n, err := mc.netConn.Write(data[:4+size]) n, err := writeFunc(data[:4+size])
if err == nil && n == 4+size { if err != nil {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}
// Handle error
if err == nil { // n != len(data)
mc.cleanup() mc.cleanup()
mc.log(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return cerr return cerr
} }
if n == 0 && pktLen == len(data)-4 { if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet // only for the first loop iteration when nothing was written yet
mc.log(err)
return errBadConnNoWrite return errBadConnNoWrite
} else {
return err
} }
mc.cleanup()
mc.log(err)
} }
return ErrInvalidConn if n != 4+size {
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
mc.cleanup()
return io.ErrShortWrite
}
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
} }
} }
@ -159,11 +177,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket() data, err = mc.readPacket()
if err != nil { if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return return
} }
@ -207,10 +220,13 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos { if len(data) > pos {
// character set [1 byte] // character set [1 byte]
// status flags [2 bytes] // status flags [2 bytes]
pos += 3
// capability flags (upper 2 bytes) [2 bytes] // capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2
// length of auth-plugin-data [1 byte] // length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes] // reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10 pos += 11
// second part of the password cipher [minimum 13 bytes], // second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8) // where len=MAX(13, length of auth-plugin-data - 8)
@ -258,13 +274,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles | clientLocalFiles |
clientPluginAuth | clientPluginAuth |
clientMultiResults | clientMultiResults |
clientConnectAttrs | mc.flags&clientConnectAttrs |
mc.flags&clientLongFlag mc.flags&clientLongFlag
sendConnectAttrs := mc.flags&clientConnectAttrs != 0
if mc.cfg.ClientFoundRows { if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows clientFlags |= clientFoundRows
} }
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
clientFlags |= clientCompress
}
// To enable TLS / SSL // To enable TLS / SSL
if mc.cfg.TLS != nil { if mc.cfg.TLS != nil {
clientFlags |= clientSSL clientFlags |= clientSSL
@ -293,43 +313,37 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
} }
// encode length of the connection attributes // encode length of the connection attributes
var connAttrsLEIBuf [9]byte var connAttrsLEI []byte
connAttrsLen := len(mc.connector.encodedAttributes) if sendConnectAttrs {
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) var connAttrsLEIBuf [9]byte
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
}
// Calculate packet length and get buffer with that size // Calculate packet length and get buffer with that size
data, err := mc.buf.takeBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection mc.cleanup()
mc.log(err) return err
return errBadConnNoWrite
} }
// ClientFlags [32 bit] // ClientFlags [32 bit]
data[4] = byte(clientFlags) binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
// MaxPacketSize [32 bit] (none) // MaxPacketSize [32 bit] (none)
data[8] = 0x00 binary.LittleEndian.PutUint32(data[8:], 0)
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
// Collation ID [1 byte] // Collation ID [1 byte]
cname := mc.cfg.Collation data[12] = defaultCollationID
if cname == "" { if cname := mc.cfg.Collation; cname != "" {
cname = defaultCollation colID, ok := collations[cname]
} if ok {
var found bool data[12] = colID
data[12], found = collations[cname] } else if len(mc.cfg.charsets) > 0 {
if !found { // When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
// Note possibility for false negatives: return fmt.Errorf("unknown collation: %q", cname)
// could be triggered although the collation is valid if the }
// collations map does not contain entries the server supports.
return fmt.Errorf("unknown collation: %q", cname)
} }
// Filler [23 bytes] (all 0x00) // Filler [23 bytes] (all 0x00)
@ -349,10 +363,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Switch to TLS // Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
return err return err
} }
mc.netConn = tlsConn mc.netConn = tlsConn
mc.buf.nc = tlsConn
} }
// User [null terminated string] // User [null terminated string]
@ -378,8 +394,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pos++ pos++
// Connection Attributes // Connection Attributes
pos += copy(data[pos:], connAttrsLEI) if sendConnectAttrs {
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
}
// Send Auth packet // Send Auth packet
return mc.writePacket(data[:pos]) return mc.writePacket(data[:pos])
@ -388,11 +406,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData) pktLen := 4 + len(authData)
data, err := mc.buf.takeSmallBuffer(pktLen) data, err := mc.buf.takeBuffer(pktLen)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection mc.cleanup()
mc.log(err) return err
return errBadConnNoWrite
} }
// Add the auth data [EOF] // Add the auth data [EOF]
@ -406,13 +423,11 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1) data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
mc.log(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
@ -424,14 +439,12 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
pktLen := 1 + len(arg) pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
mc.log(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
@ -441,28 +454,25 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
copy(data[5:], arg) copy(data[5:], arg)
// Send CMD packet // Send CMD packet
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
mc.log(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
data[4] = command data[4] = command
// Add arg [32 bit] // Add arg [32 bit]
data[5] = byte(arg) binary.LittleEndian.PutUint32(data[5:], arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
// Send CMD packet // Send CMD packet
return mc.writePacket(data) return mc.writePacket(data)
@ -500,6 +510,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
} }
plugin := string(data[1:pluginEndIndex]) plugin := string(data[1:pluginEndIndex])
authData := data[pluginEndIndex+1:] authData := data[pluginEndIndex+1:]
if len(authData) > 0 && authData[len(authData)-1] == 0 {
authData = authData[:len(authData)-1]
}
return authData, plugin, nil return authData, plugin, nil
default: // Error otherwise default: // Error otherwise
@ -521,32 +534,33 @@ func (mc *okHandler) readResultOK() error {
} }
// Result Set Header Packet // Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
func (mc *okHandler) readResultSetHeaderPacket() (int, error) { func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged. // handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0) mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket() data, err := mc.conn().readPacket()
if err == nil { if err != nil {
switch data[0] { return 0, err
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
} }
return 0, err
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
} }
// Error Packet // Error Packet
@ -563,7 +577,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // 1836: ER_READ_ONLY_MODE
if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
// Oops; we are connected to a read-only connection, and won't be able // Oops; we are connected to a read-only connection, and won't be able
// to issue any write statements. Since RejectReadOnly is configured, // to issue any write statements. Since RejectReadOnly is configured,
// we throw away this connection hoping this one would have write // we throw away this connection hoping this one would have write
@ -930,19 +945,15 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
pktLen = dataOffset + argLen pktLen = dataOffset + argLen
} }
stmt.mc.sequence = 0 stmt.mc.resetSequence()
// Add command byte [1 byte] // Add command byte [1 byte]
data[4] = comStmtSendLongData data[4] = comStmtSendLongData
// Add stmtID [32 bit] // Add stmtID [32 bit]
data[5] = byte(stmt.id) binary.LittleEndian.PutUint32(data[5:], stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit] // Add paramID [16 bit]
data[9] = byte(paramID) binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
data[10] = byte(paramID >> 8)
// Send CMD packet // Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen]) err := stmt.mc.writePacket(data[:4+pktLen])
@ -951,11 +962,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
continue continue
} }
return err return err
} }
// Reset Packet Sequence // Reset Packet Sequence
stmt.mc.sequence = 0 stmt.mc.resetSequence()
return nil return nil
} }
@ -980,7 +990,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
} }
// Reset packet-sequence // Reset packet-sequence
mc.sequence = 0 mc.resetSequence()
var data []byte var data []byte
var err error var err error
@ -992,28 +1002,20 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In this case the len(data) == cap(data) which is used to optimise the flow below. // In this case the len(data) == cap(data) which is used to optimise the flow below.
} }
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
mc.log(err)
return errBadConnNoWrite
} }
// command [1 byte] // command [1 byte]
data[4] = comStmtExecute data[4] = comStmtExecute
// statement_id [4 bytes] // statement_id [4 bytes]
data[5] = byte(stmt.id) binary.LittleEndian.PutUint32(data[5:], stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00 data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes] // iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01 binary.LittleEndian.PutUint32(data[10:], 1)
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 { if len(args) > 0 {
pos := minPktLen pos := minPktLen
@ -1067,50 +1069,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
case int64: case int64:
paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00 paramTypes[i+i+1] = 0x00
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case uint64: case uint64:
paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x80 // type is unsigned paramTypes[i+i+1] = 0x80 // type is unsigned
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64: case float64:
paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00 paramTypes[i+i+1] = 0x00
paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool: case bool:
paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i] = byte(fieldTypeTiny)
@ -1191,17 +1160,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer // In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) { if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...) data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil { mc.buf.store(data) // allow this buffer to be reused
mc.log(err)
return errBadConnNoWrite
}
} }
pos += len(paramValues) pos += len(paramValues)
data = data[:pos] data = data[:pos]
} }
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
// For each remaining resultset in the stream, discards its rows and updates // For each remaining resultset in the stream, discards its rows and updates
@ -1325,7 +1293,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeVector:
var isNull bool var isNull bool
var n int var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) dest[i], isNull, n, err = readLengthEncodedString(data[pos:])

View file

@ -111,13 +111,6 @@ func (rows *mysqlRows) Close() (err error) {
return err return err
} }
// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()
// Remove unread packets from stream // Remove unread packets from stream
if !rows.rs.done { if !rows.rs.done {
err = mc.readUntilEOF() err = mc.readUntilEOF()

View file

@ -24,11 +24,12 @@ type mysqlStmt struct {
func (stmt *mysqlStmt) Close() error { func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.Load() { if stmt.mc == nil || stmt.mc.closed.Load() {
// driver.Stmt.Close can be called more than once, thus this function // driver.Stmt.Close could be called more than once, thus this function
// has to be idempotent. // had to be idempotent. See also Issue #450 and golang/go#16019.
// See also Issue #450 and golang/go#16019. // This bug has been fixed in Go 1.8.
//errLog.Print(ErrInvalidConn) // https://github.com/golang/go/commit/90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a
return driver.ErrBadConn // But we keep this function idempotent because it is safer.
return nil
} }
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
@ -51,7 +52,6 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
@ -95,7 +95,6 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command

View file

@ -490,17 +490,16 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
* Convert from and to bytes * * Convert from and to bytes *
******************************************************************************/ ******************************************************************************/
func uint64ToBytes(n uint64) []byte { // 24bit integer: used for packet headers.
return []byte{
byte(n), func putUint24(data []byte, n int) {
byte(n >> 8), data[2] = byte(n >> 16)
byte(n >> 16), data[1] = byte(n >> 8)
byte(n >> 24), data[0] = byte(n)
byte(n >> 32), }
byte(n >> 40),
byte(n >> 48), func getUint24(data []byte) int {
byte(n >> 56), return int(data[2])<<16 | int(data[1])<<8 | int(data[0])
}
} }
func uint64ToString(n uint64) []byte { func uint64ToString(n uint64) []byte {
@ -525,16 +524,6 @@ func uint64ToString(n uint64) []byte {
return a[i:] return a[i:]
} }
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, whether the value is NULL, // returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than // the number of bytes read and an error, in case the string is longer than
// the input slice // the input slice
@ -586,18 +575,15 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// 252: value of following 2 // 252: value of following 2
case 0xfc: case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3 return uint64(binary.LittleEndian.Uint16(b[1:])), false, 3
// 253: value of following 3 // 253: value of following 3
case 0xfd: case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 return uint64(getUint24(b[1:])), false, 4
// 254: value of following 8 // 254: value of following 8
case 0xfe: case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
} }
// 0-250: value of first byte // 0-250: value of first byte
@ -611,13 +597,14 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
return append(b, byte(n)) return append(b, byte(n))
case n <= 0xffff: case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8)) b = append(b, 0xfc)
return binary.LittleEndian.AppendUint16(b, uint16(n))
case n <= 0xffffff: case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
} }
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), b = append(b, 0xfe)
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) return binary.LittleEndian.AppendUint64(b, n)
} }
func appendLengthEncodedString(b []byte, s string) []byte { func appendLengthEncodedString(b []byte, s string) []byte {

4
vendor/modules.txt vendored
View file

@ -150,8 +150,8 @@ github.com/cespare/xxhash/v2
# github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f # github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f
## explicit ## explicit
github.com/dgryski/go-rendezvous github.com/dgryski/go-rendezvous
# github.com/go-sql-driver/mysql v1.8.1 # github.com/go-sql-driver/mysql v1.9.0
## explicit; go 1.18 ## explicit; go 1.21
github.com/go-sql-driver/mysql github.com/go-sql-driver/mysql
# github.com/golang/protobuf v1.5.4 # github.com/golang/protobuf v1.5.4
## explicit; go 1.17 ## explicit; go 1.17