From 538d5559c6f9fdf7534a9625dbd0bc90d88b4c25 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Sun, 26 Jan 2020 22:52:03 +0900 Subject: [PATCH] Add connection error test cases (#80) --- packet_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++ serve.go | 46 ++++++++++++++---------- 2 files changed, 124 insertions(+), 19 deletions(-) diff --git a/packet_test.go b/packet_test.go index bcb1c69..cc7f086 100644 --- a/packet_test.go +++ b/packet_test.go @@ -17,11 +17,13 @@ package mqtt import ( "bytes" "context" + "fmt" "io" "net" "reflect" "strings" "testing" + "time" "github.com/at-wat/mqtt-go/internal/errs" ) @@ -209,3 +211,98 @@ func TestPacketSendError(t *testing.T) { } }) } + +func TestConnectionError(t *testing.T) { + resps := [][]byte{ + {0x20, 0x02, 0x00, 0x00}, + {0x90, 0x03, 0x00, 0x01, 0x00}, + {0xB0, 0x02, 0x00, 0x02}, + {0xD0, 0x00}, + {}, + {0x40, 0x02, 0x00, 0x04}, + {0x50, 0x02, 0x00, 0x05}, + {0x70, 0x02, 0x00, 0x05}, + {}, + } + reqs := []func(ctx context.Context, cli *BaseClient) error{ + func(ctx context.Context, cli *BaseClient) error { + _, err := cli.Connect(ctx, "cli") + cli.idLast = 0 + return err + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Subscribe(ctx, Subscription{Topic: "test"}) + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Unsubscribe(ctx, "test") + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Ping(ctx) + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Publish(ctx, &Message{QoS: QoS0}) + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Publish(ctx, &Message{QoS: QoS1}) + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Publish(ctx, &Message{QoS: QoS2}) + }, + func(ctx context.Context, cli *BaseClient) error { + return cli.Disconnect(ctx) + }, + } + + cases := []struct { + closeAt int + errorAt int + }{ + {0, 0}, // CONNECT + {1, 1}, // SUBSCRIBE + {2, 2}, // UNSUBSCRIBE + {3, 3}, // PINGREQ + {4, 4}, // PUBLISH QoS0 + {5, 5}, // PUBLISH QoS1 + {6, 6}, // PUBLISH QoS2 PUBREC + {7, 6}, // PUBLISH QoS2 PUBCOMP + {8, 7}, // DISCONNECT + } + for _, c := range cases { + t.Run(fmt.Sprintf("CloseAt%dErrorAt%d", c.closeAt, c.errorAt), func(t *testing.T) { + ca, cb := net.Pipe() + cli := &BaseClient{Transport: ca} + + go func() { + defer cb.Close() + for i, resp := range resps { + if i == c.closeAt { + return + } + if _, _, _, err := readPacket(cb); err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if len(resp) > 0 { + io.Copy(cb, bytes.NewReader(resp)) + } + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for i, req := range reqs { + err := req(ctx, cli) + if i == c.errorAt { + if !errs.Is(err, io.ErrClosedPipe) { + t.Errorf("Expected error: '%v', got: '%v'", io.ErrClosedPipe, err) + } + break + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + } + }) + } +} diff --git a/serve.go b/serve.go index fc8f8b4..55d23b1 100644 --- a/serve.go +++ b/serve.go @@ -18,29 +18,37 @@ import ( "io" ) +func readPacket(r io.Reader) (packetType, byte, []byte, error) { + pktTypeBytes := make([]byte, 1) + if _, err := io.ReadFull(r, pktTypeBytes); err != nil { + return 0, 0, nil, err + } + pktType := packetType(pktTypeBytes[0] & 0xF0) + pktFlag := pktTypeBytes[0] & 0x0F + var remainingLength int + for shift := uint(0); ; shift += 7 { + b := make([]byte, 1) + if _, err := io.ReadFull(r, b); err != nil { + return 0, 0, nil, err + } + remainingLength |= (int(b[0]) & 0x7F) << shift + if !(b[0]&0x80 != 0) { + break + } + } + contents := make([]byte, remainingLength) + if _, err := io.ReadFull(r, contents); err != nil { + return 0, 0, nil, err + } + return pktType, pktFlag, contents, nil +} + func (c *BaseClient) serve() error { r := c.Transport subBuffer := make(map[uint16]*Message) for { - pktTypeBytes := make([]byte, 1) - if _, err := io.ReadFull(r, pktTypeBytes); err != nil { - return err - } - pktType := packetType(pktTypeBytes[0] & 0xF0) - pktFlag := pktTypeBytes[0] & 0x0F - var remainingLength int - for shift := uint(0); ; shift += 7 { - b := make([]byte, 1) - if _, err := io.ReadFull(r, b); err != nil { - return err - } - remainingLength |= (int(b[0]) & 0x7F) << shift - if !(b[0]&0x80 != 0) { - break - } - } - contents := make([]byte, remainingLength) - if _, err := io.ReadFull(r, contents); err != nil { + pktType, pktFlag, contents, err := readPacket(r) + if err != nil { return err } // fmt.Printf("%s: %v\n", pktType, contents)