From e965e01f756ea39cd00834940c95b3fb8cc221a9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 16 Jun 2024 10:50:17 +0900 Subject: [PATCH 1/2] improve error handling in writePacket * handle error before success case. * return io.ErrShortWrite if not all bytes were written but err is nil. * return err instead of ErrInvalidConn. --- packets.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/packets.go b/packets.go index b90b14c5c..df850fd41 100644 --- a/packets.go +++ b/packets.go @@ -124,32 +124,32 @@ func (mc *mysqlConn) writePacket(data []byte) error { } n, err := mc.netConn.Write(data[:4+size]) - if err == nil && n == 4+size { - mc.sequence++ - if size != maxPacketSize { - return nil - } - pktLen -= size - data = data[size:] - continue - } - - // Handle error - if err == nil { // n != len(data) - mc.cleanup() - mc.log(ErrMalformPkt) - } else { + if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return cerr } + mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet + mc.log(err) return errBadConnNoWrite + } else { + return err } + } + if n != 4+size { + // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes. + // The io.ErrShortWrite error is used to indicate that this rule has not been followed. mc.cleanup() - mc.log(err) + return io.ErrShortWrite + } + + mc.sequence++ + if size != maxPacketSize { + return nil } - return ErrInvalidConn + pktLen -= size + data = data[size:] } } From 028f718beda6237457210ecb519067bd0cd08670 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 16 Jun 2024 13:25:26 +0900 Subject: [PATCH 2/2] fix tests --- connection_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/connection_test.go b/connection_test.go index c59cb6176..6f8d2a6d7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -163,6 +163,8 @@ func TestPingMarkBadConnection(t *testing.T) { netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + cfg: NewConfig(), } err := mc.Ping(context.Background()) @@ -184,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) { err := mc.Ping(context.Background()) - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %#v", err) + if err != nc.err { + t.Errorf("expected %#v, got %#v", nc.err, err) } }