Skip to content

Commit

Permalink
Add DialOption to set ConnState handler (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Jan 6, 2020
1 parent 7975b1a commit db962e1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 54 deletions.
46 changes: 23 additions & 23 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,22 @@ func TestIntegration_PublishSubscribe(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, qos := range []QoS{QoS0, QoS1, QoS2} {
t.Run(fmt.Sprintf("QoS%d", int(qos)), func(t *testing.T) {
cli, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
chReceived := make(chan *Message, 100)

cli, err := Dial(url,
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
WithConnStateHandler(func(s ConnState, err error) {
switch s {
case StateClosed:
close(chReceived)
t.Errorf("Connection is expected to be disconnected, but closed.")
}
}),
)
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

chReceived := make(chan *Message, 100)
cli.ConnState = func(s ConnState, err error) {
switch s {
case StateActive:
case StateClosed:
close(chReceived)
t.Errorf("Connection is expected to be disconnected, but closed.")
case StateDisconnected:
}
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if _, err := cli.Connect(ctx, "PubSubClient"+name); err != nil {
Expand Down Expand Up @@ -254,21 +254,21 @@ func TestIntegration_Ping(t *testing.T) {
func BenchmarkPublishSubscribe(b *testing.B) {
for name, url := range urls {
b.Run(name, func(b *testing.B) {
cli, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
chReceived := make(chan *Message, 100)

cli, err := Dial(url,
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
WithConnStateHandler(func(s ConnState, err error) {
switch s {
case StateClosed:
close(chReceived)
}
}),
)
if err != nil {
b.Fatalf("Unexpected error: '%v'", err)
}

chReceived := make(chan *Message, 100)
cli.ConnState = func(s ConnState, err error) {
switch s {
case StateActive:
case StateClosed:
close(chReceived)
case StateDisconnected:
}
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := cli.Connect(ctx, "PubSubBenchClient"+name); err != nil {
Expand Down
13 changes: 12 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type DialOption func(*DialOptions) error
type DialOptions struct {
Dialer *net.Dialer
TLSConfig *tls.Config
ConnState func(ConnState, error)
}

// WithDialer sets dialer.
Expand All @@ -92,8 +93,18 @@ func WithTLSConfig(config *tls.Config) DialOption {
}
}

// WithConnStateHandler sets connection state change handler.
func WithConnStateHandler(handler func(ConnState, error)) DialOption {
return func(o *DialOptions) error {
o.ConnState = handler
return nil
}
}

func (d *DialOptions) dial(urlStr string) (*BaseClient, error) {
c := &BaseClient{}
c := &BaseClient{
ConnState: d.ConnState,
}

u, err := url.Parse(urlStr)
if err != nil {
Expand Down
30 changes: 9 additions & 21 deletions examples/mqtts-client-cert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,16 @@ func main() {

cli, err := mqtt.NewReconnectClient(
// Dialer to connect/reconnect to the server.
mqtt.DialerFunc(func() (mqtt.ClientCloser, error) {
cli, err := mqtt.Dial(
fmt.Sprintf("mqtts://%s:8883", host),
&mqtt.URLDialer{
URL: fmt.Sprintf("mqtts://%s:8883", host),
Options: []mqtt.DialOption{
mqtt.WithTLSConfig(tlsConfig),
)
if err != nil {
return nil, err
}
// Register ConnState callback to low level client
cli.ConnState = func(s mqtt.ConnState, err error) {
fmt.Printf("State changed to %s (err: %v)\n", s, err)
}
return cli, nil
}),
// If you don't need customized (with state callback) low layer client,
// just use mqtt.URLDialer:
// &mqtt.URLDialer{
// URL: fmt.Sprintf("mqtts://%s:8883", host),
// Options: []mqtt.DialOption{
// mqtt.WithTLSConfig(tlsConfig),
// },
// },
mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) {
// Register ConnState callback to low level client
fmt.Printf("State changed to %s (err: %v)\n", s, err)
}),
},
},
mqtt.WithPingInterval(10*time.Second),
mqtt.WithTimeout(5*time.Second),
mqtt.WithReconnectWait(1*time.Second, 15*time.Second),
Expand Down
13 changes: 4 additions & 9 deletions examples/wss-presign-url/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,12 @@ func main() {
host, time.Now().UnixNano(),
)
println("New URL:", url)
cli, err := mqtt.Dial(url,
return mqtt.Dial(url,
mqtt.WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) {
fmt.Printf("State changed to %s (err: %v)\n", s, err)
}),
)
if err != nil {
return nil, err
}
// Register ConnState callback to low level client
cli.ConnState = func(s mqtt.ConnState, err error) {
fmt.Printf("State changed to %s (err: %v)\n", s, err)
}
return cli, nil
}),
mqtt.WithPingInterval(10*time.Second),
mqtt.WithTimeout(5*time.Second),
Expand Down

0 comments on commit db962e1

Please sign in to comment.