Skip to content

Commit

Permalink
Refactoring (#34)
Browse files Browse the repository at this point in the history
* Refactoring
* Fix example
* Add tests
* Fix default option handling
* Fix paho.IsConnected
  • Loading branch information
at-wat authored Dec 24, 2019
1 parent 21e2736 commit f031e20
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 86 deletions.
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ script:
-tags integration \
-race -coverprofile=coverage.txt -covermode=atomic
- (cd paho; go vet ./...)
- |
(cd paho; go test $(go list ./...) \
-v \
-tags integration \
-race -coverprofile=coverage.txt -covermode=atomic)
after_script:
- docker-compose down
Expand Down
8 changes: 8 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ type Dialer interface {
Dial() (ClientCloser, error)
}

// DialerFunc type is an adapter to use functions as MQTT connection dialer.
type DialerFunc func() (ClientCloser, error)

// Dial calls d().
func (d DialerFunc) Dial() (ClientCloser, error) {
return d()
}

// Dial creates connection using its values.
func (d *URLDialer) Dial() (ClientCloser, error) {
return Dial(d.URL, d.Options...)
Expand Down
8 changes: 1 addition & 7 deletions examples/wss-presign-url/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
"github.com/at-wat/mqtt-go"
)

type dialerFunc func() (mqtt.ClientCloser, error)

func (d dialerFunc) Dial() (mqtt.ClientCloser, error) {
return d()
}

func main() {
if len(os.Args) < 2 {
fmt.Printf("usage: %s server-host.domain\n", os.Args[0])
Expand All @@ -44,7 +38,7 @@ func main() {

cli, err := mqtt.NewReconnectClient(ctx,
// Dialer to connect/reconnect to the server.
dialerFunc(func() (mqtt.ClientCloser, error) {
mqtt.DialerFunc(func() (mqtt.ClientCloser, error) {
// Presign URL here.
url := fmt.Sprintf("wss://%s:9443?token=%x",
host, time.Now().UnixNano(),
Expand Down
61 changes: 61 additions & 0 deletions paho/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2019 The mqtt-go authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mqtt

import (
"github.com/at-wat/mqtt-go"
paho "github.com/eclipse/paho.mqtt.golang"
)

func (c *pahoWrapper) wrapMessageHandler(h paho.MessageHandler) mqtt.Handler {
return mqtt.HandlerFunc(func(m *mqtt.Message) {
h(c, wrapMessage(m))
})
}

func wrapMessage(msg *mqtt.Message) paho.Message {
return &wrappedMessage{msg}
}

type wrappedMessage struct {
*mqtt.Message
}

func (m *wrappedMessage) Duplicate() bool {
return m.Message.Dup
}

func (m *wrappedMessage) Qos() byte {
return byte(m.Message.QoS)
}

func (m *wrappedMessage) Retained() bool {
return m.Message.Retain
}

func (m *wrappedMessage) Topic() string {
return m.Message.Topic
}

func (m *wrappedMessage) MessageID() uint16 {
return m.Message.ID
}

func (m *wrappedMessage) Payload() []byte {
return m.Message.Payload
}

func (m *wrappedMessage) Ack() {
}
49 changes: 49 additions & 0 deletions paho/message_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package mqtt

import (
"bytes"
"testing"

"github.com/at-wat/mqtt-go"
paho "github.com/eclipse/paho.mqtt.golang"
)

func TestWrapMessageHandler(t *testing.T) {
msg := make(chan paho.Message, 100)
ph := func(c paho.Client, m paho.Message) {
msg <- m
}
h := (&pahoWrapper{}).wrapMessageHandler(ph)

h.Serve(&mqtt.Message{
Topic: "topic",
QoS: mqtt.QoS1,
Payload: []byte{0x01, 0x02},
Dup: true,
Retain: true,
ID: 0x1234,
})

if len(msg) != 1 {
t.Fatalf("Expected number of handled messages: 1, got: %d", len(msg))
}
m := <-msg
if m.Topic() != "topic" {
t.Errorf("Expected topic: 'topic', got: '%s'", m.Topic())
}
if m.Qos() != 1 {
t.Errorf("Expected QoS: 1, got: %d", m.Qos())
}
if !bytes.Equal([]byte{0x01, 0x02}, m.Payload()) {
t.Errorf("Expected payload: [1, 2], got: %v", m.Payload())
}
if !m.Duplicate() {
t.Errorf("Expected dup: true, got: %v", m.Duplicate())
}
if !m.Retained() {
t.Errorf("Expected retain: true, got: %v", m.Retained())
}
if m.MessageID() != 0x1234 {
t.Errorf("Expected ID: 1234, got: %x", m.MessageID())
}
}
93 changes: 14 additions & 79 deletions paho/paho.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,87 +50,14 @@ func NewClient(o *paho.ClientOptions) paho.Client {
return w
}

type token struct {
err error
done chan struct{}
}

func newToken() *token {
return &token{
done: make(chan struct{}),
}
}

func (t *token) release() {
close(t.done)
}

func (t *token) Wait() bool {
<-t.done
return true
}

func (t *token) WaitTimeout(d time.Duration) bool {
select {
case <-t.done:
return true
case <-time.After(d):
return false
}
}

func (t *token) Error() error {
return t.err
}

type wrappedMessage struct {
*mqtt.Message
}

func (m *wrappedMessage) Duplicate() bool {
return m.Message.Dup
}

func (m *wrappedMessage) Qos() byte {
return byte(m.Message.QoS)
}

func (m *wrappedMessage) Retained() bool {
return m.Message.Retain
}

func (m *wrappedMessage) Topic() string {
return m.Message.Topic
}

func (m *wrappedMessage) MessageID() uint16 {
return m.Message.ID
}

func (m *wrappedMessage) Payload() []byte {
return m.Message.Payload
}

func (m *wrappedMessage) Ack() {
}

func wrapMessage(msg *mqtt.Message) paho.Message {
return &wrappedMessage{msg}
}

func (c *pahoWrapper) wrapMessageHandler(h paho.MessageHandler) mqtt.Handler {
return mqtt.HandlerFunc(func(m *mqtt.Message) {
h(c, wrapMessage(m))
})
}

func (c *pahoWrapper) IsConnected() bool {
select {
case <-c.cli.Done():
default:
if c.cli.Err() != nil {
if c.cli.Err() == nil {
return true
}
default:
return true
}
return false
}
Expand Down Expand Up @@ -159,9 +86,13 @@ func (c *pahoWrapper) Connect() paho.Token {
cli.ConnState = func(s mqtt.ConnState, err error) {
switch s {
case mqtt.StateActive:
c.pahoConfig.OnConnect(c)
if c.pahoConfig.OnConnect != nil {
c.pahoConfig.OnConnect(c)
}
case mqtt.StateClosed:
c.pahoConfig.OnConnectionLost(c, err)
if c.pahoConfig.OnConnectionLost != nil {
c.pahoConfig.OnConnectionLost(c, err)
}
}
}
cli.Handle(c.serveMux)
Expand All @@ -173,7 +104,11 @@ func (c *pahoWrapper) Connect() paho.Token {
mqtt.WithUserNamePassword(c.pahoConfig.Username, c.pahoConfig.Password),
mqtt.WithCleanSession(c.pahoConfig.CleanSession),
mqtt.WithKeepAlive(uint16(c.pahoConfig.KeepAlive)),
mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)),
}
if c.pahoConfig.ProtocolVersion > 0 {
opts = append(opts,
mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)),
)
}
if c.pahoConfig.WillEnabled {
opts = append(opts, mqtt.WithWill(&mqtt.Message{
Expand Down
83 changes: 83 additions & 0 deletions paho/paho_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// +build integration

// Copyright 2019 The mqtt-go authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mqtt

import (
"bytes"
"net/url"
"testing"
"time"

paho "github.com/eclipse/paho.mqtt.golang"
)

func TestIntegration_PublishSubscribe(t *testing.T) {
opts := paho.NewClientOptions()
server, err := url.Parse("mqtt://localhost:1883")
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
opts.Servers = []*url.URL{server}
opts.AutoReconnect = false
opts.ClientID = "PahoWrapper"
opts.KeepAlive = 0

cli := NewClient(opts)
token := cli.Connect()
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Connect timeout")
}
msg := make(chan paho.Message, 100)
token = cli.Subscribe("paho", 1, func(c paho.Client, m paho.Message) {
msg <- m
})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Subscribe timeout")
}
token = cli.Publish("paho", 1, false, []byte{0x12})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Publish timeout")
}

if !cli.IsConnected() {
t.Error("Not connected")
}
if !cli.IsConnectionOpen() {
t.Error("Not connection open")
}

select {
case m := <-msg:
if m.Topic() != "paho" {
t.Errorf("Expected topic: 'topic', got: %s", m.Topic())
}
if !bytes.Equal(m.Payload(), []byte{0x12}) {
t.Errorf("Expected payload: [18], got: %v", m.Payload())
}
case <-time.After(5 * time.Second):
t.Fatal("Message timeout")
}
cli.Disconnect(10)
time.Sleep(time.Second)

if cli.IsConnected() {
t.Error("Connected after disconnect")
}
if cli.IsConnectionOpen() {
t.Error("Connection open after disconnect")
}
}
Loading

0 comments on commit f031e20

Please sign in to comment.