Skip to content

Commit

Permalink
Support reconnect in paho wrapper (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Dec 24, 2019
1 parent f031e20 commit 5f1371a
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 103 deletions.
183 changes: 129 additions & 54 deletions paho/paho.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import (
var errNotConnected = errors.New("not connected")

type pahoWrapper struct {
cli mqtt.ClientCloser
cli mqtt.Client
cliCloser mqtt.Closer
serveMux *mqtt.ServeMux
pahoConfig *paho.ClientOptions
mu sync.Mutex
Expand All @@ -43,17 +44,20 @@ func NewClient(o *paho.ClientOptions) paho.Client {
if len(o.Servers) != 1 {
panic("unsupported number of servers")
}
if o.AutoReconnect {
panic("paho style auto-reconnect is not supported")
}

return w
}

func (c *pahoWrapper) IsConnected() bool {
c.mu.Lock()
cli := c.cliCloser
c.mu.Unlock()
if cli == nil {
return false
}
select {
case <-c.cli.Done():
if c.cli.Err() == nil {
case <-cli.Done():
if cli.Err() == nil {
return true
}
default:
Expand All @@ -63,79 +67,150 @@ func (c *pahoWrapper) IsConnected() bool {
}

func (c *pahoWrapper) IsConnectionOpen() bool {
c.mu.Lock()
cli := c.cliCloser
c.mu.Unlock()
if cli == nil {
return false
}
select {
case <-c.cli.Done():
case <-cli.Done():
default:
return true
}
return false
}

func (c *pahoWrapper) Connect() paho.Token {
opts := []mqtt.ConnectOption{
mqtt.WithUserNamePassword(c.pahoConfig.Username, c.pahoConfig.Password),
mqtt.WithCleanSession(c.pahoConfig.CleanSession),
mqtt.WithKeepAlive(uint16(c.pahoConfig.KeepAlive)),
}
if c.pahoConfig.ProtocolVersion > 0 {
opts = append(opts,
mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)),
)
}
if c.pahoConfig.WillEnabled {
opts = append(opts, mqtt.WithWill(&mqtt.Message{
Topic: c.pahoConfig.WillTopic,
Payload: c.pahoConfig.WillPayload,
QoS: mqtt.QoS(c.pahoConfig.WillQos),
Retain: c.pahoConfig.WillRetained,
}))
}
if c.pahoConfig.AutoReconnect {
return c.connectRetry(opts)
}
return c.connectOnce(opts)
}

func (c *pahoWrapper) connectRetry(opts []mqtt.ConnectOption) paho.Token {
token := newToken()
go func() {
cli, err := mqtt.Dial(
c.pahoConfig.Servers[0].String(),
mqtt.WithTLSConfig(c.pahoConfig.TLSConfig),
pingInterval := time.Duration(c.pahoConfig.KeepAlive) * time.Second

cli, err := mqtt.NewReconnectClient(context.Background(),
mqtt.DialerFunc(func() (mqtt.ClientCloser, error) {
cb, err := mqtt.Dial(c.pahoConfig.Servers[0].String(),
mqtt.WithTLSConfig(c.pahoConfig.TLSConfig),
)
if err != nil {
return nil, err
}
cb.ConnState = func(s mqtt.ConnState, err error) {
switch s {
case mqtt.StateActive:
if c.pahoConfig.OnConnect != nil {
c.pahoConfig.OnConnect(c)
}
case mqtt.StateClosed:
if c.pahoConfig.OnConnectionLost != nil {
c.pahoConfig.OnConnectionLost(c, err)
}
}
}
c.mu.Lock()
c.cliCloser = cb
c.mu.Unlock()
return cb, err
}),
c.pahoConfig.ClientID,
mqtt.WithConnectOption(opts...),
mqtt.WithPingInterval(pingInterval),
mqtt.WithTimeout(c.pahoConfig.PingTimeout),
mqtt.WithReconnectWait(
time.Second, // c.pahoConfig.ConnectRetryInterval,
10*time.Second, // c.pahoConfig.MaxReconnectInterval,
),
)
if err != nil {
token.err = err
token.release()
return
}
cli.ConnState = func(s mqtt.ConnState, err error) {
switch s {
case mqtt.StateActive:
if c.pahoConfig.OnConnect != nil {
c.pahoConfig.OnConnect(c)
}
case mqtt.StateClosed:
if c.pahoConfig.OnConnectionLost != nil {
c.pahoConfig.OnConnectionLost(c, err)
}
}
}
cli.Handle(c.serveMux)
c.mu.Lock()
c.cli = cli
c.mu.Unlock()

opts := []mqtt.ConnectOption{
mqtt.WithUserNamePassword(c.pahoConfig.Username, c.pahoConfig.Password),
mqtt.WithCleanSession(c.pahoConfig.CleanSession),
mqtt.WithKeepAlive(uint16(c.pahoConfig.KeepAlive)),
}
if c.pahoConfig.ProtocolVersion > 0 {
opts = append(opts,
mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)),
token.release()
}()
return token
}

func (c *pahoWrapper) connectOnce(opts []mqtt.ConnectOption) paho.Token {
token := newToken()
go func() {
for { // Connect retry loop
cli, err := mqtt.Dial(
c.pahoConfig.Servers[0].String(),
mqtt.WithTLSConfig(c.pahoConfig.TLSConfig),
)
}
if c.pahoConfig.WillEnabled {
opts = append(opts, mqtt.WithWill(&mqtt.Message{
Topic: c.pahoConfig.WillTopic,
Payload: c.pahoConfig.WillPayload,
QoS: mqtt.QoS(c.pahoConfig.WillQos),
Retain: c.pahoConfig.WillRetained,
}))
}
_, token.err = c.cli.Connect(context.Background(), c.pahoConfig.ClientID, opts...)
if token.err == nil {
if c.pahoConfig.KeepAlive > 0 {
// Start keep alive.
go func() {
timeout := c.pahoConfig.PingTimeout
if timeout < time.Second {
timeout = time.Second
if err != nil {
// if c.pahoConfig.ConnectRetry {
// time.Sleep(c.pahoConfig.ConnectRetryInterval)
// continue
// }
token.err = err
token.release()
return
}
cli.ConnState = func(s mqtt.ConnState, err error) {
switch s {
case mqtt.StateActive:
if c.pahoConfig.OnConnect != nil {
c.pahoConfig.OnConnect(c)
}
case mqtt.StateClosed:
if c.pahoConfig.OnConnectionLost != nil {
c.pahoConfig.OnConnectionLost(c, err)
}
_ = mqtt.KeepAlive(
context.Background(), cli,
time.Duration(c.pahoConfig.KeepAlive)*time.Second,
timeout,
)
}()
}
}
cli.Handle(c.serveMux)
c.mu.Lock()
c.cli = cli
c.cliCloser = cli
c.mu.Unlock()

_, token.err = c.cli.Connect(context.Background(), c.pahoConfig.ClientID, opts...)
if token.err == nil {
if c.pahoConfig.KeepAlive > 0 {
// Start keep alive.
go func() {
_ = mqtt.KeepAlive(
context.Background(), cli,
time.Duration(c.pahoConfig.KeepAlive)*time.Second,
c.pahoConfig.PingTimeout,
)
}()
}
}
token.release()
return
}
token.release()
}()
return token
}
Expand Down
103 changes: 54 additions & 49 deletions paho/paho_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,63 @@ import (
)

func TestIntegration_PublishSubscribe(t *testing.T) {
opts := paho.NewClientOptions()
server, err := url.Parse("mqtt://localhost:1883")
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
opts.Servers = []*url.URL{server}
opts.AutoReconnect = false
opts.ClientID = "PahoWrapper"
opts.KeepAlive = 0
for name, recon := range map[string]bool{"Reconnect": true, "NoReconnect": false} {
t.Run(name, func(t *testing.T) {
opts := paho.NewClientOptions()
server, err := url.Parse("mqtt://localhost:1883")
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
opts.Servers = []*url.URL{server}
opts.AutoReconnect = recon
opts.ClientID = "PahoWrapper"
opts.KeepAlive = 0

cli := NewClient(opts)
token := cli.Connect()
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Connect timeout")
}
msg := make(chan paho.Message, 100)
token = cli.Subscribe("paho", 1, func(c paho.Client, m paho.Message) {
msg <- m
})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Subscribe timeout")
}
token = cli.Publish("paho", 1, false, []byte{0x12})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Publish timeout")
}
cli := NewClient(opts)
token := cli.Connect()
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Connect timeout")
}

if !cli.IsConnected() {
t.Error("Not connected")
}
if !cli.IsConnectionOpen() {
t.Error("Not connection open")
}
msg := make(chan paho.Message, 100)
token = cli.Subscribe("paho"+name, 1, func(c paho.Client, m paho.Message) {
msg <- m
})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Subscribe timeout")
}
token = cli.Publish("paho"+name, 1, false, []byte{0x12})
if !token.WaitTimeout(5 * time.Second) {
t.Fatal("Publish timeout")
}

select {
case m := <-msg:
if m.Topic() != "paho" {
t.Errorf("Expected topic: 'topic', got: %s", m.Topic())
}
if !bytes.Equal(m.Payload(), []byte{0x12}) {
t.Errorf("Expected payload: [18], got: %v", m.Payload())
}
case <-time.After(5 * time.Second):
t.Fatal("Message timeout")
}
cli.Disconnect(10)
time.Sleep(time.Second)
if !cli.IsConnected() {
t.Error("Not connected")
}
if !cli.IsConnectionOpen() {
t.Error("Not connection open")
}

if cli.IsConnected() {
t.Error("Connected after disconnect")
}
if cli.IsConnectionOpen() {
t.Error("Connection open after disconnect")
select {
case m := <-msg:
if m.Topic() != "paho"+name {
t.Errorf("Expected topic: 'topic%s', got: %s", name, m.Topic())
}
if !bytes.Equal(m.Payload(), []byte{0x12}) {
t.Errorf("Expected payload: [18], got: %v", m.Payload())
}
case <-time.After(5 * time.Second):
t.Errorf("Message timeout")
}
cli.Disconnect(10)
time.Sleep(time.Second)

if cli.IsConnected() {
t.Error("Connected after disconnect")
}
if cli.IsConnectionOpen() {
t.Error("Connection open after disconnect")
}
})
}
}

0 comments on commit 5f1371a

Please sign in to comment.