diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b1c1f2b3..2e07fea9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: my-cnf: | innodb_log_file_size=256MB innodb_buffer_pool_size=512MB - max_allowed_packet=16MB + max_allowed_packet=48MB ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 diff --git a/AUTHORS b/AUTHORS index a9860850..e117441a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -21,6 +21,7 @@ Animesh Ray Arne Hormann Ariel Mashraki Asta Xie +B Lamarche Brian Hendriks Bulat Gaifullin Caine Jette @@ -62,6 +63,7 @@ Jennifer Purevsuren Jerome Meyer Jiajia Zhong Jian Zhen +Joe Mann Joshua Prunier Julien Lefevre Julien Schmidt diff --git a/README.md b/README.md index e9d9222b..da4593cc 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation + * Supports zlib compression. ## Requirements @@ -267,6 +268,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Toggles zlib compression. false by default. + ##### `interpolateParams` ``` diff --git a/benchmark_test.go b/benchmark_test.go index a4ecc0a6..5c9a046b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,9 +46,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open(driverNameTest, dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -60,10 +64,18 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -222,7 +234,7 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), + buf: newBuffer(), } args := []driver.Value{ @@ -269,7 +281,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -305,7 +317,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -323,7 +335,7 @@ func BenchmarkExecContext(b *testing.B) { // "size=" means size of each blobs. func BenchmarkQueryRawBytes(b *testing.B) { var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS bench_rawbytes", "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", ) @@ -376,7 +388,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { // BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. func BenchmarkReceiveMassiveRows(b *testing.B) { // Setup -- prepare 10000 rows. - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() diff --git a/buffer.go b/buffer.go index dd82c931..a6532431 100644 --- a/buffer.go +++ b/buffer.go @@ -10,13 +10,16 @@ package mysql import ( "io" - "net" - "time" ) const defaultBufSize = 4096 const maxCachedBufSize = 256 * 1024 +// readerFunc is a function that compatible with io.Reader. +// We use this function type instead of io.Reader because we want to +// just pass mc.readWithTimeout. +type readerFunc func([]byte) (int, error) + // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. @@ -25,15 +28,12 @@ const maxCachedBufSize = 256 * 1024 type buffer struct { buf []byte // read buffer. cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. - nc net.Conn - timeout time.Duration } // newBuffer allocates and returns a new buffer. -func newBuffer(nc net.Conn) buffer { +func newBuffer() buffer { return buffer{ cachedBuf: make([]byte, defaultBufSize), - nc: nc, } } @@ -43,7 +43,7 @@ func (b *buffer) busy() bool { } // fill reads into the read buffer until at least _need_ bytes are in it. -func (b *buffer) fill(need int) error { +func (b *buffer) fill(need int, r readerFunc) error { // we'll move the contents of the current buffer to dest before filling it. dest := b.cachedBuf @@ -64,13 +64,7 @@ func (b *buffer) fill(need int) error { copy(dest[:n], b.buf) for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { - return err - } - } - - nn, err := b.nc.Read(dest[n:]) + nn, err := r(dest[n:]) n += nn if err == nil && n < need { @@ -92,10 +86,10 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { +func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { if len(b.buf) < need { // refill - if err := b.fill(need); err != nil { + if err := b.fill(need, r); err != nil { return nil, err } } diff --git a/compress.go b/compress.go new file mode 100644 index 00000000..fa42772a --- /dev/null +++ b/compress.go @@ -0,0 +1,214 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" +) + +var ( + zrPool *sync.Pool // Do not use directly. Use zDecompress() instead. + zwPool *sync.Pool // Do not use directly. Use zCompress() instead. +) + +func init() { + zrPool = &sync.Pool{ + New: func() any { return nil }, + } + zwPool = &sync.Pool{ + New: func() any { + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } + return zw + }, + } +} + +func zDecompress(src []byte, dst *bytes.Buffer) (int, error) { + br := bytes.NewReader(src) + var zr io.ReadCloser + var err error + + if a := zrPool.Get(); a == nil { + if zr, err = zlib.NewReader(br); err != nil { + return 0, err + } + } else { + zr = a.(io.ReadCloser) + if err := zr.(zlib.Resetter).Reset(br, nil); err != nil { + return 0, err + } + } + + n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again. + err = zr.Close() // zr.Close() may return chuecksum error. + zrPool.Put(zr) + return int(n), err +} + +func zCompress(src []byte, dst io.Writer) error { + zw := zwPool.Get().(*zlib.Writer) + zw.Reset(dst) + if _, err := zw.Write(src); err != nil { + return err + } + err := zw.Close() + zwPool.Put(zw) + return err +} + +type compIO struct { + mc *mysqlConn + buff bytes.Buffer +} + +func newCompIO(mc *mysqlConn) *compIO { + return &compIO{ + mc: mc, + } +} + +func (c *compIO) reset() { + c.buff.Reset() +} + +func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { + for c.buff.Len() < need { + if err := c.readCompressedPacket(r); err != nil { + return nil, err + } + } + data := c.buff.Next(need) + return data[:need:need], nil // prevent caller writes into c.buff +} + +func (c *compIO) readCompressedPacket(r readerFunc) error { + header, err := c.mc.buf.readNext(7, r) // size of compressed header + if err != nil { + return err + } + _ = header[6] // bounds check hint to compiler; guaranteed by readNext + + // compressed header structure + comprLength := getUint24(header[0:3]) + compressionSequence := uint8(header[3]) + uncompressedLength := getUint24(header[4:7]) + if debug { + fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence) + } + // Do not return ErrPktSync here. + // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it. + if debug && compressionSequence != c.mc.sequence { + fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) + } + c.mc.sequence = compressionSequence + 1 + c.mc.compressSequence = c.mc.sequence + + comprData, err := c.mc.buf.readNext(comprLength, r) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + c.buff.Write(comprData) + return nil + } + + // use existing capacity in bytesBuf if possible + c.buff.Grow(uncompressedLength) + nread, err := zDecompress(comprData, &c.buff) + if err != nil { + return err + } + if nread != uncompressedLength { + return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", + uncompressedLength, nread) + } + return nil +} + +const minCompressLength = 150 +const maxPayloadLen = maxPacketSize - 4 + +// writePackets sends one or some packets with compression. +// Use this instead of mc.netConn.Write() when mc.compress is true. +func (c *compIO) writePackets(packets []byte) (int, error) { + totalBytes := len(packets) + blankHeader := make([]byte, 7) + buf := &c.buff + + for len(packets) > 0 { + payloadLen := min(maxPayloadLen, len(packets)) + payload := packets[:payloadLen] + uncompressedLen := payloadLen + + buf.Reset() + buf.Write(blankHeader) // Buffer.Write() never returns error + + // If payload is less than minCompressLength, don't compress. + if uncompressedLen < minCompressLength { + buf.Write(payload) + uncompressedLen = 0 + } else { + err := zCompress(payload, buf) + if debug && err != nil { + fmt.Printf("zCompress error: %v", err) + } + // do not compress if compressed data is larger than uncompressed data + // I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes. + if err != nil || buf.Len() >= uncompressedLen { + buf.Reset() + buf.Write(blankHeader) + buf.Write(payload) + uncompressedLen = 0 + } + } + + if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + // To allow returning ErrBadConn when sending really 0 bytes, we sum + // up compressed bytes that is returned by underlying Write(). + return totalBytes - len(packets) + n, err + } + packets = packets[payloadLen:] + } + + return totalBytes, nil +} + +// writeCompressedPacket writes a compressed packet with header. +// data should start with 7 size space for header followed by payload. +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) { + mc := c.mc + comprLength := len(data) - 7 + if debug { + fmt.Printf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, mc.compressSequence) + } + + // compression header + putUint24(data[0:3], comprLength) + data[3] = mc.compressSequence + putUint24(data[4:7], uncompressedLen) + + mc.compressSequence++ + return mc.writeWithTimeout(data) +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 00000000..030deaef --- /dev/null +++ b/compress_test.go @@ -0,0 +1,119 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + conn := new(mockConn) + mc.netConn = conn + + err := mc.writePacket(append(make([]byte, 4), uncompressedPacket...)) + if err != nil { + t.Fatal(err) + } + + return conn.written +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []byte { + // mocking out buf variable + conn := new(mockConn) + conn.data = compressedPacket + mc.netConn = conn + + uncompressedPacket, err := mc.readPacket() + if err != nil { + if err != io.EOF { + t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) + } + } + return uncompressedPacket +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed) +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000), + desc: "100 rand * 10000 repeat bytes", + }, + } + + _, cSend := newRWMockConn(0) + cSend.compress = true + cSend.compIO = newCompIO(cSend) + _, cReceive := newRWMockConn(0) + cReceive.compress = true + cReceive.compIO = newCompIO(cReceive) + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + cSend.resetSequence() + cReceive.resetSequence() + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if len(uncompressed) != len(test.uncompressed) { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", + len(test.uncompressed), len(uncompressed)) + } + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Errorf("roundtrip failed") + } + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } + }) + } +} diff --git a/connection.go b/connection.go index c220a836..db5c981b 100644 --- a/connection.go +++ b/connection.go @@ -28,15 +28,17 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO cfg *Config connector *connector maxAllowedPacket int maxWriteSize int - writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 + compressSequence uint8 parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -62,6 +64,43 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } +func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { + to := mc.cfg.ReadTimeout + if to > 0 { + if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Read(b) +} + +func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { + to := mc.cfg.WriteTimeout + if to > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Write(b) +} + +func (mc *mysqlConn) resetSequence() { + mc.sequence = 0 + mc.compressSequence = 0 +} + +// syncSequence must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequence() { + // Syncs compressionSequence to sequence. + // This is not documented but done in `net_flush()` in MySQL and MariaDB. + // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 + // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 + if mc.compress { + mc.sequence = mc.compressSequence + mc.compIO.reset() + } +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -147,7 +186,7 @@ func (mc *mysqlConn) cleanup() { return } if err := conn.Close(); err != nil { - mc.log(err) + mc.log("closing connection:", err) } // This function can be called from multiple goroutines. // So we can not mc.clearResult() here. diff --git a/connection_test.go b/connection_test.go index 6f8d2a6d..f7740898 100644 --- a/connection_test.go +++ b/connection_test.go @@ -19,7 +19,7 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -39,7 +39,7 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsJSONRawMessage(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -66,7 +66,7 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -83,7 +83,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -99,7 +99,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -161,7 +161,7 @@ func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -178,7 +178,7 @@ func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index 769b3adc..a4f3655e 100644 --- a/connector.go +++ b/connector.go @@ -127,11 +127,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } defer mc.finish() - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout + mc.buf = newBuffer() // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() @@ -170,6 +166,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + mc.compress = true + mc.compIO = newCompIO(mc) + } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/const.go b/const.go index 0cee9b2e..4aadcd64 100644 --- a/const.go +++ b/const.go @@ -11,6 +11,8 @@ package mysql import "runtime" const ( + debug = false // for debugging. Set true only in development. + defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 diff --git a/driver_test.go b/driver_test.go index 24d73c34..58b3cb38 100644 --- a/driver_test.go +++ b/driver_test.go @@ -147,12 +147,11 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db, err := sql.Open(driverNameTest, dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() - - cleanup := func() { - db.Exec("DROP TABLE IF EXISTS test") + if err = db.Ping(); err != nil { + t.Fatalf("connecting %q: %s", dsn, err) } dsn2 := dsn + "&interpolateParams=true" @@ -160,25 +159,46 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open(driverNameTest, dsn2) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn2, err) } defer db2.Close() } + dsn3 := dsn + "&compress=true" + var db3 *sql.DB + db3, err = sql.Open(driverNameTest, dsn3) + if err != nil { + t.Fatalf("connecting %q: %s", dsn3, err) + } + defer db3.Close() + + cleanupSql := "DROP TABLE IF EXISTS test" + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} - t.Cleanup(cleanup) + t.Cleanup(func() { + db.Exec(cleanupSql) + }) test(dbt) }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} - t.Cleanup(cleanup) + t.Cleanup(func() { + db2.Exec(cleanupSql) + }) test(dbt2) }) } + t.Run("compress", func(t *testing.T) { + dbt3 := &DBTest{t, db3} + t.Cleanup(func() { + db3.Exec(cleanupSql) + }) + test(dbt3) + }) } } @@ -958,12 +978,16 @@ func TestDateTime(t *testing.T) { var err error rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) if err == nil { - rows.Scan(µsecsSupported) + if rows.Next() { + rows.Scan(µsecsSupported) + } rows.Close() } rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) if err == nil { - rows.Scan(&zeroDateSupported) + if rows.Next() { + rows.Scan(&zeroDateSupported) + } rows.Close() } for _, setups := range testcases { @@ -1265,8 +1289,7 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - const nonDataQueryLen = 28 // length query w/o value - inS := in[:maxAllowedPacketSize-nonDataQueryLen] + inS := in[:maxAllowedPacketSize-100] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") defer rows.Close() diff --git a/dsn.go b/dsn.go index f391a8fc..9b560b73 100644 --- a/dsn.go +++ b/dsn.go @@ -73,7 +73,10 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - // unexported fields. new options should be come here + // unexported fields. new options should be come here. + // boolean first. alphabetical order. + + compress bool // Enable zlib compression beforeConnect func(context.Context, *Config) error // Invoked before a connection is established pubKey *rsa.PublicKey // Server public key @@ -93,7 +96,6 @@ func NewConfig() *Config { AllowNativePasswords: true, CheckConnLiveness: true, } - return cfg } @@ -125,6 +127,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option { } } +// EnableCompress sets the compression mode. +func EnableCompression(yes bool) Option { + return func(cfg *Config) error { + cfg.compress = yes + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { @@ -297,6 +307,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.compress { + writeDSNParam(&buf, &hasParam, "compress", "true") + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -525,7 +539,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/infile.go b/infile.go index cf892bea..555ef71a 100644 --- a/infile.go +++ b/infile.go @@ -172,6 +172,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } + mc.conn().syncSequence() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index 736e4418..e4d2820e 100644 --- a/packets.go +++ b/packets.go @@ -28,9 +28,16 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte + invalidSequence := false + + readNext := mc.buf.readNext + if mc.compress { + readNext = mc.compIO.readNext + } + for { // read packet header - data, err := mc.buf.readNext(4) + data, err := readNext(4, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -41,17 +48,29 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // packet length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - - // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.close() - if data[3] > mc.sequence { - return nil, ErrPktSyncMul + pktLen := getUint24(data[:3]) + seq := data[3] + + if mc.compress { + // MySQL and MariaDB doesn't check packet nr in compressed packet. + if debug && seq != mc.compressSequence { + fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", + mc.compressSequence, seq) + } + mc.compressSequence = seq + 1 + } else { + // check packet sync [8 bit] + if seq != mc.sequence { + mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq)) + // For large packets, we stop reading as soon as sync error. + if len(prevData) > 0 { + mc.close() + return nil, ErrPktSyncMul + } + invalidSequence = true } - return nil, ErrPktSync + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -62,12 +81,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.close() return nil, ErrInvalidConn } - return prevData, nil } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = readNext(pktLen, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -80,11 +98,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets - if prevData == nil { - return data, nil + if prevData != nil { + data = append(prevData, data...) } - - return append(prevData, data...), nil + if invalidSequence { + mc.close() + // return sync error only for regular packet. + // error packets may have wrong sequence number. + if data[0] != iERR { + return nil, ErrPktSync + } + } + return data, nil } prevData = append(prevData, data...) @@ -94,36 +119,26 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + writeFunc := mc.writeWithTimeout + if mc.compress { + writeFunc = mc.compIO.writePackets + } + for { - var size int - if pktLen >= maxPacketSize { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - size = maxPacketSize - } else { - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - size = pktLen - } + size := min(maxPacketSize, pktLen) + putUint24(data[:3], size) data[3] = mc.sequence // Write packet - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() - mc.log(err) - return err - } + if debug { + fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - n, err := mc.netConn.Write(data[:4+size]) + n, err := writeFunc(data[:4+size]) if err != nil { mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { @@ -267,7 +282,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + clientFlags |= clientCompress + } // To enable TLS / SSL if mc.cfg.TLS != nil { clientFlags |= clientSSL @@ -358,7 +375,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn } // User [null terminated string] @@ -413,7 +429,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { @@ -429,7 +445,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -444,12 +460,14 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -932,7 +950,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 + stmt.mc.resetSequence() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -953,11 +971,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence - stmt.mc.sequence = 0 + stmt.mc.resetSequence() return nil } @@ -982,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.sequence = 0 + mc.resetSequence() var data []byte var err error @@ -1198,7 +1215,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } // For each remaining resultset in the stream, discards its rows and updates diff --git a/packets_test.go b/packets_test.go index fa4683ea..694b0564 100644 --- a/packets_test.go +++ b/packets_test.go @@ -98,7 +98,7 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: newBuffer(), cfg: connector.cfg, connector: connector, netConn: conn, @@ -112,7 +112,9 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -143,12 +145,12 @@ func TestReadPacketWrongSequenceID(t *testing.T) { { ClientSequenceID: 0, ServerSequenceID: 0x42, - ExpectedErr: ErrPktSyncMul, + ExpectedErr: ErrPktSync, }, } { conn, mc := newRWMockConn(testCase.ClientSequenceID) - conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0x22} _, err := mc.readPacket() if err != testCase.ExpectedErr { t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) @@ -164,7 +166,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } data := make([]byte, maxPacketSize*2+4*3) @@ -269,7 +273,8 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), closech: make(chan struct{}), cfg: NewConfig(), } @@ -285,7 +290,7 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read header conn.closed = true @@ -298,7 +303,7 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read body conn.maxReads = 1 @@ -313,7 +318,8 @@ func TestReadPacketFail(t *testing.T) { func TestRegression801(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), cfg: new(Config), sequence: 42, closech: make(chan struct{}), diff --git a/utils.go b/utils.go index cda24fe7..d902f3b6 100644 --- a/utils.go +++ b/utils.go @@ -490,6 +490,18 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { * Convert from and to bytes * ******************************************************************************/ +// 24bit integer: used for packet headers. + +func putUint24(data []byte, n int) { + data[2] = byte(n >> 16) + data[1] = byte(n >> 8) + data[0] = byte(n) +} + +func getUint24(data []byte) int { + return int(data[2])<<16 | int(data[1])<<8 | int(data[0]) +} + func uint64ToBytes(n uint64) []byte { return []byte{ byte(n),