diff --git a/buffer.go b/buffer.go index f7126281..935b6389 100644 --- a/buffer.go +++ b/buffer.go @@ -23,10 +23,11 @@ const maxCachedBufSize = 256 * 1024 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte // read buffer. - cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. - nc net.Conn - timeout time.Duration + buf []byte // read buffer. + cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. + nc net.Conn + readTimeout time.Duration + writeTimeout time.Duration } // newBuffer allocates and returns a new buffer. @@ -64,8 +65,8 @@ func (b *buffer) fill(need int) error { copy(dest[:n], b.buf) for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + if b.readTimeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.readTimeout)); err != nil { return err } } @@ -159,5 +160,10 @@ func (b *buffer) store(buf []byte) { // writePackets is a proxy function to nc.Write. // This is used to make the buffer type compatible with compressed I/O. func (b *buffer) writePackets(packets []byte) (int, error) { + if b.writeTimeout > 0 { + if err := b.nc.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil { + return 0, err + } + } return b.nc.Write(packets) } diff --git a/compress.go b/compress.go index 84c47678..937d0cf0 100644 --- a/compress.go +++ b/compress.go @@ -195,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) { } } - if err := c.mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { return 0, err } dataLen -= payloadLen @@ -207,7 +207,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) { // writeCompressedPacket writes a compressed packet with header. // data should start with 7 size space for header followed by payload. -func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error { + mc := c.mc comprLength := len(data) - 7 if debugTrace { fmt.Printf( @@ -220,7 +221,7 @@ func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) err data[3] = mc.compressSequence putUint24(data[4:7], uncompressedLen) - if _, err := mc.netConn.Write(data); err != nil { + if _, err := mc.buf.writePackets(data); err != nil { mc.log("writing compressed packet:", err) return err } diff --git a/connection.go b/connection.go index 6c026ab4..d95369ad 100644 --- a/connection.go +++ b/connection.go @@ -33,7 +33,6 @@ type mysqlConn struct { connector *connector maxAllowedPacket int maxWriteSize int - writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 diff --git a/connector.go b/connector.go index fc826750..3d37c22f 100644 --- a/connector.go +++ b/connector.go @@ -132,8 +132,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.packetRW = &mc.buf // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout + mc.buf.readTimeout = mc.cfg.ReadTimeout + mc.buf.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() diff --git a/packets.go b/packets.go index 5e5832f1..b5c8d9dd 100644 --- a/packets.go +++ b/packets.go @@ -125,13 +125,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { if debugTrace { fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() - mc.log(err) - return err - } - } n, err := mc.packetRW.writePackets(data[:4+size]) if err != nil {