Skip to content

Commit

Permalink
Add connection error test cases (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Jan 26, 2020
1 parent a2ad0a3 commit 538d555
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 19 deletions.
97 changes: 97 additions & 0 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package mqtt
import (
"bytes"
"context"
"fmt"
"io"
"net"
"reflect"
"strings"
"testing"
"time"

"github.com/at-wat/mqtt-go/internal/errs"
)
Expand Down Expand Up @@ -209,3 +211,98 @@ func TestPacketSendError(t *testing.T) {
}
})
}

func TestConnectionError(t *testing.T) {
resps := [][]byte{
{0x20, 0x02, 0x00, 0x00},
{0x90, 0x03, 0x00, 0x01, 0x00},
{0xB0, 0x02, 0x00, 0x02},
{0xD0, 0x00},
{},
{0x40, 0x02, 0x00, 0x04},
{0x50, 0x02, 0x00, 0x05},
{0x70, 0x02, 0x00, 0x05},
{},
}
reqs := []func(ctx context.Context, cli *BaseClient) error{
func(ctx context.Context, cli *BaseClient) error {
_, err := cli.Connect(ctx, "cli")
cli.idLast = 0
return err
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Subscribe(ctx, Subscription{Topic: "test"})
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Unsubscribe(ctx, "test")
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Ping(ctx)
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Publish(ctx, &Message{QoS: QoS0})
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Publish(ctx, &Message{QoS: QoS1})
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Publish(ctx, &Message{QoS: QoS2})
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Disconnect(ctx)
},
}

cases := []struct {
closeAt int
errorAt int
}{
{0, 0}, // CONNECT
{1, 1}, // SUBSCRIBE
{2, 2}, // UNSUBSCRIBE
{3, 3}, // PINGREQ
{4, 4}, // PUBLISH QoS0
{5, 5}, // PUBLISH QoS1
{6, 6}, // PUBLISH QoS2 PUBREC
{7, 6}, // PUBLISH QoS2 PUBCOMP
{8, 7}, // DISCONNECT
}
for _, c := range cases {
t.Run(fmt.Sprintf("CloseAt%dErrorAt%d", c.closeAt, c.errorAt), func(t *testing.T) {
ca, cb := net.Pipe()
cli := &BaseClient{Transport: ca}

go func() {
defer cb.Close()
for i, resp := range resps {
if i == c.closeAt {
return
}
if _, _, _, err := readPacket(cb); err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if len(resp) > 0 {
io.Copy(cb, bytes.NewReader(resp))
}
}
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
for i, req := range reqs {
err := req(ctx, cli)
if i == c.errorAt {
if !errs.Is(err, io.ErrClosedPipe) {
t.Errorf("Expected error: '%v', got: '%v'", io.ErrClosedPipe, err)
}
break
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
}
})
}
}
46 changes: 27 additions & 19 deletions serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,37 @@ import (
"io"
)

func readPacket(r io.Reader) (packetType, byte, []byte, error) {
pktTypeBytes := make([]byte, 1)
if _, err := io.ReadFull(r, pktTypeBytes); err != nil {
return 0, 0, nil, err
}
pktType := packetType(pktTypeBytes[0] & 0xF0)
pktFlag := pktTypeBytes[0] & 0x0F
var remainingLength int
for shift := uint(0); ; shift += 7 {
b := make([]byte, 1)
if _, err := io.ReadFull(r, b); err != nil {
return 0, 0, nil, err
}
remainingLength |= (int(b[0]) & 0x7F) << shift
if !(b[0]&0x80 != 0) {
break
}
}
contents := make([]byte, remainingLength)
if _, err := io.ReadFull(r, contents); err != nil {
return 0, 0, nil, err
}
return pktType, pktFlag, contents, nil
}

func (c *BaseClient) serve() error {
r := c.Transport
subBuffer := make(map[uint16]*Message)
for {
pktTypeBytes := make([]byte, 1)
if _, err := io.ReadFull(r, pktTypeBytes); err != nil {
return err
}
pktType := packetType(pktTypeBytes[0] & 0xF0)
pktFlag := pktTypeBytes[0] & 0x0F
var remainingLength int
for shift := uint(0); ; shift += 7 {
b := make([]byte, 1)
if _, err := io.ReadFull(r, b); err != nil {
return err
}
remainingLength |= (int(b[0]) & 0x7F) << shift
if !(b[0]&0x80 != 0) {
break
}
}
contents := make([]byte, remainingLength)
if _, err := io.ReadFull(r, contents); err != nil {
pktType, pktFlag, contents, err := readPacket(r)
if err != nil {
return err
}
// fmt.Printf("%s: %v\n", pktType, contents)
Expand Down

0 comments on commit 538d555

Please sign in to comment.