diff --git a/client_test.go b/client_test.go index 5207d8c..d902bf2 100644 --- a/client_test.go +++ b/client_test.go @@ -15,54 +15,12 @@ package mqtt import ( - "bytes" "context" - "errors" "net" "testing" "time" ) -func TestConnect(t *testing.T) { - ca, cb := net.Pipe() - cli := &BaseClient{Transport: cb} - - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, _ = cli.Connect(ctx, "cli", - WithUserNamePassword("user", "pass"), - WithKeepAlive(0x0123), - WithCleanSession(true), - WithProtocolLevel(ProtocolLevel4), - WithWill(&Message{QoS: QoS1, Topic: "topic", Payload: []byte{0x01}}), - ) - }() - - b := make([]byte, 100) - n, err := ca.Read(b) - if err != nil { - t.Fatalf("Unexpected error: '%v'", err) - } - - expected := []byte{ - 0x10, // CONNECT - 0x25, - 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT - 0x04, // 3.1.1 - 0xCE, 0x01, 0x23, // flags, keepalive - 0x00, 0x03, 0x63, 0x6C, 0x69, // cli - 0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic - 0x00, 0x01, 0x01, // payload - 0x00, 0x04, 0x75, 0x73, 0x65, 0x72, // user - 0x00, 0x04, 0x70, 0x61, 0x73, 0x73, // pass - } - if !bytes.Equal(expected, b[:n]) { - t.Fatalf("Expected CONNECT packet: \n '%v',\ngot: \n '%v'", expected, b[:n]) - } - cli.Close() -} - func TestProtocolViolation(t *testing.T) { ca, cb := net.Pipe() cli := &BaseClient{Transport: cb} @@ -112,19 +70,3 @@ func TestProtocolViolation(t *testing.T) { t.Error("Timeout") } } - -func TestConnect_OptionsError(t *testing.T) { - errExpected := errors.New("an error") - sessionPresent, err := (&BaseClient{}).Connect( - context.Background(), "cli", - func(*ConnectOptions) error { - return errExpected - }, - ) - if err != errExpected { - t.Errorf("Expected error: ''%v'', got: ''%v''", errExpected, err) - } - if sessionPresent { - t.Errorf("SessionPresent flag must not be set on options error") - } -} diff --git a/connect.go b/connect.go index dc740ba..cc560d5 100644 --- a/connect.go +++ b/connect.go @@ -41,82 +41,102 @@ const ( connectFlagUserName connectFlag = 0x80 ) -// Connect to the broker. -func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) { - o := &ConnectOptions{ - ProtocolLevel: ProtocolLevel4, - } - for _, opt := range opts { - if err := opt(o); err != nil { - return false, err - } - } - c.mu.Lock() - c.sig = &signaller{} - c.connClosed = make(chan struct{}) - c.initID() - c.mu.Unlock() +type pktConnect struct { + ProtocolLevel ProtocolLevel + CleanSession bool + KeepAlive uint16 + ClientID string + UserName string + Password string + Will *Message +} - go func() { - err := c.serve() - if errConn := c.Close(); errConn != nil && err == nil { - err = errConn - } - c.mu.Lock() - if c.connState != StateDisconnected { - c.err = err - } - c.mu.Unlock() - c.connStateUpdate(StateClosed) - close(c.connClosed) - }() - payload := packString(clientID) +func (p *pktConnect) pack() []byte { + payload := packString(p.ClientID) var flag byte - if o.CleanSession { + if p.CleanSession { flag |= byte(connectFlagCleanSession) } - if o.Will != nil { + if p.Will != nil { flag |= byte(connectFlagWill) - switch o.Will.QoS { + switch p.Will.QoS { case QoS0: flag |= byte(connectFlagWillQoS0) case QoS1: flag |= byte(connectFlagWillQoS1) case QoS2: flag |= byte(connectFlagWillQoS2) - default: - panic("invalid QoS") } - if o.Will.Retain { + if p.Will.Retain { flag |= byte(connectFlagWillRetain) } - payload = append(payload, packString(o.Will.Topic)...) - payload = append(payload, packBytes(o.Will.Payload)...) + payload = append(payload, packString(p.Will.Topic)...) + payload = append(payload, packBytes(p.Will.Payload)...) } - if o.UserName != "" { + if p.UserName != "" { flag |= byte(connectFlagUserName) - payload = append(payload, packString(o.UserName)...) + payload = append(payload, packString(p.UserName)...) } - if o.Password != "" { + if p.Password != "" { flag |= byte(connectFlagPassword) - payload = append(payload, packString(o.Password)...) + payload = append(payload, packString(p.Password)...) } - pkt := pack( + return pack( packetConnect.b(), []byte{ 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, - byte(o.ProtocolLevel), + byte(p.ProtocolLevel), flag, }, - packUint16(o.KeepAlive), + packUint16(p.KeepAlive), payload, ) +} + +// Connect to the broker. +func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) { + o := &ConnectOptions{ + ProtocolLevel: ProtocolLevel4, + } + for _, opt := range opts { + if err := opt(o); err != nil { + return false, err + } + } + c.sig = &signaller{} + c.connClosed = make(chan struct{}) + c.initID() + + go func() { + err := c.serve() + if errConn := c.Close(); errConn != nil && err == nil { + err = errConn + } + c.mu.Lock() + if c.connState != StateDisconnected { + c.err = err + } + c.mu.Unlock() + c.connStateUpdate(StateClosed) + close(c.connClosed) + }() chConnAck := make(chan *pktConnAck, 1) c.mu.Lock() c.sig.chConnAck = chConnAck c.mu.Unlock() + + pkt := (&pktConnect{ + ProtocolLevel: o.ProtocolLevel, + CleanSession: o.CleanSession, + KeepAlive: o.KeepAlive, + ClientID: clientID, + UserName: o.UserName, + Password: o.Password, + Will: o.Will, + }).pack() + if err := c.write(pkt); err != nil { return false, err } @@ -175,6 +195,11 @@ func WithCleanSession(cleanSession bool) ConnectOption { // WithWill sets will message. func WithWill(will *Message) ConnectOption { return func(o *ConnectOptions) error { + switch will.QoS { + case QoS0, QoS1, QoS2: + default: + return ErrInvalidPacket + } o.Will = will return nil } diff --git a/connect_test.go b/connect_test.go new file mode 100644 index 0000000..d3fc0d1 --- /dev/null +++ b/connect_test.go @@ -0,0 +1,139 @@ +// Copyright 2019 The mqtt-go authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mqtt + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + "time" +) + +func TestConnect(t *testing.T) { + cases := map[string]struct { + opts []ConnectOption + expected []byte + }{ + "UserPassCleanWill": { + opts: []ConnectOption{ + WithUserNamePassword("user", "pass"), + WithKeepAlive(0x0123), + WithCleanSession(true), + WithProtocolLevel(ProtocolLevel4), + WithWill(&Message{QoS: QoS1, Topic: "topic", Payload: []byte{0x01}}), + }, + expected: []byte{ + 0x10, // CONNECT + 0x25, + 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT + 0x04, // 3.1.1 + 0xCE, 0x01, 0x23, // flags, keepalive + 0x00, 0x03, 0x63, 0x6C, 0x69, // cli + 0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic + 0x00, 0x01, 0x01, // payload + 0x00, 0x04, 0x75, 0x73, 0x65, 0x72, // user + 0x00, 0x04, 0x70, 0x61, 0x73, 0x73, // pass + }, + }, + "WillQoS0": { + opts: []ConnectOption{ + WithKeepAlive(0x0123), + WithWill(&Message{QoS: QoS0, Topic: "topic", Payload: []byte{0x01}}), + }, + expected: []byte{ + 0x10, // CONNECT + 0x19, + 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT + 0x04, // 3.1.1 + 0x04, 0x01, 0x23, // flags, keepalive + 0x00, 0x03, 0x63, 0x6C, 0x69, // cli + 0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic + 0x00, 0x01, 0x01, // payload + }, + }, + "WillQoS2Retain": { + opts: []ConnectOption{ + WithKeepAlive(0x0123), + WithWill(&Message{QoS: QoS2, Retain: true, Topic: "topic", Payload: []byte{0x01}}), + }, + expected: []byte{ + 0x10, // CONNECT + 0x19, + 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT + 0x04, // 3.1.1 + 0x34, 0x01, 0x23, // flags, keepalive + 0x00, 0x03, 0x63, 0x6C, 0x69, // cli + 0x00, 0x05, 0x74, 0x6F, 0x70, 0x69, 0x63, // topic + 0x00, 0x01, 0x01, // payload + }, + }, + "ProtocolLv3": { + opts: []ConnectOption{ + WithKeepAlive(0x0123), + WithProtocolLevel(ProtocolLevel3), + }, + expected: []byte{ + 0x10, // CONNECT + 0x0F, + 0x00, 0x04, 0x4D, 0x51, 0x54, 0x54, // MQTT + 0x03, // 3.1.1 + 0x00, 0x01, 0x23, // flags, keepalive + 0x00, 0x03, 0x63, 0x6C, 0x69, // cli + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + ca, cb := net.Pipe() + cli := &BaseClient{Transport: cb} + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, _ = cli.Connect(ctx, "cli", c.opts...) + }() + + b := make([]byte, 100) + n, err := ca.Read(b) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + if !bytes.Equal(c.expected, b[:n]) { + t.Fatalf("Expected CONNECT packet: \n '%v',\ngot: \n '%v'", c.expected, b[:n]) + } + cli.Close() + }) + } +} + +func TestConnect_OptionsError(t *testing.T) { + errExpected := errors.New("an error") + sessionPresent, err := (&BaseClient{}).Connect( + context.Background(), "cli", + func(*ConnectOptions) error { + return errExpected + }, + ) + if err != errExpected { + t.Errorf("Expected error: ''%v'', got: ''%v''", errExpected, err) + } + if sessionPresent { + t.Errorf("SessionPresent flag must not be set on options error") + } +} diff --git a/disconnect.go b/disconnect.go index 774b1c3..f08eb88 100644 --- a/disconnect.go +++ b/disconnect.go @@ -20,11 +20,7 @@ import ( // Disconnect from the broker. func (c *BaseClient) Disconnect(ctx context.Context) error { - pkt := pack( - packetDisconnect.b(), - []byte{}, - []byte{}, - ) + pkt := pack(packetDisconnect.b()) c.connStateUpdate(StateDisconnected) if err := c.write(pkt); err != nil { return err diff --git a/pingreq.go b/pingreq.go index a29b867..e91a1ff 100644 --- a/pingreq.go +++ b/pingreq.go @@ -20,13 +20,12 @@ import ( // Ping to the broker. func (c *BaseClient) Ping(ctx context.Context) error { - pkt := pack(packetPingReq.b()) - chPingResp := make(chan *pktPingResp, 1) c.mu.Lock() c.sig.chPingResp = chPingResp c.mu.Unlock() + pkt := pack(packetPingReq.b()) if err := c.write(pkt); err != nil { return err } diff --git a/puback.go b/puback.go index 368782c..cc4790b 100644 --- a/puback.go +++ b/puback.go @@ -28,3 +28,10 @@ func (p *pktPubAck) parse(flag byte, contents []byte) (*pktPubAck, error) { _, p.ID = unpackUint16(contents) return p, nil } + +func (p *pktPubAck) pack() []byte { + return pack( + packetPubAck.b()|packetFromClient.b(), + packUint16(p.ID), + ) +} diff --git a/pubcomp.go b/pubcomp.go index 0fb3d2b..022adc6 100644 --- a/pubcomp.go +++ b/pubcomp.go @@ -28,3 +28,10 @@ func (p *pktPubComp) parse(flag byte, contents []byte) (*pktPubComp, error) { _, p.ID = unpackUint16(contents) return p, nil } + +func (p *pktPubComp) pack() []byte { + return pack( + packetPubComp.b()|packetFromClient.b(), + packUint16(p.ID), + ) +} diff --git a/publish.go b/publish.go index 7789940..7314690 100644 --- a/publish.go +++ b/publish.go @@ -29,16 +29,50 @@ const ( publishFlagDup publishFlag = 0x08 ) -// Publish a message to the broker. -// ID field of the message is filled if zero. -func (c *BaseClient) Publish(ctx context.Context, message *Message) error { +type pktPublish struct { + Message *Message +} + +func (p *pktPublish) parse(flag byte, contents []byte) (*pktPublish, error) { + p.Message = &Message{ + Dup: (publishFlag(flag) & publishFlagDup) != 0, + Retain: (publishFlag(flag) & publishFlagRetain) != 0, + } + switch publishFlag(flag) & publishFlagQoSMask { + case publishFlagQoS0: + p.Message.QoS = QoS0 + case publishFlagQoS1: + p.Message.QoS = QoS1 + case publishFlagQoS2: + p.Message.QoS = QoS2 + default: + return nil, ErrInvalidPacket + } + + var n, nID int + var err error + n, p.Message.Topic, err = unpackString(contents) + if err != nil { + return nil, err + } + if p.Message.QoS != QoS0 { + if len(contents)-n < 2 { + return nil, ErrInvalidPacketLength + } + nID, p.Message.ID = unpackUint16(contents[n:]) + } + p.Message.Payload = contents[n+nID:] + + return p, nil +} + +func (p *pktPublish) pack() []byte { pktHeader := packetPublish.b() - header := packString(message.Topic) - if message.Retain { + if p.Message.Retain { pktHeader |= byte(publishFlagRetain) } - switch message.QoS { + switch p.Message.QoS { case QoS0: pktHeader |= byte(publishFlagQoS0) case QoS1: @@ -48,17 +82,28 @@ func (c *BaseClient) Publish(ctx context.Context, message *Message) error { default: panic("invalid QoS") } - if message.Dup { + if p.Message.Dup { pktHeader |= byte(publishFlagDup) } - if message.ID == 0 { - message.ID = c.newID() + + header := packString(p.Message.Topic) + if p.Message.QoS != QoS0 { + header = append(header, packUint16(p.Message.ID)...) } - if message.QoS != QoS0 { - header = append(header, packUint16(message.ID)...) + return pack( + pktHeader, + header, + p.Message.Payload, + ) +} + +// Publish a message to the broker. +// ID field of the message is filled if zero. +func (c *BaseClient) Publish(ctx context.Context, message *Message) error { + if message.ID == 0 { + message.ID = c.newID() } - pkt := pack(pktHeader, header, message.Payload) var chPubAck chan *pktPubAck var chPubRec chan *pktPubRec @@ -87,6 +132,7 @@ func (c *BaseClient) Publish(ctx context.Context, message *Message) error { c.sig.mu.Unlock() } + pkt := (&pktPublish{Message: message}).pack() if err := c.write(pkt); err != nil { return err } @@ -107,7 +153,7 @@ func (c *BaseClient) Publish(ctx context.Context, message *Message) error { return ctx.Err() case <-chPubRec: } - pktPubRel := pack(packetPubRel.b()|packetFromClient.b(), packUint16(message.ID)) + pktPubRel := (&pktPubRel{ID: message.ID}).pack() if err := c.write(pktPubRel); err != nil { return err } @@ -121,38 +167,3 @@ func (c *BaseClient) Publish(ctx context.Context, message *Message) error { } return nil } - -type pktPublish struct { - Message -} - -func (p *pktPublish) parse(flag byte, contents []byte) (*pktPublish, error) { - p.Message.Dup = (publishFlag(flag) & publishFlagDup) != 0 - p.Message.Retain = (publishFlag(flag) & publishFlagRetain) != 0 - switch publishFlag(flag) & publishFlagQoSMask { - case publishFlagQoS0: - p.Message.QoS = QoS0 - case publishFlagQoS1: - p.Message.QoS = QoS1 - case publishFlagQoS2: - p.Message.QoS = QoS2 - default: - return nil, ErrInvalidPacket - } - - var n, nID int - var err error - n, p.Message.Topic, err = unpackString(contents) - if err != nil { - return nil, err - } - if p.Message.QoS != QoS0 { - if len(contents)-n < 2 { - return nil, ErrInvalidPacketLength - } - nID, p.Message.ID = unpackUint16(contents[n:]) - } - p.Message.Payload = contents[n+nID:] - - return p, nil -} diff --git a/pubrec.go b/pubrec.go index ce07a5b..7c2250d 100644 --- a/pubrec.go +++ b/pubrec.go @@ -28,3 +28,10 @@ func (p *pktPubRec) parse(flag byte, contents []byte) (*pktPubRec, error) { _, p.ID = unpackUint16(contents) return p, nil } + +func (p *pktPubRec) pack() []byte { + return pack( + packetPubRec.b()|packetFromClient.b(), + packUint16(p.ID), + ) +} diff --git a/pubrel.go b/pubrel.go index cf7983e..4e1580c 100644 --- a/pubrel.go +++ b/pubrel.go @@ -28,3 +28,10 @@ func (p *pktPubRel) parse(flag byte, contents []byte) (*pktPubRel, error) { _, p.ID = unpackUint16(contents) return p, nil } + +func (p *pktPubRel) pack() []byte { + return pack( + packetPubRel.b()|packetFromClient.b(), + packUint16(p.ID), + ) +} diff --git a/serve.go b/serve.go index ba358ca..85de2e5 100644 --- a/serve.go +++ b/serve.go @@ -67,7 +67,7 @@ func (c *BaseClient) serve() error { handler := c.handler c.mu.RUnlock() if handler != nil { - handler.Serve(&publish.Message) + handler.Serve(publish.Message) } case QoS1: // Ownership of the message is now transferred to the receiver. @@ -75,24 +75,18 @@ func (c *BaseClient) serve() error { handler := c.handler c.mu.RUnlock() if handler != nil { - handler.Serve(&publish.Message) + handler.Serve(publish.Message) } - pktPubAck := pack( - packetPubAck.b()|packetFromClient.b(), - packUint16(publish.Message.ID), - ) + pktPubAck := (&pktPubAck{ID: publish.Message.ID}).pack() if err := c.write(pktPubAck); err != nil { return err } case QoS2: - pktPubRec := pack( - packetPubRec.b()|packetFromClient.b(), - packUint16(publish.Message.ID), - ) + pktPubRec := (&pktPubRec{ID: publish.Message.ID}).pack() if err := c.write(pktPubRec); err != nil { return err } - subBuffer[publish.Message.ID] = &publish.Message + subBuffer[publish.Message.ID] = publish.Message } case packetPubAck: pubAck, err := (&pktPubAck{}).parse(pktFlag, contents) @@ -132,10 +126,7 @@ func (c *BaseClient) serve() error { delete(subBuffer, pubRel.ID) } - pktPubComp := pack( - packetPubComp.b()|packetFromClient.b(), - packUint16(pubRel.ID), - ) + pktPubComp := (&pktPubComp{ID: pubRel.ID}).pack() if err := c.write(pktPubComp); err != nil { return err } diff --git a/subscribe.go b/subscribe.go index cc10d4a..8e825bd 100644 --- a/subscribe.go +++ b/subscribe.go @@ -30,15 +30,14 @@ const ( subscribeFlagQoS2 subscribeFlag = 0x02 ) -// Subscribe topics. -func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error { - pktHeader := byte(packetSubscribe | packetFromClient) - - id := c.newID() - header := packUint16(id) +type pktSubscribe struct { + ID uint16 + Subscriptions []Subscription +} +func (p *pktSubscribe) pack() []byte { var payload []byte - for _, sub := range subs { + for _, sub := range p.Subscriptions { payload = append(payload, packString(sub.Topic)...) var flag byte @@ -54,7 +53,16 @@ func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error } payload = append(payload, flag) } - pkt := pack(pktHeader, header, payload) + return pack( + byte(packetSubscribe|packetFromClient), + packUint16(p.ID), + payload, + ) +} + +// Subscribe topics. +func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error { + id := c.newID() chSubAck := make(chan *pktSubAck, 1) c.sig.mu.Lock() @@ -64,6 +72,7 @@ func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error c.sig.chSubAck[id] = chSubAck c.sig.mu.Unlock() + pkt := (&pktSubscribe{ID: id, Subscriptions: subs}).pack() if err := c.write(pkt); err != nil { return err } diff --git a/unsubscribe.go b/unsubscribe.go index 2548f5a..42d6da2 100644 --- a/unsubscribe.go +++ b/unsubscribe.go @@ -18,18 +18,27 @@ import ( "context" ) -// Unsubscribe topics. -func (c *BaseClient) Unsubscribe(ctx context.Context, subs ...string) error { - pktHeader := byte(packetUnsubscribe | packetFromClient) - - id := c.newID() - header := packUint16(id) +type pktUnsubscribe struct { + ID uint16 + Topics []string +} +func (p *pktUnsubscribe) pack() []byte { var payload []byte - for _, sub := range subs { + for _, sub := range p.Topics { payload = append(payload, packString(sub)...) } - pkt := pack(pktHeader, header, payload) + + return pack( + byte(packetUnsubscribe|packetFromClient), + packUint16(p.ID), + payload, + ) +} + +// Unsubscribe topics. +func (c *BaseClient) Unsubscribe(ctx context.Context, subs ...string) error { + id := c.newID() chUnsubAck := make(chan *pktUnsubAck, 1) c.sig.mu.Lock() @@ -39,6 +48,7 @@ func (c *BaseClient) Unsubscribe(ctx context.Context, subs ...string) error { c.sig.chUnsubAck[id] = chUnsubAck c.sig.mu.Unlock() + pkt := (&pktUnsubscribe{ID: id, Topics: subs}).pack() if err := c.write(pkt); err != nil { return err }