From e6c682cec452c7b20319f048fe38f1245015845f Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 31 Jul 2017 16:43:52 -0400 Subject: [PATCH 01/88] packets: implemented compression protocol --- AUTHORS | 1 + benchmark_test.go | 1 + compress.go | 217 ++++++++++++++++++++++++++++++++++++++++++++ compress_test.go | 220 +++++++++++++++++++++++++++++++++++++++++++++ connection.go | 29 +++--- connection_test.go | 4 + driver.go | 9 ++ dsn.go | 7 +- packets.go | 15 +++- packets_test.go | 6 ++ 10 files changed, 492 insertions(+), 17 deletions(-) create mode 100644 compress.go create mode 100644 compress_test.go diff --git a/AUTHORS b/AUTHORS index 5526e3e90..f137bfcc7 100644 --- a/AUTHORS +++ b/AUTHORS @@ -15,6 +15,7 @@ Aaron Hopkins Achille Roussel Arne Hormann Asta Xie +B Lamarche Bulat Gaifullin Carlos Nieto Chris Moos diff --git a/benchmark_test.go b/benchmark_test.go index 7da833a2a..460553e03 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -224,6 +224,7 @@ func BenchmarkInterpolation(b *testing.B) { maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } + mc.reader = &mc.buf args := []driver.Value{ int64(42424242), diff --git a/compress.go b/compress.go new file mode 100644 index 000000000..2349aa13b --- /dev/null +++ b/compress.go @@ -0,0 +1,217 @@ +package mysql + +import ( + "bytes" + "compress/zlib" + "io" +) + +const ( + minCompressLength = 50 +) + +type packetReader interface { + readNext(need int) ([]byte, error) +} + +type compressedReader struct { + buf packetReader + bytesBuf []byte + mc *mysqlConn +} + +type compressedWriter struct { + connWriter io.Writer + mc *mysqlConn +} + +func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { + return &compressedReader{ + buf: buf, + bytesBuf: make([]byte, 0), + mc: mc, + } +} + +func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { + return &compressedWriter{ + connWriter: connWriter, + mc: mc, + } +} + +func (cr *compressedReader) readNext(need int) ([]byte, error) { + for len(cr.bytesBuf) < need { + err := cr.uncompressPacket() + if err != nil { + return nil, err + } + } + + data := make([]byte, need) + + copy(data, cr.bytesBuf[:len(data)]) + + cr.bytesBuf = cr.bytesBuf[len(data):] + + return data, nil +} + +func (cr *compressedReader) uncompressPacket() error { + header, err := cr.buf.readNext(7) // size of compressed header + + if err != nil { + return err + } + + // compressed header structure + comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + compressionSequence := uint8(header[3]) + + if compressionSequence != cr.mc.compressionSequence { + return ErrPktSync + } + + cr.mc.compressionSequence++ + + comprData, err := cr.buf.readNext(comprLength) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + cr.bytesBuf = append(cr.bytesBuf, comprData...) + return nil + } + + // write comprData to a bytes.buffer, then read it using zlib into data + var b bytes.Buffer + b.Write(comprData) + r, err := zlib.NewReader(&b) + + if r != nil { + defer r.Close() + } + + if err != nil { + return err + } + + data := make([]byte, uncompressedLength) + lenRead := 0 + + // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate + for lenRead < uncompressedLength { + + tmp := data[lenRead:] + + n, err := r.Read(tmp) + lenRead += n + + if err == io.EOF { + if lenRead < uncompressedLength { + return io.ErrUnexpectedEOF + } + break + } + + if err != nil { + return err + } + } + + cr.bytesBuf = append(cr.bytesBuf, data...) + + return nil +} + +func (cw *compressedWriter) Write(data []byte) (int, error) { + // when asked to write an empty packet, do nothing + if len(data) == 0 { + return 0, nil + } + totalBytes := len(data) + + length := len(data) - 4 + + maxPayloadLength := maxPacketSize - 4 + + for length >= maxPayloadLength { + // cut off a slice of size max payload length + dataSmall := data[:maxPayloadLength] + lenSmall := len(dataSmall) + + var b bytes.Buffer + writer := zlib.NewWriter(&b) + _, err := writer.Write(dataSmall) + writer.Close() + if err != nil { + return 0, err + } + + err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + if err != nil { + return 0, err + } + + length -= maxPayloadLength + data = data[maxPayloadLength:] + } + + lenSmall := len(data) + + // do not compress if packet is too small + if lenSmall < minCompressLength { + err := cw.writeComprPacketToNetwork(data, 0) + if err != nil { + return 0, err + } + + return totalBytes, nil + } + + var b bytes.Buffer + writer := zlib.NewWriter(&b) + + _, err := writer.Write(data) + writer.Close() + + if err != nil { + return 0, err + } + + err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + + if err != nil { + return 0, err + } + return totalBytes, nil +} + +func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { + data = append([]byte{0, 0, 0, 0, 0, 0, 0}, data...) + + comprLength := len(data) - 7 + + // compression header + data[0] = byte(0xff & comprLength) + data[1] = byte(0xff & (comprLength >> 8)) + data[2] = byte(0xff & (comprLength >> 16)) + + data[3] = cw.mc.compressionSequence + + //this value is never greater than maxPayloadLength + data[4] = byte(0xff & uncomprLength) + data[5] = byte(0xff & (uncomprLength >> 8)) + data[6] = byte(0xff & (uncomprLength >> 16)) + + if _, err := cw.connWriter.Write(data); err != nil { + return err + } + + cw.mc.compressionSequence++ + return nil +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 000000000..c626ff3ee --- /dev/null +++ b/compress_test.go @@ -0,0 +1,220 @@ +package mysql + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +func newMockConn() *mysqlConn { + newConn := &mysqlConn{} + return newConn +} + +type mockBuf struct { + reader io.Reader +} + +func newMockBuf(reader io.Reader) *mockBuf { + return &mockBuf{ + reader: reader, + } +} + +func (mb *mockBuf) readNext(need int) ([]byte, error) { + + data := make([]byte, need) + _, err := mb.reader.Read(data) + if err != nil { + return nil, err + } + return data, nil +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + // get status variables + + cs := mc.compressionSequence + + var b bytes.Buffer + connWriter := &b + + cw := NewCompressedWriter(connWriter, mc) + + n, err := cw.Write(uncompressedPacket) + + if err != nil { + t.Fatal(err.Error()) + } + + if n != len(uncompressedPacket) { + t.Fatal(fmt.Sprintf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)) + } + + if len(uncompressedPacket) > 0 { + + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + + return b.Bytes() +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { + // get status variables + cs := mc.compressionSequence + + // mocking out buf variable + mockConnReader := bytes.NewReader(compressedPacket) + mockBuf := newMockBuf(mockConnReader) + + cr := NewCompressedReader(mockBuf, mc) + + uncompressedPacket, err := cr.readNext(expSize) + if err != nil { + if err != io.EOF { + t.Fatal(fmt.Sprintf("non-nil/non-EOF error when reading contents: %s", err.Error())) + } + } + + if expSize > 0 { + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + return uncompressedPacket +} + +// TestCompressedReaderThenWriter tests reader and writer seperately. +func TestCompressedReaderThenWriter(t *testing.T) { + + makeTestUncompressedPacket := func(size int) []byte { + uncompressedHeader := make([]byte, 4) + uncompressedHeader[0] = byte(size) + uncompressedHeader[1] = byte(size >> 8) + uncompressedHeader[2] = byte(size >> 16) + + payload := make([]byte, size) + for i := range payload { + payload[i] = 'b' + } + + uncompressedPacket := append(uncompressedHeader, payload...) + return uncompressedPacket + } + + tests := []struct { + compressed []byte + uncompressed []byte + desc string + }{ + {compressed: []byte{5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 'a'}, + uncompressed: []byte{1, 0, 0, 0, 'a'}, + desc: "a"}, + {compressed: []byte{10, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + desc: "golang"}, + {compressed: []byte{19, 0, 0, 0, 104, 0, 0, 120, 156, 74, 97, 96, 96, 72, 162, 3, 0, 4, 0, 0, 255, 255, 182, 165, 38, 173}, + uncompressed: makeTestUncompressedPacket(100), + desc: "100 bytes letter b"}, + {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, + uncompressed: makeTestUncompressedPacket(33000), + desc: "33000 bytes letter b"}, + } + + for _, test := range tests { + s := fmt.Sprintf("Test compress uncompress with %s", test.desc) + + // test uncompression only + c := newMockConn() + uncompressed := uncompressHelper(t, c, test.compressed, len(test.uncompressed)) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: uncompression failed", s)) + } + + // test compression only + c = newMockConn() + compressed := compressHelper(t, c, test.uncompressed) + if bytes.Compare(compressed, test.compressed) != 0 { + t.Fatal(fmt.Sprintf("%s: compression failed", s)) + } + } +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte{0}, + desc: "0 byte"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: make([]byte, 0), + desc: "nothing"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: makeRandByteSlice(33000), + desc: "33000 rand bytes", + }, + } + + cSend := newMockConn() + + cReceive := newMockConn() + + for _, test := range tests { + s := fmt.Sprintf("Test roundtrip with %s", test.desc) + //t.Run(s, func(t *testing.T) { + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) + } + + //}) + } +} diff --git a/connection.go b/connection.go index 2630f5211..663f94b59 100644 --- a/connection.go +++ b/connection.go @@ -28,19 +28,22 @@ type mysqlContext interface { } type mysqlConn struct { - buf buffer - netConn net.Conn - affectedRows uint64 - insertId uint64 - cfg *Config - maxAllowedPacket int - maxWriteSize int - writeTimeout time.Duration - flags clientFlag - status statusFlag - sequence uint8 - parseTime bool - strict bool + buf buffer + netConn net.Conn + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + compressionSequence uint8 + parseTime bool + strict bool + reader packetReader + writer io.Writer // for context support (Go 1.8+) watching bool diff --git a/connection_test.go b/connection_test.go index 65325f101..187c76116 100644 --- a/connection_test.go +++ b/connection_test.go @@ -21,6 +21,7 @@ func TestInterpolateParams(t *testing.T) { InterpolateParams: true, }, } + mc.reader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -41,6 +42,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { InterpolateParams: true, }, } + mc.reader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -59,6 +61,8 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { }, } + mc.reader = &mc.buf + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` if err != driver.ErrSkip { diff --git a/driver.go b/driver.go index c341b6680..691326218 100644 --- a/driver.go +++ b/driver.go @@ -94,6 +94,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) + // packet reader and writer in handshake are never compressed + mc.reader = &mc.buf + mc.writer = mc.netConn + // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout @@ -120,6 +124,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } + if mc.cfg.Compression { + mc.reader = NewCompressedReader(&mc.buf, mc) + mc.writer = NewCompressedWriter(mc.writer, mc) + } + if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/dsn.go b/dsn.go index ab2fdfc6a..626fde365 100644 --- a/dsn.go +++ b/dsn.go @@ -56,6 +56,7 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections Strict bool // Return warnings as errors + Compression bool // Compress packets } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -445,7 +446,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.Compression, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/packets.go b/packets.go index 9715067c4..fbdeb6897 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.buf.readNext(4) + data, err := mc.reader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = mc.reader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -118,7 +118,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.netConn.Write(data[:4+size]) + n, err := mc.writer.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -249,6 +249,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientFoundRows } + if mc.cfg.Compression { + clientFlags |= clientCompress + } + // To enable TLS / SSL if mc.cfg.tls != nil { clientFlags |= clientSSL @@ -314,6 +318,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } mc.netConn = tlsConn mc.buf.nc = tlsConn + + mc.writer = mc.netConn } // Filler [23 bytes] (all 0x00) @@ -416,6 +422,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { @@ -434,6 +441,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) @@ -456,6 +464,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { diff --git a/packets_test.go b/packets_test.go index 31c892d85..53752e3b8 100644 --- a/packets_test.go +++ b/packets_test.go @@ -94,6 +94,8 @@ func TestReadPacketSingleByte(t *testing.T) { buf: newBuffer(conn), } + mc.reader = &mc.buf + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 packet, err := mc.readPacket() @@ -113,6 +115,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } + mc.reader = &mc.buf // too low sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -142,6 +145,8 @@ func TestReadPacketSplit(t *testing.T) { buf: newBuffer(conn), } + mc.reader = &mc.buf + data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 const pkt3ofs = 2 * (maxPacketSize + 4) @@ -247,6 +252,7 @@ func TestReadPacketFail(t *testing.T) { buf: newBuffer(conn), closech: make(chan struct{}), } + mc.reader = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} From 77f679299d7320e2ab462df5b5d9e6913a43ca47 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 13:55:33 -0400 Subject: [PATCH 02/88] packets: implemented compression protocol CR changes --- benchmark_go18_test.go | 4 +-- benchmark_test.go | 19 +++++++++-- compress.go | 72 ++++++++++++++++++++++++++++++------------ packets.go | 3 ++ 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/benchmark_go18_test.go b/benchmark_go18_test.go index d6a7e9d6e..5522ab9cf 100644 --- a/benchmark_go18_test.go +++ b/benchmark_go18_test.go @@ -42,7 +42,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -78,7 +78,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, diff --git a/benchmark_test.go b/benchmark_test.go index 460553e03..2d690906f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -43,9 +43,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open("mysql", dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { if w, ok := err.(MySQLWarnings); ok { @@ -61,10 +65,19 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { + tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, diff --git a/compress.go b/compress.go index 2349aa13b..14a131f4d 100644 --- a/compress.go +++ b/compress.go @@ -18,6 +18,8 @@ type compressedReader struct { buf packetReader bytesBuf []byte mc *mysqlConn + br *bytes.Reader + zr io.ReadCloser } type compressedWriter struct { @@ -48,12 +50,8 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) { } } - data := make([]byte, need) - - copy(data, cr.bytesBuf[:len(data)]) - - cr.bytesBuf = cr.bytesBuf[len(data):] - + data := cr.bytesBuf[:need] + cr.bytesBuf = cr.bytesBuf[need:] return data, nil } @@ -88,27 +86,43 @@ func (cr *compressedReader) uncompressPacket() error { } // write comprData to a bytes.buffer, then read it using zlib into data - var b bytes.Buffer - b.Write(comprData) - r, err := zlib.NewReader(&b) + if cr.br == nil { + cr.br = bytes.NewReader(comprData) + } else { + cr.br.Reset(comprData) + } + + resetter, ok := cr.zr.(zlib.Resetter) - if r != nil { - defer r.Close() + if ok { + err := resetter.Reset(cr.br, []byte{}) + if err != nil { + return err + } + } else { + cr.zr, err = zlib.NewReader(cr.br) + if err != nil { + return err + } } - if err != nil { - return err + defer cr.zr.Close() + + //use existing capacity in bytesBuf if possible + offset := len(cr.bytesBuf) + if cap(cr.bytesBuf)-offset < uncompressedLength { + old := cr.bytesBuf + cr.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(cr.bytesBuf, old) } - data := make([]byte, uncompressedLength) + data := cr.bytesBuf[offset : offset+uncompressedLength] + lenRead := 0 // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { - - tmp := data[lenRead:] - - n, err := r.Read(tmp) + n, err := cr.zr.Read(data[lenRead:]) lenRead += n if err == io.EOF { @@ -152,7 +166,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { return 0, err } - err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + // if compression expands the payload, do not compress + useData := b.Bytes() + + if len(useData) > len(dataSmall) { + useData = dataSmall + lenSmall = 0 + } + + err = cw.writeComprPacketToNetwork(useData, lenSmall) if err != nil { return 0, err } @@ -163,7 +185,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { lenSmall := len(data) - // do not compress if packet is too small + // do not attempt compression if packet is too small if lenSmall < minCompressLength { err := cw.writeComprPacketToNetwork(data, 0) if err != nil { @@ -183,7 +205,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { return 0, err } - err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + // if compression expands the payload, do not compress + useData := b.Bytes() + + if len(useData) > len(data) { + useData = data + lenSmall = 0 + } + + err = cw.writeComprPacketToNetwork(useData, lenSmall) if err != nil { return 0, err diff --git a/packets.go b/packets.go index fbdeb6897..f8ff6a298 100644 --- a/packets.go +++ b/packets.go @@ -881,6 +881,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } stmt.mc.sequence = 0 + stmt.mc.compressionSequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -906,6 +907,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Reset Packet Sequence stmt.mc.sequence = 0 + stmt.mc.compressionSequence = 0 return nil } @@ -925,6 +927,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Reset packet-sequence mc.sequence = 0 + mc.compressionSequence = 0 var data []byte From a0cf94b33baca6fd00a0d761192b6f2b4d8fd103 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 14:32:26 -0400 Subject: [PATCH 03/88] packets: implemented compression protocol: remove bytes.Reset for backwards compatibility --- compress.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/compress.go b/compress.go index 14a131f4d..17bf1560b 100644 --- a/compress.go +++ b/compress.go @@ -18,13 +18,13 @@ type compressedReader struct { buf packetReader bytesBuf []byte mc *mysqlConn - br *bytes.Reader zr io.ReadCloser } type compressedWriter struct { connWriter io.Writer mc *mysqlConn + header []byte } func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { @@ -39,6 +39,7 @@ func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter return &compressedWriter{ connWriter: connWriter, mc: mc, + header: []byte{0, 0, 0, 0, 0, 0, 0}, } } @@ -86,21 +87,17 @@ func (cr *compressedReader) uncompressPacket() error { } // write comprData to a bytes.buffer, then read it using zlib into data - if cr.br == nil { - cr.br = bytes.NewReader(comprData) - } else { - cr.br.Reset(comprData) - } + br := bytes.NewReader(comprData) resetter, ok := cr.zr.(zlib.Resetter) if ok { - err := resetter.Reset(cr.br, []byte{}) + err := resetter.Reset(br, []byte{}) if err != nil { return err } } else { - cr.zr, err = zlib.NewReader(cr.br) + cr.zr, err = zlib.NewReader(br) if err != nil { return err } @@ -222,7 +219,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append([]byte{0, 0, 0, 0, 0, 0, 0}, data...) + data = append(cw.header, data...) comprLength := len(data) - 7 From d0ea1a418dc8ab516f245ad96b63157e852d8c9d Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 18 Aug 2017 16:24:53 -0400 Subject: [PATCH 04/88] reading working --- compress.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/compress.go b/compress.go index 17bf1560b..2ccd8970c 100644 --- a/compress.go +++ b/compress.go @@ -10,6 +10,10 @@ const ( minCompressLength = 50 ) +var ( + blankHeader = []byte{0, 0, 0, 0, 0, 0, 0} +) + type packetReader interface { readNext(need int) ([]byte, error) } @@ -24,7 +28,7 @@ type compressedReader struct { type compressedWriter struct { connWriter io.Writer mc *mysqlConn - header []byte + zw *zlib.Writer } func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { @@ -39,7 +43,7 @@ func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter return &compressedWriter{ connWriter: connWriter, mc: mc, - header: []byte{0, 0, 0, 0, 0, 0, 0}, + zw: zlib.NewWriter(new(bytes.Buffer)), } } @@ -89,18 +93,14 @@ func (cr *compressedReader) uncompressPacket() error { // write comprData to a bytes.buffer, then read it using zlib into data br := bytes.NewReader(comprData) - resetter, ok := cr.zr.(zlib.Resetter) - - if ok { - err := resetter.Reset(br, []byte{}) - if err != nil { - return err - } - } else { + if cr.zr == nil { cr.zr, err = zlib.NewReader(br) - if err != nil { - return err - } + } else { + err = cr.zr.(zlib.Resetter).Reset(br, nil) + } + + if err != nil { + return err } defer cr.zr.Close() @@ -219,7 +219,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append(cw.header, data...) + data = append(blankHeader, data...) comprLength := len(data) - 7 From 477c9f844736475945470522395fe6811cca52bf Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 18 Aug 2017 16:36:28 -0400 Subject: [PATCH 05/88] writerly changes --- compress.go | 64 ++++++++++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/compress.go b/compress.go index 2ccd8970c..96f431139 100644 --- a/compress.go +++ b/compress.go @@ -144,6 +144,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { if len(data) == 0 { return 0, nil } + totalBytes := len(data) length := len(data) - 4 @@ -151,27 +152,26 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { maxPayloadLength := maxPacketSize - 4 for length >= maxPayloadLength { - // cut off a slice of size max payload length - dataSmall := data[:maxPayloadLength] - lenSmall := len(dataSmall) - - var b bytes.Buffer - writer := zlib.NewWriter(&b) - _, err := writer.Write(dataSmall) - writer.Close() + payload := data[:maxPayloadLength] + payloadLen := len(payload) + + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(payload) if err != nil { return 0, err } + cw.zw.Close() // if compression expands the payload, do not compress - useData := b.Bytes() - - if len(useData) > len(dataSmall) { - useData = dataSmall - lenSmall = 0 + compressedPayload := bytesBuf.Bytes() + if len(compressedPayload) > maxPayloadLength { + compressedPayload = append(blankHeader, payload...) + payloadLen = 0 } - err = cw.writeComprPacketToNetwork(useData, lenSmall) + err = cw.writeToNetwork(compressedPayload, payloadLen) if err != nil { return 0, err } @@ -180,46 +180,44 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { data = data[maxPayloadLength:] } - lenSmall := len(data) + payloadLen := len(data) // do not attempt compression if packet is too small - if lenSmall < minCompressLength { - err := cw.writeComprPacketToNetwork(data, 0) + if payloadLen < minCompressLength { + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } - return totalBytes, nil } - var b bytes.Buffer - writer := zlib.NewWriter(&b) - - _, err := writer.Write(data) - writer.Close() - + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(data) if err != nil { return 0, err } + cw.zw.Close() - // if compression expands the payload, do not compress - useData := b.Bytes() + compressedPayload := bytesBuf.Bytes() - if len(useData) > len(data) { - useData = data - lenSmall = 0 + if len(compressedPayload) > len(data) { + compressedPayload = append(blankHeader, data...) + payloadLen = 0 } - err = cw.writeComprPacketToNetwork(useData, lenSmall) - + // add header and send over the wire + err = cw.writeToNetwork(compressedPayload, payloadLen) if err != nil { return 0, err } + return totalBytes, nil + } -func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append(blankHeader, data...) +func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { comprLength := len(data) - 7 From 996ed2d17131da2ef41e46fbdbd34186d47cb700 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Sun, 8 Oct 2017 14:00:11 -0400 Subject: [PATCH 06/88] PR 649: adding compression (second code review) --- compress.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/compress.go b/compress.go index 96f431139..3425c72fa 100644 --- a/compress.go +++ b/compress.go @@ -10,9 +10,6 @@ const ( minCompressLength = 50 ) -var ( - blankHeader = []byte{0, 0, 0, 0, 0, 0, 0} -) type packetReader interface { readNext(need int) ([]byte, error) @@ -146,17 +143,16 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } totalBytes := len(data) - length := len(data) - 4 - maxPayloadLength := maxPacketSize - 4 + blankHeader := make([]byte, 7) for length >= maxPayloadLength { payload := data[:maxPayloadLength] payloadLen := len(payload) bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) + bytesBuf.Write(blankHeader) cw.zw.Reset(bytesBuf) _, err := cw.zw.Write(payload) if err != nil { @@ -167,7 +163,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // if compression expands the payload, do not compress compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > maxPayloadLength { - compressedPayload = append(blankHeader, payload...) + compressedPayload = append(blankHeader, payload...) payloadLen = 0 } @@ -184,7 +180,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // do not attempt compression if packet is too small if payloadLen < minCompressLength { - err := cw.writeToNetwork(append(blankHeader, data...), 0) + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } @@ -203,7 +199,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > len(data) { - compressedPayload = append(blankHeader, data...) + compressedPayload = append(blankHeader, data...) payloadLen = 0 } From f74faedaa752df4ed5f57082f67fa80b60b0f62e Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Thu, 12 Oct 2017 10:31:18 +0200 Subject: [PATCH 07/88] do not query max_allowed_packet by default (#680) --- README.md | 4 ++-- const.go | 3 ++- driver_test.go | 2 +- dsn.go | 3 ++- dsn_test.go | 50 ++++++++++++++++++++++++-------------------------- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index b5882e6c8..6a306bb30 100644 --- a/README.md +++ b/README.md @@ -232,10 +232,10 @@ Please keep in mind, that param values must be [url.QueryEscape](https://golang. ##### `maxAllowedPacket` ``` Type: decimal number -Default: 0 +Default: 4194304 ``` -Max packet size allowed in bytes. Use `maxAllowedPacket=0` to automatically fetch the `max_allowed_packet` variable from server. +Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` diff --git a/const.go b/const.go index 88cfff3fd..2570b23fe 100644 --- a/const.go +++ b/const.go @@ -9,7 +9,8 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) diff --git a/driver_test.go b/driver_test.go index bc0386a09..27b067dff 100644 --- a/driver_test.go +++ b/driver_test.go @@ -964,7 +964,7 @@ func TestUint64(t *testing.T) { } func TestLongData(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { var maxAllowedPacketSize int err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) if err != nil { diff --git a/dsn.go b/dsn.go index 5ebd1d9f7..e3ead3ce5 100644 --- a/dsn.go +++ b/dsn.go @@ -65,6 +65,7 @@ func NewConfig() *Config { return &Config{ Collation: defaultCollation, Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, } } @@ -275,7 +276,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } - if cfg.MaxAllowedPacket > 0 { + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { diff --git a/dsn_test.go b/dsn_test.go index af28da351..07b223f6b 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -22,55 +22,55 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216}, }, { - "user:password@/dbname?allowNativePasswords=false", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: false}, + "user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, } @@ -233,16 +233,14 @@ func TestDSNUnsafeCollation(t *testing.T) { func TestParamsAreSorted(t *testing.T) { expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" - dsn := &Config{ - DBName: "dbname", - InterpolateParams: true, - AllowNativePasswords: true, - Params: map[string]string{ - "quux": "loo", - "foobar": "baz", - }, + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", } - actual := dsn.FormatDSN() + actual := cfg.FormatDSN() if actual != expected { t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) } From b3a093e1ccb62923917ed9c6dc2374ae041ff330 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Mon, 16 Oct 2017 22:44:03 +0200 Subject: [PATCH 08/88] packets: do not call function on nulled value (#678) --- packets.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packets.go b/packets.go index ff6b1394d..40b7f1115 100644 --- a/packets.go +++ b/packets.go @@ -1155,10 +1155,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } return io.EOF } + mc := rows.mc rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] From 5eaa5ff08a12e4bf29321fdcc92afd1c4d21e3f7 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 17 Oct 2017 14:45:56 +0200 Subject: [PATCH 09/88] ColumnType interfaces (#667) * rows: implement driver.RowsColumnTypeScanType Implementation for time.Time not yet complete! * rows: implement driver.RowsColumnTypeNullable * rows: move fields related code to fields.go * fields: use NullTime for nullable datetime fields * fields: make fieldType its own type * rows: implement driver.RowsColumnTypeDatabaseTypeName * fields: fix copyright year * rows: compile time interface implementation checks * rows: move tests to versioned driver test files * rows: cache parseTime in resultSet instead of mysqlConn * fields: fix string and time types * rows: implement ColumnTypeLength * rows: implement ColumnTypePrecisionScale * rows: fix ColumnTypeNullable * rows: ColumnTypes tests part1 * rows: use keyed composite literals in ColumnTypes tests * rows: ColumnTypes tests part2 * rows: always use NullTime as ScanType for datetime * rows: avoid errors through rounding of time values * rows: remove parseTime cache * fields: remove unused scanTypes * rows: fix ColumnTypePrecisionScale implementation * fields: sort types alphabetical * rows: remove ColumnTypeLength implementation for now * README: document ColumnType Support --- README.md | 14 ++- connection.go | 1 + const.go | 6 +- driver_go18_test.go | 220 ++++++++++++++++++++++++++++++++++++++++++++ driver_test.go | 6 ++ fields.go | 140 ++++++++++++++++++++++++++++ packets.go | 23 +++-- rows.go | 51 ++++++++-- 8 files changed, 436 insertions(+), 25 deletions(-) create mode 100644 fields.go diff --git a/README.md b/README.md index 6a306bb30..f6eb0b0d2 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,11 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * [Parameters](#parameters) * [Examples](#examples) * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) - * [context.Context Support](#contextcontext-support) * [Testing / Development](#testing--development) * [License](#license) @@ -400,6 +401,13 @@ user:password@/ ### Connection pool and timeouts The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. +## `ColumnType` Support +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. + +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. + ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): @@ -433,10 +441,6 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. -## `context.Context` Support -Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. -See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. - ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. diff --git a/connection.go b/connection.go index b31d63d7e..3a30c46a9 100644 --- a/connection.go +++ b/connection.go @@ -406,6 +406,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) return nil, err } } + // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err diff --git a/const.go b/const.go index 2570b23fe..4a19ca523 100644 --- a/const.go +++ b/const.go @@ -88,8 +88,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -108,7 +110,7 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON byte = iota + 0xf5 + fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet diff --git a/driver_go18_test.go b/driver_go18_test.go index 4962838f2..953adeb8a 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -15,6 +15,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "math" "reflect" "testing" "time" @@ -35,6 +36,22 @@ var ( _ driver.StmtQueryContext = &mysqlStmt{} ) +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + func TestMultiResultSet(t *testing.T) { type result struct { values [][]int @@ -558,3 +575,206 @@ func TestContextBeginReadOnly(t *testing.T) { } }) } + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"textnull", "TEXT", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"longtext", "LONGTEXT NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) + } +} diff --git a/driver_test.go b/driver_test.go index 27b067dff..53e70dab7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -27,6 +27,12 @@ import ( "time" ) +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + var ( user string pass string diff --git a/fields.go b/fields.go new file mode 100644 index 000000000..cded986d2 --- /dev/null +++ b/fields.go @@ -0,0 +1,140 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +var typeDatabaseName = map[fieldType]string{ + fieldTypeBit: "BIT", + fieldTypeBLOB: "BLOB", + fieldTypeDate: "DATE", + fieldTypeDateTime: "DATETIME", + fieldTypeDecimal: "DECIMAL", + fieldTypeDouble: "DOUBLE", + fieldTypeEnum: "ENUM", + fieldTypeFloat: "FLOAT", + fieldTypeGeometry: "GEOMETRY", + fieldTypeInt24: "MEDIUMINT", + fieldTypeJSON: "JSON", + fieldTypeLong: "INT", + fieldTypeLongBLOB: "LONGBLOB", + fieldTypeLongLong: "BIGINT", + fieldTypeMediumBLOB: "MEDIUMBLOB", + fieldTypeNewDate: "DATE", + fieldTypeNewDecimal: "DECIMAL", + fieldTypeNULL: "NULL", + fieldTypeSet: "SET", + fieldTypeShort: "SMALLINT", + fieldTypeString: "CHAR", + fieldTypeTime: "TIME", + fieldTypeTimestamp: "TIMESTAMP", + fieldTypeTiny: "TINYINT", + fieldTypeTinyBLOB: "TINYBLOB", + fieldTypeVarChar: "VARCHAR", + fieldTypeVarString: "VARCHAR", + fieldTypeYear: "YEAR", +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/packets.go b/packets.go index 40b7f1115..97afd0abc 100644 --- a/packets.go +++ b/packets.go @@ -708,11 +708,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Filler [uint8] // Charset [charset, collation uint8] + pos += n + 1 + 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -992,7 +995,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -1000,7 +1003,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -1016,7 +1019,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -1032,7 +1035,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -1044,7 +1047,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { @@ -1062,11 +1065,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { @@ -1081,7 +1084,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 var a [64]byte diff --git a/rows.go b/rows.go index c7f5ee26c..18f41693e 100644 --- a/rows.go +++ b/rows.go @@ -11,16 +11,10 @@ package mysql import ( "database/sql/driver" "io" + "math" + "reflect" ) -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte -} - type resultSet struct { columns []mysqlField columnNames []string @@ -65,6 +59,47 @@ func (rows *mysqlRows) Columns() []string { return columns } +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok { + return name + } + return "" +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + func (rows *mysqlRows) Close() (err error) { if f := rows.finish; f != nil { f() From ee460286f5d798ae53974561cc2e6827a78dd0ef Mon Sep 17 00:00:00 2001 From: Jeffrey Charles Date: Tue, 17 Oct 2017 13:10:23 -0400 Subject: [PATCH 10/88] Add Aurora errno to rejectReadOnly check (#634) AWS Aurora returns a 1290 after failing over requiring the connection to be closed and opened again to be able to perform writes. --- AUTHORS | 1 + README.md | 7 ++++++- packets.go | 3 ++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 0a936b1f6..c98ef9dbd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -35,6 +35,7 @@ INADA Naoki Jacek Szwec James Harr Jeff Hodges +Jeffrey Charles Jian Zhen Joshua Prunier Julien Lefevre diff --git a/README.md b/README.md index f6eb0b0d2..d24aaa0f0 100644 --- a/README.md +++ b/README.md @@ -279,7 +279,7 @@ Default: false ``` -`rejectreadOnly=true` causes the driver to reject read-only connections. This +`rejectReadOnly=true` causes the driver to reject read-only connections. This is for a possible race condition during an automatic failover, where the mysql client gets connected to a read-only replica after the failover. @@ -294,6 +294,11 @@ If you are not relying on read-only transactions to reject writes that aren't supposed to happen, setting this on some MySQL providers (such as AWS Aurora) is safer for failovers. +Note that ERROR 1290 can be returned for a `read-only` server and this option will +cause a retry for that error. However the same error number is used for some +other cases. You should ensure your application will never cause an ERROR 1290 +except for `read-only` mode when enabling this option. + ##### `timeout` diff --git a/packets.go b/packets.go index 97afd0abc..7bd2dd309 100644 --- a/packets.go +++ b/packets.go @@ -580,7 +580,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { errno := binary.LittleEndian.Uint16(data[1:3]) // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION - if errno == 1792 && mc.cfg.RejectReadOnly { + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // Oops; we are connected to a read-only connection, and won't be able // to issue any write statements. Since RejectReadOnly is configured, // we throw away this connection hoping this one would have write From 93aed7307deff9a0a6dc64c80b7862c29bc67c8d Mon Sep 17 00:00:00 2001 From: Jeff Hodges Date: Tue, 17 Oct 2017 11:16:16 -0700 Subject: [PATCH 11/88] allow successful TravisCI runs in forks (#639) Most forks won't be in goveralls and so this command in travis.yml was, previously, failing and causing the build to fail. Now, it doesn't! --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index fa0b2c933..6369281e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -90,4 +90,5 @@ script: - go test -v -covermode=count -coverprofile=coverage.out - go vet ./... - test -z "$(gofmt -d -s . | tee /dev/stderr)" +after_script: - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci From 4f10ee537a00db3ae88fa835f5e687a71639fe76 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sun, 12 Nov 2017 22:30:34 +0100 Subject: [PATCH 12/88] Drop support for Go 1.6 and lower (#696) * Drop support for Go 1.6 and lower * Remove cloneTLSConfig for legacy Go versions --- .travis.yml | 2 -- README.md | 2 +- utils_legacy.go | 18 ------------------ 3 files changed, 1 insertion(+), 21 deletions(-) delete mode 100644 utils_legacy.go diff --git a/.travis.yml b/.travis.yml index 6369281e8..64b06a70c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,6 @@ sudo: false language: go go: - - 1.5 - - 1.6 - 1.7 - 1.8 - 1.9 diff --git a/README.md b/README.md index d24aaa0f0..299198d53 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.5 or higher + * Go 1.7 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- diff --git a/utils_legacy.go b/utils_legacy.go deleted file mode 100644 index a03b10de2..000000000 --- a/utils_legacy.go +++ /dev/null @@ -1,18 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build !go1.7 - -package mysql - -import "crypto/tls" - -func cloneTLSConfig(c *tls.Config) *tls.Config { - clone := *c - return &clone -} From 59b0f90fea7003118587750d3590ebfb0cfc3d4f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 14 Nov 2017 09:18:14 +0100 Subject: [PATCH 13/88] Make gofmt happy (#704) --- driver_test.go | 1 - dsn.go | 1 - utils.go | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/driver_test.go b/driver_test.go index 53e70dab7..f6965b191 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1375,7 +1375,6 @@ func TestTimezoneConversion(t *testing.T) { // Regression test for timezone handling tzTest := func(dbt *DBTest) { - // Create table dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") diff --git a/dsn.go b/dsn.go index e3ead3ce5..418bc86b9 100644 --- a/dsn.go +++ b/dsn.go @@ -399,7 +399,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool diff --git a/utils.go b/utils.go index 82da83099..a92a4029b 100644 --- a/utils.go +++ b/utils.go @@ -566,8 +566,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 From 3fbf53ab2f434b1d97ce2ce9395251793ad1d3f7 Mon Sep 17 00:00:00 2001 From: Daniel Montoya Date: Wed, 15 Nov 2017 16:37:47 -0600 Subject: [PATCH 14/88] Added support for custom string types in ConvertValue. (#623) * Added support for custom string types. * Add author name * Added license header * Added a newline to force a commit. * Remove newline. --- AUTHORS | 1 + statement.go | 2 ++ statement_test.go | 21 +++++++++++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 statement_test.go diff --git a/AUTHORS b/AUTHORS index c98ef9dbd..780561a98 100644 --- a/AUTHORS +++ b/AUTHORS @@ -19,6 +19,7 @@ B Lamarche Bulat Gaifullin Carlos Nieto Chris Moos +Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski diff --git a/statement.go b/statement.go index ae223507f..628174b64 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,8 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.String: + return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 000000000..8de4a8b26 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,21 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import "testing" + +type customString string + +func TestConvertValueCustomTypes(t *testing.T) { + var cstr customString = "string" + c := converter{} + if _, err := c.ConvertValue(cstr); err != nil { + t.Errorf("custom string type should be valid") + } +} From f9c6a2cea1651d4e197b8034fafa768bbd44223f Mon Sep 17 00:00:00 2001 From: Justin Li Date: Thu, 16 Nov 2017 02:25:03 -0500 Subject: [PATCH 15/88] Implement NamedValueChecker for mysqlConn (#690) * Also add conversions for additional types in ConvertValue ref https://github.com/golang/go/commit/d7c0de98a96893e5608358f7578c85be7ba12b25 --- AUTHORS | 1 + connection_go18.go | 5 ++ connection_go18_test.go | 30 ++++++++++ statement.go | 8 +++ statement_test.go | 119 +++++++++++++++++++++++++++++++++++++--- 5 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 connection_go18_test.go diff --git a/AUTHORS b/AUTHORS index 780561a98..95d14a076 100644 --- a/AUTHORS +++ b/AUTHORS @@ -41,6 +41,7 @@ Jian Zhen Joshua Prunier Julien Lefevre Julien Schmidt +Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski diff --git a/connection_go18.go b/connection_go18.go index 48a9cca64..1306b70b7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() { } }() } + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} diff --git a/connection_go18_test.go b/connection_go18_test.go new file mode 100644 index 000000000..2719ab3b7 --- /dev/null +++ b/connection_go18_test.go @@ -0,0 +1,30 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "database/sql/driver" + "testing" +) + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) + } +} diff --git a/statement.go b/statement.go index 628174b64..4870a307c 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 8de4a8b26..98a6c1933 100644 --- a/statement_test.go +++ b/statement_test.go @@ -8,14 +8,119 @@ package mysql -import "testing" +import ( + "bytes" + "testing" +) -type customString string +func TestConvertDerivedString(t *testing.T) { + type derived string -func TestConvertValueCustomTypes(t *testing.T) { - var cstr customString = "string" - c := converter{} - if _, err := c.ConvertValue(cstr); err != nil { - t.Errorf("custom string type should be valid") + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output) } } From 6046bf014ffba46d3753cfc64d1a3c9656318d8f Mon Sep 17 00:00:00 2001 From: Dave Stubbs Date: Thu, 16 Nov 2017 16:10:24 +0000 Subject: [PATCH 16/88] Fix Valuers by returning driver.ErrSkip if couldn't convert type internally (#709) Fixes #708 --- AUTHORS | 1 + connection_go18.go | 6 +++++- driver_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 95d14a076..66a4ad202 100644 --- a/AUTHORS +++ b/AUTHORS @@ -72,6 +72,7 @@ Zhenye Xie # Organizations Barracuda Networks, Inc. +Counting Ltd. Google Inc. Keybase Inc. Pivotal Inc. diff --git a/connection_go18.go b/connection_go18.go index 1306b70b7..65cc63ef2 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -197,6 +197,10 @@ func (mc *mysqlConn) startWatcher() { } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = converter{}.ConvertValue(nv.Value) + value, err := converter{}.ConvertValue(nv.Value) + if err != nil { + return driver.ErrSkip + } + nv.Value = value return } diff --git a/driver_test.go b/driver_test.go index f6965b191..392e752a3 100644 --- a/driver_test.go +++ b/driver_test.go @@ -499,6 +499,36 @@ func TestString(t *testing.T) { }) } +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + type timeTests struct { dbtype string tlayout string From 385673a27ccb40f4a14746623da6e849c80eb079 Mon Sep 17 00:00:00 2001 From: Linh Tran Tuan Date: Fri, 17 Nov 2017 14:23:23 +0700 Subject: [PATCH 17/88] statement: Fix conversion of Valuer (#710) Updates #709 Fixes #706 --- AUTHORS | 1 + connection_go18.go | 6 +----- driver_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ statement.go | 6 ++++++ 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/AUTHORS b/AUTHORS index 66a4ad202..3fc9ece3a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -47,6 +47,7 @@ Kamil Dziedzic Kevin Malachowski Lennart Rudolph Leonardo YongUk Kim +Linh Tran Tuan Lion Yang Luca Looz Lucas Liu diff --git a/connection_go18.go b/connection_go18.go index 65cc63ef2..1306b70b7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -197,10 +197,6 @@ func (mc *mysqlConn) startWatcher() { } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - value, err := converter{}.ConvertValue(nv.Value) - if err != nil { - return driver.ErrSkip - } - nv.Value = value + nv.Value, err = converter{}.ConvertValue(nv.Value) return } diff --git a/driver_test.go b/driver_test.go index 392e752a3..224a24c53 100644 --- a/driver_test.go +++ b/driver_test.go @@ -529,6 +529,55 @@ func TestValuer(t *testing.T) { }) } +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + type timeTests struct { dbtype string tlayout string diff --git a/statement.go b/statement.go index 4870a307c..98e57bcd8 100644 --- a/statement.go +++ b/statement.go @@ -137,6 +137,12 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return v, nil } + if v != nil { + if valuer, ok := v.(driver.Valuer); ok { + return valuer.Value() + } + } + rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: From 9031984e2b7bab392eb005c4c014ea66af892796 Mon Sep 17 00:00:00 2001 From: "Robert R. Russell" Date: Fri, 17 Nov 2017 05:51:24 -0600 Subject: [PATCH 18/88] Fixed imports for appengine/cloudsql (#700) * Fixed broken import for appengine/cloudsql appengine.go import path of appengine/cloudsql has changed to google.golang.org/appengine/cloudsql - Fixed. * Added my name to the AUTHORS --- AUTHORS | 1 + appengine.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 3fc9ece3a..9988284ef 100644 --- a/AUTHORS +++ b/AUTHORS @@ -61,6 +61,7 @@ Paul Bonser Peter Schultz Rebecca Chin Runrioter Wung +Robert Russell Shuode Li Soroush Pour Stan Putrya diff --git a/appengine.go b/appengine.go index 565614eef..be41f2ee6 100644 --- a/appengine.go +++ b/appengine.go @@ -11,7 +11,7 @@ package mysql import ( - "appengine/cloudsql" + "google.golang.org/appengine/cloudsql" ) func init() { From 6992fad9c49d4e7df3c8346949e44684fa826146 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 4 Dec 2017 09:43:26 +0900 Subject: [PATCH 19/88] Fix tls=true didn't work with host without port (#718) Fixes #717 --- dsn.go | 20 +++++++++----------- dsn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/dsn.go b/dsn.go index 418bc86b9..fa50ad3c0 100644 --- a/dsn.go +++ b/dsn.go @@ -95,6 +95,15 @@ func (cfg *Config) normalize() error { cfg.Addr = ensureHavePort(cfg.Addr) } + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + return nil } @@ -526,10 +535,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { if boolValue { cfg.TLSConfig = "true" cfg.tls = &tls.Config{} - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - cfg.tls.ServerName = host - } } else { cfg.TLSConfig = "false" } @@ -543,13 +548,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { } if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - cfg.TLSConfig = name cfg.tls = tlsConfig } else { diff --git a/dsn_test.go b/dsn_test.go index 07b223f6b..7507d1201 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) { DeregisterTLSConfig("utils_test") } +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + func TestDSNWithCustomTLSQueryEscape(t *testing.T) { const configKey = "&%!:" dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) From 386f84bcc4e8e23703cc5b822191df934c85fd73 Mon Sep 17 00:00:00 2001 From: Kieron Woodhouse Date: Wed, 10 Jan 2018 11:31:24 +0000 Subject: [PATCH 20/88] Differentiate between BINARY and CHAR (#724) * Differentiate between BINARY and CHAR When looking up the database type name, we now check the character set for the following field types: * CHAR * VARCHAR * BLOB * TINYBLOB * MEDIUMBLOB * LONGBLOB If the character set is 63 (which is the binary pseudo character set), we return the binary names, which are (respectively): * BINARY * VARBINARY * BLOB * TINYBLOB * MEDIUMBLOB * LONGBLOB If any other character set is in use, we return the text names, which are (again, respectively): * CHAR * VARCHAR * TEXT * TINYTEXT * MEDIUMTEXT * LONGTEXT To facilitate this, mysqlField has been extended to include a uint8 field for character set, which is read from the appropriate packet. Column type tests have been updated to ensure coverage of binary and text types. * Increase test coverage for column types --- AUTHORS | 2 + collations.go | 1 + driver_go18_test.go | 22 ++++++++- fields.go | 112 ++++++++++++++++++++++++++++++++------------ packets.go | 6 ++- rows.go | 5 +- 6 files changed, 112 insertions(+), 36 deletions(-) diff --git a/AUTHORS b/AUTHORS index 9988284ef..5d84a6eb1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -45,6 +45,7 @@ Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski +Kieron Woodhouse Lennart Rudolph Leonardo YongUk Kim Linh Tran Tuan @@ -76,6 +77,7 @@ Zhenye Xie Barracuda Networks, Inc. Counting Ltd. Google Inc. +InfoSum Ltd. Keybase Inc. Pivotal Inc. Stripe Inc. diff --git a/collations.go b/collations.go index 82079cfb9..136c9e4d1 100644 --- a/collations.go +++ b/collations.go @@ -9,6 +9,7 @@ package mysql const defaultCollation = "utf8_general_ci" +const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/driver_go18_test.go b/driver_go18_test.go index 953adeb8a..e461455dd 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -588,10 +588,16 @@ func TestRowsColumnTypes(t *testing.T) { nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := NullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") rb42 := sql.RawBytes("42") rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") var columns = []struct { name string @@ -604,6 +610,7 @@ func TestRowsColumnTypes(t *testing.T) { valuesIn [3]string valuesOut [3]interface{} }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, @@ -611,6 +618,7 @@ func TestRowsColumnTypes(t *testing.T) { {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, @@ -630,11 +638,21 @@ func TestRowsColumnTypes(t *testing.T) { {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"textnull", "TEXT", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"longtext", "LONGTEXT NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, } schema := "" diff --git a/fields.go b/fields.go index cded986d2..e1e2ece4b 100644 --- a/fields.go +++ b/fields.go @@ -13,35 +13,88 @@ import ( "reflect" ) -var typeDatabaseName = map[fieldType]string{ - fieldTypeBit: "BIT", - fieldTypeBLOB: "BLOB", - fieldTypeDate: "DATE", - fieldTypeDateTime: "DATETIME", - fieldTypeDecimal: "DECIMAL", - fieldTypeDouble: "DOUBLE", - fieldTypeEnum: "ENUM", - fieldTypeFloat: "FLOAT", - fieldTypeGeometry: "GEOMETRY", - fieldTypeInt24: "MEDIUMINT", - fieldTypeJSON: "JSON", - fieldTypeLong: "INT", - fieldTypeLongBLOB: "LONGBLOB", - fieldTypeLongLong: "BIGINT", - fieldTypeMediumBLOB: "MEDIUMBLOB", - fieldTypeNewDate: "DATE", - fieldTypeNewDecimal: "DECIMAL", - fieldTypeNULL: "NULL", - fieldTypeSet: "SET", - fieldTypeShort: "SMALLINT", - fieldTypeString: "CHAR", - fieldTypeTime: "TIME", - fieldTypeTimestamp: "TIMESTAMP", - fieldTypeTiny: "TINYINT", - fieldTypeTinyBLOB: "TINYBLOB", - fieldTypeVarChar: "VARCHAR", - fieldTypeVarString: "VARCHAR", - fieldTypeYear: "YEAR", +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } } var ( @@ -69,6 +122,7 @@ type mysqlField struct { flags fieldFlag fieldType fieldType decimals byte + charSet uint8 } func (mf *mysqlField) scanType() reflect.Type { diff --git a/packets.go b/packets.go index 7bd2dd309..36ce691c5 100644 --- a/packets.go +++ b/packets.go @@ -706,10 +706,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] - pos += n + 1 + 2 + columns[i].charSet = data[pos] + pos += 2 // Length [uint32] columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) diff --git a/rows.go b/rows.go index 18f41693e..d3b1e2822 100644 --- a/rows.go +++ b/rows.go @@ -60,10 +60,7 @@ func (rows *mysqlRows) Columns() []string { } func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { - if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok { - return name - } - return "" + return rows.rs.columns[i].typeDatabaseName() } // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { From f853432d62faa9c9265c1454f7b1bae07f4b2a6e Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 10 Jan 2018 13:44:24 +0200 Subject: [PATCH 21/88] Test with latest Go patch versions (#693) --- .travis.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 64b06a70c..e922f9187 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,10 @@ sudo: false language: go go: - - 1.7 - - 1.8 - - 1.9 - - tip + - 1.7.x + - 1.8.x + - 1.9.x + - master before_install: - go get golang.org/x/tools/cmd/cover @@ -21,7 +21,7 @@ matrix: - env: DB=MYSQL57 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: @@ -43,7 +43,7 @@ matrix: - env: DB=MARIA55 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: @@ -65,7 +65,7 @@ matrix: - env: DB=MARIA10_1 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: From d1a8b86f7fef0f773e55ca15defb2347be22a106 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Sun, 14 Jan 2018 05:07:44 +0900 Subject: [PATCH 22/88] Fix prepared statement (#734) * Fix prepared statement When there are many args and maxAllowedPacket is not enough, writeExecutePacket() attempted to use STMT_LONG_DATA even for 0byte string. But writeCommandLongData() doesn't support 0byte data. So it caused to send malfold packet. This commit loosen threshold for using STMT_LONG_DATA. * Change minimum size of LONG_DATA to 64byte * Add test which reproduce issue 730 * TestPreparedManyCols test only numParams = 65535 case * s/as possible// --- driver_test.go | 17 ++++++++++++++--- packets.go | 10 ++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/driver_test.go b/driver_test.go index 224a24c53..7877aa979 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1669,8 +1669,9 @@ func TestStmtMultiRows(t *testing.T) { // Regression test for // * more than 32 NULL parameters (issue 209) // * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) func TestPreparedManyCols(t *testing.T) { - const numParams = defaultBufSize + numParams := 65535 runTests(t, dsn, func(dbt *DBTest) { query := "SELECT ?" + strings.Repeat(",?", numParams-1) stmt, err := dbt.db.Prepare(query) @@ -1678,15 +1679,25 @@ func TestPreparedManyCols(t *testing.T) { dbt.Fatal(err) } defer stmt.Close() + // create more parameters than fit into the buffer // which will take nil-values params := make([]interface{}, numParams) rows, err := stmt.Query(params...) if err != nil { - stmt.Close() dbt.Fatal(err) } - defer rows.Close() + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() }) } diff --git a/packets.go b/packets.go index 36ce691c5..e6d8e4e88 100644 --- a/packets.go +++ b/packets.go @@ -927,6 +927,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 mc.compressionSequence = 0 @@ -1055,7 +1061,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1077,7 +1083,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) From 31679208840e1a2a05db26638c7bcda4ff362bf1 Mon Sep 17 00:00:00 2001 From: Reed Allman Date: Wed, 24 Jan 2018 21:47:45 -0800 Subject: [PATCH 23/88] driver.ErrBadConn when init packet read fails (#736) Thank you! --- AUTHORS | 1 + packets.go | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/AUTHORS b/AUTHORS index 5d84a6eb1..d9144aece 100644 --- a/AUTHORS +++ b/AUTHORS @@ -61,6 +61,7 @@ oscarzhao Paul Bonser Peter Schultz Rebecca Chin +Reed Allman Runrioter Wung Robert Russell Shuode Li diff --git a/packets.go b/packets.go index e6d8e4e88..2e9cb4984 100644 --- a/packets.go +++ b/packets.go @@ -157,6 +157,11 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, driver.ErrBadConn + } return nil, err } From fb33a2cb2ede88e3222b62f272909216a1121e5b Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 31 Jul 2017 16:43:52 -0400 Subject: [PATCH 24/88] packets: implemented compression protocol --- compress.go | 3 ++- dsn.go | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/compress.go b/compress.go index 3425c72fa..56b07362b 100644 --- a/compress.go +++ b/compress.go @@ -10,7 +10,6 @@ const ( minCompressLength = 50 ) - type packetReader interface { readNext(need int) ([]byte, error) } @@ -117,6 +116,7 @@ func (cr *compressedReader) uncompressPacket() error { // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { n, err := cr.zr.Read(data[lenRead:]) + lenRead += n if err == io.EOF { @@ -168,6 +168,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } err = cw.writeToNetwork(compressedPayload, payloadLen) + if err != nil { return 0, err } diff --git a/dsn.go b/dsn.go index fa50ad3c0..92f137daa 100644 --- a/dsn.go +++ b/dsn.go @@ -53,6 +53,7 @@ type Config struct { AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names + Compression bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time From f1746058a298b62cc541c3475df4ca4e6569d4c2 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 13:55:33 -0400 Subject: [PATCH 25/88] packets: implemented compression protocol CR changes --- compress.go | 1 - 1 file changed, 1 deletion(-) diff --git a/compress.go b/compress.go index 56b07362b..9339e0eca 100644 --- a/compress.go +++ b/compress.go @@ -116,7 +116,6 @@ func (cr *compressedReader) uncompressPacket() error { // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { n, err := cr.zr.Read(data[lenRead:]) - lenRead += n if err == io.EOF { From dbd1e2befc161ab78bf0aa36e8e8bbd2e05e9ad7 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 23 Mar 2018 11:22:20 -0400 Subject: [PATCH 26/88] third code review changes --- compress.go | 12 ++++-------- connection.go | 8 ++++++-- driver.go | 2 +- dsn.go | 4 ++-- packets.go | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/compress.go b/compress.go index 9339e0eca..74b293a58 100644 --- a/compress.go +++ b/compress.go @@ -10,10 +10,6 @@ const ( minCompressLength = 50 ) -type packetReader interface { - readNext(need int) ([]byte, error) -} - type compressedReader struct { buf packetReader bytesBuf []byte @@ -151,7 +147,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { payloadLen := len(payload) bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) + bytesBuf.Write(blankHeader) cw.zw.Reset(bytesBuf) _, err := cw.zw.Write(payload) if err != nil { @@ -162,7 +158,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // if compression expands the payload, do not compress compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > maxPayloadLength { - compressedPayload = append(blankHeader, payload...) + compressedPayload = append(blankHeader, payload...) payloadLen = 0 } @@ -180,7 +176,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // do not attempt compression if packet is too small if payloadLen < minCompressLength { - err := cw.writeToNetwork(append(blankHeader, data...), 0) + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } @@ -199,7 +195,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > len(data) { - compressedPayload = append(blankHeader, data...) + compressedPayload = append(blankHeader, data...) payloadLen = 0 } diff --git a/connection.go b/connection.go index 3a30c46a9..cc802fa42 100644 --- a/connection.go +++ b/connection.go @@ -30,6 +30,8 @@ type mysqlContext interface { type mysqlConn struct { buf buffer netConn net.Conn + reader packetReader + writer io.Writer affectedRows uint64 insertId uint64 cfg *Config @@ -41,8 +43,6 @@ type mysqlConn struct { sequence uint8 compressionSequence uint8 parseTime bool - reader packetReader - writer io.Writer // for context support (Go 1.8+) watching bool @@ -53,6 +53,10 @@ type mysqlConn struct { closed atomicBool // set when conn is closed, before closech is closed } +type packetReader interface { + readNext(need int) ([]byte, error) +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { diff --git a/driver.go b/driver.go index 86d38f70d..636ea1fb3 100644 --- a/driver.go +++ b/driver.go @@ -123,7 +123,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } - if mc.cfg.Compression { + if mc.cfg.Compress { mc.reader = NewCompressedReader(&mc.buf, mc) mc.writer = NewCompressedWriter(mc.writer, mc) } diff --git a/dsn.go b/dsn.go index 92f137daa..b7e9c5495 100644 --- a/dsn.go +++ b/dsn.go @@ -53,7 +53,7 @@ type Config struct { AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names - Compression bool // Compress packets + Compress bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time @@ -464,7 +464,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": var isBool bool - cfg.Compression, isBool = readBool(value) + cfg.Compress, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } diff --git a/packets.go b/packets.go index 2e9cb4984..0303c426d 100644 --- a/packets.go +++ b/packets.go @@ -258,7 +258,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientFoundRows } - if mc.cfg.Compression { + if mc.cfg.Compress { clientFlags |= clientCompress } From 3e12e32d9970baaaa6316d83e9e934e72e4564eb Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 23 Mar 2018 11:58:56 -0400 Subject: [PATCH 27/88] PR 649: minor cleanup --- compress.go | 4 ++-- compress_test.go | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/compress.go b/compress.go index 74b293a58..719e3625d 100644 --- a/compress.go +++ b/compress.go @@ -97,7 +97,7 @@ func (cr *compressedReader) uncompressPacket() error { defer cr.zr.Close() - //use existing capacity in bytesBuf if possible + // use existing capacity in bytesBuf if possible offset := len(cr.bytesBuf) if cap(cr.bytesBuf)-offset < uncompressedLength { old := cr.bytesBuf @@ -220,7 +220,7 @@ func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error data[3] = cw.mc.compressionSequence - //this value is never greater than maxPayloadLength + // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncomprLength) data[5] = byte(0xff & (uncomprLength >> 8)) data[6] = byte(0xff & (uncomprLength >> 16)) diff --git a/compress_test.go b/compress_test.go index c626ff3ee..d497ed56d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -208,13 +208,10 @@ func TestRoundtrip(t *testing.T) { for _, test := range tests { s := fmt.Sprintf("Test roundtrip with %s", test.desc) - //t.Run(s, func(t *testing.T) { uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) if bytes.Compare(uncompressed, test.uncompressed) != 0 { t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) } - - //}) } } From 60bdaec793f67557d38e805047066911e9d8cab5 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 26 Mar 2018 19:01:55 +0900 Subject: [PATCH 28/88] Sort AUTHORS --- AUTHORS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 6274cd5e1..291c0f2fd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -64,8 +64,8 @@ Paul Bonser Peter Schultz Rebecca Chin Reed Allman -Runrioter Wung Robert Russell +Runrioter Wung Shuode Li Soroush Pour Stan Putrya From 422ab6f48ea938b28641d9cb23a67e51190f19bb Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 26 Mar 2018 19:28:02 +0900 Subject: [PATCH 29/88] Update dsn.go --- dsn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/dsn.go b/dsn.go index b7e9c5495..82d15a8fb 100644 --- a/dsn.go +++ b/dsn.go @@ -58,7 +58,6 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - Compression bool // Compress packets } // NewConfig creates a new Config and sets default values. From 1f38652db2fa7b41697c235b1d04d701f1f52632 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Wed, 4 Oct 2023 12:38:14 +0200 Subject: [PATCH 30/88] Please linter. --- compress_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/compress_test.go b/compress_test.go index d497ed56d..f5ad89045 100644 --- a/compress_test.go +++ b/compress_test.go @@ -57,18 +57,18 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } if n != len(uncompressedPacket) { - t.Fatal(fmt.Sprintf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)) + t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n) } if len(uncompressedPacket) > 0 { if mc.compressionSequence != (cs + 1) { - t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence) } } else { if mc.compressionSequence != cs { - t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence)) + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence) } } @@ -95,17 +95,17 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS uncompressedPacket, err := cr.readNext(expSize) if err != nil { if err != io.EOF { - t.Fatal(fmt.Sprintf("non-nil/non-EOF error when reading contents: %s", err.Error())) + t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) } } if expSize > 0 { if mc.compressionSequence != (cs + 1) { - t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence) } } else { if mc.compressionSequence != cs { - t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence)) + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence) } } return uncompressedPacket @@ -154,15 +154,15 @@ func TestCompressedReaderThenWriter(t *testing.T) { // test uncompression only c := newMockConn() uncompressed := uncompressHelper(t, c, test.compressed, len(test.uncompressed)) - if bytes.Compare(uncompressed, test.uncompressed) != 0 { - t.Fatal(fmt.Sprintf("%s: uncompression failed", s)) + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Fatalf("%s: uncompression failed", s) } // test compression only c = newMockConn() compressed := compressHelper(t, c, test.uncompressed) - if bytes.Compare(compressed, test.compressed) != 0 { - t.Fatal(fmt.Sprintf("%s: compression failed", s)) + if !bytes.Equal(compressed, test.compressed) { + t.Fatalf("%s: compression failed", s) } } } @@ -210,8 +210,8 @@ func TestRoundtrip(t *testing.T) { s := fmt.Sprintf("Test roundtrip with %s", test.desc) uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) - if bytes.Compare(uncompressed, test.uncompressed) != 0 { - t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Fatalf("%s: roundtrip failed", s) } } } From d43864e1c01df340f16d4c899a99b7a9e1ae7246 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Wed, 4 Oct 2023 16:44:20 +0200 Subject: [PATCH 31/88] Formatting. --- benchmark_test.go | 1 - compress_test.go | 2 -- 2 files changed, 3 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 235a8054b..cba2d6783 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -72,7 +72,6 @@ func BenchmarkQueryCompression(b *testing.B) { } func benchmarkQueryHelper(b *testing.B, compr bool) { - tb := (*TB)(b) b.StopTimer() b.ReportAllocs() diff --git a/compress_test.go b/compress_test.go index f5ad89045..ee55ea368 100644 --- a/compress_test.go +++ b/compress_test.go @@ -113,7 +113,6 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS // TestCompressedReaderThenWriter tests reader and writer seperately. func TestCompressedReaderThenWriter(t *testing.T) { - makeTestUncompressedPacket := func(size int) []byte { uncompressedHeader := make([]byte, 4) uncompressedHeader[0] = byte(size) @@ -169,7 +168,6 @@ func TestCompressedReaderThenWriter(t *testing.T) { // TestRoundtrip tests two connections, where one is reading and the other is writing func TestRoundtrip(t *testing.T) { - tests := []struct { uncompressed []byte desc string From 1c2ac702110d821250f268997211b01bd65411e7 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 5 Oct 2023 10:43:13 +0200 Subject: [PATCH 32/88] Unexport constructors. --- compress.go | 4 ++-- compress_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compress.go b/compress.go index 719e3625d..70d1dbb1e 100644 --- a/compress.go +++ b/compress.go @@ -23,7 +23,7 @@ type compressedWriter struct { zw *zlib.Writer } -func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { +func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { return &compressedReader{ buf: buf, bytesBuf: make([]byte, 0), @@ -31,7 +31,7 @@ func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { } } -func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { +func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { return &compressedWriter{ connWriter: connWriter, mc: mc, diff --git a/compress_test.go b/compress_test.go index ee55ea368..4be704211 100644 --- a/compress_test.go +++ b/compress_test.go @@ -48,7 +48,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by var b bytes.Buffer connWriter := &b - cw := NewCompressedWriter(connWriter, mc) + cw := newCompressedWriter(connWriter, mc) n, err := cw.Write(uncompressedPacket) @@ -90,7 +90,7 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS mockConnReader := bytes.NewReader(compressedPacket) mockBuf := newMockBuf(mockConnReader) - cr := NewCompressedReader(mockBuf, mc) + cr := newCompressedReader(mockBuf, mc) uncompressedPacket, err := cr.readNext(expSize) if err != nil { From 944e6380da6788d8f2c58d7be212b1466c15a0f9 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 5 Oct 2023 10:44:10 +0200 Subject: [PATCH 33/88] Fix tests. --- connection_test.go | 20 ++++++++++++++------ connector.go | 9 +++++++++ packets_test.go | 9 ++++++++- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/connection_test.go b/connection_test.go index d39787ec5..59ebb44e0 100644 --- a/connection_test.go +++ b/connection_test.go @@ -163,13 +163,17 @@ func TestCleanCancel(t *testing.T) { func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} - ms := &mysqlConn{ + + buf := newBuffer(nc) + mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: buf, + reader: &buf, + writer: nc, maxAllowedPacket: defaultMaxAllowedPacket, } - err := ms.Ping(context.Background()) + err := mc.Ping(context.Background()) if err != driver.ErrBadConn { t.Errorf("expected driver.ErrBadConn, got %#v", err) @@ -178,15 +182,19 @@ func TestPingMarkBadConnection(t *testing.T) { func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} - ms := &mysqlConn{ + + buf := newBuffer(nc) + mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: buf, + reader: &buf, + writer: nc, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), } - err := ms.Ping(context.Background()) + err := mc.Ping(context.Background()) if err != ErrInvalidConn { t.Errorf("expected ErrInvalidConn, got %#v", err) diff --git a/connector.go b/connector.go index 7e0b16734..b59b262b0 100644 --- a/connector.go +++ b/connector.go @@ -114,6 +114,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) + // packet reader and writer in handshake are never compressed + mc.reader = &mc.buf + mc.writer = mc.netConn + // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout @@ -155,6 +159,11 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + if mc.cfg.Compress { + mc.reader = newCompressedReader(&mc.buf, mc) + mc.writer = newCompressedWriter(mc.writer, mc) + } + if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/packets_test.go b/packets_test.go index 063e1b47a..d934d95ce 100644 --- a/packets_test.go +++ b/packets_test.go @@ -100,8 +100,14 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { if err != nil { panic(err) } + + buf := newBuffer(conn) + reader := newBuffer(conn) + mc := &mysqlConn{ - buf: newBuffer(conn), + buf: buf, + reader: &reader, + writer: conn, cfg: connector.cfg, connector: connector, netConn: conn, @@ -326,6 +332,7 @@ func TestRegression801(t *testing.T) { sequence: 42, closech: make(chan struct{}), } + mc.reader = &mc.buf conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, From 15017fcbd35610e1a15a755b2c2ec7165beeaedf Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 5 Oct 2023 12:38:04 +0200 Subject: [PATCH 34/88] Update AUTHORS. --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index ce247fb09..bae1cddee 100644 --- a/AUTHORS +++ b/AUTHORS @@ -56,6 +56,7 @@ Jeffrey Charles Jerome Meyer Jiajia Zhong Jian Zhen +Joe Mann Joshua Prunier Julien Lefevre Julien Schmidt From f400590e33e597e0ff9c64f3bfe2262353c1d232 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 5 Oct 2023 12:39:13 +0200 Subject: [PATCH 35/88] Formatting. --- compress_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/compress_test.go b/compress_test.go index 4be704211..8812cd39d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -61,7 +61,6 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } if len(uncompressedPacket) > 0 { - if mc.compressionSequence != (cs + 1) { t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence) } From 084dafb87343f46b488baad3af318b274b24ab72 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 5 Oct 2023 12:46:23 +0200 Subject: [PATCH 36/88] Update README feature list. --- README.md | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 9257c1fd2..7c6547067 100644 --- a/README.md +++ b/README.md @@ -27,17 +27,18 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac --------------------------------------- ## Features - * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") - * Native Go implementation. No C-bindings, just pure Go - * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) - * Automatic handling of broken connections - * Automatic Connection Pooling *(by database/sql package)* - * Supports queries larger than 16MB + * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance"). + * Native Go implementation. No C-bindings, just pure Go. + * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc). + * Automatic handling of broken connections. + * Automatic Connection Pooling *(by database/sql package)*. + * Supports queries larger than 16MB. * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. - * Intelligent `LONG DATA` handling in prepared statements - * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support - * Optional `time.Time` parsing - * Optional placeholder interpolation + * Intelligent `LONG DATA` handling in prepared statements. + * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support. + * Optional `time.Time` parsing. + * Optional placeholder interpolation. + * Supports zlib compression. ## Requirements * Go 1.18 or higher. We aim to support the 3 latest versions of Go. From ee87a7da283d2285773aa579d0930bf8918f6c02 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Fri, 6 Oct 2023 13:53:52 +0200 Subject: [PATCH 37/88] Fix TLS. --- packets.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packets.go b/packets.go index 8ebbde16d..268065bb5 100644 --- a/packets.go +++ b/packets.go @@ -387,6 +387,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn + + mc.writer = mc.netConn } // User [null terminated string] From b8cfe77798954184a5cac850442db9e949f0962a Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Wed, 11 Oct 2023 15:59:58 +0200 Subject: [PATCH 38/88] Formatting. --- compress.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/compress.go b/compress.go index 70d1dbb1e..55b5bfc28 100644 --- a/compress.go +++ b/compress.go @@ -54,7 +54,6 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) { func (cr *compressedReader) uncompressPacket() error { header, err := cr.buf.readNext(7) // size of compressed header - if err != nil { return err } @@ -210,7 +209,6 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { - comprLength := len(data) - 7 // compression header From d2501ec9cf2c338df12ad880cb11b0cd19244acf Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Wed, 11 Oct 2023 16:03:30 +0200 Subject: [PATCH 39/88] Tidy up. --- compress.go | 96 +++++++++++++++++++++++------------------------------ 1 file changed, 42 insertions(+), 54 deletions(-) diff --git a/compress.go b/compress.go index 55b5bfc28..efaadd961 100644 --- a/compress.go +++ b/compress.go @@ -39,21 +39,20 @@ func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter } } -func (cr *compressedReader) readNext(need int) ([]byte, error) { - for len(cr.bytesBuf) < need { - err := cr.uncompressPacket() - if err != nil { +func (r *compressedReader) readNext(need int) ([]byte, error) { + for len(r.bytesBuf) < need { + if err := r.uncompressPacket(); err != nil { return nil, err } } - data := cr.bytesBuf[:need] - cr.bytesBuf = cr.bytesBuf[need:] + data := r.bytesBuf[:need] + r.bytesBuf = r.bytesBuf[need:] return data, nil } -func (cr *compressedReader) uncompressPacket() error { - header, err := cr.buf.readNext(7) // size of compressed header +func (r *compressedReader) uncompressPacket() error { + header, err := r.buf.readNext(7) // size of compressed header if err != nil { return err } @@ -62,14 +61,13 @@ func (cr *compressedReader) uncompressPacket() error { comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) compressionSequence := uint8(header[3]) - - if compressionSequence != cr.mc.compressionSequence { + if compressionSequence != r.mc.compressionSequence { return ErrPktSync } - cr.mc.compressionSequence++ + r.mc.compressionSequence++ - comprData, err := cr.buf.readNext(comprLength) + comprData, err := r.buf.readNext(comprLength) if err != nil { return err } @@ -77,40 +75,38 @@ func (cr *compressedReader) uncompressPacket() error { // if payload is uncompressed, its length will be specified as zero, and its // true length is contained in comprLength if uncompressedLength == 0 { - cr.bytesBuf = append(cr.bytesBuf, comprData...) + r.bytesBuf = append(r.bytesBuf, comprData...) return nil } // write comprData to a bytes.buffer, then read it using zlib into data br := bytes.NewReader(comprData) - - if cr.zr == nil { - cr.zr, err = zlib.NewReader(br) + if r.zr == nil { + if r.zr, err = zlib.NewReader(br); err != nil { + return err + } } else { - err = cr.zr.(zlib.Resetter).Reset(br, nil) - } - - if err != nil { - return err + if err = r.zr.(zlib.Resetter).Reset(br, nil); err != nil { + return err + } } - - defer cr.zr.Close() + defer r.zr.Close() // use existing capacity in bytesBuf if possible - offset := len(cr.bytesBuf) - if cap(cr.bytesBuf)-offset < uncompressedLength { - old := cr.bytesBuf - cr.bytesBuf = make([]byte, offset, offset+uncompressedLength) - copy(cr.bytesBuf, old) + offset := len(r.bytesBuf) + if cap(r.bytesBuf)-offset < uncompressedLength { + old := r.bytesBuf + r.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(r.bytesBuf, old) } - data := cr.bytesBuf[offset : offset+uncompressedLength] + data := r.bytesBuf[offset : offset+uncompressedLength] lenRead := 0 // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { - n, err := cr.zr.Read(data[lenRead:]) + n, err := r.zr.Read(data[lenRead:]) lenRead += n if err == io.EOF { @@ -118,19 +114,17 @@ func (cr *compressedReader) uncompressPacket() error { return io.ErrUnexpectedEOF } break - } - - if err != nil { + } else if err != nil { return err } } - cr.bytesBuf = append(cr.bytesBuf, data...) + r.bytesBuf = append(r.bytesBuf, data...) return nil } -func (cw *compressedWriter) Write(data []byte) (int, error) { +func (w *compressedWriter) Write(data []byte) (int, error) { // when asked to write an empty packet, do nothing if len(data) == 0 { return 0, nil @@ -147,12 +141,11 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { bytesBuf := &bytes.Buffer{} bytesBuf.Write(blankHeader) - cw.zw.Reset(bytesBuf) - _, err := cw.zw.Write(payload) - if err != nil { + w.zw.Reset(bytesBuf) + if _, err := w.zw.Write(payload); err != nil { return 0, err } - cw.zw.Close() + w.zw.Close() // if compression expands the payload, do not compress compressedPayload := bytesBuf.Bytes() @@ -161,9 +154,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { payloadLen = 0 } - err = cw.writeToNetwork(compressedPayload, payloadLen) - - if err != nil { + if err := w.writeToNetwork(compressedPayload, payloadLen); err != nil { return 0, err } @@ -175,8 +166,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // do not attempt compression if packet is too small if payloadLen < minCompressLength { - err := cw.writeToNetwork(append(blankHeader, data...), 0) - if err != nil { + if err := w.writeToNetwork(append(blankHeader, data...), 0); err != nil { return 0, err } return totalBytes, nil @@ -184,12 +174,11 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { bytesBuf := &bytes.Buffer{} bytesBuf.Write(blankHeader) - cw.zw.Reset(bytesBuf) - _, err := cw.zw.Write(data) - if err != nil { + w.zw.Reset(bytesBuf) + if _, err := w.zw.Write(data); err != nil { return 0, err } - cw.zw.Close() + w.zw.Close() compressedPayload := bytesBuf.Bytes() @@ -199,8 +188,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } // add header and send over the wire - err = cw.writeToNetwork(compressedPayload, payloadLen) - if err != nil { + if err := w.writeToNetwork(compressedPayload, payloadLen); err != nil { return 0, err } @@ -208,7 +196,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } -func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { +func (w *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { comprLength := len(data) - 7 // compression header @@ -216,17 +204,17 @@ func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error data[1] = byte(0xff & (comprLength >> 8)) data[2] = byte(0xff & (comprLength >> 16)) - data[3] = cw.mc.compressionSequence + data[3] = w.mc.compressionSequence // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncomprLength) data[5] = byte(0xff & (uncomprLength >> 8)) data[6] = byte(0xff & (uncomprLength >> 16)) - if _, err := cw.connWriter.Write(data); err != nil { + if _, err := w.connWriter.Write(data); err != nil { return err } - cw.mc.compressionSequence++ + w.mc.compressionSequence++ return nil } From 59c3cf1d3512df774a1d8130cd92978922818654 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Wed, 11 Oct 2023 17:07:06 +0200 Subject: [PATCH 40/88] Fix compression negotiations. --- connector.go | 3 ++- dsn.go | 37 ++++++++++++++++++++++++++++++++----- errors.go | 1 + packets.go | 12 +++++++++++- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/connector.go b/connector.go index b59b262b0..668aa1711 100644 --- a/connector.go +++ b/connector.go @@ -145,6 +145,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err @@ -159,7 +160,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.Compress { + if mc.cfg.Compress != CompressionModeDisabled { mc.reader = newCompressedReader(&mc.buf, mc) mc.writer = newCompressedWriter(mc.writer, mc) } diff --git a/dsn.go b/dsn.go index ba42721f1..69422e5b1 100644 --- a/dsn.go +++ b/dsn.go @@ -61,13 +61,22 @@ type Config struct { CheckConnLiveness bool // Check connections for liveness before using them ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names - Compress bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + + Compress CompressionMode // Compress packets } +type CompressionMode string + +const ( + CompressionModeDisabled CompressionMode = "disabled" + CompressionModePreferred CompressionMode = "preferred" + CompressionModeRequired CompressionMode = "required" +) + // NewConfig creates a new Config and sets default values. func NewConfig() *Config { return &Config{ @@ -247,6 +256,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.Compress != CompressionModeDisabled { + writeDSNParam(&buf, &hasParam, "compress", url.QueryEscape(string(cfg.Compress))) + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -467,10 +480,24 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - var isBool bool - cfg.Compress, isBool = readBool(value) - if !isBool { - return errors.New("invalid bool value: " + value) + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.Compress = CompressionModePreferred + } else { + cfg.Compress = CompressionModeDisabled + } + } else { + switch strings.ToLower(value) { + case string(CompressionModeDisabled): + cfg.Compress = CompressionModeDisabled + case string(CompressionModePreferred): + cfg.Compress = CompressionModePreferred + case string(CompressionModeRequired): + cfg.Compress = CompressionModeRequired + default: + return fmt.Errorf("invalid value for compression mode") + } } // Enable client side placeholder substitution diff --git a/errors.go b/errors.go index a9a3060c9..840d146a6 100644 --- a/errors.go +++ b/errors.go @@ -20,6 +20,7 @@ var ( ErrInvalidConn = errors.New("invalid connection") ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrNoCompression = errors.New("compression requested but server does not support compression") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") ErrNativePassword = errors.New("this user requires mysql native password authentication") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") diff --git a/packets.go b/packets.go index 268065bb5..f101a1f35 100644 --- a/packets.go +++ b/packets.go @@ -223,6 +223,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if mc.flags&clientProtocol41 == 0 { return nil, "", ErrOldProtocol } + if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil @@ -230,6 +231,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrNoTLS } } + + if mc.flags&clientCompress == 0 { + if mc.cfg.Compress != CompressionModeRequired { + mc.cfg.Compress = CompressionModeDisabled + } else { + return nil, "", ErrNoCompression + } + } + pos += 2 if len(data) > pos { @@ -293,7 +303,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientFoundRows } - if mc.cfg.Compress { + if mc.cfg.Compress != CompressionModeDisabled { clientFlags |= clientCompress } From d1aef08c8323ffb7568ce6c4c3f986b4ed51df3f Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 15:58:39 +0200 Subject: [PATCH 41/88] Format README. --- README.md | 176 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 102 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 7c6547067..565995c25 100644 --- a/README.md +++ b/README.md @@ -4,59 +4,66 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") ---------------------------------------- - * [Features](#features) - * [Requirements](#requirements) - * [Installation](#installation) - * [Usage](#usage) - * [DSN (Data Source Name)](#dsn-data-source-name) - * [Password](#password) - * [Protocol](#protocol) - * [Address](#address) - * [Parameters](#parameters) - * [Examples](#examples) - * [Connection pool and timeouts](#connection-pool-and-timeouts) - * [context.Context Support](#contextcontext-support) - * [ColumnType Support](#columntype-support) - * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) - * [time.Time support](#timetime-support) - * [Unicode support](#unicode-support) - * [Testing / Development](#testing--development) - * [License](#license) - ---------------------------------------- +--- + +- [Features](#features) +- [Requirements](#requirements) +- [Installation](#installation) +- [Usage](#usage) + - [DSN (Data Source Name)](#dsn-data-source-name) + - [Password](#password) + - [Protocol](#protocol) + - [Address](#address) + - [Parameters](#parameters) + - [Examples](#examples) + - [Connection pool and timeouts](#connection-pool-and-timeouts) + - [context.Context Support](#contextcontext-support) + - [ColumnType Support](#columntype-support) + - [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) + - [time.Time support](#timetime-support) + - [Unicode support](#unicode-support) +- [Testing / Development](#testing--development) +- [License](#license) + +--- ## Features - * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance"). - * Native Go implementation. No C-bindings, just pure Go. - * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc). - * Automatic handling of broken connections. - * Automatic Connection Pooling *(by database/sql package)*. - * Supports queries larger than 16MB. - * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. - * Intelligent `LONG DATA` handling in prepared statements. - * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support. - * Optional `time.Time` parsing. - * Optional placeholder interpolation. - * Supports zlib compression. + +- Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance"). +- Native Go implementation. No C-bindings, just pure Go. +- Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc). +- Automatic handling of broken connections. +- Automatic Connection Pooling _(by database/sql package)_. +- Supports queries larger than 16MB. +- Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. +- Intelligent `LONG DATA` handling in prepared statements. +- Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support. +- Optional `time.Time` parsing. +- Optional placeholder interpolation. +- Supports zlib compression. ## Requirements - * Go 1.18 or higher. We aim to support the 3 latest versions of Go. - * MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) ---------------------------------------- +- Go 1.18 or higher. We aim to support the 3 latest versions of Go. +- MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + +--- ## Installation + Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: + ```bash $ go get -u github.com/go-sql-driver/mysql ``` + Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. ## Usage + _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. -Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: +Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go import ( @@ -88,29 +95,34 @@ db.SetMaxIdleConns(10) `db.SetMaxIdleConns()` is recommended to be set same to `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed much more frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. - ### DSN (Data Source Name) The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): + ``` [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] ``` A DSN in its fullest form: + ``` username:password@protocol(address)/dbname?param=value ``` Except for the databasename, all values are optional. So the minimal DSN is: + ``` /dbname ``` If you do not want to preselect a database, leave `dbname` empty: + ``` / ``` + This has the same effect as an empty DSN string: + ``` ``` @@ -124,13 +136,16 @@ This has the same effect as an empty DSN string: Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password + Passwords can consist of any character. Escaping is **not** necessary. #### Protocol + See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. In general you should use an Unix domain socket if available and TCP otherwise for best performance. #### Address + For TCP and UDP networks, addresses have the form `host[:port]`. If `port` is omitted, the default port will be used. If `host` is a literal IPv6 address, it must be enclosed in square brackets. @@ -139,7 +154,8 @@ The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [ For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. #### Parameters -*Parameters are case-sensitive!* + +_Parameters are case-sensitive!_ Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`. @@ -151,8 +167,8 @@ Valid Values: true, false Default: false ``` -`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. -[*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) +`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows _all_ files. +[_Might be insecure!_](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) ##### `allowCleartextPasswords` @@ -164,7 +180,6 @@ Default: false `allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. - ##### `allowFallbackToPlaintext` ``` @@ -182,6 +197,7 @@ Type: bool Valid Values: true, false Default: true ``` + `allowNativePasswords=false` disallows the usage of MySQL native password method. ##### `allowOldPasswords` @@ -191,6 +207,7 @@ Type: bool Valid Values: true, false Default: false ``` + `allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). ##### `charset` @@ -228,7 +245,7 @@ Sets the collation used for client-server interaction on connection. In contrast A list of valid charsets for a server is retrievable with `SHOW COLLATION`. -The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). @@ -270,7 +287,7 @@ Default: false If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. -*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* +_This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!_ ##### `loc` @@ -280,19 +297,20 @@ Valid Values: Default: UTC ``` -Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. +Sets the location for time.Time values (when using `parseTime=true`). _"Local"_ sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. ##### `maxAllowedPacket` + ``` Type: decimal number Default: 64*1024*1024 ``` -Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. +Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server _on every connection_. ##### `multiStatements` @@ -333,7 +351,6 @@ Default: false `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`. - ##### `readTimeout` ``` @@ -341,7 +358,7 @@ Type: duration Default: 0 ``` -I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. +I/O read timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. ##### `rejectReadOnly` @@ -351,7 +368,6 @@ Valid Values: true, false Default: false ``` - `rejectReadOnly=true` causes the driver to reject read-only connections. This is for a possible race condition during an automatic failover, where the mysql client gets connected to a read-only replica after the failover. @@ -372,7 +388,6 @@ cause a retry for that error. However the same error number is used for some other cases. You should ensure your application will never cause an ERROR 1290 except for `read-only` mode when enabling this option. - ##### `serverPubKey` ``` @@ -385,7 +400,6 @@ Server public keys can be registered with [`mysql.RegisterServerPubKey`](https:/ Public keys are used to transmit encrypted data, e.g. for authentication. If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. - ##### `timeout` ``` @@ -393,8 +407,7 @@ Type: duration Default: OS default ``` -Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. - +Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. ##### `tls` @@ -406,7 +419,6 @@ Default: false `tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). - ##### `writeTimeout` ``` @@ -414,7 +426,7 @@ Type: duration Default: 0 ``` -I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. +I/O write timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. ##### `connectionAttributes` @@ -429,22 +441,25 @@ Default: none ##### System Variables Any other parameters are interpreted as system variables: - * `=`: `SET =` - * `=`: `SET =` - * `=%27%27`: `SET =''` + +- `=`: `SET =` +- `=`: `SET =` +- `=%27%27`: `SET =''` Rules: -* The values for string variables must be quoted with `'`. -* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! - (which implies values of string variables must be wrapped with `%27`). + +- The values for string variables must be quoted with `'`. +- The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! + (which implies values of string variables must be wrapped with `%27`). Examples: - * `autocommit=1`: `SET autocommit=1` - * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` - * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` +- `autocommit=1`: `SET autocommit=1` +- [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` +- [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` #### Examples + ``` user@unix(/path/to/socket)/dbname ``` @@ -458,74 +473,84 @@ user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true ``` Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): + ``` user:password@/dbname?sql_mode=TRADITIONAL ``` TCP via IPv6: + ``` user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci ``` TCP on a remote host, e.g. Amazon RDS: + ``` id:password@tcp(your-amazonaws-uri.com:3306)/dbname ``` Google Cloud SQL on App Engine: + ``` user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname ``` TCP using default port (3306) on localhost: + ``` user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped ``` Use the default protocol (tcp) and host (localhost:3306): + ``` user:password@/dbname ``` No Database preselected: + ``` user:password@/ ``` - ### Connection pool and timeouts + The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support + This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`. ## `context.Context` Support + Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. - ### `LOAD DATA LOCAL INFILE` support + For this feature you need direct access to the package. Therefore you must change the import path (no `_`): + ```go import "github.com/go-sql-driver/mysql" ``` -Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([_Might be insecure!_](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. - ### `time.Time` support + The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). - ### Unicode support + Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter. @@ -537,6 +562,7 @@ Other charsets / collations can be set using the [`charset`](#charset) or [`coll See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. ## Testing / Development + To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. @@ -544,19 +570,21 @@ If you want to contribute, you can work on an [open issue](https://github.com/go See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details. ---------------------------------------- +--- ## License + Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) Mozilla summarizes the license scope as follows: -> MPL: The copyleft applies to any files containing MPLed code. +> MPL: The copyleft applies to any files containing MPLed code. That means: - * You can **use** the **unchanged** source code both in private and commercially. - * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). - * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. + +- You can **use** the **unchanged** source code both in private and commercially. +- When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). +- You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. From 09a4fb8621891f5bb5e7936a0fe010bb90da6424 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 16:06:10 +0200 Subject: [PATCH 42/88] Add usage instructions to README. --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index 565995c25..d06f0ea7d 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: string +Valid Values: disabled, preferred, required +Default: disabled +``` + +Toggles zlib compression. `compress=disabled` is the default value and disables compression even if offered by the server. `compress=preferred` uses compression if offered by the server, and `compress=required` will cause connection to fail if not offered by the server. In both of these cases, compression is also controlled by the `minCompressLength` parameter. + ##### `interpolateParams` ``` @@ -312,6 +322,15 @@ Default: 64*1024*1024 Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server _on every connection_. +##### `minCompressLength` + +``` +Type: decimal number +Default: 50 +``` + +Min packet size in bytes to compress, when compression is enabled (see the `compress` parameter). Packets smaller than this will be sent uncompressed. + ##### `multiStatements` ``` From e523af28200742ac0da1e40e798754a3715b134d Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 16:06:33 +0200 Subject: [PATCH 43/88] Add minCompressLength param. --- const.go | 11 ++++++----- dsn.go | 13 ++++++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/const.go b/const.go index 0f2621a6f..ed828cfe2 100644 --- a/const.go +++ b/const.go @@ -11,11 +11,12 @@ package mysql import "runtime" const ( - defaultAuthPlugin = "mysql_native_password" - defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 - minProtocolVersion = 10 - maxPacketSize = 1<<24 - 1 - timeFormat = "2006-01-02 15:04:05.999999" + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 + minProtocolVersion = 10 + maxPacketSize = 1<<24 - 1 + timeFormat = "2006-01-02 15:04:05.999999" + defaultMinCompressLength = 50 // Connection attributes // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available diff --git a/dsn.go b/dsn.go index 69422e5b1..6a8c23e91 100644 --- a/dsn.go +++ b/dsn.go @@ -66,7 +66,8 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - Compress CompressionMode // Compress packets + Compress CompressionMode // Compress packets + MinCompressLength int // Don't compress packets less than this number of bytes } type CompressionMode string @@ -85,6 +86,7 @@ func NewConfig() *Config { Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, + MinCompressLength: defaultMinCompressLength, } } @@ -260,6 +262,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "compress", url.QueryEscape(string(cfg.Compress))) } + if cfg.MinCompressLength != defaultMinCompressLength { + writeDSNParam(&buf, &hasParam, "minCompressLength", strconv.Itoa(cfg.MinCompressLength)) + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -499,6 +505,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { return fmt.Errorf("invalid value for compression mode") } } + case "minCompressLength": + cfg.MinCompressLength, err = strconv.Atoi(value) + if err != nil { + return + } // Enable client side placeholder substitution case "interpolateParams": From efbc53b3a7265fb3f1a4f8f7bec5df930b1e91ad Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 16:07:15 +0200 Subject: [PATCH 44/88] Fix non-compression of small packets. --- compress.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/compress.go b/compress.go index efaadd961..dfe6fcd2a 100644 --- a/compress.go +++ b/compress.go @@ -137,24 +137,25 @@ func (w *compressedWriter) Write(data []byte) (int, error) { for length >= maxPayloadLength { payload := data[:maxPayloadLength] - payloadLen := len(payload) + uncompressedLen := len(payload) - bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) - w.zw.Reset(bytesBuf) - if _, err := w.zw.Write(payload); err != nil { - return 0, err - } - w.zw.Close() + buf := bytes.NewBuffer(blankHeader) - // if compression expands the payload, do not compress - compressedPayload := bytesBuf.Bytes() - if len(compressedPayload) > maxPayloadLength { - compressedPayload = append(blankHeader, payload...) - payloadLen = 0 + // If payload is less than minCompressLength, don't compress. + if uncompressedLen < w.mc.cfg.MinCompressLength { + if _, err := buf.Write(payload); err != nil { + return 0, err + } + uncompressedLen = 0 + } else { + w.zw.Reset(buf) + if _, err := w.zw.Write(payload); err != nil { + return 0, err + } + w.zw.Close() } - if err := w.writeToNetwork(compressedPayload, payloadLen); err != nil { + if err := w.writeToNetwork(buf.Bytes(), uncompressedLen); err != nil { return 0, err } @@ -196,7 +197,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { } -func (w *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { +func (w *compressedWriter) writeToNetwork(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 // compression header @@ -207,9 +208,9 @@ func (w *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error data[3] = w.mc.compressionSequence // this value is never greater than maxPayloadLength - data[4] = byte(0xff & uncomprLength) - data[5] = byte(0xff & (uncomprLength >> 8)) - data[6] = byte(0xff & (uncomprLength >> 16)) + data[4] = byte(0xff & uncompressedLen) + data[5] = byte(0xff & (uncompressedLen >> 8)) + data[6] = byte(0xff & (uncompressedLen >> 16)) if _, err := w.connWriter.Write(data); err != nil { return err From 7610823429e78d1454c713fa7a460d980b96a403 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 16:10:26 +0200 Subject: [PATCH 45/88] Rename fields for clarity. --- benchmark_test.go | 2 +- connection.go | 4 ++-- connection_test.go | 14 +++++++------- connector.go | 8 ++++---- packets.go | 8 ++++---- packets_test.go | 12 ++++++------ 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index cba2d6783..622f57cc7 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -236,7 +236,7 @@ func BenchmarkInterpolation(b *testing.B) { maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } - mc.reader = &mc.buf + mc.packetReader = &mc.buf args := []driver.Value{ int64(42424242), diff --git a/connection.go b/connection.go index 17bb72736..bee1ab0e5 100644 --- a/connection.go +++ b/connection.go @@ -25,8 +25,8 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). - reader packetReader - writer io.Writer + packetReader packetReader + packetWriter io.Writer cfg *Config connector *connector maxAllowedPacket int diff --git a/connection_test.go b/connection_test.go index 59ebb44e0..4f129b658 100644 --- a/connection_test.go +++ b/connection_test.go @@ -25,7 +25,7 @@ func TestInterpolateParams(t *testing.T) { InterpolateParams: true, }, } - mc.reader = &mc.buf + mc.packetReader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -73,7 +73,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { InterpolateParams: true, }, } - mc.reader = &mc.buf + mc.packetReader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -92,7 +92,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { }, } - mc.reader = &mc.buf + mc.packetReader = &mc.buf q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` @@ -168,8 +168,8 @@ func TestPingMarkBadConnection(t *testing.T) { mc := &mysqlConn{ netConn: nc, buf: buf, - reader: &buf, - writer: nc, + packetReader: &buf, + packetWriter: nc, maxAllowedPacket: defaultMaxAllowedPacket, } @@ -187,8 +187,8 @@ func TestPingErrInvalidConn(t *testing.T) { mc := &mysqlConn{ netConn: nc, buf: buf, - reader: &buf, - writer: nc, + packetReader: &buf, + packetWriter: nc, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index 6fe0e20de..e10d73362 100644 --- a/connector.go +++ b/connector.go @@ -114,8 +114,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) // packet reader and writer in handshake are never compressed - mc.reader = &mc.buf - mc.writer = mc.netConn + mc.packetReader = &mc.buf + mc.packetWriter = mc.netConn // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout @@ -160,8 +160,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } if mc.cfg.Compress != CompressionModeDisabled { - mc.reader = newCompressedReader(&mc.buf, mc) - mc.writer = newCompressedWriter(mc.writer, mc) + mc.packetReader = newCompressedReader(&mc.buf, mc) + mc.packetWriter = newCompressedWriter(mc.packetWriter, mc) } if mc.cfg.MaxAllowedPacket > 0 { diff --git a/packets.go b/packets.go index 06d51f2de..c72da16b7 100644 --- a/packets.go +++ b/packets.go @@ -29,7 +29,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.reader.readNext(4) + data, err := mc.packetReader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -66,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.reader.readNext(pktLen) + data, err = mc.packetReader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -120,7 +120,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.writer.Write(data[:4+size]) + n, err := mc.packetWriter.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -370,7 +370,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.netConn = tlsConn mc.buf.nc = tlsConn - mc.writer = mc.netConn + mc.packetWriter = mc.netConn } // User [null terminated string] diff --git a/packets_test.go b/packets_test.go index 5a6120ae5..8000df3e9 100644 --- a/packets_test.go +++ b/packets_test.go @@ -106,8 +106,8 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { mc := &mysqlConn{ buf: buf, - reader: &reader, - writer: conn, + packetReader: &reader, + packetWriter: conn, cfg: connector.cfg, connector: connector, netConn: conn, @@ -124,7 +124,7 @@ func TestReadPacketSingleByte(t *testing.T) { buf: newBuffer(conn), } - mc.reader = &mc.buf + mc.packetReader = &mc.buf conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 @@ -178,7 +178,7 @@ func TestReadPacketSplit(t *testing.T) { buf: newBuffer(conn), } - mc.reader = &mc.buf + mc.packetReader = &mc.buf data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 @@ -286,7 +286,7 @@ func TestReadPacketFail(t *testing.T) { closech: make(chan struct{}), cfg: NewConfig(), } - mc.reader = &mc.buf + mc.packetReader = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} @@ -332,7 +332,7 @@ func TestRegression801(t *testing.T) { sequence: 42, closech: make(chan struct{}), } - mc.reader = &mc.buf + mc.packetReader = &mc.buf conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, From eb449fa82e713f125961953a48d336b08a7b439e Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 12 Oct 2023 17:10:32 +0200 Subject: [PATCH 46/88] Simplify compressedWriter.Write. --- compress.go | 58 ++++++++++++++--------------------------------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/compress.go b/compress.go index dfe6fcd2a..39dbb26cd 100644 --- a/compress.go +++ b/compress.go @@ -6,10 +6,6 @@ import ( "io" ) -const ( - minCompressLength = 50 -) - type compressedReader struct { buf packetReader bytesBuf []byte @@ -124,6 +120,10 @@ func (r *compressedReader) uncompressPacket() error { return nil } +const maxPayloadLen = maxPacketSize - 4 + +var blankHeader = make([]byte, 7) + func (w *compressedWriter) Write(data []byte) (int, error) { // when asked to write an empty packet, do nothing if len(data) == 0 { @@ -131,13 +131,16 @@ func (w *compressedWriter) Write(data []byte) (int, error) { } totalBytes := len(data) - length := len(data) - 4 - maxPayloadLength := maxPacketSize - 4 - blankHeader := make([]byte, 7) - for length >= maxPayloadLength { - payload := data[:maxPayloadLength] - uncompressedLen := len(payload) + dataLen := len(data) + for dataLen != 0 { + payloadLen := dataLen + if payloadLen > maxPayloadLen { + payloadLen = maxPayloadLen + } + payload := data[:payloadLen] + + uncompressedLen := payloadLen buf := bytes.NewBuffer(blankHeader) @@ -159,42 +162,11 @@ func (w *compressedWriter) Write(data []byte) (int, error) { return 0, err } - length -= maxPayloadLength - data = data[maxPayloadLength:] - } - - payloadLen := len(data) - - // do not attempt compression if packet is too small - if payloadLen < minCompressLength { - if err := w.writeToNetwork(append(blankHeader, data...), 0); err != nil { - return 0, err - } - return totalBytes, nil - } - - bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) - w.zw.Reset(bytesBuf) - if _, err := w.zw.Write(data); err != nil { - return 0, err - } - w.zw.Close() - - compressedPayload := bytesBuf.Bytes() - - if len(compressedPayload) > len(data) { - compressedPayload = append(blankHeader, data...) - payloadLen = 0 - } - - // add header and send over the wire - if err := w.writeToNetwork(compressedPayload, payloadLen); err != nil { - return 0, err + dataLen -= payloadLen + data = data[payloadLen:] } return totalBytes, nil - } func (w *compressedWriter) writeToNetwork(data []byte, uncompressedLen int) error { From 75c2480c25cf27b9033e61343fd6efa324a3a40d Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Fri, 13 Oct 2023 10:25:59 +0200 Subject: [PATCH 47/88] Disable compression by default. --- dsn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dsn.go b/dsn.go index ab844915a..651ba31d9 100644 --- a/dsn.go +++ b/dsn.go @@ -86,6 +86,7 @@ func NewConfig() *Config { Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, + Compress: CompressionModeDisabled, MinCompressLength: defaultMinCompressLength, } } From 8b8b428d357b29cde455c2a8c1f6523e35bad3c6 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Fri, 27 Oct 2023 16:42:25 +0200 Subject: [PATCH 48/88] Fix bytes.NewBuffer usage. --- compress.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/compress.go b/compress.go index 39dbb26cd..cae6642d6 100644 --- a/compress.go +++ b/compress.go @@ -142,7 +142,10 @@ func (w *compressedWriter) Write(data []byte) (int, error) { uncompressedLen := payloadLen - buf := bytes.NewBuffer(blankHeader) + var buf bytes.Buffer + if _, err := buf.Write(blankHeader); err != nil { + return 0, err + } // If payload is less than minCompressLength, don't compress. if uncompressedLen < w.mc.cfg.MinCompressLength { @@ -151,7 +154,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { } uncompressedLen = 0 } else { - w.zw.Reset(buf) + w.zw.Reset(&buf) if _, err := w.zw.Write(payload); err != nil { return 0, err } From bc3ad685c09d7f51a4cc5ed4270e1fb16804b486 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 2 Nov 2023 19:21:42 +0100 Subject: [PATCH 49/88] Revert README formatting. --- README.md | 196 +++++++++++++++++++++--------------------------------- 1 file changed, 74 insertions(+), 122 deletions(-) diff --git a/README.md b/README.md index 924e9c47c..5c29b9bfe 100644 --- a/README.md +++ b/README.md @@ -4,66 +4,58 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") ---- - -- [Features](#features) -- [Requirements](#requirements) -- [Installation](#installation) -- [Usage](#usage) - - [DSN (Data Source Name)](#dsn-data-source-name) - - [Password](#password) - - [Protocol](#protocol) - - [Address](#address) - - [Parameters](#parameters) - - [Examples](#examples) - - [Connection pool and timeouts](#connection-pool-and-timeouts) - - [context.Context Support](#contextcontext-support) - - [ColumnType Support](#columntype-support) - - [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) - - [time.Time support](#timetime-support) - - [Unicode support](#unicode-support) -- [Testing / Development](#testing--development) -- [License](#license) - ---- +--------------------------------------- + * [Features](#features) + * [Requirements](#requirements) + * [Installation](#installation) + * [Usage](#usage) + * [DSN (Data Source Name)](#dsn-data-source-name) + * [Password](#password) + * [Protocol](#protocol) + * [Address](#address) + * [Parameters](#parameters) + * [Examples](#examples) + * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) + * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) + * [time.Time support](#timetime-support) + * [Unicode support](#unicode-support) + * [Testing / Development](#testing--development) + * [License](#license) + +--------------------------------------- ## Features - -- Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance"). -- Native Go implementation. No C-bindings, just pure Go. -- Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc). -- Automatic handling of broken connections. -- Automatic Connection Pooling _(by database/sql package)_. -- Supports queries larger than 16MB. -- Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. -- Intelligent `LONG DATA` handling in prepared statements. -- Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support. -- Optional `time.Time` parsing. -- Optional placeholder interpolation. -- Supports zlib compression. + * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") + * Native Go implementation. No C-bindings, just pure Go + * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) + * Automatic handling of broken connections + * Automatic Connection Pooling *(by database/sql package)* + * Supports queries larger than 16MB + * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. + * Intelligent `LONG DATA` handling in prepared statements + * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support + * Optional `time.Time` parsing + * Optional placeholder interpolation ## Requirements + * Go 1.18 or higher. We aim to support the 3 latest versions of Go. + * MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) -- Go 1.18 or higher. We aim to support the 3 latest versions of Go. -- MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) - ---- +--------------------------------------- ## Installation - Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: - ```bash $ go get -u github.com/go-sql-driver/mysql ``` - Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. ## Usage - _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. -Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: +Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go import ( @@ -95,34 +87,29 @@ db.SetMaxIdleConns(10) `db.SetMaxIdleConns()` is recommended to be set same to `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed much more frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. + ### DSN (Data Source Name) The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): - ``` [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] ``` A DSN in its fullest form: - ``` username:password@protocol(address)/dbname?param=value ``` Except for the databasename, all values are optional. So the minimal DSN is: - ``` /dbname ``` If you do not want to preselect a database, leave `dbname` empty: - ``` / ``` - This has the same effect as an empty DSN string: - ``` ``` @@ -136,16 +123,13 @@ This has the same effect as an empty DSN string: Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password - Passwords can consist of any character. Escaping is **not** necessary. #### Protocol - See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. In general you should use a Unix domain socket if available and TCP otherwise for best performance. #### Address - For TCP and UDP networks, addresses have the form `host[:port]`. If `port` is omitted, the default port will be used. If `host` is a literal IPv6 address, it must be enclosed in square brackets. @@ -154,8 +138,7 @@ The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [ For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. #### Parameters - -_Parameters are case-sensitive!_ +*Parameters are case-sensitive!* Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`. @@ -167,8 +150,8 @@ Valid Values: true, false Default: false ``` -`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows _all_ files. -[_Might be insecure!_](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) +`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. +[*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) ##### `allowCleartextPasswords` @@ -180,6 +163,7 @@ Default: false `allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + ##### `allowFallbackToPlaintext` ``` @@ -197,7 +181,6 @@ Type: bool Valid Values: true, false Default: true ``` - `allowNativePasswords=false` disallows the usage of MySQL native password method. ##### `allowOldPasswords` @@ -207,7 +190,6 @@ Type: bool Valid Values: true, false Default: false ``` - `allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). ##### `charset` @@ -245,7 +227,7 @@ Sets the collation used for client-server interaction on connection. In contrast A list of valid charsets for a server is retrievable with `SHOW COLLATION`. -The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). @@ -277,16 +259,6 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. -##### `compress` - -``` -Type: string -Valid Values: disabled, preferred, required -Default: disabled -``` - -Toggles zlib compression. `compress=disabled` is the default value and disables compression even if offered by the server. `compress=preferred` uses compression if offered by the server, and `compress=required` will cause connection to fail if not offered by the server. In both of these cases, compression is also controlled by the `minCompressLength` parameter. - ##### `interpolateParams` ``` @@ -297,7 +269,7 @@ Default: false If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. -_This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!_ +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* ##### `loc` @@ -307,29 +279,19 @@ Valid Values: Default: UTC ``` -Sets the location for time.Time values (when using `parseTime=true`). _"Local"_ sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. +Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. ##### `maxAllowedPacket` - ``` Type: decimal number Default: 64*1024*1024 ``` -Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server _on every connection_. - -##### `minCompressLength` - -``` -Type: decimal number -Default: 50 -``` - -Min packet size in bytes to compress, when compression is enabled (see the `compress` parameter). Packets smaller than this will be sent uncompressed. +Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` @@ -370,6 +332,7 @@ Default: false `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`. + ##### `readTimeout` ``` @@ -377,7 +340,7 @@ Type: duration Default: 0 ``` -I/O read timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. +I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### `rejectReadOnly` @@ -387,6 +350,7 @@ Valid Values: true, false Default: false ``` + `rejectReadOnly=true` causes the driver to reject read-only connections. This is for a possible race condition during an automatic failover, where the mysql client gets connected to a read-only replica after the failover. @@ -407,6 +371,7 @@ cause a retry for that error. However the same error number is used for some other cases. You should ensure your application will never cause an ERROR 1290 except for `read-only` mode when enabling this option. + ##### `serverPubKey` ``` @@ -419,6 +384,7 @@ Server public keys can be registered with [`mysql.RegisterServerPubKey`](https:/ Public keys are used to transmit encrypted data, e.g. for authentication. If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. + ##### `timeout` ``` @@ -426,7 +392,8 @@ Type: duration Default: OS default ``` -Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. +Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + ##### `tls` @@ -438,6 +405,7 @@ Default: false `tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). + ##### `writeTimeout` ``` @@ -445,7 +413,7 @@ Type: duration Default: 0 ``` -I/O write timeout. The value must be a decimal number with a unit suffix (_"ms"_, _"s"_, _"m"_, _"h"_), such as _"30s"_, _"0.5m"_ or _"1m30s"_. +I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### `connectionAttributes` @@ -460,25 +428,22 @@ Default: none ##### System Variables Any other parameters are interpreted as system variables: - -- `=`: `SET =` -- `=`: `SET =` -- `=%27%27`: `SET =''` + * `=`: `SET =` + * `=`: `SET =` + * `=%27%27`: `SET =''` Rules: - -- The values for string variables must be quoted with `'`. -- The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! - (which implies values of string variables must be wrapped with `%27`). +* The values for string variables must be quoted with `'`. +* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! + (which implies values of string variables must be wrapped with `%27`). Examples: + * `autocommit=1`: `SET autocommit=1` + * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` + * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` -- `autocommit=1`: `SET autocommit=1` -- [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` -- [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` #### Examples - ``` user@unix(/path/to/socket)/dbname ``` @@ -492,84 +457,74 @@ user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true ``` Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): - ``` user:password@/dbname?sql_mode=TRADITIONAL ``` TCP via IPv6: - ``` user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci ``` TCP on a remote host, e.g. Amazon RDS: - ``` id:password@tcp(your-amazonaws-uri.com:3306)/dbname ``` Google Cloud SQL on App Engine: - ``` user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname ``` TCP using default port (3306) on localhost: - ``` user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped ``` Use the default protocol (tcp) and host (localhost:3306): - ``` user:password@/dbname ``` No Database preselected: - ``` user:password@/ ``` -### Connection pool and timeouts +### Connection pool and timeouts The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support - This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`. ## `context.Context` Support - Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. -### `LOAD DATA LOCAL INFILE` support +### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): - ```go import "github.com/go-sql-driver/mysql" ``` -Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([_Might be insecure!_](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. -### `time.Time` support +### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). -### Unicode support +### Unicode support Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter. @@ -581,7 +536,6 @@ Other charsets / collations can be set using the [`charset`](#charset) or [`coll See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. ## Testing / Development - To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. @@ -589,24 +543,22 @@ If you want to contribute, you can work on an [open issue](https://github.com/go See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details. ---- +--------------------------------------- ## License - Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) Mozilla summarizes the license scope as follows: - > MPL: The copyleft applies to any files containing MPLed code. -That means: -- You can **use** the **unchanged** source code both in private and commercially. -- When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). -- You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. +That means: + * You can **use** the **unchanged** source code both in private and commercially. + * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). + * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). -![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") +![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") \ No newline at end of file From b6d9883d2b51bcd63a9b2210c9631770977401b9 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 2 Nov 2023 19:23:42 +0100 Subject: [PATCH 50/88] Update README with compression usage. --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 5c29b9bfe..fcac887dc 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation + * Supports zlib compression. ## Requirements * Go 1.18 or higher. We aim to support the 3 latest versions of Go. @@ -259,6 +260,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: string +Valid Values: disabled, preferred, required +Default: disabled +``` + +Toggles zlib compression. `compress=disabled` is the default value and disables compression even if offered by the server. `compress=preferred` uses compression if offered by the server, and `compress=required` will cause connection to fail if not offered by the server. In both of these cases, compression is also controlled by the `minCompressLength` parameter. + ##### `interpolateParams` ``` @@ -293,6 +304,15 @@ Default: 64*1024*1024 Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. +##### `minCompressLength` + +``` +Type: decimal number +Default: 50 +``` + +Min packet size in bytes to compress, when compression is enabled (see the `compress` parameter). Packets smaller than this will be sent uncompressed. + ##### `multiStatements` ``` From 5ec621c5765620a7d3cd84c414e2f02173f5d050 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Mar 2024 19:11:01 +0900 Subject: [PATCH 51/88] simplify --- README.md | 10 ++++---- compress.go | 2 +- compress_test.go | 41 +++++++++++++---------------- connector.go | 2 +- driver_test.go | 4 +++ dsn.go | 67 +++++++++++++++--------------------------------- packets.go | 10 +------- 7 files changed, 51 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index eb8fd7c6f..06bb8853b 100644 --- a/README.md +++ b/README.md @@ -271,12 +271,12 @@ will return `u.id` instead of just `id` if `columnsWithAlias=true`. ##### `compress` ``` -Type: string -Valid Values: disabled, preferred, required -Default: disabled +Type: bool +Valid Values: true, false +Default: false ``` -Toggles zlib compression. `compress=disabled` is the default value and disables compression even if offered by the server. `compress=preferred` uses compression if offered by the server, and `compress=required` will cause connection to fail if not offered by the server. In both of these cases, compression is also controlled by the `minCompressLength` parameter. +Toggles zlib compression. false by default. ##### `interpolateParams` @@ -598,4 +598,4 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). -![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") \ No newline at end of file +![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") diff --git a/compress.go b/compress.go index cae6642d6..7635f61f1 100644 --- a/compress.go +++ b/compress.go @@ -148,7 +148,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { } // If payload is less than minCompressLength, don't compress. - if uncompressedLen < w.mc.cfg.MinCompressLength { + if uncompressedLen < defaultMinCompressLength { if _, err := buf.Write(payload); err != nil { return 0, err } diff --git a/compress_test.go b/compress_test.go index 8812cd39d..f58a8c78f 100644 --- a/compress_test.go +++ b/compress_test.go @@ -15,7 +15,7 @@ func makeRandByteSlice(size int) []byte { } func newMockConn() *mysqlConn { - newConn := &mysqlConn{} + newConn := &mysqlConn{cfg: NewConfig()} return newConn } @@ -30,7 +30,6 @@ func newMockBuf(reader io.Reader) *mockBuf { } func (mb *mockBuf) readNext(need int) ([]byte, error) { - data := make([]byte, need) _, err := mb.reader.Read(data) if err != nil { @@ -74,12 +73,6 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by return b.Bytes() } -// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables -func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { - compressed := compressHelper(t, cSend, uncompressedPacket) - return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) -} - // uncompressHelper uncompresses compressedPacket and checks state variables func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { // get status variables @@ -110,21 +103,24 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS return uncompressedPacket } +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) +} + // TestCompressedReaderThenWriter tests reader and writer seperately. func TestCompressedReaderThenWriter(t *testing.T) { - makeTestUncompressedPacket := func(size int) []byte { - uncompressedHeader := make([]byte, 4) - uncompressedHeader[0] = byte(size) - uncompressedHeader[1] = byte(size >> 8) - uncompressedHeader[2] = byte(size >> 16) - - payload := make([]byte, size) - for i := range payload { - payload[i] = 'b' + makeUncompressedPacket := func(size int) []byte { + packet := make([]byte, 4+size) + packet[0] = byte(size) + packet[1] = byte(size >> 8) + packet[2] = byte(size >> 16) + + for i := 0; i < size; i++ { + packet[4+i] = 'b' } - - uncompressedPacket := append(uncompressedHeader, payload...) - return uncompressedPacket + return packet } tests := []struct { @@ -139,10 +135,10 @@ func TestCompressedReaderThenWriter(t *testing.T) { uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, desc: "golang"}, {compressed: []byte{19, 0, 0, 0, 104, 0, 0, 120, 156, 74, 97, 96, 96, 72, 162, 3, 0, 4, 0, 0, 255, 255, 182, 165, 38, 173}, - uncompressed: makeTestUncompressedPacket(100), + uncompressed: makeUncompressedPacket(100), desc: "100 bytes letter b"}, {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, - uncompressed: makeTestUncompressedPacket(33000), + uncompressed: makeUncompressedPacket(33000), desc: "33000 bytes letter b"}, } @@ -200,7 +196,6 @@ func TestRoundtrip(t *testing.T) { } cSend := newMockConn() - cReceive := newMockConn() for _, test := range tests { diff --git a/connector.go b/connector.go index c5906fb0b..c5bda1701 100644 --- a/connector.go +++ b/connector.go @@ -170,7 +170,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.Compress != CompressionModeDisabled { + if mc.cfg.compress { mc.packetReader = newCompressedReader(&mc.buf, mc) mc.packetWriter = newCompressedWriter(mc.packetWriter, mc) } diff --git a/driver_test.go b/driver_test.go index 001957244..0a24808a7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3539,6 +3539,10 @@ func TestConnectionAttributes(t *testing.T) { func TestErrorInMultiResult(t *testing.T) { // https://github.com/go-sql-driver/mysql/issues/1361 + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { db, err = sql.Open("mysql", dsn) diff --git a/dsn.go b/dsn.go index 4452b7240..d9d9b8e59 100644 --- a/dsn.go +++ b/dsn.go @@ -70,22 +70,15 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - // unexported fields. new options should be come here + // unexported fields. new options should be come here. + // boolean first. alphabetical order. - beforeConnect func(context.Context, *Config) error // Invoked before a connection is established - pubKey *rsa.PublicKey // Server public key - timeTruncate time.Duration // Truncate time.Time values to the specified duration - Compress CompressionMode // Compress packets - MinCompressLength int // Don't compress packets less than this number of bytes -} - -type CompressionMode string + compress bool // Enable zlib compression -const ( - CompressionModeDisabled CompressionMode = "disabled" - CompressionModePreferred CompressionMode = "preferred" - CompressionModeRequired CompressionMode = "required" -) + beforeConnect func(context.Context, *Config) error // Invoked before a connection is established + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration +} // Functional Options Pattern // https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis @@ -99,10 +92,7 @@ func NewConfig() *Config { Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, - Compress: CompressionModeDisabled, - MinCompressLength: defaultMinCompressLength, } - return cfg } @@ -134,6 +124,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option { } } +// EnableCompress sets the compression mode. +func EnableCompression(yes bool) Option { + return func(cfg *Config) error { + cfg.compress = yes + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { @@ -302,12 +300,8 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } - if cfg.Compress != CompressionModeDisabled { - writeDSNParam(&buf, &hasParam, "compress", url.QueryEscape(string(cfg.Compress))) - } - - if cfg.MinCompressLength != defaultMinCompressLength { - writeDSNParam(&buf, &hasParam, "minCompressLength", strconv.Itoa(cfg.MinCompressLength)) + if cfg.compress { + writeDSNParam(&buf, &hasParam, "compress", "true") } if cfg.InterpolateParams { @@ -534,29 +528,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - boolValue, isBool := readBool(value) - if isBool { - if boolValue { - cfg.Compress = CompressionModePreferred - } else { - cfg.Compress = CompressionModeDisabled - } - } else { - switch strings.ToLower(value) { - case string(CompressionModeDisabled): - cfg.Compress = CompressionModeDisabled - case string(CompressionModePreferred): - cfg.Compress = CompressionModePreferred - case string(CompressionModeRequired): - cfg.Compress = CompressionModeRequired - default: - return fmt.Errorf("invalid value for compression mode") - } - } - case "minCompressLength": - cfg.MinCompressLength, err = strconv.Atoi(value) - if err != nil { - return + var isBool bool + cfg.compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) } // Enable client side placeholder substitution diff --git a/packets.go b/packets.go index d088c8a35..d06bd7015 100644 --- a/packets.go +++ b/packets.go @@ -204,14 +204,6 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro } } - if mc.flags&clientCompress == 0 { - if mc.cfg.Compress != CompressionModeRequired { - mc.cfg.Compress = CompressionModeDisabled - } else { - return nil, "", ErrNoCompression - } - } - pos += 2 if len(data) > pos { @@ -275,7 +267,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientFoundRows } - if mc.cfg.Compress != CompressionModeDisabled { + if mc.cfg.compress { clientFlags |= clientCompress } From 3d0d418456f5a2cbe7669c8d669a94ecce7c8299 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Mar 2024 19:14:03 +0900 Subject: [PATCH 52/88] change minCompressLength to 150 --- compress.go | 2 +- const.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/compress.go b/compress.go index 7635f61f1..24584d2e3 100644 --- a/compress.go +++ b/compress.go @@ -148,7 +148,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { } // If payload is less than minCompressLength, don't compress. - if uncompressedLen < defaultMinCompressLength { + if uncompressedLen < minCompressLength { if _, err := buf.Write(payload); err != nil { return 0, err } diff --git a/const.go b/const.go index 01660ba39..1d19ed135 100644 --- a/const.go +++ b/const.go @@ -11,12 +11,12 @@ package mysql import "runtime" const ( - defaultAuthPlugin = "mysql_native_password" - defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 - minProtocolVersion = 10 - maxPacketSize = 1<<24 - 1 - timeFormat = "2006-01-02 15:04:05.999999" - defaultMinCompressLength = 50 + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 + minProtocolVersion = 10 + maxPacketSize = 1<<24 - 1 + timeFormat = "2006-01-02 15:04:05.999999" + minCompressLength = 150 // Connection attributes // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available From 31b8b3876168a59051aea97a465facb31002cd90 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Mar 2024 19:22:15 +0900 Subject: [PATCH 53/88] fixup --- packets.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packets.go b/packets.go index d06bd7015..78d89f65d 100644 --- a/packets.go +++ b/packets.go @@ -196,6 +196,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrOldProtocol } + // TODO(methane): writing to mc.cfg.XXX is bad idea. Fix it later. if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil @@ -203,6 +204,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrNoTLS } } + if mc.flags&clientCompress == 0 { + mc.cfg.compress = false + } pos += 2 From 9f797b1698cb991b3263ae1b2f91ad1bec30d468 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Mar 2024 20:25:24 +0900 Subject: [PATCH 54/88] remove unnecessary test --- compress_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/compress_test.go b/compress_test.go index f58a8c78f..eded4231f 100644 --- a/compress_test.go +++ b/compress_test.go @@ -134,9 +134,6 @@ func TestCompressedReaderThenWriter(t *testing.T) { {compressed: []byte{10, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, desc: "golang"}, - {compressed: []byte{19, 0, 0, 0, 104, 0, 0, 120, 156, 74, 97, 96, 96, 72, 162, 3, 0, 4, 0, 0, 255, 255, 182, 165, 38, 173}, - uncompressed: makeUncompressedPacket(100), - desc: "100 bytes letter b"}, {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, uncompressed: makeUncompressedPacket(33000), desc: "33000 bytes letter b"}, From d7ed57812f739cde252c1c82a7f32c15e6daa3ea Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 13 Mar 2024 20:01:37 +0900 Subject: [PATCH 55/88] code cleanup and minor improvements --- compress.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/compress.go b/compress.go index 24584d2e3..19922a4ea 100644 --- a/compress.go +++ b/compress.go @@ -3,6 +3,7 @@ package mysql import ( "bytes" "compress/zlib" + "fmt" "io" ) @@ -28,10 +29,15 @@ func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { } func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { + // level 1 or 2 is the best trade-off between speed and compression ratio + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } return &compressedWriter{ connWriter: connWriter, mc: mc, - zw: zlib.NewWriter(new(bytes.Buffer)), + zw: zw, } } @@ -42,7 +48,7 @@ func (r *compressedReader) readNext(need int) ([]byte, error) { } } - data := r.bytesBuf[:need] + data := r.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf r.bytesBuf = r.bytesBuf[need:] return data, nil } @@ -60,7 +66,6 @@ func (r *compressedReader) uncompressPacket() error { if compressionSequence != r.mc.compressionSequence { return ErrPktSync } - r.mc.compressionSequence++ comprData, err := r.buf.readNext(comprLength) @@ -97,7 +102,6 @@ func (r *compressedReader) uncompressPacket() error { } data := r.bytesBuf[offset : offset+uncompressedLength] - lenRead := 0 // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate @@ -114,9 +118,11 @@ func (r *compressedReader) uncompressPacket() error { return err } } - - r.bytesBuf = append(r.bytesBuf, data...) - + if lenRead != uncompressedLength { + return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", + uncompressedLength, lenRead) + } + r.bytesBuf = r.bytesBuf[:offset+uncompressedLength] return nil } @@ -125,13 +131,7 @@ const maxPayloadLen = maxPacketSize - 4 var blankHeader = make([]byte, 7) func (w *compressedWriter) Write(data []byte) (int, error) { - // when asked to write an empty packet, do nothing - if len(data) == 0 { - return 0, nil - } - totalBytes := len(data) - dataLen := len(data) for dataLen != 0 { payloadLen := dataLen @@ -139,7 +139,6 @@ func (w *compressedWriter) Write(data []byte) (int, error) { payloadLen = maxPayloadLen } payload := data[:payloadLen] - uncompressedLen := payloadLen var buf bytes.Buffer @@ -161,7 +160,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { w.zw.Close() } - if err := w.writeToNetwork(buf.Bytes(), uncompressedLen); err != nil { + if err := w.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } @@ -172,7 +171,9 @@ func (w *compressedWriter) Write(data []byte) (int, error) { return totalBytes, nil } -func (w *compressedWriter) writeToNetwork(data []byte, uncompressedLen int) error { +// writeCompressedPacket writes a compressed packet with header. +// data should start with 7 size space for header followed by payload. +func (w *compressedWriter) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 // compression header From 0f9ec9fc823b2cd0d8aeb39ed0db2be7c3ce3ccd Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 13 Mar 2024 20:06:27 +0900 Subject: [PATCH 56/88] remove test depends on compressed output --- compress_test.go | 49 ------------------------------------------------ 1 file changed, 49 deletions(-) diff --git a/compress_test.go b/compress_test.go index eded4231f..524f70cef 100644 --- a/compress_test.go +++ b/compress_test.go @@ -109,55 +109,6 @@ func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncomp return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) } -// TestCompressedReaderThenWriter tests reader and writer seperately. -func TestCompressedReaderThenWriter(t *testing.T) { - makeUncompressedPacket := func(size int) []byte { - packet := make([]byte, 4+size) - packet[0] = byte(size) - packet[1] = byte(size >> 8) - packet[2] = byte(size >> 16) - - for i := 0; i < size; i++ { - packet[4+i] = 'b' - } - return packet - } - - tests := []struct { - compressed []byte - uncompressed []byte - desc string - }{ - {compressed: []byte{5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 'a'}, - uncompressed: []byte{1, 0, 0, 0, 'a'}, - desc: "a"}, - {compressed: []byte{10, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, - uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, - desc: "golang"}, - {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, - uncompressed: makeUncompressedPacket(33000), - desc: "33000 bytes letter b"}, - } - - for _, test := range tests { - s := fmt.Sprintf("Test compress uncompress with %s", test.desc) - - // test uncompression only - c := newMockConn() - uncompressed := uncompressHelper(t, c, test.compressed, len(test.uncompressed)) - if !bytes.Equal(uncompressed, test.uncompressed) { - t.Fatalf("%s: uncompression failed", s) - } - - // test compression only - c = newMockConn() - compressed := compressHelper(t, c, test.uncompressed) - if !bytes.Equal(compressed, test.compressed) { - t.Fatalf("%s: compression failed", s) - } - } -} - // TestRoundtrip tests two connections, where one is reading and the other is writing func TestRoundtrip(t *testing.T) { tests := []struct { From a64171f512fec21ae8d51073254eddc136218900 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 14 Mar 2024 13:24:42 +0900 Subject: [PATCH 57/88] cleanup --- connector.go | 3 --- errors.go | 1 - packets.go | 4 ---- packets_test.go | 7 +------ 4 files changed, 1 insertion(+), 14 deletions(-) diff --git a/connector.go b/connector.go index c5bda1701..c920a5008 100644 --- a/connector.go +++ b/connector.go @@ -123,7 +123,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { defer mc.finish() mc.buf = newBuffer(mc.netConn) - // packet reader and writer in handshake are never compressed mc.packetReader = &mc.buf mc.packetWriter = mc.netConn @@ -155,7 +154,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err @@ -174,7 +172,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.packetReader = newCompressedReader(&mc.buf, mc) mc.packetWriter = newCompressedWriter(mc.packetWriter, mc) } - if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/errors.go b/errors.go index 840d146a6..a9a3060c9 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,6 @@ var ( ErrInvalidConn = errors.New("invalid connection") ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") - ErrNoCompression = errors.New("compression requested but server does not support compression") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") ErrNativePassword = errors.New("this user requires mysql native password authentication") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") diff --git a/packets.go b/packets.go index 78d89f65d..5747cc2c0 100644 --- a/packets.go +++ b/packets.go @@ -270,16 +270,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - if mc.cfg.compress { clientFlags |= clientCompress } - // To enable TLS / SSL if mc.cfg.TLS != nil { clientFlags |= clientSSL } - if mc.cfg.MultiStatements { clientFlags |= clientMultiStatements } @@ -364,7 +361,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn - mc.packetWriter = mc.netConn } diff --git a/packets_test.go b/packets_test.go index 4f772b326..abeded749 100644 --- a/packets_test.go +++ b/packets_test.go @@ -97,13 +97,10 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) - buf := newBuffer(conn) - reader := newBuffer(conn) - mc := &mysqlConn{ buf: buf, - packetReader: &reader, + packetReader: &buf, packetWriter: conn, cfg: connector.cfg, connector: connector, @@ -120,7 +117,6 @@ func TestReadPacketSingleByte(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } - mc.packetReader = &mc.buf conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -174,7 +170,6 @@ func TestReadPacketSplit(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } - mc.packetReader = &mc.buf data := make([]byte, maxPacketSize*2+4*3) From d78cdf864964311d9fb92c5722b046ad253756b8 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 14 Mar 2024 13:52:24 +0900 Subject: [PATCH 58/88] fix test --- packets_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packets_test.go b/packets_test.go index abeded749..a2fe0e8a8 100644 --- a/packets_test.go +++ b/packets_test.go @@ -147,7 +147,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { { ClientSequenceID: 0, ServerSequenceID: 0x42, - ExpectedErr: ErrPktSyncMul, + ExpectedErr: ErrPktSync, }, } { conn, mc := newRWMockConn(testCase.ClientSequenceID) From eb42024b9b21ec73426e2dfd247f275b8ffe58d1 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 14 Mar 2024 11:51:00 +0000 Subject: [PATCH 59/88] fix sync error --- compress.go | 22 ++++++++++++++++++++-- connection.go | 19 ++++++++++++++++++- infile.go | 1 + packets.go | 34 +++++++++++++++++++--------------- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/compress.go b/compress.go index 19922a4ea..2a27987f3 100644 --- a/compress.go +++ b/compress.go @@ -5,8 +5,12 @@ import ( "compress/zlib" "fmt" "io" + "os" ) +// for debugging wire protocol. +const debugTrace = false + type compressedReader struct { buf packetReader bytesBuf []byte @@ -63,6 +67,10 @@ func (r *compressedReader) uncompressPacket() error { comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) compressionSequence := uint8(header[3]) + if debugTrace { + fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n", + comprLength, uncompressedLength, compressionSequence) + } if compressionSequence != r.mc.compressionSequence { return ErrPktSync } @@ -133,7 +141,9 @@ var blankHeader = make([]byte, 7) func (w *compressedWriter) Write(data []byte) (int, error) { totalBytes := len(data) dataLen := len(data) - for dataLen != 0 { + var buf bytes.Buffer + + for dataLen > 0 { payloadLen := dataLen if payloadLen > maxPayloadLen { payloadLen = maxPayloadLen @@ -141,7 +151,6 @@ func (w *compressedWriter) Write(data []byte) (int, error) { payload := data[:payloadLen] uncompressedLen := payloadLen - var buf bytes.Buffer if _, err := buf.Write(blankHeader); err != nil { return 0, err } @@ -166,6 +175,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { dataLen -= payloadLen data = data[payloadLen:] + buf.Reset() } return totalBytes, nil @@ -188,7 +198,15 @@ func (w *compressedWriter) writeCompressedPacket(data []byte, uncompressedLen in data[5] = byte(0xff & (uncompressedLen >> 8)) data[6] = byte(0xff & (uncompressedLen >> 16)) + if debugTrace { + w.mc.cfg.Logger.Print( + fmt.Sprintf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, int(data[3]))) + } + if _, err := w.connWriter.Write(data); err != nil { + w.mc.cfg.Logger.Print(err) return err } diff --git a/connection.go b/connection.go index c93c964c7..c497c671e 100644 --- a/connection.go +++ b/connection.go @@ -51,6 +51,24 @@ type packetReader interface { readNext(need int) ([]byte, error) } +func (mc *mysqlConn) resetSeqNo() { + mc.sequence = 0 + mc.compressionSequence = 0 +} + +// syncSeqNo must be called when: +// - at least one large packet is sent (split packet happend), and +// - finished writing, before start reading. +func (mc *mysqlConn) syncSeqNo() { + // This syncs compressionSequence to sequence. + // This is done in `net_flush()` in MySQL and MariaDB. + // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 + // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 + if mc.cfg.compress { + mc.sequence = mc.compressionSequence + } +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -139,7 +157,6 @@ func (mc *mysqlConn) Close() (err error) { } mc.cleanup() - return } diff --git a/infile.go b/infile.go index 0c8af9f11..c1c12390f 100644 --- a/infile.go +++ b/infile.go @@ -171,6 +171,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } + mc.conn().syncSeqNo() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index 5747cc2c0..a24e33297 100644 --- a/packets.go +++ b/packets.go @@ -34,6 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } + // debug.PrintStack() mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn @@ -41,9 +42,13 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + if debugTrace { + mc.cfg.Logger.Print(fmt.Sprintf("readPacket: packet seq = %d, mc.sequence = %d", data[3], mc.sequence)) + } // check packet sync [8 bit] if data[3] != mc.sequence { + // debug.PrintStack() mc.Close() if data[3] > mc.sequence { return nil, ErrPktSyncMul @@ -112,6 +117,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { size = pktLen } data[3] = mc.sequence + // fmt.Fprintf(os.Stderr, "writePacket: seq=%v len=%v\n", mc.sequence, pktLen) // Write packet if mc.writeTimeout > 0 { @@ -415,12 +421,12 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.sequence = 0 - mc.compressionSequence = 0 + mc.resetSeqNo() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection + // debug.PrintStack() mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -434,8 +440,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.sequence = 0 - mc.compressionSequence = 0 + mc.resetSeqNo() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -452,13 +457,14 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSeqNo() + return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.sequence = 0 - mc.compressionSequence = 0 + mc.resetSeqNo() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -942,8 +948,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 - stmt.mc.compressionSequence = 0 + stmt.mc.resetSeqNo() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -964,12 +969,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence - stmt.mc.sequence = 0 - stmt.mc.compressionSequence = 0 + stmt.mc.resetSeqNo() return nil } @@ -994,8 +997,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.sequence = 0 - mc.compressionSequence = 0 + mc.resetSeqNo() var data []byte var err error @@ -1216,7 +1218,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + err = mc.writePacket(data) + mc.resetSeqNo() + return err } // For each remaining resultset in the stream, discards its rows and updates From 679cc530fe7b07bf94b60316a40c2c8063566a80 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 14 Mar 2024 13:22:12 +0000 Subject: [PATCH 60/88] fix sync error again --- packets.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/packets.go b/packets.go index a24e33297..ec0c5a080 100644 --- a/packets.go +++ b/packets.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + "runtime/debug" "strconv" "time" ) @@ -48,7 +49,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { - // debug.PrintStack() + if debugTrace { + debug.PrintStack() + } mc.Close() if data[3] > mc.sequence { return nil, ErrPktSyncMul @@ -117,9 +120,11 @@ func (mc *mysqlConn) writePacket(data []byte) error { size = pktLen } data[3] = mc.sequence - // fmt.Fprintf(os.Stderr, "writePacket: seq=%v len=%v\n", mc.sequence, pktLen) // Write packet + if debugTrace { + mc.cfg.Logger.Print(fmt.Sprintf("writePacket: size=%v seq=%v", size, mc.sequence)) + } if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { return err @@ -426,7 +431,6 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - // debug.PrintStack() mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -1219,7 +1223,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } err = mc.writePacket(data) - mc.resetSeqNo() + mc.syncSeqNo() return err } From 876af0790f5cc5726a30d0cdb3fa13cb23c1fb68 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 14 Mar 2024 14:28:01 +0000 Subject: [PATCH 61/88] fix todo --- connection.go | 3 ++- connector.go | 2 +- packets.go | 6 ++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/connection.go b/connection.go index c497c671e..2848436aa 100644 --- a/connection.go +++ b/connection.go @@ -37,6 +37,7 @@ type mysqlConn struct { sequence uint8 compressionSequence uint8 parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -64,7 +65,7 @@ func (mc *mysqlConn) syncSeqNo() { // This is done in `net_flush()` in MySQL and MariaDB. // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 - if mc.cfg.compress { + if mc.compress { mc.sequence = mc.compressionSequence } } diff --git a/connector.go b/connector.go index c920a5008..00f353403 100644 --- a/connector.go +++ b/connector.go @@ -168,7 +168,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress { + if mc.compress { mc.packetReader = newCompressedReader(&mc.buf, mc) mc.packetWriter = newCompressedWriter(mc.packetWriter, mc) } diff --git a/packets.go b/packets.go index ec0c5a080..c89477412 100644 --- a/packets.go +++ b/packets.go @@ -215,9 +215,6 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrNoTLS } } - if mc.flags&clientCompress == 0 { - mc.cfg.compress = false - } pos += 2 @@ -281,8 +278,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - if mc.cfg.compress { + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { clientFlags |= clientCompress + mc.compress = true } // To enable TLS / SSL if mc.cfg.TLS != nil { From d5ad92e701e10169775b0154e75dda46f5f41225 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 05:55:19 +0000 Subject: [PATCH 62/88] merge compressedReader and compressedWriter --- compress.go | 38 +++++++++++++------------------------- compress_test.go | 28 +++++++--------------------- connector.go | 5 +++-- packets.go | 2 -- 4 files changed, 23 insertions(+), 50 deletions(-) diff --git a/compress.go b/compress.go index 2a27987f3..c80ed02b3 100644 --- a/compress.go +++ b/compress.go @@ -11,41 +11,29 @@ import ( // for debugging wire protocol. const debugTrace = false -type compressedReader struct { - buf packetReader +type compressor struct { + mc *mysqlConn + // for reader bytesBuf []byte - mc *mysqlConn zr io.ReadCloser -} - -type compressedWriter struct { + // for writer connWriter io.Writer - mc *mysqlConn zw *zlib.Writer } -func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { - return &compressedReader{ - buf: buf, - bytesBuf: make([]byte, 0), - mc: mc, - } -} - -func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { - // level 1 or 2 is the best trade-off between speed and compression ratio +func newCompressor(mc *mysqlConn, w io.Writer) *compressor { zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) if err != nil { panic(err) // compress/zlib return non-nil error only if level is invalid } - return &compressedWriter{ - connWriter: connWriter, + return &compressor{ mc: mc, + connWriter: w, zw: zw, } } -func (r *compressedReader) readNext(need int) ([]byte, error) { +func (r *compressor) readNext(need int) ([]byte, error) { for len(r.bytesBuf) < need { if err := r.uncompressPacket(); err != nil { return nil, err @@ -57,8 +45,8 @@ func (r *compressedReader) readNext(need int) ([]byte, error) { return data, nil } -func (r *compressedReader) uncompressPacket() error { - header, err := r.buf.readNext(7) // size of compressed header +func (r *compressor) uncompressPacket() error { + header, err := r.mc.buf.readNext(7) // size of compressed header if err != nil { return err } @@ -76,7 +64,7 @@ func (r *compressedReader) uncompressPacket() error { } r.mc.compressionSequence++ - comprData, err := r.buf.readNext(comprLength) + comprData, err := r.mc.buf.readNext(comprLength) if err != nil { return err } @@ -138,7 +126,7 @@ const maxPayloadLen = maxPacketSize - 4 var blankHeader = make([]byte, 7) -func (w *compressedWriter) Write(data []byte) (int, error) { +func (w *compressor) Write(data []byte) (int, error) { totalBytes := len(data) dataLen := len(data) var buf bytes.Buffer @@ -183,7 +171,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (w *compressedWriter) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (w *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 // compression header diff --git a/compress_test.go b/compress_test.go index 524f70cef..264d9c4b2 100644 --- a/compress_test.go +++ b/compress_test.go @@ -19,25 +19,13 @@ func newMockConn() *mysqlConn { return newConn } -type mockBuf struct { - reader io.Reader -} - -func newMockBuf(reader io.Reader) *mockBuf { - return &mockBuf{ - reader: reader, +func newMockBuf(data []byte) buffer { + return buffer{ + buf: data, + length: len(data), } } -func (mb *mockBuf) readNext(need int) ([]byte, error) { - data := make([]byte, need) - _, err := mb.reader.Read(data) - if err != nil { - return nil, err - } - return data, nil -} - // compressHelper compresses uncompressedPacket and checks state variables func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { // get status variables @@ -47,7 +35,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by var b bytes.Buffer connWriter := &b - cw := newCompressedWriter(connWriter, mc) + cw := newCompressor(mc, connWriter) n, err := cw.Write(uncompressedPacket) @@ -79,10 +67,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS cs := mc.compressionSequence // mocking out buf variable - mockConnReader := bytes.NewReader(compressedPacket) - mockBuf := newMockBuf(mockConnReader) - - cr := newCompressedReader(mockBuf, mc) + mc.buf = newMockBuf(compressedPacket) + cr := newCompressor(mc, nil) uncompressedPacket, err := cr.readNext(expSize) if err != nil { diff --git a/connector.go b/connector.go index 00f353403..7945180e5 100644 --- a/connector.go +++ b/connector.go @@ -169,8 +169,9 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } if mc.compress { - mc.packetReader = newCompressedReader(&mc.buf, mc) - mc.packetWriter = newCompressedWriter(mc.packetWriter, mc) + cmpr := newCompressor(mc, mc.packetWriter) + mc.packetReader = cmpr + mc.packetWriter = cmpr } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index c89477412..e9c86e370 100644 --- a/packets.go +++ b/packets.go @@ -206,8 +206,6 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if mc.flags&clientProtocol41 == 0 { return nil, "", ErrOldProtocol } - - // TODO(methane): writing to mc.cfg.XXX is bad idea. Fix it later. if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil From 1c059169b7ebdb2723de617f057a1b736b4ecf0e Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 06:53:44 +0000 Subject: [PATCH 63/88] use sync.Pool for zlib --- compress.go | 183 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 106 insertions(+), 77 deletions(-) diff --git a/compress.go b/compress.go index c80ed02b3..31d4297d7 100644 --- a/compress.go +++ b/compress.go @@ -6,47 +6,108 @@ import ( "fmt" "io" "os" + "sync" ) +var ( + zrPool *sync.Pool // Do not use directly. Use zDecompress() instead. + zwPool *sync.Pool // Do not use directly. Use zCompress() instead. +) + +func init() { + zrPool = &sync.Pool{ + New: func() any { return nil }, + } + zwPool = &sync.Pool{ + New: func() any { + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } + return zw + }, + } +} + +func zDecompress(src, dst []byte) (int, error) { + br := bytes.NewReader(src) + var zr io.ReadCloser + var err error + + if a := zrPool.Get(); a == nil { + if zr, err = zlib.NewReader(br); err != nil { + return 0, err + } + } else { + zr = a.(io.ReadCloser) + if zr.(zlib.Resetter).Reset(br, nil); err != nil { + return 0, err + } + } + defer func() { + zr.Close() + zrPool.Put(zr) + }() + + lenRead := 0 + size := len(dst) + + for lenRead < size { + n, err := zr.Read(dst[lenRead:]) + lenRead += n + + if err == io.EOF { + if lenRead < size { + return lenRead, io.ErrUnexpectedEOF + } + } else if err != nil { + return lenRead, err + } + } + return lenRead, nil +} + +func zCompress(src []byte, dst io.Writer) error { + zw := zwPool.Get().(*zlib.Writer) + zw.Reset(dst) + if _, err := zw.Write(src); err != nil { + return err + } + zw.Close() + zwPool.Put(zw) + return nil +} + // for debugging wire protocol. const debugTrace = false type compressor struct { - mc *mysqlConn - // for reader - bytesBuf []byte - zr io.ReadCloser - // for writer + mc *mysqlConn + bytesBuf []byte // read buffer (FIFO) connWriter io.Writer - zw *zlib.Writer } func newCompressor(mc *mysqlConn, w io.Writer) *compressor { - zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) - if err != nil { - panic(err) // compress/zlib return non-nil error only if level is invalid - } return &compressor{ mc: mc, connWriter: w, - zw: zw, } } -func (r *compressor) readNext(need int) ([]byte, error) { - for len(r.bytesBuf) < need { - if err := r.uncompressPacket(); err != nil { +func (c *compressor) readNext(need int) ([]byte, error) { + for len(c.bytesBuf) < need { + if err := c.uncompressPacket(); err != nil { return nil, err } } - data := r.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf - r.bytesBuf = r.bytesBuf[need:] + data := c.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf + c.bytesBuf = c.bytesBuf[need:] return data, nil } -func (r *compressor) uncompressPacket() error { - header, err := r.mc.buf.readNext(7) // size of compressed header +func (c *compressor) uncompressPacket() error { + header, err := c.mc.buf.readNext(7) // size of compressed header if err != nil { return err } @@ -59,12 +120,12 @@ func (r *compressor) uncompressPacket() error { fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n", comprLength, uncompressedLength, compressionSequence) } - if compressionSequence != r.mc.compressionSequence { + if compressionSequence != c.mc.compressionSequence { return ErrPktSync } - r.mc.compressionSequence++ + c.mc.compressionSequence++ - comprData, err := r.mc.buf.readNext(comprLength) + comprData, err := c.mc.buf.readNext(comprLength) if err != nil { return err } @@ -72,53 +133,27 @@ func (r *compressor) uncompressPacket() error { // if payload is uncompressed, its length will be specified as zero, and its // true length is contained in comprLength if uncompressedLength == 0 { - r.bytesBuf = append(r.bytesBuf, comprData...) + c.bytesBuf = append(c.bytesBuf, comprData...) return nil } - // write comprData to a bytes.buffer, then read it using zlib into data - br := bytes.NewReader(comprData) - if r.zr == nil { - if r.zr, err = zlib.NewReader(br); err != nil { - return err - } - } else { - if err = r.zr.(zlib.Resetter).Reset(br, nil); err != nil { - return err - } - } - defer r.zr.Close() - // use existing capacity in bytesBuf if possible - offset := len(r.bytesBuf) - if cap(r.bytesBuf)-offset < uncompressedLength { - old := r.bytesBuf - r.bytesBuf = make([]byte, offset, offset+uncompressedLength) - copy(r.bytesBuf, old) + offset := len(c.bytesBuf) + if cap(c.bytesBuf)-offset < uncompressedLength { + old := c.bytesBuf + c.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(c.bytesBuf, old) } - data := r.bytesBuf[offset : offset+uncompressedLength] - lenRead := 0 - - // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate - for lenRead < uncompressedLength { - n, err := r.zr.Read(data[lenRead:]) - lenRead += n - - if err == io.EOF { - if lenRead < uncompressedLength { - return io.ErrUnexpectedEOF - } - break - } else if err != nil { - return err - } + lenRead, err := zDecompress(comprData, c.bytesBuf[offset:offset+uncompressedLength]) + if err != nil { + return err } if lenRead != uncompressedLength { return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", uncompressedLength, lenRead) } - r.bytesBuf = r.bytesBuf[:offset+uncompressedLength] + c.bytesBuf = c.bytesBuf[:offset+uncompressedLength] return nil } @@ -126,7 +161,7 @@ const maxPayloadLen = maxPacketSize - 4 var blankHeader = make([]byte, 7) -func (w *compressor) Write(data []byte) (int, error) { +func (c *compressor) Write(data []byte) (int, error) { totalBytes := len(data) dataLen := len(data) var buf bytes.Buffer @@ -150,17 +185,12 @@ func (w *compressor) Write(data []byte) (int, error) { } uncompressedLen = 0 } else { - w.zw.Reset(&buf) - if _, err := w.zw.Write(payload); err != nil { - return 0, err - } - w.zw.Close() + zCompress(payload, &buf) } - if err := w.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } - dataLen -= payloadLen data = data[payloadLen:] buf.Reset() @@ -171,33 +201,32 @@ func (w *compressor) Write(data []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (w *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 + if debugTrace { + c.mc.cfg.Logger.Print( + fmt.Sprintf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, c.mc.compressionSequence)) + } // compression header data[0] = byte(0xff & comprLength) data[1] = byte(0xff & (comprLength >> 8)) data[2] = byte(0xff & (comprLength >> 16)) - data[3] = w.mc.compressionSequence + data[3] = c.mc.compressionSequence // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncompressedLen) data[5] = byte(0xff & (uncompressedLen >> 8)) data[6] = byte(0xff & (uncompressedLen >> 16)) - if debugTrace { - w.mc.cfg.Logger.Print( - fmt.Sprintf( - "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", - comprLength, uncompressedLen, int(data[3]))) - } - - if _, err := w.connWriter.Write(data); err != nil { - w.mc.cfg.Logger.Print(err) + if _, err := c.connWriter.Write(data); err != nil { + c.mc.cfg.Logger.Print(err) return err } - w.mc.compressionSequence++ + c.mc.compressionSequence++ return nil } From 39e52e40c6d1024549153c0e6e12244f66052a4e Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 07:02:26 +0000 Subject: [PATCH 64/88] cleanup --- compress.go | 3 +-- compress_test.go | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/compress.go b/compress.go index 31d4297d7..5572f6fef 100644 --- a/compress.go +++ b/compress.go @@ -159,11 +159,10 @@ func (c *compressor) uncompressPacket() error { const maxPayloadLen = maxPacketSize - 4 -var blankHeader = make([]byte, 7) - func (c *compressor) Write(data []byte) (int, error) { totalBytes := len(data) dataLen := len(data) + blankHeader := make([]byte, 7) var buf bytes.Buffer for dataLen > 0 { diff --git a/compress_test.go b/compress_test.go index 264d9c4b2..c97e1bd5b 100644 --- a/compress_test.go +++ b/compress_test.go @@ -33,9 +33,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by cs := mc.compressionSequence var b bytes.Buffer - connWriter := &b - - cw := newCompressor(mc, connWriter) + cw := newCompressor(mc, &b) n, err := cw.Write(uncompressedPacket) From 0e3ace3c8117b74735c8bbdfcdd64c06cf05dcb0 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 08:21:28 +0000 Subject: [PATCH 65/88] code cleanup --- compress.go | 18 ++++++++++++----- compress_test.go | 28 +++++++++++++++++---------- connection.go | 50 +++++++++++++++++++++++------------------------- infile.go | 2 +- packets.go | 16 ++++++++-------- 5 files changed, 64 insertions(+), 50 deletions(-) diff --git a/compress.go b/compress.go index 5572f6fef..5059ae435 100644 --- a/compress.go +++ b/compress.go @@ -1,3 +1,11 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + package mysql import ( @@ -120,10 +128,10 @@ func (c *compressor) uncompressPacket() error { fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n", comprLength, uncompressedLength, compressionSequence) } - if compressionSequence != c.mc.compressionSequence { + if compressionSequence != c.mc.compresSequence { return ErrPktSync } - c.mc.compressionSequence++ + c.mc.compresSequence++ comprData, err := c.mc.buf.readNext(comprLength) if err != nil { @@ -206,7 +214,7 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err c.mc.cfg.Logger.Print( fmt.Sprintf( "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", - comprLength, uncompressedLen, c.mc.compressionSequence)) + comprLength, uncompressedLen, c.mc.compresSequence)) } // compression header @@ -214,7 +222,7 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err data[1] = byte(0xff & (comprLength >> 8)) data[2] = byte(0xff & (comprLength >> 16)) - data[3] = c.mc.compressionSequence + data[3] = c.mc.compresSequence // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncompressedLen) @@ -226,6 +234,6 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err return err } - c.mc.compressionSequence++ + c.mc.compresSequence++ return nil } diff --git a/compress_test.go b/compress_test.go index c97e1bd5b..eb98b4935 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,3 +1,11 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + package mysql import ( @@ -30,7 +38,7 @@ func newMockBuf(data []byte) buffer { func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { // get status variables - cs := mc.compressionSequence + cs := mc.compresSequence var b bytes.Buffer cw := newCompressor(mc, &b) @@ -46,13 +54,13 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } if len(uncompressedPacket) > 0 { - if mc.compressionSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence) + if mc.compresSequence != (cs + 1) { + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence) } } else { - if mc.compressionSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence) + if mc.compresSequence != cs { + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compresSequence) } } @@ -62,7 +70,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by // uncompressHelper uncompresses compressedPacket and checks state variables func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { // get status variables - cs := mc.compressionSequence + cs := mc.compresSequence // mocking out buf variable mc.buf = newMockBuf(compressedPacket) @@ -76,12 +84,12 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS } if expSize > 0 { - if mc.compressionSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence) + if mc.compresSequence != (cs + 1) { + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence) } } else { - if mc.compressionSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence) + if mc.compresSequence != cs { + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compresSequence) } } return uncompressedPacket diff --git a/connection.go b/connection.go index 2848436aa..96d94e67a 100644 --- a/connection.go +++ b/connection.go @@ -21,23 +21,23 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - result mysqlResult // managed by clearResult() and handleOkPacket(). - packetReader packetReader - packetWriter io.Writer - cfg *Config - connector *connector - maxAllowedPacket int - maxWriteSize int - writeTimeout time.Duration - flags clientFlag - status statusFlag - sequence uint8 - compressionSequence uint8 - parseTime bool - compress bool + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). + packetReader packetReader + packetWriter io.Writer + cfg *Config + connector *connector + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + compresSequence uint8 + parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -52,21 +52,19 @@ type packetReader interface { readNext(need int) ([]byte, error) } -func (mc *mysqlConn) resetSeqNo() { +func (mc *mysqlConn) resetSequenceNr() { mc.sequence = 0 - mc.compressionSequence = 0 + mc.compresSequence = 0 } -// syncSeqNo must be called when: -// - at least one large packet is sent (split packet happend), and -// - finished writing, before start reading. -func (mc *mysqlConn) syncSeqNo() { - // This syncs compressionSequence to sequence. - // This is done in `net_flush()` in MySQL and MariaDB. +// syncSequenceNr must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequenceNr() { + // Syncs compressionSequence to sequence. + // This is not documented but done in `net_flush()` in MySQL and MariaDB. // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 if mc.compress { - mc.sequence = mc.compressionSequence + mc.sequence = mc.compresSequence } } diff --git a/infile.go b/infile.go index c1c12390f..1f1b88735 100644 --- a/infile.go +++ b/infile.go @@ -171,7 +171,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } - mc.conn().syncSeqNo() + mc.conn().syncSequenceNr() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index e9c86e370..f56a3b6c7 100644 --- a/packets.go +++ b/packets.go @@ -422,7 +422,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.resetSeqNo() + mc.resetSequenceNr() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { @@ -440,7 +440,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.resetSeqNo() + mc.resetSequenceNr() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -458,13 +458,13 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Send CMD packet err = mc.writePacket(data) - mc.syncSeqNo() + mc.syncSequenceNr() return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.resetSeqNo() + mc.resetSequenceNr() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -948,7 +948,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.resetSeqNo() + stmt.mc.resetSequenceNr() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -972,7 +972,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Reset Packet Sequence - stmt.mc.resetSeqNo() + stmt.mc.resetSequenceNr() return nil } @@ -997,7 +997,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.resetSeqNo() + mc.resetSequenceNr() var data []byte var err error @@ -1219,7 +1219,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } err = mc.writePacket(data) - mc.syncSeqNo() + mc.syncSequenceNr() return err } From 750fe2aeaced26dd2ad93f5d2a6dc460ba3ef90d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 09:10:47 +0000 Subject: [PATCH 66/88] fix typo --- compress.go | 10 +++++----- compress_test.go | 20 ++++++++++---------- connection.go | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/compress.go b/compress.go index 5059ae435..2233704cb 100644 --- a/compress.go +++ b/compress.go @@ -128,10 +128,10 @@ func (c *compressor) uncompressPacket() error { fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n", comprLength, uncompressedLength, compressionSequence) } - if compressionSequence != c.mc.compresSequence { + if compressionSequence != c.mc.compressSequence { return ErrPktSync } - c.mc.compresSequence++ + c.mc.compressSequence++ comprData, err := c.mc.buf.readNext(comprLength) if err != nil { @@ -214,7 +214,7 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err c.mc.cfg.Logger.Print( fmt.Sprintf( "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", - comprLength, uncompressedLen, c.mc.compresSequence)) + comprLength, uncompressedLen, c.mc.compressSequence)) } // compression header @@ -222,7 +222,7 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err data[1] = byte(0xff & (comprLength >> 8)) data[2] = byte(0xff & (comprLength >> 16)) - data[3] = c.mc.compresSequence + data[3] = c.mc.compressSequence // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncompressedLen) @@ -234,6 +234,6 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err return err } - c.mc.compresSequence++ + c.mc.compressSequence++ return nil } diff --git a/compress_test.go b/compress_test.go index eb98b4935..9530f931d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -38,7 +38,7 @@ func newMockBuf(data []byte) buffer { func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { // get status variables - cs := mc.compresSequence + cs := mc.compressSequence var b bytes.Buffer cw := newCompressor(mc, &b) @@ -54,13 +54,13 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } if len(uncompressedPacket) > 0 { - if mc.compresSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence) + if mc.compressSequence != (cs + 1) { + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence) } } else { - if mc.compresSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compresSequence) + if mc.compressSequence != cs { + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressSequence) } } @@ -70,7 +70,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by // uncompressHelper uncompresses compressedPacket and checks state variables func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { // get status variables - cs := mc.compresSequence + cs := mc.compressSequence // mocking out buf variable mc.buf = newMockBuf(compressedPacket) @@ -84,12 +84,12 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS } if expSize > 0 { - if mc.compresSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence) + if mc.compressSequence != (cs + 1) { + t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence) } } else { - if mc.compresSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compresSequence) + if mc.compressSequence != cs { + t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressSequence) } } return uncompressedPacket diff --git a/connection.go b/connection.go index 96d94e67a..861890e61 100644 --- a/connection.go +++ b/connection.go @@ -35,7 +35,7 @@ type mysqlConn struct { flags clientFlag status statusFlag sequence uint8 - compresSequence uint8 + compressSequence uint8 parseTime bool compress bool @@ -54,7 +54,7 @@ type packetReader interface { func (mc *mysqlConn) resetSequenceNr() { mc.sequence = 0 - mc.compresSequence = 0 + mc.compressSequence = 0 } // syncSequenceNr must be called when finished writing some packet and before start reading. @@ -64,7 +64,7 @@ func (mc *mysqlConn) syncSequenceNr() { // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 if mc.compress { - mc.sequence = mc.compresSequence + mc.sequence = mc.compressSequence } } From 0512769f39a6650c5f699c310ca14dd14ddfb101 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 09:32:59 +0000 Subject: [PATCH 67/88] move const flag --- compress.go | 3 --- const.go | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/compress.go b/compress.go index 2233704cb..cad7738a6 100644 --- a/compress.go +++ b/compress.go @@ -86,9 +86,6 @@ func zCompress(src []byte, dst io.Writer) error { return nil } -// for debugging wire protocol. -const debugTrace = false - type compressor struct { mc *mysqlConn bytesBuf []byte // read buffer (FIFO) diff --git a/const.go b/const.go index 1d19ed135..58d64c618 100644 --- a/const.go +++ b/const.go @@ -11,6 +11,8 @@ package mysql import "runtime" const ( + debugTrace = false // for debugging wire protocol. + defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 From 60ce7881f446046016af325fc111d9131ab93311 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 09:55:13 +0000 Subject: [PATCH 68/88] remove writer from compressor --- compress.go | 16 +++++++++------- compress_test.go | 19 +++++++++++++++---- connector.go | 2 +- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/compress.go b/compress.go index cad7738a6..4d2658877 100644 --- a/compress.go +++ b/compress.go @@ -87,15 +87,17 @@ func zCompress(src []byte, dst io.Writer) error { } type compressor struct { - mc *mysqlConn - bytesBuf []byte // read buffer (FIFO) - connWriter io.Writer + mc *mysqlConn + // read buffer (FIFO). + // We can not reuse already-read buffer until dropping Go 1.20 support. + // It is because of database/mysql's weired behavior. + // See https://github.com/go-sql-driver/mysql/issues/1435 + bytesBuf []byte } -func newCompressor(mc *mysqlConn, w io.Writer) *compressor { +func newCompressor(mc *mysqlConn) *compressor { return &compressor{ - mc: mc, - connWriter: w, + mc: mc, } } @@ -226,7 +228,7 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err data[5] = byte(0xff & (uncompressedLen >> 8)) data[6] = byte(0xff & (uncompressedLen >> 16)) - if _, err := c.connWriter.Write(data); err != nil { + if _, err := c.mc.netConn.Write(data); err != nil { c.mc.cfg.Logger.Print(err) return err } diff --git a/compress_test.go b/compress_test.go index 9530f931d..4193c7247 100644 --- a/compress_test.go +++ b/compress_test.go @@ -13,6 +13,7 @@ import ( "crypto/rand" "fmt" "io" + "net" "testing" ) @@ -34,14 +35,24 @@ func newMockBuf(data []byte) buffer { } } +type dummyConn struct { + buf bytes.Buffer + net.Conn +} + +func (c *dummyConn) Write(data []byte) (int, error) { + return c.buf.Write(data) +} + // compressHelper compresses uncompressedPacket and checks state variables func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { // get status variables cs := mc.compressSequence - var b bytes.Buffer - cw := newCompressor(mc, &b) + var b dummyConn + mc.netConn = &b + cw := newCompressor(mc) n, err := cw.Write(uncompressedPacket) @@ -64,7 +75,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } } - return b.Bytes() + return b.buf.Bytes() } // uncompressHelper uncompresses compressedPacket and checks state variables @@ -74,7 +85,7 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS // mocking out buf variable mc.buf = newMockBuf(compressedPacket) - cr := newCompressor(mc, nil) + cr := newCompressor(mc) uncompressedPacket, err := cr.readNext(expSize) if err != nil { diff --git a/connector.go b/connector.go index 7945180e5..f40444450 100644 --- a/connector.go +++ b/connector.go @@ -169,7 +169,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } if mc.compress { - cmpr := newCompressor(mc, mc.packetWriter) + cmpr := newCompressor(mc) mc.packetReader = cmpr mc.packetWriter = cmpr } From ee70acf1bb32ae2c26611d59d206761d38bee7cd Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 10:56:20 +0000 Subject: [PATCH 69/88] remove packetWriter and simplify tests --- compress.go | 38 ++++++++++--------- compress_test.go | 92 ++++++++++++++-------------------------------- connection.go | 1 - connection_test.go | 2 - connector.go | 8 ++-- packets.go | 12 ++++-- packets_test.go | 1 - 7 files changed, 60 insertions(+), 94 deletions(-) diff --git a/compress.go b/compress.go index 4d2658877..dba660555 100644 --- a/compress.go +++ b/compress.go @@ -86,7 +86,7 @@ func zCompress(src []byte, dst io.Writer) error { return nil } -type compressor struct { +type decompressor struct { mc *mysqlConn // read buffer (FIFO). // We can not reuse already-read buffer until dropping Go 1.20 support. @@ -95,13 +95,13 @@ type compressor struct { bytesBuf []byte } -func newCompressor(mc *mysqlConn) *compressor { - return &compressor{ +func newDecompressor(mc *mysqlConn) *decompressor { + return &decompressor{ mc: mc, } } -func (c *compressor) readNext(need int) ([]byte, error) { +func (c *decompressor) readNext(need int) ([]byte, error) { for len(c.bytesBuf) < need { if err := c.uncompressPacket(); err != nil { return nil, err @@ -113,7 +113,7 @@ func (c *compressor) readNext(need int) ([]byte, error) { return data, nil } -func (c *compressor) uncompressPacket() error { +func (c *decompressor) uncompressPacket() error { header, err := c.mc.buf.readNext(7) // size of compressed header if err != nil { return err @@ -166,9 +166,11 @@ func (c *compressor) uncompressPacket() error { const maxPayloadLen = maxPacketSize - 4 -func (c *compressor) Write(data []byte) (int, error) { - totalBytes := len(data) - dataLen := len(data) +// writeCompressed sends one or some packets with compression. +// Use this instead of mc.netConn.Write() when mc.compress is true. +func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { + totalBytes := len(packets) + dataLen := len(packets) blankHeader := make([]byte, 7) var buf bytes.Buffer @@ -177,7 +179,7 @@ func (c *compressor) Write(data []byte) (int, error) { if payloadLen > maxPayloadLen { payloadLen = maxPayloadLen } - payload := data[:payloadLen] + payload := packets[:payloadLen] uncompressedLen := payloadLen if _, err := buf.Write(blankHeader); err != nil { @@ -194,11 +196,11 @@ func (c *compressor) Write(data []byte) (int, error) { zCompress(payload, &buf) } - if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } dataLen -= payloadLen - data = data[payloadLen:] + packets = packets[payloadLen:] buf.Reset() } @@ -207,13 +209,13 @@ func (c *compressor) Write(data []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 if debugTrace { - c.mc.cfg.Logger.Print( + mc.cfg.Logger.Print( fmt.Sprintf( "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", - comprLength, uncompressedLen, c.mc.compressSequence)) + comprLength, uncompressedLen, mc.compressSequence)) } // compression header @@ -221,18 +223,18 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err data[1] = byte(0xff & (comprLength >> 8)) data[2] = byte(0xff & (comprLength >> 16)) - data[3] = c.mc.compressSequence + data[3] = mc.compressSequence // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncompressedLen) data[5] = byte(0xff & (uncompressedLen >> 8)) data[6] = byte(0xff & (uncompressedLen >> 16)) - if _, err := c.mc.netConn.Write(data); err != nil { - c.mc.cfg.Logger.Print(err) + if _, err := mc.netConn.Write(data); err != nil { + mc.cfg.Logger.Print(err) return err } - c.mc.compressSequence++ + mc.compressSequence++ return nil } diff --git a/compress_test.go b/compress_test.go index 4193c7247..6d81db335 100644 --- a/compress_test.go +++ b/compress_test.go @@ -13,7 +13,6 @@ import ( "crypto/rand" "fmt" "io" - "net" "testing" ) @@ -23,69 +22,28 @@ func makeRandByteSlice(size int) []byte { return randBytes } -func newMockConn() *mysqlConn { - newConn := &mysqlConn{cfg: NewConfig()} - return newConn -} - -func newMockBuf(data []byte) buffer { - return buffer{ - buf: data, - length: len(data), - } -} - -type dummyConn struct { - buf bytes.Buffer - net.Conn -} - -func (c *dummyConn) Write(data []byte) (int, error) { - return c.buf.Write(data) -} - // compressHelper compresses uncompressedPacket and checks state variables func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { - // get status variables - - cs := mc.compressSequence - - var b dummyConn - mc.netConn = &b - cw := newCompressor(mc) - - n, err := cw.Write(uncompressedPacket) + conn := new(mockConn) + mc.netConn = conn + n, err := mc.writeCompressed(uncompressedPacket) if err != nil { - t.Fatal(err.Error()) + t.Fatal(err) } - if n != len(uncompressedPacket) { t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n) } - - if len(uncompressedPacket) > 0 { - if mc.compressSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence) - } - - } else { - if mc.compressSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressSequence) - } - } - - return b.buf.Bytes() + return conn.written } // uncompressHelper uncompresses compressedPacket and checks state variables func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { - // get status variables - cs := mc.compressSequence - // mocking out buf variable - mc.buf = newMockBuf(compressedPacket) - cr := newCompressor(mc) + conn := new(mockConn) + conn.data = compressedPacket + mc.buf.nc = conn + cr := newDecompressor(mc) uncompressedPacket, err := cr.readNext(expSize) if err != nil { @@ -93,15 +51,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) } } - - if expSize > 0 { - if mc.compressSequence != (cs + 1) { - t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence) - } - } else { - if mc.compressSequence != cs { - t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressSequence) - } + if len(uncompressedPacket) != expSize { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", expSize, len(uncompressedPacket)) } return uncompressedPacket } @@ -141,20 +92,33 @@ func TestRoundtrip(t *testing.T) { {uncompressed: makeRandByteSlice(32768), desc: "32768 rand bytes", }, - {uncompressed: makeRandByteSlice(33000), - desc: "33000 rand bytes", + {uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000), + desc: "100 rand * 10000 repeat bytes", }, } - cSend := newMockConn() - cReceive := newMockConn() + _, cSend := newRWMockConn(0) + cSend.compress = true + _, cReceive := newRWMockConn(0) + cReceive.compress = true for _, test := range tests { s := fmt.Sprintf("Test roundtrip with %s", test.desc) + cSend.resetSequenceNr() + cReceive.resetSequenceNr() uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) if !bytes.Equal(uncompressed, test.uncompressed) { t.Fatalf("%s: roundtrip failed", s) } + + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } } } diff --git a/connection.go b/connection.go index 861890e61..61216f95e 100644 --- a/connection.go +++ b/connection.go @@ -26,7 +26,6 @@ type mysqlConn struct { rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). packetReader packetReader - packetWriter io.Writer cfg *Config connector *connector maxAllowedPacket int diff --git a/connection_test.go b/connection_test.go index 4f129b658..32d342d74 100644 --- a/connection_test.go +++ b/connection_test.go @@ -169,7 +169,6 @@ func TestPingMarkBadConnection(t *testing.T) { netConn: nc, buf: buf, packetReader: &buf, - packetWriter: nc, maxAllowedPacket: defaultMaxAllowedPacket, } @@ -188,7 +187,6 @@ func TestPingErrInvalidConn(t *testing.T) { netConn: nc, buf: buf, packetReader: &buf, - packetWriter: nc, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index f40444450..263859cab 100644 --- a/connector.go +++ b/connector.go @@ -125,7 +125,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) // packet reader and writer in handshake are never compressed mc.packetReader = &mc.buf - mc.packetWriter = mc.netConn // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout @@ -168,10 +167,9 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.compress { - cmpr := newCompressor(mc) - mc.packetReader = cmpr - mc.packetWriter = cmpr + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + mc.compress = true + mc.packetReader = newDecompressor(mc) } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index f56a3b6c7..dfe3820d1 100644 --- a/packets.go +++ b/packets.go @@ -131,7 +131,15 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.packetWriter.Write(data[:4+size]) + var ( + n int + err error + ) + if mc.compress { + n, err = mc.writeCompressed(data[:4+size]) + } else { + n, err = mc.netConn.Write(data[:4+size]) + } if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -278,7 +286,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } if mc.cfg.compress && mc.flags&clientCompress == clientCompress { clientFlags |= clientCompress - mc.compress = true } // To enable TLS / SSL if mc.cfg.TLS != nil { @@ -368,7 +375,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn - mc.packetWriter = mc.netConn } // User [null terminated string] diff --git a/packets_test.go b/packets_test.go index a2fe0e8a8..1d4de7af0 100644 --- a/packets_test.go +++ b/packets_test.go @@ -101,7 +101,6 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { mc := &mysqlConn{ buf: buf, packetReader: &buf, - packetWriter: conn, cfg: connector.cfg, connector: connector, netConn: conn, From 1e785615c800109b2475d4b0486297f6fd281266 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 11:03:36 +0000 Subject: [PATCH 70/88] run tests with compression --- driver_test.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/driver_test.go b/driver_test.go index 0a24808a7..86f162b1e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -146,7 +146,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db, err := sql.Open(driverNameTest, dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() @@ -159,11 +159,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open(driverNameTest, dsn2) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn2, err) } defer db2.Close() } + dsn3 := dsn + "&compress=true" + var db3 *sql.DB + db3, err = sql.Open(driverNameTest, dsn3) + if err != nil { + t.Fatalf("connecting %q: %s", dsn3, err) + } + defer db3.Close() + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { @@ -178,6 +186,11 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { test(dbt2) }) } + t.Run("compress", func(t *testing.T) { + dbt3 := &DBTest{t, db3} + t.Cleanup(cleanup) + test(dbt3) + }) } } From 77d86eca95188317199e8b3ac0c6a039343addd6 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 15 Mar 2024 16:17:38 +0000 Subject: [PATCH 71/88] fix tests --- .github/workflows/test.yml | 2 +- compress.go | 22 ++++++++++++++-------- driver_test.go | 3 ++- packets.go | 22 ++++++++-------------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5b2aa313..ea8a972b3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: my-cnf: | innodb_log_file_size=256MB innodb_buffer_pool_size=512MB - max_allowed_packet=16MB + max_allowed_packet=48MB ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 diff --git a/compress.go b/compress.go index dba660555..2d2bea0c3 100644 --- a/compress.go +++ b/compress.go @@ -13,7 +13,6 @@ import ( "compress/zlib" "fmt" "io" - "os" "sync" ) @@ -124,13 +123,20 @@ func (c *decompressor) uncompressPacket() error { uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) compressionSequence := uint8(header[3]) if debugTrace { - fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n", - comprLength, uncompressedLength, compressionSequence) - } - if compressionSequence != c.mc.compressSequence { - return ErrPktSync - } - c.mc.compressSequence++ + c.mc.cfg.Logger.Print( + fmt.Sprintf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence)) + } + if compressionSequence != c.mc.sequence { + // return ErrPktSync + // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + c.mc.cfg.Logger.Print( + fmt.Sprintf("[warn] unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence)) + } + c.mc.sequence = compressionSequence + 1 + c.mc.compressSequence = c.mc.sequence comprData, err := c.mc.buf.readNext(comprLength) if err != nil { diff --git a/driver_test.go b/driver_test.go index 86f162b1e..3f80e7415 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1277,7 +1277,8 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - const nonDataQueryLen = 28 // length query w/o value + // const nonDataQueryLen = 28 // length query w/o value + compress header + const nonDataQueryLen = 100 inS := in[:maxAllowedPacketSize-nonDataQueryLen] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") diff --git a/packets.go b/packets.go index dfe3820d1..a2d5a7746 100644 --- a/packets.go +++ b/packets.go @@ -35,7 +35,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - // debug.PrintStack() + if debugTrace { + debug.PrintStack() + } mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn @@ -43,22 +45,14 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if debugTrace { - mc.cfg.Logger.Print(fmt.Sprintf("readPacket: packet seq = %d, mc.sequence = %d", data[3], mc.sequence)) - } - // check packet sync [8 bit] - if data[3] != mc.sequence { - if debugTrace { - debug.PrintStack() + if !mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet. + // check packet sync [8 bit] + if data[3] != mc.sequence { + mc.cfg.Logger.Print(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, data[3])) } - mc.Close() - if data[3] > mc.sequence { - return nil, ErrPktSyncMul - } - return nil, ErrPktSync + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long From e1dc55768709cab652f089c01d946679171c4e5c Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 16 Mar 2024 08:31:04 +0000 Subject: [PATCH 72/88] wip --- auth.go | 2 +- compress.go | 22 +++++++++++----------- connection.go | 6 +++--- errors.go | 29 ++++++++++++++++++++++++++--- packets.go | 13 +++++++++++-- 5 files changed, 52 insertions(+), 20 deletions(-) diff --git a/auth.go b/auth.go index 658259b24..74e1bd03e 100644 --- a/auth.go +++ b/auth.go @@ -338,7 +338,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { return authEd25519(authData, mc.cfg.Passwd) default: - mc.cfg.Logger.Print("unknown auth plugin:", plugin) + mc.log("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin } } diff --git a/compress.go b/compress.go index 2d2bea0c3..8b792a0e7 100644 --- a/compress.go +++ b/compress.go @@ -123,17 +123,18 @@ func (c *decompressor) uncompressPacket() error { uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) compressionSequence := uint8(header[3]) if debugTrace { - c.mc.cfg.Logger.Print( - fmt.Sprintf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", - comprLength, uncompressedLength, compressionSequence, c.mc.sequence)) + traceLogger.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence) } if compressionSequence != c.mc.sequence { // return ErrPktSync // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) // before receiving all packets from client. In this case, seqnr is younger than expected. - c.mc.cfg.Logger.Print( - fmt.Sprintf("[warn] unexpected cmpress seq nr: expected %v, got %v", - c.mc.sequence, compressionSequence)) + if debugTrace { + traceLogger.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) + } + c.mc.invalid = true } c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence @@ -218,10 +219,9 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 if debugTrace { - mc.cfg.Logger.Print( - fmt.Sprintf( - "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", - comprLength, uncompressedLen, mc.compressSequence)) + traceLogger.Printf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, mc.compressSequence) } // compression header @@ -237,7 +237,7 @@ func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) err data[6] = byte(0xff & (uncompressedLen >> 16)) if _, err := mc.netConn.Write(data); err != nil { - mc.cfg.Logger.Print(err) + mc.log("writing compressed packet:", err) return err } diff --git a/connection.go b/connection.go index 61216f95e..e62b2a46e 100644 --- a/connection.go +++ b/connection.go @@ -37,6 +37,7 @@ type mysqlConn struct { compressSequence uint8 parseTime bool compress bool + invalid bool // true when the connection is in invalid state and will be closed later. // for context support (Go 1.8+) watching bool @@ -132,7 +133,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -173,7 +173,7 @@ func (mc *mysqlConn) cleanup() { return } if err := mc.netConn.Close(); err != nil { - mc.cfg.Logger.Print(err) + mc.log("closing connection:", err) } mc.clearResult() } @@ -698,5 +698,5 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() + return !mc.closed.Load() && !mc.invalid } diff --git a/errors.go b/errors.go index a9a3060c9..cdbb0dab7 100644 --- a/errors.go +++ b/errors.go @@ -37,18 +37,41 @@ var ( errBadConnNoWrite = errors.New("bad connection") ) -var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) +// traceLogger is used for debug trace log. +var traceLogger *log.Logger + +func init() { + if debugTrace { + traceLogger = log.New(os.Stderr, "[mysql.trace]", log.Lmicroseconds|log.Lshortfile) + } +} + +func trace(format string, v ...any) { + if debugTrace { + traceLogger.Printf(format, v...) + } +} // Logger is used to log critical error messages. type Logger interface { - Print(v ...interface{}) + Print(v ...any) +} + +var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) + +func (mc *mysqlConn) log(v ...any) { + mc.cfg.Logger.Print(v...) +} + +func (mc *mysqlConn) logf(format string, v ...any) { + mc.cfg.Logger.Print(fmt.Sprintf(format, v...)) } // NopLogger is a nop implementation of the Logger interface. type NopLogger struct{} // Print implements Logger interface. -func (nl *NopLogger) Print(_ ...interface{}) {} +func (nl *NopLogger) Print(_ ...any) {} // SetLogger is used to set the default logger for critical errors. // The initial logger is os.Stderr. diff --git a/packets.go b/packets.go index a2d5a7746..8099f97db 100644 --- a/packets.go +++ b/packets.go @@ -28,6 +28,7 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte + var rerr error = nil for { // read packet header data, err := mc.packetReader.readNext(4) @@ -46,10 +47,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if !mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet. + if mc.compress { + // MySQL and MariaDB doesn't check packet nr in compressed packet. + if debugTrace && data[3] != mc.compressSequence { + mc.cfg.Logger.Print + } + mc.compressSequence = data[3]+1 + } else mc.compress { // check packet sync [8 bit] if data[3] != mc.sequence { mc.cfg.Logger.Print(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, data[3])) + mc.invalid = true + rerr = ErrInvalidConn } mc.sequence++ } @@ -117,7 +126,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Write packet if debugTrace { - mc.cfg.Logger.Print(fmt.Sprintf("writePacket: size=%v seq=%v", size, mc.sequence)) + traceLogger.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { From e9f5b2462a19e1c7a778f5cfcffb1bc5ee8c2558 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 25 Mar 2024 08:50:51 +0000 Subject: [PATCH 73/88] fix some errors --- compress.go | 2 +- connection.go | 3 +-- errors.go | 6 ------ packets.go | 43 +++++++++++++++++++++++-------------------- 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/compress.go b/compress.go index 8b792a0e7..54289a615 100644 --- a/compress.go +++ b/compress.go @@ -134,7 +134,7 @@ func (c *decompressor) uncompressPacket() error { traceLogger.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", c.mc.sequence, compressionSequence) } - c.mc.invalid = true + // TODO(methane): report error when the packet is not an error packet. } c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence diff --git a/connection.go b/connection.go index 2c50fdca6..758dba0c5 100644 --- a/connection.go +++ b/connection.go @@ -38,7 +38,6 @@ type mysqlConn struct { compressSequence uint8 parseTime bool compress bool - invalid bool // true when the connection is in invalid state and will be closed later. // for context support (Go 1.8+) watching bool @@ -709,5 +708,5 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() && !mc.invalid + return !mc.closed.Load() } diff --git a/errors.go b/errors.go index 6005bc0fb..f6e4ff4ce 100644 --- a/errors.go +++ b/errors.go @@ -48,12 +48,6 @@ func init() { } } -func trace(format string, v ...any) { - if debugTrace { - traceLogger.Printf(format, v...) - } -} - // Logger is used to log critical error messages. type Logger interface { Print(v ...any) diff --git a/packets.go b/packets.go index cc1f6d00b..c55da8507 100644 --- a/packets.go +++ b/packets.go @@ -17,7 +17,6 @@ import ( "fmt" "io" "math" - "runtime/debug" "strconv" "time" ) @@ -28,7 +27,8 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte - var rerr error = nil + invalid := false + for { // read packet header data, err := mc.packetReader.readNext(4) @@ -36,9 +36,6 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - if debugTrace { - debug.PrintStack() - } mc.log(err) mc.Close() return nil, ErrInvalidConn @@ -46,19 +43,25 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + seqNr := data[3] if mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet. - if debugTrace && data[3] != mc.compressSequence { - mc.cfg.Logger.Print + if debugTrace && seqNr != mc.compressSequence { + mc.logf("[debug] mismatched compression sequence nr: expected: %v, got %v", + mc.compressSequence, seqNr) } - mc.compressSequence = data[3]+1 - } else mc.compress { + mc.compressSequence = seqNr + 1 + } else { // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.cfg.Logger.Print(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, data[3])) - mc.invalid = true - rerr = ErrInvalidConn + if seqNr != mc.sequence { + mc.logf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr) + // For large packets, we stop reading as soon as sync error. + if len(prevData) > 0 { + return nil, ErrPktSyncMul + } + // TODO(methane): report error when the packet is not an error packet. + invalid = true } mc.sequence++ } @@ -72,7 +75,6 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.Close() return nil, ErrInvalidConn } - return prevData, nil } @@ -91,6 +93,10 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { + if invalid && data[0] != iERR { + // return sync error only for regular packet. + return nil, ErrPktSync + } return data, nil } @@ -432,12 +438,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.resetSequenceNr() - data, err := mc.buf.takeSmallBuffer(4 + 1) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite - } + // We do not use mc.buf because this function is used by mc.Close() + // and mc.Close() could be used when some error happend during read. + data := make([]byte, 5) // Add command byte data[4] = command From 1fee4a0fcfd5165dfaa2603ab2c766789d4b4e21 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 2 Dec 2024 17:12:27 +0900 Subject: [PATCH 74/88] remove traceLogger --- compress.go | 6 +++--- driver_test.go | 4 ---- errors.go | 13 ------------- packets.go | 8 ++++---- 4 files changed, 7 insertions(+), 24 deletions(-) diff --git a/compress.go b/compress.go index 54289a615..70c9dba32 100644 --- a/compress.go +++ b/compress.go @@ -123,7 +123,7 @@ func (c *decompressor) uncompressPacket() error { uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) compressionSequence := uint8(header[3]) if debugTrace { - traceLogger.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", comprLength, uncompressedLength, compressionSequence, c.mc.sequence) } if compressionSequence != c.mc.sequence { @@ -131,7 +131,7 @@ func (c *decompressor) uncompressPacket() error { // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) // before receiving all packets from client. In this case, seqnr is younger than expected. if debugTrace { - traceLogger.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", c.mc.sequence, compressionSequence) } // TODO(methane): report error when the packet is not an error packet. @@ -219,7 +219,7 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 if debugTrace { - traceLogger.Printf( + fmt.Printf( "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", comprLength, uncompressedLen, mc.compressSequence) } diff --git a/driver_test.go b/driver_test.go index 264baca91..71e7cdf64 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3557,10 +3557,6 @@ func TestErrorInMultiResult(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } // https://github.com/go-sql-driver/mysql/issues/1361 - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { db, err = sql.Open("mysql", dsn) diff --git a/errors.go b/errors.go index 136dc4158..584617b11 100644 --- a/errors.go +++ b/errors.go @@ -39,24 +39,11 @@ var ( var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime)) -// traceLogger is used for debug trace log. -var traceLogger *log.Logger - -func init() { - if debugTrace { - traceLogger = log.New(os.Stderr, "[mysql.trace]", log.Lmicroseconds|log.Lshortfile) - } -} - // Logger is used to log critical error messages. type Logger interface { Print(v ...any) } -func (mc *mysqlConn) logf(format string, v ...any) { - mc.cfg.Logger.Print(fmt.Sprintf(format, v...)) -} - // NopLogger is a nop implementation of the Logger interface. type NopLogger struct{} diff --git a/packets.go b/packets.go index 4c7de1249..1ebd4d480 100644 --- a/packets.go +++ b/packets.go @@ -49,14 +49,14 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet. if debugTrace && seqNr != mc.compressSequence { - mc.logf("[debug] mismatched compression sequence nr: expected: %v, got %v", + fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", mc.compressSequence, seqNr) } mc.compressSequence = seqNr + 1 } else { // check packet sync [8 bit] if seqNr != mc.sequence { - mc.logf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr) + mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr)) // For large packets, we stop reading as soon as sync error. if len(prevData) > 0 { return nil, ErrPktSyncMul @@ -133,7 +133,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Write packet if debugTrace { - traceLogger.Printf("writePacket: size=%v seq=%v", size, mc.sequence) + fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { @@ -228,7 +228,6 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrNoTLS } } - pos += 2 if len(data) > pos { @@ -303,6 +302,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.TLS != nil { clientFlags |= clientSSL } + if mc.cfg.MultiStatements { clientFlags |= clientMultiStatements } From 3062a2fd1a1e090efff7f8ae29c74945b17ac7af Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 3 Dec 2024 09:46:26 +0900 Subject: [PATCH 75/88] fix README --- README.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/README.md b/README.md index a6b9d7ef0..da4593ccf 100644 --- a/README.md +++ b/README.md @@ -321,15 +321,6 @@ Default: 64*1024*1024 Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. -##### `minCompressLength` - -``` -Type: decimal number -Default: 50 -``` - -Min packet size in bytes to compress, when compression is enabled (see the `compress` parameter). Packets smaller than this will be sent uncompressed. - ##### `multiStatements` ``` From 700db2278df19708215468364af401c2c700aff0 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 4 Dec 2024 21:19:20 +0900 Subject: [PATCH 76/88] compress: better buffer reuse --- benchmark_test.go | 2 +- buffer.go | 6 +++++ compress.go | 55 +++++++++++++++++++--------------------------- compress_test.go | 5 +++-- connection.go | 5 +++-- connection_test.go | 10 ++++----- connector.go | 4 ++-- packets.go | 14 +++--------- packets_test.go | 10 ++++----- 9 files changed, 51 insertions(+), 60 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index cb2a2beac..9aef8c8b8 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -236,7 +236,7 @@ func BenchmarkInterpolation(b *testing.B) { maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf args := []driver.Value{ int64(42424242), diff --git a/buffer.go b/buffer.go index dd82c9313..f71262816 100644 --- a/buffer.go +++ b/buffer.go @@ -155,3 +155,9 @@ func (b *buffer) store(buf []byte) { b.cachedBuf = buf[:cap(buf)] } } + +// writePackets is a proxy function to nc.Write. +// This is used to make the buffer type compatible with compressed I/O. +func (b *buffer) writePackets(packets []byte) (int, error) { + return b.nc.Write(packets) +} diff --git a/compress.go b/compress.go index 70c9dba32..b86e1a064 100644 --- a/compress.go +++ b/compress.go @@ -85,34 +85,28 @@ func zCompress(src []byte, dst io.Writer) error { return nil } -type decompressor struct { - mc *mysqlConn - // read buffer (FIFO). - // We can not reuse already-read buffer until dropping Go 1.20 support. - // It is because of database/mysql's weired behavior. - // See https://github.com/go-sql-driver/mysql/issues/1435 - bytesBuf []byte +type compIO struct { + mc *mysqlConn + buff bytes.Buffer } -func newDecompressor(mc *mysqlConn) *decompressor { - return &decompressor{ +func newCompIO(mc *mysqlConn) *compIO { + return &compIO{ mc: mc, } } -func (c *decompressor) readNext(need int) ([]byte, error) { - for len(c.bytesBuf) < need { +func (c *compIO) readNext(need int) ([]byte, error) { + for c.buff.Len() < need { if err := c.uncompressPacket(); err != nil { return nil, err } } - - data := c.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf - c.bytesBuf = c.bytesBuf[need:] - return data, nil + data := c.buff.Next(need) + return data[:need:need], nil // prevent caller writes into c.buff } -func (c *decompressor) uncompressPacket() error { +func (c *compIO) uncompressPacket() error { header, err := c.mc.buf.readNext(7) // size of compressed header if err != nil { return err @@ -147,19 +141,14 @@ func (c *decompressor) uncompressPacket() error { // if payload is uncompressed, its length will be specified as zero, and its // true length is contained in comprLength if uncompressedLength == 0 { - c.bytesBuf = append(c.bytesBuf, comprData...) + c.buff.Write(comprData) return nil } // use existing capacity in bytesBuf if possible - offset := len(c.bytesBuf) - if cap(c.bytesBuf)-offset < uncompressedLength { - old := c.bytesBuf - c.bytesBuf = make([]byte, offset, offset+uncompressedLength) - copy(c.bytesBuf, old) - } - - lenRead, err := zDecompress(comprData, c.bytesBuf[offset:offset+uncompressedLength]) + c.buff.Grow(uncompressedLength) + dec := c.buff.AvailableBuffer()[:uncompressedLength] + lenRead, err := zDecompress(comprData, dec) if err != nil { return err } @@ -167,21 +156,22 @@ func (c *decompressor) uncompressPacket() error { return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", uncompressedLength, lenRead) } - c.bytesBuf = c.bytesBuf[:offset+uncompressedLength] + c.buff.Write(dec) // fast copy. See bytes.Buffer.AvailableBuffer() doc. return nil } const maxPayloadLen = maxPacketSize - 4 -// writeCompressed sends one or some packets with compression. +// writePackets sends one or some packets with compression. // Use this instead of mc.netConn.Write() when mc.compress is true. -func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { +func (c *compIO) writePackets(packets []byte) (int, error) { totalBytes := len(packets) dataLen := len(packets) blankHeader := make([]byte, 7) - var buf bytes.Buffer + buf := &c.buff for dataLen > 0 { + buf.Reset() payloadLen := dataLen if payloadLen > maxPayloadLen { payloadLen = maxPayloadLen @@ -200,10 +190,10 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { } uncompressedLen = 0 } else { - zCompress(payload, &buf) + zCompress(payload, buf) } - if err := mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } dataLen -= payloadLen @@ -216,7 +206,8 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { + mc := c.mc comprLength := len(data) - 7 if debugTrace { fmt.Printf( diff --git a/compress_test.go b/compress_test.go index 6d81db335..56e93dbfb 100644 --- a/compress_test.go +++ b/compress_test.go @@ -26,8 +26,9 @@ func makeRandByteSlice(size int) []byte { func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { conn := new(mockConn) mc.netConn = conn + comp := newCompIO(mc) - n, err := mc.writeCompressed(uncompressedPacket) + n, err := comp.writePackets(uncompressedPacket) if err != nil { t.Fatal(err) } @@ -43,7 +44,7 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS conn := new(mockConn) conn.data = compressedPacket mc.buf.nc = conn - cr := newDecompressor(mc) + cr := newCompIO(mc) uncompressedPacket, err := cr.readNext(expSize) if err != nil { diff --git a/connection.go b/connection.go index af622cd6c..6c026ab42 100644 --- a/connection.go +++ b/connection.go @@ -28,7 +28,7 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). - packetReader packetReader + packetRW packetIO cfg *Config connector *connector maxAllowedPacket int @@ -65,8 +65,9 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } -type packetReader interface { +type packetIO interface { readNext(need int) ([]byte, error) + writePackets(data []byte) (int, error) } func (mc *mysqlConn) resetSequenceNr() { diff --git a/connection_test.go b/connection_test.go index b55cff52b..0342ead77 100644 --- a/connection_test.go +++ b/connection_test.go @@ -25,7 +25,7 @@ func TestInterpolateParams(t *testing.T) { InterpolateParams: true, }, } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -73,7 +73,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { InterpolateParams: true, }, } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -92,7 +92,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { }, } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` @@ -168,7 +168,7 @@ func TestPingMarkBadConnection(t *testing.T) { mc := &mysqlConn{ netConn: nc, buf: buf, - packetReader: &buf, + packetRW: &buf, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -188,7 +188,7 @@ func TestPingErrInvalidConn(t *testing.T) { mc := &mysqlConn{ netConn: nc, buf: buf, - packetReader: &buf, + packetRW: &buf, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index cb834efde..fc8267508 100644 --- a/connector.go +++ b/connector.go @@ -129,7 +129,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) // packet reader and writer in handshake are never compressed - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout @@ -174,7 +174,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if mc.cfg.compress && mc.flags&clientCompress == clientCompress { mc.compress = true - mc.packetReader = newDecompressor(mc) + mc.packetRW = newCompIO(mc) } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index 1ebd4d480..42fda92a4 100644 --- a/packets.go +++ b/packets.go @@ -32,7 +32,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { for { // read packet header - data, err := mc.packetReader.readNext(4) + data, err := mc.packetRW.readNext(4) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -80,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.packetReader.readNext(pktLen) + data, err = mc.packetRW.readNext(pktLen) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -143,15 +143,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - var ( - n int - err error - ) - if mc.compress { - n, err = mc.writeCompressed(data[:4+size]) - } else { - n, err = mc.netConn.Write(data[:4+size]) - } + n, err := mc.packetRW.writePackets(data[:4+size]) if err != nil { mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { diff --git a/packets_test.go b/packets_test.go index 1d4de7af0..b68a6ce4b 100644 --- a/packets_test.go +++ b/packets_test.go @@ -100,7 +100,7 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { buf := newBuffer(conn) mc := &mysqlConn{ buf: buf, - packetReader: &buf, + packetRW: &buf, cfg: connector.cfg, connector: connector, netConn: conn, @@ -116,7 +116,7 @@ func TestReadPacketSingleByte(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 @@ -169,7 +169,7 @@ func TestReadPacketSplit(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 @@ -277,7 +277,7 @@ func TestReadPacketFail(t *testing.T) { closech: make(chan struct{}), cfg: NewConfig(), } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} @@ -323,7 +323,7 @@ func TestRegression801(t *testing.T) { sequence: 42, closech: make(chan struct{}), } - mc.packetReader = &mc.buf + mc.packetRW = &mc.buf conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, From 3f89621e8c961423da43d1f811dd9248413284a0 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 8 Dec 2024 15:23:40 +0900 Subject: [PATCH 77/88] some refactoring --- compress.go | 49 +++++++++++++++++++++---------------------------- packets.go | 16 +++------------- utils.go | 12 ++++++++++++ 3 files changed, 36 insertions(+), 41 deletions(-) diff --git a/compress.go b/compress.go index b86e1a064..84c476783 100644 --- a/compress.go +++ b/compress.go @@ -98,7 +98,7 @@ func newCompIO(mc *mysqlConn) *compIO { func (c *compIO) readNext(need int) ([]byte, error) { for c.buff.Len() < need { - if err := c.uncompressPacket(); err != nil { + if err := c.readCompressedPacket(); err != nil { return nil, err } } @@ -106,16 +106,17 @@ func (c *compIO) readNext(need int) ([]byte, error) { return data[:need:need], nil // prevent caller writes into c.buff } -func (c *compIO) uncompressPacket() error { +func (c *compIO) readCompressedPacket() error { header, err := c.mc.buf.readNext(7) // size of compressed header if err != nil { return err } + _ = header[6] // bounds check hint to compiler; guaranteed by readNext // compressed header structure - comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + comprLength := getUint24(header[0:3]) compressionSequence := uint8(header[3]) + uncompressedLength := getUint24(header[4:7]) if debugTrace { fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", comprLength, uncompressedLength, compressionSequence, c.mc.sequence) @@ -171,34 +172,34 @@ func (c *compIO) writePackets(packets []byte) (int, error) { buf := &c.buff for dataLen > 0 { - buf.Reset() - payloadLen := dataLen - if payloadLen > maxPayloadLen { - payloadLen = maxPayloadLen - } + payloadLen := min(maxPayloadLen, dataLen) payload := packets[:payloadLen] uncompressedLen := payloadLen - if _, err := buf.Write(blankHeader); err != nil { - return 0, err - } + buf.Reset() + buf.Write(blankHeader) // Buffer.Write() never returns error // If payload is less than minCompressLength, don't compress. if uncompressedLen < minCompressLength { - if _, err := buf.Write(payload); err != nil { - return 0, err - } + buf.Write(payload) uncompressedLen = 0 } else { zCompress(payload, buf) + // do not compress if compressed data is larger than uncompressed data + // I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes. + if buf.Len() > uncompressedLen { + buf.Reset() + buf.Write(blankHeader) + buf.Write(payload) + uncompressedLen = 0 + } } - if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := c.mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } dataLen -= payloadLen packets = packets[payloadLen:] - buf.Reset() } return totalBytes, nil @@ -206,8 +207,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { - mc := c.mc +func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { comprLength := len(data) - 7 if debugTrace { fmt.Printf( @@ -216,16 +216,9 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { } // compression header - data[0] = byte(0xff & comprLength) - data[1] = byte(0xff & (comprLength >> 8)) - data[2] = byte(0xff & (comprLength >> 16)) - + putUint24(data[0:3], comprLength) data[3] = mc.compressSequence - - // this value is never greater than maxPayloadLength - data[4] = byte(0xff & uncompressedLen) - data[5] = byte(0xff & (uncompressedLen >> 8)) - data[6] = byte(0xff & (uncompressedLen >> 16)) + putUint24(data[4:7], uncompressedLen) if _, err := mc.netConn.Write(data); err != nil { mc.log("writing compressed packet:", err) diff --git a/packets.go b/packets.go index 42fda92a4..5e5832f1c 100644 --- a/packets.go +++ b/packets.go @@ -43,7 +43,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // packet length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + pktLen := getUint24(data[:3]) seqNr := data[3] if mc.compress { @@ -117,18 +117,8 @@ func (mc *mysqlConn) writePacket(data []byte) error { } for { - var size int - if pktLen >= maxPacketSize { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - size = maxPacketSize - } else { - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - size = pktLen - } + size := min(maxPacketSize, pktLen) + putUint24(data[:3], size) data[3] = mc.sequence // Write packet diff --git a/utils.go b/utils.go index cda24fe74..d902f3b60 100644 --- a/utils.go +++ b/utils.go @@ -490,6 +490,18 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { * Convert from and to bytes * ******************************************************************************/ +// 24bit integer: used for packet headers. + +func putUint24(data []byte, n int) { + data[2] = byte(n >> 16) + data[1] = byte(n >> 8) + data[0] = byte(n) +} + +func getUint24(data []byte) int { + return int(data[2])<<16 | int(data[1])<<8 | int(data[0]) +} + func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), From e8b96f2b5d29bb30a70928871cd4ff24457f4452 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 11 Dec 2024 16:47:10 +0900 Subject: [PATCH 78/88] move writeTimeout from mysqlConn to buffer --- buffer.go | 18 ++++++++++++------ compress.go | 7 ++++--- connection.go | 1 - connector.go | 4 ++-- packets.go | 7 ------- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/buffer.go b/buffer.go index f71262816..935b63891 100644 --- a/buffer.go +++ b/buffer.go @@ -23,10 +23,11 @@ const maxCachedBufSize = 256 * 1024 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte // read buffer. - cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. - nc net.Conn - timeout time.Duration + buf []byte // read buffer. + cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. + nc net.Conn + readTimeout time.Duration + writeTimeout time.Duration } // newBuffer allocates and returns a new buffer. @@ -64,8 +65,8 @@ func (b *buffer) fill(need int) error { copy(dest[:n], b.buf) for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + if b.readTimeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.readTimeout)); err != nil { return err } } @@ -159,5 +160,10 @@ func (b *buffer) store(buf []byte) { // writePackets is a proxy function to nc.Write. // This is used to make the buffer type compatible with compressed I/O. func (b *buffer) writePackets(packets []byte) (int, error) { + if b.writeTimeout > 0 { + if err := b.nc.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil { + return 0, err + } + } return b.nc.Write(packets) } diff --git a/compress.go b/compress.go index 84c476783..937d0cf09 100644 --- a/compress.go +++ b/compress.go @@ -195,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) { } } - if err := c.mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } dataLen -= payloadLen @@ -207,7 +207,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { + mc := c.mc comprLength := len(data) - 7 if debugTrace { fmt.Printf( @@ -220,7 +221,7 @@ func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) err data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - if _, err := mc.netConn.Write(data); err != nil { + if _, err := mc.buf.writePackets(data); err != nil { mc.log("writing compressed packet:", err) return err } diff --git a/connection.go b/connection.go index 6c026ab42..d95369ad5 100644 --- a/connection.go +++ b/connection.go @@ -33,7 +33,6 @@ type mysqlConn struct { connector *connector maxAllowedPacket int maxWriteSize int - writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 diff --git a/connector.go b/connector.go index fc8267508..3d37c22f4 100644 --- a/connector.go +++ b/connector.go @@ -132,8 +132,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.packetRW = &mc.buf // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout + mc.buf.readTimeout = mc.cfg.ReadTimeout + mc.buf.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() diff --git a/packets.go b/packets.go index 5e5832f1c..b5c8d9dd1 100644 --- a/packets.go +++ b/packets.go @@ -125,13 +125,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { if debugTrace { fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() - mc.log(err) - return err - } - } n, err := mc.packetRW.writePackets(data[:4+size]) if err != nil { From 9fad4c0dcfec9bfa54d19f83c302e50b8a0ecb72 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 12 Dec 2024 18:01:25 +0900 Subject: [PATCH 79/88] refactoring --- benchmark_test.go | 3 +-- buffer.go | 42 +++++++++++--------------------------- compress.go | 14 ++++++------- compress_test.go | 50 ++++++++++++++++++++-------------------------- connection.go | 23 +++++++++++++++++---- connection_test.go | 21 +++++++------------ connector.go | 10 ++-------- packets.go | 27 ++++++++++++++++++------- packets_test.go | 28 +++++++++++++------------- 9 files changed, 104 insertions(+), 114 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 9aef8c8b8..5c9a046b5 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -234,9 +234,8 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), + buf: newBuffer(), } - mc.packetRW = &mc.buf args := []driver.Value{ int64(42424242), diff --git a/buffer.go b/buffer.go index 935b63891..49b2523c6 100644 --- a/buffer.go +++ b/buffer.go @@ -10,31 +10,30 @@ package mysql import ( "io" - "net" - "time" ) const defaultBufSize = 4096 const maxCachedBufSize = 256 * 1024 +// readwriteFunc is a function that compatible with io.Reader and io.Writer. +// We use this function type instead of io.ReadWriter because we want to +// just pass mc.readWithTimeout or mc.writeWithTimeout functions. +type readwriteFunc func([]byte) (int, error) + // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte // read buffer. - cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. - nc net.Conn - readTimeout time.Duration - writeTimeout time.Duration + buf []byte // read buffer. + cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. } // newBuffer allocates and returns a new buffer. -func newBuffer(nc net.Conn) buffer { +func newBuffer() buffer { return buffer{ cachedBuf: make([]byte, defaultBufSize), - nc: nc, } } @@ -44,7 +43,7 @@ func (b *buffer) busy() bool { } // fill reads into the read buffer until at least _need_ bytes are in it. -func (b *buffer) fill(need int) error { +func (b *buffer) fill(need int, r readwriteFunc) error { // we'll move the contents of the current buffer to dest before filling it. dest := b.cachedBuf @@ -65,13 +64,7 @@ func (b *buffer) fill(need int) error { copy(dest[:n], b.buf) for { - if b.readTimeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.readTimeout)); err != nil { - return err - } - } - - nn, err := b.nc.Read(dest[n:]) + nn, err := r(dest[n:]) n += nn if err == nil && n < need { @@ -93,10 +86,10 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { +func (b *buffer) readNext(need int, r readwriteFunc) ([]byte, error) { if len(b.buf) < need { // refill - if err := b.fill(need); err != nil { + if err := b.fill(need, r); err != nil { return nil, err } } @@ -156,14 +149,3 @@ func (b *buffer) store(buf []byte) { b.cachedBuf = buf[:cap(buf)] } } - -// writePackets is a proxy function to nc.Write. -// This is used to make the buffer type compatible with compressed I/O. -func (b *buffer) writePackets(packets []byte) (int, error) { - if b.writeTimeout > 0 { - if err := b.nc.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil { - return 0, err - } - } - return b.nc.Write(packets) -} diff --git a/compress.go b/compress.go index 937d0cf09..9b424f99b 100644 --- a/compress.go +++ b/compress.go @@ -47,7 +47,7 @@ func zDecompress(src, dst []byte) (int, error) { } } else { zr = a.(io.ReadCloser) - if zr.(zlib.Resetter).Reset(br, nil); err != nil { + if err := zr.(zlib.Resetter).Reset(br, nil); err != nil { return 0, err } } @@ -96,9 +96,9 @@ func newCompIO(mc *mysqlConn) *compIO { } } -func (c *compIO) readNext(need int) ([]byte, error) { +func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) { for c.buff.Len() < need { - if err := c.readCompressedPacket(); err != nil { + if err := c.readCompressedPacket(r); err != nil { return nil, err } } @@ -106,8 +106,8 @@ func (c *compIO) readNext(need int) ([]byte, error) { return data[:need:need], nil // prevent caller writes into c.buff } -func (c *compIO) readCompressedPacket() error { - header, err := c.mc.buf.readNext(7) // size of compressed header +func (c *compIO) readCompressedPacket(r readwriteFunc) error { + header, err := c.mc.buf.readNext(7, r) // size of compressed header if err != nil { return err } @@ -134,7 +134,7 @@ func (c *compIO) readCompressedPacket() error { c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence - comprData, err := c.mc.buf.readNext(comprLength) + comprData, err := c.mc.buf.readNext(comprLength, r) if err != nil { return err } @@ -221,7 +221,7 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - if _, err := mc.buf.writePackets(data); err != nil { + if _, err := mc.writeWithTimeout(data); err != nil { mc.log("writing compressed packet:", err) return err } diff --git a/compress_test.go b/compress_test.go index 56e93dbfb..0d217efc0 100644 --- a/compress_test.go +++ b/compress_test.go @@ -11,7 +11,6 @@ package mysql import ( "bytes" "crypto/rand" - "fmt" "io" "testing" ) @@ -26,15 +25,12 @@ func makeRandByteSlice(size int) []byte { func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { conn := new(mockConn) mc.netConn = conn - comp := newCompIO(mc) - n, err := comp.writePackets(uncompressedPacket) + err := mc.writePacket(append(make([]byte, 4), uncompressedPacket...)) if err != nil { t.Fatal(err) } - if n != len(uncompressedPacket) { - t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n) - } + return conn.written } @@ -43,10 +39,9 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS // mocking out buf variable conn := new(mockConn) conn.data = compressedPacket - mc.buf.nc = conn - cr := newCompIO(mc) + mc.netConn = conn - uncompressedPacket, err := cr.readNext(expSize) + uncompressedPacket, err := mc.readPacket() if err != nil { if err != io.EOF { t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) @@ -72,8 +67,6 @@ func TestRoundtrip(t *testing.T) { }{ {uncompressed: []byte("a"), desc: "a"}, - {uncompressed: []byte{0}, - desc: "0 byte"}, {uncompressed: []byte("hello world"), desc: "hello world"}, {uncompressed: make([]byte, 100), @@ -82,8 +75,6 @@ func TestRoundtrip(t *testing.T) { desc: "32768 bytes"}, {uncompressed: make([]byte, 330000), desc: "33000 bytes"}, - {uncompressed: make([]byte, 0), - desc: "nothing"}, {uncompressed: makeRandByteSlice(10), desc: "10 rand bytes", }, @@ -100,26 +91,29 @@ func TestRoundtrip(t *testing.T) { _, cSend := newRWMockConn(0) cSend.compress = true + cSend.compIO = newCompIO(cSend) _, cReceive := newRWMockConn(0) cReceive.compress = true + cReceive.compIO = newCompIO(cReceive) for _, test := range tests { - s := fmt.Sprintf("Test roundtrip with %s", test.desc) - cSend.resetSequenceNr() - cReceive.resetSequenceNr() + t.Run(test.desc, func(t *testing.T) { + cSend.resetSequenceNr() + cReceive.resetSequenceNr() - uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) - if !bytes.Equal(uncompressed, test.uncompressed) { - t.Fatalf("%s: roundtrip failed", s) - } + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Fatalf("roundtrip failed") + } - if cSend.sequence != cReceive.sequence { - t.Errorf("inconsistent sequence number: send=%v recv=%v", - cSend.sequence, cReceive.sequence) - } - if cSend.compressSequence != cReceive.compressSequence { - t.Errorf("inconsistent compress sequence number: send=%v recv=%v", - cSend.compressSequence, cReceive.compressSequence) - } + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } + }) } } diff --git a/connection.go b/connection.go index d95369ad5..ccf86b834 100644 --- a/connection.go +++ b/connection.go @@ -28,7 +28,7 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). - packetRW packetIO + compIO *compIO cfg *Config connector *connector maxAllowedPacket int @@ -64,9 +64,24 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } -type packetIO interface { - readNext(need int) ([]byte, error) - writePackets(data []byte) (int, error) +func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { + to := mc.cfg.ReadTimeout + if to > 0 { + if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Read(b) +} + +func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { + to := mc.cfg.WriteTimeout + if to > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Write(b) } func (mc *mysqlConn) resetSequenceNr() { diff --git a/connection_test.go b/connection_test.go index 0342ead77..696db758d 100644 --- a/connection_test.go +++ b/connection_test.go @@ -19,13 +19,12 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.packetRW = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -40,7 +39,7 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsJSONRawMessage(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -67,13 +66,12 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.packetRW = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -85,15 +83,13 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.packetRW = &mc.buf - q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` if err != driver.ErrSkip { @@ -103,7 +99,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -164,11 +160,10 @@ func TestCleanCancel(t *testing.T) { func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} - buf := newBuffer(nc) + buf := newBuffer() mc := &mysqlConn{ netConn: nc, buf: buf, - packetRW: &buf, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -184,11 +179,9 @@ func TestPingMarkBadConnection(t *testing.T) { func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} - buf := newBuffer(nc) mc := &mysqlConn{ netConn: nc, - buf: buf, - packetRW: &buf, + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index 3d37c22f4..a4f3655ef 100644 --- a/connector.go +++ b/connector.go @@ -127,13 +127,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } defer mc.finish() - mc.buf = newBuffer(mc.netConn) - // packet reader and writer in handshake are never compressed - mc.packetRW = &mc.buf - - // Set I/O timeouts - mc.buf.readTimeout = mc.cfg.ReadTimeout - mc.buf.writeTimeout = mc.cfg.WriteTimeout + mc.buf = newBuffer() // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() @@ -174,7 +168,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if mc.cfg.compress && mc.flags&clientCompress == clientCompress { mc.compress = true - mc.packetRW = newCompIO(mc) + mc.compIO = newCompIO(mc) } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index b5c8d9dd1..5914c1cb8 100644 --- a/packets.go +++ b/packets.go @@ -30,9 +30,14 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalid := false + readFunc := mc.buf.readNext + if mc.compress { + readFunc = mc.compIO.readNext + } + for { // read packet header - data, err := mc.packetRW.readNext(4) + data, err := readFunc(4, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -59,6 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr)) // For large packets, we stop reading as soon as sync error. if len(prevData) > 0 { + mc.close() return nil, ErrPktSyncMul } // TODO(methane): report error when the packet is not an error packet. @@ -80,7 +86,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.packetRW.readNext(pktLen) + data, err = readFunc(pktLen, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -94,9 +100,13 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { - if invalid && data[0] != iERR { + if invalid { + mc.close() // return sync error only for regular packet. - return nil, ErrPktSync + // error packets may have wrong sequence number. + if data[0] != iERR { + return nil, ErrPktSync + } } return data, nil } @@ -111,11 +121,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + writeFunc := mc.writeWithTimeout + if mc.compress { + writeFunc = mc.compIO.writePackets + } + for { size := min(maxPacketSize, pktLen) putUint24(data[:3], size) @@ -126,7 +140,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - n, err := mc.packetRW.writePackets(data[:4+size]) + n, err := writeFunc(data[:4+size]) if err != nil { mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { @@ -363,7 +377,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn } // User [null terminated string] diff --git a/packets_test.go b/packets_test.go index b68a6ce4b..694b0564c 100644 --- a/packets_test.go +++ b/packets_test.go @@ -97,10 +97,8 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) - buf := newBuffer(conn) mc := &mysqlConn{ - buf: buf, - packetRW: &buf, + buf: newBuffer(), cfg: connector.cfg, connector: connector, netConn: conn, @@ -114,9 +112,10 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } - mc.packetRW = &mc.buf conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 @@ -151,7 +150,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { } { conn, mc := newRWMockConn(testCase.ClientSequenceID) - conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0x22} _, err := mc.readPacket() if err != testCase.ExpectedErr { t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) @@ -167,9 +166,10 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } - mc.packetRW = &mc.buf data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 @@ -273,11 +273,11 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), closech: make(chan struct{}), cfg: NewConfig(), } - mc.packetRW = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} @@ -290,7 +290,7 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read header conn.closed = true @@ -303,7 +303,7 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read body conn.maxReads = 1 @@ -318,12 +318,12 @@ func TestReadPacketFail(t *testing.T) { func TestRegression801(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), cfg: new(Config), sequence: 42, closech: make(chan struct{}), } - mc.packetRW = &mc.buf conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, From 44d7dee3e80f8c03e9b24ffbda0ba962e2052a27 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 12 Dec 2024 19:32:58 +0900 Subject: [PATCH 80/88] fix tests --- compress.go | 4 ++++ connection.go | 1 + driver_test.go | 8 ++++++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/compress.go b/compress.go index 9b424f99b..9a0db91ec 100644 --- a/compress.go +++ b/compress.go @@ -96,6 +96,10 @@ func newCompIO(mc *mysqlConn) *compIO { } } +func (c *compIO) reset() { + c.buff.Reset() +} + func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) { for c.buff.Len() < need { if err := c.readCompressedPacket(r); err != nil { diff --git a/connection.go b/connection.go index ccf86b834..558930d3f 100644 --- a/connection.go +++ b/connection.go @@ -97,6 +97,7 @@ func (mc *mysqlConn) syncSequenceNr() { // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 if mc.compress { mc.sequence = mc.compressSequence + mc.compIO.reset() } } diff --git a/driver_test.go b/driver_test.go index 71e7cdf64..4753b2d08 100644 --- a/driver_test.go +++ b/driver_test.go @@ -971,12 +971,16 @@ func TestDateTime(t *testing.T) { var err error rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) if err == nil { - rows.Scan(µsecsSupported) + if rows.Next() { + rows.Scan(µsecsSupported) + } rows.Close() } rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) if err == nil { - rows.Scan(&zeroDateSupported) + if rows.Next() { + rows.Scan(&zeroDateSupported) + } rows.Close() } for _, setups := range testcases { From 25cf587f412f6201d66f6e0451333ecbba3851a9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 11:55:32 +0900 Subject: [PATCH 81/88] sequenceNr -> sequence --- compress_test.go | 4 ++-- connection.go | 6 +++--- infile.go | 2 +- packets.go | 16 ++++++++-------- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/compress_test.go b/compress_test.go index 0d217efc0..c4bc927c9 100644 --- a/compress_test.go +++ b/compress_test.go @@ -98,8 +98,8 @@ func TestRoundtrip(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - cSend.resetSequenceNr() - cReceive.resetSequenceNr() + cSend.resetSequence() + cReceive.resetSequence() uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) if !bytes.Equal(uncompressed, test.uncompressed) { diff --git a/connection.go b/connection.go index 558930d3f..db5c981b0 100644 --- a/connection.go +++ b/connection.go @@ -84,13 +84,13 @@ func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { return mc.netConn.Write(b) } -func (mc *mysqlConn) resetSequenceNr() { +func (mc *mysqlConn) resetSequence() { mc.sequence = 0 mc.compressSequence = 0 } -// syncSequenceNr must be called when finished writing some packet and before start reading. -func (mc *mysqlConn) syncSequenceNr() { +// syncSequence must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequence() { // Syncs compressionSequence to sequence. // This is not documented but done in `net_flush()` in MySQL and MariaDB. // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 diff --git a/infile.go b/infile.go index 42f2d72cd..555ef71ad 100644 --- a/infile.go +++ b/infile.go @@ -172,7 +172,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } - mc.conn().syncSequenceNr() + mc.conn().syncSequence() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index 5914c1cb8..df9c61336 100644 --- a/packets.go +++ b/packets.go @@ -431,7 +431,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.resetSequenceNr() + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { @@ -447,7 +447,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.resetSequenceNr() + mc.resetSequence() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -463,13 +463,13 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Send CMD packet err = mc.writePacket(data) - mc.syncSequenceNr() + mc.syncSequence() return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.resetSequenceNr() + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -952,7 +952,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.resetSequenceNr() + stmt.mc.resetSequence() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -976,7 +976,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Reset Packet Sequence - stmt.mc.resetSequenceNr() + stmt.mc.resetSequence() return nil } @@ -1001,7 +1001,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.resetSequenceNr() + mc.resetSequence() var data []byte var err error @@ -1218,7 +1218,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } err = mc.writePacket(data) - mc.syncSequenceNr() + mc.syncSequence() return err } From 0bc8145ec232abe0c086d019bc0c45b1584ba64e Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 12:08:44 +0900 Subject: [PATCH 82/88] allow returning ErrBadConn on compression --- compress.go | 16 +++++++++------- connection_test.go | 5 +---- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/compress.go b/compress.go index 9a0db91ec..df7414c51 100644 --- a/compress.go +++ b/compress.go @@ -199,8 +199,10 @@ func (c *compIO) writePackets(packets []byte) (int, error) { } } - if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { - return 0, err + if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + // To allow returning ErrBadConn when sending really 0 bytes, we sum + // up compressed bytes that is returned by underlying Write(). + return totalBytes - len(packets) + n, err } dataLen -= payloadLen packets = packets[payloadLen:] @@ -211,7 +213,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) { mc := c.mc comprLength := len(data) - 7 if debugTrace { @@ -225,11 +227,11 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - if _, err := mc.writeWithTimeout(data); err != nil { - mc.log("writing compressed packet:", err) - return err + if n, err := mc.writeWithTimeout(data); err != nil { + // mc.log("writing compressed packet:", err) + return n, err } mc.compressSequence++ - return nil + return n, nil } diff --git a/connection_test.go b/connection_test.go index 696db758d..f7740898e 100644 --- a/connection_test.go +++ b/connection_test.go @@ -159,11 +159,9 @@ func TestCleanCancel(t *testing.T) { func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} - - buf := newBuffer() mc := &mysqlConn{ netConn: nc, - buf: buf, + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -178,7 +176,6 @@ func TestPingMarkBadConnection(t *testing.T) { func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} - mc := &mysqlConn{ netConn: nc, buf: newBuffer(), From d2ecd5755e1cf7f9b1f0fbe4f465c5b5502e8dc6 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 16:28:21 +0900 Subject: [PATCH 83/88] simplify --- buffer.go | 12 ++++----- compress.go | 73 +++++++++++++++++++---------------------------------- const.go | 3 +-- packets.go | 22 ++++++++-------- 4 files changed, 44 insertions(+), 66 deletions(-) diff --git a/buffer.go b/buffer.go index 49b2523c6..a65324315 100644 --- a/buffer.go +++ b/buffer.go @@ -15,10 +15,10 @@ import ( const defaultBufSize = 4096 const maxCachedBufSize = 256 * 1024 -// readwriteFunc is a function that compatible with io.Reader and io.Writer. -// We use this function type instead of io.ReadWriter because we want to -// just pass mc.readWithTimeout or mc.writeWithTimeout functions. -type readwriteFunc func([]byte) (int, error) +// readerFunc is a function that compatible with io.Reader. +// We use this function type instead of io.Reader because we want to +// just pass mc.readWithTimeout. +type readerFunc func([]byte) (int, error) // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. @@ -43,7 +43,7 @@ func (b *buffer) busy() bool { } // fill reads into the read buffer until at least _need_ bytes are in it. -func (b *buffer) fill(need int, r readwriteFunc) error { +func (b *buffer) fill(need int, r readerFunc) error { // we'll move the contents of the current buffer to dest before filling it. dest := b.cachedBuf @@ -86,7 +86,7 @@ func (b *buffer) fill(need int, r readwriteFunc) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int, r readwriteFunc) ([]byte, error) { +func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { if len(b.buf) < need { // refill if err := b.fill(need, r); err != nil { diff --git a/compress.go b/compress.go index df7414c51..ea6524855 100644 --- a/compress.go +++ b/compress.go @@ -36,7 +36,7 @@ func init() { } } -func zDecompress(src, dst []byte) (int, error) { +func zDecompress(src []byte, dst *bytes.Buffer) (int, error) { br := bytes.NewReader(src) var zr io.ReadCloser var err error @@ -51,27 +51,11 @@ func zDecompress(src, dst []byte) (int, error) { return 0, err } } - defer func() { - zr.Close() - zrPool.Put(zr) - }() - lenRead := 0 - size := len(dst) - - for lenRead < size { - n, err := zr.Read(dst[lenRead:]) - lenRead += n - - if err == io.EOF { - if lenRead < size { - return lenRead, io.ErrUnexpectedEOF - } - } else if err != nil { - return lenRead, err - } - } - return lenRead, nil + n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again. + err = zr.Close() // zr.Close() may return chuecksum error. + zrPool.Put(zr) + return int(n), err } func zCompress(src []byte, dst io.Writer) error { @@ -100,7 +84,7 @@ func (c *compIO) reset() { c.buff.Reset() } -func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) { +func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { for c.buff.Len() < need { if err := c.readCompressedPacket(r); err != nil { return nil, err @@ -110,7 +94,7 @@ func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) { return data[:need:need], nil // prevent caller writes into c.buff } -func (c *compIO) readCompressedPacket(r readwriteFunc) error { +func (c *compIO) readCompressedPacket(r readerFunc) error { header, err := c.mc.buf.readNext(7, r) // size of compressed header if err != nil { return err @@ -121,19 +105,17 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error { comprLength := getUint24(header[0:3]) compressionSequence := uint8(header[3]) uncompressedLength := getUint24(header[4:7]) - if debugTrace { + if debug { fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", comprLength, uncompressedLength, compressionSequence, c.mc.sequence) } - if compressionSequence != c.mc.sequence { - // return ErrPktSync - // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) - // before receiving all packets from client. In this case, seqnr is younger than expected. - if debugTrace { - fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", - c.mc.sequence, compressionSequence) - } - // TODO(methane): report error when the packet is not an error packet. + // Do not return ErrPktSync here. + // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it. + if debug && compressionSequence != c.mc.sequence { + fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) } c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence @@ -152,31 +134,29 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error { // use existing capacity in bytesBuf if possible c.buff.Grow(uncompressedLength) - dec := c.buff.AvailableBuffer()[:uncompressedLength] - lenRead, err := zDecompress(comprData, dec) + nread, err := zDecompress(comprData, &c.buff) if err != nil { return err } - if lenRead != uncompressedLength { + if nread != uncompressedLength { return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", - uncompressedLength, lenRead) + uncompressedLength, nread) } - c.buff.Write(dec) // fast copy. See bytes.Buffer.AvailableBuffer() doc. return nil } +const minCompressLength = 150 const maxPayloadLen = maxPacketSize - 4 // writePackets sends one or some packets with compression. // Use this instead of mc.netConn.Write() when mc.compress is true. func (c *compIO) writePackets(packets []byte) (int, error) { totalBytes := len(packets) - dataLen := len(packets) blankHeader := make([]byte, 7) buf := &c.buff - for dataLen > 0 { - payloadLen := min(maxPayloadLen, dataLen) + for len(packets) > 0 { + payloadLen := min(maxPayloadLen, len(packets)) payload := packets[:payloadLen] uncompressedLen := payloadLen @@ -190,8 +170,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) { } else { zCompress(payload, buf) // do not compress if compressed data is larger than uncompressed data - // I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes. - if buf.Len() > uncompressedLen { + // I intentionally miss 7 byte header in the buf; compress more than 7 bytes. + if buf.Len() >= uncompressedLen { buf.Reset() buf.Write(blankHeader) buf.Write(payload) @@ -204,7 +184,6 @@ func (c *compIO) writePackets(packets []byte) (int, error) { // up compressed bytes that is returned by underlying Write(). return totalBytes - len(packets) + n, err } - dataLen -= payloadLen packets = packets[payloadLen:] } @@ -216,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) { func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) { mc := c.mc comprLength := len(data) - 7 - if debugTrace { + if debug { fmt.Printf( "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", comprLength, uncompressedLen, mc.compressSequence) @@ -227,8 +206,8 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - if n, err := mc.writeWithTimeout(data); err != nil { - // mc.log("writing compressed packet:", err) + n, err := mc.writeWithTimeout(data) + if err != nil { return n, err } diff --git a/const.go b/const.go index 51f82ea82..4aadcd642 100644 --- a/const.go +++ b/const.go @@ -11,14 +11,13 @@ package mysql import "runtime" const ( - debugTrace = false // for debugging wire protocol. + debug = false // for debugging. Set true only in development. defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" - minCompressLength = 150 // Connection attributes // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available diff --git a/packets.go b/packets.go index df9c61336..b36e6c20a 100644 --- a/packets.go +++ b/packets.go @@ -30,14 +30,14 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalid := false - readFunc := mc.buf.readNext + readNext := mc.buf.readNext if mc.compress { - readFunc = mc.compIO.readNext + readNext = mc.compIO.readNext } for { // read packet header - data, err := readFunc(4, mc.readWithTimeout) + data, err := readNext(4, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -49,19 +49,19 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := getUint24(data[:3]) - seqNr := data[3] + seq := data[3] if mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet. - if debugTrace && seqNr != mc.compressSequence { + if debug && seq != mc.compressSequence { fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", - mc.compressSequence, seqNr) + mc.compressSequence, seq) } - mc.compressSequence = seqNr + 1 + mc.compressSequence = seq + 1 } else { // check packet sync [8 bit] - if seqNr != mc.sequence { - mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr)) + if seq != mc.sequence { + mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq)) // For large packets, we stop reading as soon as sync error. if len(prevData) > 0 { mc.close() @@ -86,7 +86,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = readFunc(pktLen, mc.readWithTimeout) + data, err = readNext(pktLen, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -136,7 +136,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet - if debugTrace { + if debug { fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } From 39db0ba31e6653cb68f41887ef297d7a26626463 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 17:21:53 +0900 Subject: [PATCH 84/88] refactor --- compress_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/compress_test.go b/compress_test.go index c4bc927c9..030deaefa 100644 --- a/compress_test.go +++ b/compress_test.go @@ -35,7 +35,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by } // uncompressHelper uncompresses compressedPacket and checks state variables -func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []byte { // mocking out buf variable conn := new(mockConn) conn.data = compressedPacket @@ -47,16 +47,13 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) } } - if len(uncompressedPacket) != expSize { - t.Errorf("uncompressed size is unexpected. expected %d but got %d", expSize, len(uncompressedPacket)) - } return uncompressedPacket } // roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { compressed := compressHelper(t, cSend, uncompressedPacket) - return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) + return uncompressHelper(t, cReceive, compressed) } // TestRoundtrip tests two connections, where one is reading and the other is writing @@ -102,10 +99,13 @@ func TestRoundtrip(t *testing.T) { cReceive.resetSequence() uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if len(uncompressed) != len(test.uncompressed) { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", + len(test.uncompressed), len(uncompressed)) + } if !bytes.Equal(uncompressed, test.uncompressed) { - t.Fatalf("roundtrip failed") + t.Errorf("roundtrip failed") } - if cSend.sequence != cReceive.sequence { t.Errorf("inconsistent sequence number: send=%v recv=%v", cSend.sequence, cReceive.sequence) From 962608a79336fdbf1a92071d441e0f79ebce9789 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 17:26:06 +0900 Subject: [PATCH 85/88] refactor --- compress.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/compress.go b/compress.go index ea6524855..862b7adbb 100644 --- a/compress.go +++ b/compress.go @@ -206,11 +206,6 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - n, err := mc.writeWithTimeout(data) - if err != nil { - return n, err - } - mc.compressSequence++ - return n, nil + return mc.writeWithTimeout(data) } From 59d0d57ed228a052e4f671292fdd889d4a6201c0 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 17:40:29 +0900 Subject: [PATCH 86/88] refactor --- driver_test.go | 19 +++++++++++++------ packets.go | 26 ++++++++++++-------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/driver_test.go b/driver_test.go index 4753b2d08..e79d95567 100644 --- a/driver_test.go +++ b/driver_test.go @@ -150,9 +150,8 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() - - cleanup := func() { - db.Exec("DROP TABLE IF EXISTS test") + if err = db.Ping(); err != nil { + t.Fatalf("connecting %q: %s", dsn, err) } dsn2 := dsn + "&interpolateParams=true" @@ -173,23 +172,31 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } defer db3.Close() + cleanupSql := "DROP TABLE IF EXISTS test" + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} - t.Cleanup(cleanup) + t.Cleanup(func() { + db.Exec(cleanupSql) + }) test(dbt) }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} - t.Cleanup(cleanup) + t.Cleanup(func() { + db2.Exec(cleanupSql) + }) test(dbt2) }) } t.Run("compress", func(t *testing.T) { dbt3 := &DBTest{t, db3} - t.Cleanup(cleanup) + t.Cleanup(func() { + db3.Exec(cleanupSql) + }) test(dbt3) }) } diff --git a/packets.go b/packets.go index b36e6c20a..e4d2820ed 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte - invalid := false + invalidSequence := false readNext := mc.buf.readNext if mc.compress { @@ -67,8 +67,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.close() return nil, ErrPktSyncMul } - // TODO(methane): report error when the packet is not an error packet. - invalid = true + invalidSequence = true } mc.sequence++ } @@ -99,19 +98,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets - if prevData == nil { - if invalid { - mc.close() - // return sync error only for regular packet. - // error packets may have wrong sequence number. - if data[0] != iERR { - return nil, ErrPktSync - } + if prevData != nil { + data = append(prevData, data...) + } + if invalidSequence { + mc.close() + // return sync error only for regular packet. + // error packets may have wrong sequence number. + if data[0] != iERR { + return nil, ErrPktSync } - return data, nil } - - return append(prevData, data...), nil + return data, nil } prevData = append(prevData, data...) From d58a709317895261d40e45642c3b995b542a3225 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 18:14:46 +0900 Subject: [PATCH 87/88] cleanup --- driver_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/driver_test.go b/driver_test.go index e79d95567..58b3cb38d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1289,9 +1289,7 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - // const nonDataQueryLen = 28 // length query w/o value + compress header - const nonDataQueryLen = 100 - inS := in[:maxAllowedPacketSize-nonDataQueryLen] + inS := in[:maxAllowedPacketSize-100] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") defer rows.Close() From 994cd82fd7d44feabd93b6c5c636722bc6f6416a Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 14 Dec 2024 20:54:39 +0900 Subject: [PATCH 88/88] error check --- compress.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/compress.go b/compress.go index 862b7adbb..fa42772ac 100644 --- a/compress.go +++ b/compress.go @@ -64,9 +64,9 @@ func zCompress(src []byte, dst io.Writer) error { if _, err := zw.Write(src); err != nil { return err } - zw.Close() + err := zw.Close() zwPool.Put(zw) - return nil + return err } type compIO struct { @@ -168,10 +168,13 @@ func (c *compIO) writePackets(packets []byte) (int, error) { buf.Write(payload) uncompressedLen = 0 } else { - zCompress(payload, buf) + err := zCompress(payload, buf) + if debug && err != nil { + fmt.Printf("zCompress error: %v", err) + } // do not compress if compressed data is larger than uncompressed data - // I intentionally miss 7 byte header in the buf; compress more than 7 bytes. - if buf.Len() >= uncompressedLen { + // I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes. + if err != nil || buf.Len() >= uncompressedLen { buf.Reset() buf.Write(blankHeader) buf.Write(payload)