diff --git a/paho/paho.go b/paho/paho.go index bc38d43..f986a75 100644 --- a/paho/paho.go +++ b/paho/paho.go @@ -27,7 +27,8 @@ import ( var errNotConnected = errors.New("not connected") type pahoWrapper struct { - cli mqtt.ClientCloser + cli mqtt.Client + cliCloser mqtt.Closer serveMux *mqtt.ServeMux pahoConfig *paho.ClientOptions mu sync.Mutex @@ -43,17 +44,20 @@ func NewClient(o *paho.ClientOptions) paho.Client { if len(o.Servers) != 1 { panic("unsupported number of servers") } - if o.AutoReconnect { - panic("paho style auto-reconnect is not supported") - } return w } func (c *pahoWrapper) IsConnected() bool { + c.mu.Lock() + cli := c.cliCloser + c.mu.Unlock() + if cli == nil { + return false + } select { - case <-c.cli.Done(): - if c.cli.Err() == nil { + case <-cli.Done(): + if cli.Err() == nil { return true } default: @@ -63,8 +67,14 @@ func (c *pahoWrapper) IsConnected() bool { } func (c *pahoWrapper) IsConnectionOpen() bool { + c.mu.Lock() + cli := c.cliCloser + c.mu.Unlock() + if cli == nil { + return false + } select { - case <-c.cli.Done(): + case <-cli.Done(): default: return true } @@ -72,70 +82,135 @@ func (c *pahoWrapper) IsConnectionOpen() bool { } func (c *pahoWrapper) Connect() paho.Token { + opts := []mqtt.ConnectOption{ + mqtt.WithUserNamePassword(c.pahoConfig.Username, c.pahoConfig.Password), + mqtt.WithCleanSession(c.pahoConfig.CleanSession), + mqtt.WithKeepAlive(uint16(c.pahoConfig.KeepAlive)), + } + if c.pahoConfig.ProtocolVersion > 0 { + opts = append(opts, + mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)), + ) + } + if c.pahoConfig.WillEnabled { + opts = append(opts, mqtt.WithWill(&mqtt.Message{ + Topic: c.pahoConfig.WillTopic, + Payload: c.pahoConfig.WillPayload, + QoS: mqtt.QoS(c.pahoConfig.WillQos), + Retain: c.pahoConfig.WillRetained, + })) + } + if c.pahoConfig.AutoReconnect { + return c.connectRetry(opts) + } + return c.connectOnce(opts) +} + +func (c *pahoWrapper) connectRetry(opts []mqtt.ConnectOption) paho.Token { token := newToken() go func() { - cli, err := mqtt.Dial( - c.pahoConfig.Servers[0].String(), - mqtt.WithTLSConfig(c.pahoConfig.TLSConfig), + pingInterval := time.Duration(c.pahoConfig.KeepAlive) * time.Second + + cli, err := mqtt.NewReconnectClient(context.Background(), + mqtt.DialerFunc(func() (mqtt.ClientCloser, error) { + cb, err := mqtt.Dial(c.pahoConfig.Servers[0].String(), + mqtt.WithTLSConfig(c.pahoConfig.TLSConfig), + ) + if err != nil { + return nil, err + } + cb.ConnState = func(s mqtt.ConnState, err error) { + switch s { + case mqtt.StateActive: + if c.pahoConfig.OnConnect != nil { + c.pahoConfig.OnConnect(c) + } + case mqtt.StateClosed: + if c.pahoConfig.OnConnectionLost != nil { + c.pahoConfig.OnConnectionLost(c, err) + } + } + } + c.mu.Lock() + c.cliCloser = cb + c.mu.Unlock() + return cb, err + }), + c.pahoConfig.ClientID, + mqtt.WithConnectOption(opts...), + mqtt.WithPingInterval(pingInterval), + mqtt.WithTimeout(c.pahoConfig.PingTimeout), + mqtt.WithReconnectWait( + time.Second, // c.pahoConfig.ConnectRetryInterval, + 10*time.Second, // c.pahoConfig.MaxReconnectInterval, + ), ) if err != nil { token.err = err token.release() return } - cli.ConnState = func(s mqtt.ConnState, err error) { - switch s { - case mqtt.StateActive: - if c.pahoConfig.OnConnect != nil { - c.pahoConfig.OnConnect(c) - } - case mqtt.StateClosed: - if c.pahoConfig.OnConnectionLost != nil { - c.pahoConfig.OnConnectionLost(c, err) - } - } - } cli.Handle(c.serveMux) c.mu.Lock() c.cli = cli c.mu.Unlock() - opts := []mqtt.ConnectOption{ - mqtt.WithUserNamePassword(c.pahoConfig.Username, c.pahoConfig.Password), - mqtt.WithCleanSession(c.pahoConfig.CleanSession), - mqtt.WithKeepAlive(uint16(c.pahoConfig.KeepAlive)), - } - if c.pahoConfig.ProtocolVersion > 0 { - opts = append(opts, - mqtt.WithProtocolLevel(mqtt.ProtocolLevel(c.pahoConfig.ProtocolVersion)), + token.release() + }() + return token +} + +func (c *pahoWrapper) connectOnce(opts []mqtt.ConnectOption) paho.Token { + token := newToken() + go func() { + for { // Connect retry loop + cli, err := mqtt.Dial( + c.pahoConfig.Servers[0].String(), + mqtt.WithTLSConfig(c.pahoConfig.TLSConfig), ) - } - if c.pahoConfig.WillEnabled { - opts = append(opts, mqtt.WithWill(&mqtt.Message{ - Topic: c.pahoConfig.WillTopic, - Payload: c.pahoConfig.WillPayload, - QoS: mqtt.QoS(c.pahoConfig.WillQos), - Retain: c.pahoConfig.WillRetained, - })) - } - _, token.err = c.cli.Connect(context.Background(), c.pahoConfig.ClientID, opts...) - if token.err == nil { - if c.pahoConfig.KeepAlive > 0 { - // Start keep alive. - go func() { - timeout := c.pahoConfig.PingTimeout - if timeout < time.Second { - timeout = time.Second + if err != nil { + // if c.pahoConfig.ConnectRetry { + // time.Sleep(c.pahoConfig.ConnectRetryInterval) + // continue + // } + token.err = err + token.release() + return + } + cli.ConnState = func(s mqtt.ConnState, err error) { + switch s { + case mqtt.StateActive: + if c.pahoConfig.OnConnect != nil { + c.pahoConfig.OnConnect(c) + } + case mqtt.StateClosed: + if c.pahoConfig.OnConnectionLost != nil { + c.pahoConfig.OnConnectionLost(c, err) } - _ = mqtt.KeepAlive( - context.Background(), cli, - time.Duration(c.pahoConfig.KeepAlive)*time.Second, - timeout, - ) - }() + } + } + cli.Handle(c.serveMux) + c.mu.Lock() + c.cli = cli + c.cliCloser = cli + c.mu.Unlock() + + _, token.err = c.cli.Connect(context.Background(), c.pahoConfig.ClientID, opts...) + if token.err == nil { + if c.pahoConfig.KeepAlive > 0 { + // Start keep alive. + go func() { + _ = mqtt.KeepAlive( + context.Background(), cli, + time.Duration(c.pahoConfig.KeepAlive)*time.Second, + c.pahoConfig.PingTimeout, + ) + }() + } } + token.release() + return } - token.release() }() return token } diff --git a/paho/paho_integration_test.go b/paho/paho_integration_test.go index f46228e..97e2eb4 100644 --- a/paho/paho_integration_test.go +++ b/paho/paho_integration_test.go @@ -26,58 +26,63 @@ import ( ) func TestIntegration_PublishSubscribe(t *testing.T) { - opts := paho.NewClientOptions() - server, err := url.Parse("mqtt://localhost:1883") - if err != nil { - t.Fatalf("Unexpected error: '%v'", err) - } - opts.Servers = []*url.URL{server} - opts.AutoReconnect = false - opts.ClientID = "PahoWrapper" - opts.KeepAlive = 0 + for name, recon := range map[string]bool{"Reconnect": true, "NoReconnect": false} { + t.Run(name, func(t *testing.T) { + opts := paho.NewClientOptions() + server, err := url.Parse("mqtt://localhost:1883") + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + opts.Servers = []*url.URL{server} + opts.AutoReconnect = recon + opts.ClientID = "PahoWrapper" + opts.KeepAlive = 0 - cli := NewClient(opts) - token := cli.Connect() - if !token.WaitTimeout(5 * time.Second) { - t.Fatal("Connect timeout") - } - msg := make(chan paho.Message, 100) - token = cli.Subscribe("paho", 1, func(c paho.Client, m paho.Message) { - msg <- m - }) - if !token.WaitTimeout(5 * time.Second) { - t.Fatal("Subscribe timeout") - } - token = cli.Publish("paho", 1, false, []byte{0x12}) - if !token.WaitTimeout(5 * time.Second) { - t.Fatal("Publish timeout") - } + cli := NewClient(opts) + token := cli.Connect() + if !token.WaitTimeout(5 * time.Second) { + t.Fatal("Connect timeout") + } - if !cli.IsConnected() { - t.Error("Not connected") - } - if !cli.IsConnectionOpen() { - t.Error("Not connection open") - } + msg := make(chan paho.Message, 100) + token = cli.Subscribe("paho"+name, 1, func(c paho.Client, m paho.Message) { + msg <- m + }) + if !token.WaitTimeout(5 * time.Second) { + t.Fatal("Subscribe timeout") + } + token = cli.Publish("paho"+name, 1, false, []byte{0x12}) + if !token.WaitTimeout(5 * time.Second) { + t.Fatal("Publish timeout") + } - select { - case m := <-msg: - if m.Topic() != "paho" { - t.Errorf("Expected topic: 'topic', got: %s", m.Topic()) - } - if !bytes.Equal(m.Payload(), []byte{0x12}) { - t.Errorf("Expected payload: [18], got: %v", m.Payload()) - } - case <-time.After(5 * time.Second): - t.Fatal("Message timeout") - } - cli.Disconnect(10) - time.Sleep(time.Second) + if !cli.IsConnected() { + t.Error("Not connected") + } + if !cli.IsConnectionOpen() { + t.Error("Not connection open") + } - if cli.IsConnected() { - t.Error("Connected after disconnect") - } - if cli.IsConnectionOpen() { - t.Error("Connection open after disconnect") + select { + case m := <-msg: + if m.Topic() != "paho"+name { + t.Errorf("Expected topic: 'topic%s', got: %s", name, m.Topic()) + } + if !bytes.Equal(m.Payload(), []byte{0x12}) { + t.Errorf("Expected payload: [18], got: %v", m.Payload()) + } + case <-time.After(5 * time.Second): + t.Errorf("Message timeout") + } + cli.Disconnect(10) + time.Sleep(time.Second) + + if cli.IsConnected() { + t.Error("Connected after disconnect") + } + if cli.IsConnectionOpen() { + t.Error("Connection open after disconnect") + } + }) } }