Skip to content

Commit

Permalink
Fix RetryClient deadlock (#118)
Browse files Browse the repository at this point in the history
* Add test to reproduce deadlock
* Fix RetryClient task queue
* Revert examples not to use ServeAsync
  • Loading branch information
at-wat authored May 30, 2020
1 parent 0a7623d commit 218a69c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 21 deletions.
7 changes: 1 addition & 6 deletions examples/mqtts-client-cert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ func main() {
}

mux := &mqtt.ServeMux{} // Multiplex message handlers by topic name.

// Register mux as a low-level handler.
// Wrap by ServeAsync to call handler in a new goroutine.
// note: Default handler processes messages in serial.
// It causes deadlock if QoS>=1 message is published in QoS>=1 message handler.
cli.Handle(&mqtt.ServeAsync{Handler: mux})
cli.Handle(mux) // Register mux as a low-level handler.

mux.Handle("#", // Handle all topics by this handler.
mqtt.HandlerFunc(func(msg *mqtt.Message) {
Expand Down
7 changes: 1 addition & 6 deletions examples/wss-presign-url/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ func main() {
}

mux := &mqtt.ServeMux{} // Multiplex message handlers by topic name.

// Register mux as a low-level handler.
// Wrap by ServeAsync to call handler in a new goroutine.
// note: Default handler processes messages in serial.
// It causes deadlock if QoS>=1 message is published in QoS>=1 message handler.
cli.Handle(&mqtt.ServeAsync{Handler: mux})
cli.Handle(mux) // Register mux as a low-level handler.

mux.Handle("#", // Handle all topics by this handler.
mqtt.HandlerFunc(func(msg *mqtt.Message) {
Expand Down
41 changes: 32 additions & 9 deletions retryclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ package mqtt

import (
"context"
"errors"
"sync"
)

// ErrClosedClient means operation was requested on closed client.
var ErrClosedClient = errors.New("operation on closed client")

// RetryClient queues unacknowledged messages and retry on reconnect.
type RetryClient struct {
cli ClientCloser
Expand All @@ -30,7 +34,8 @@ type RetryClient struct {
mu sync.Mutex
muQueue sync.Mutex
handler Handler
chTask chan func(ctx context.Context, cli Client)
chTask chan struct{}
taskQueue []func(ctx context.Context, cli Client)
}

// Handle registers the message handler.
Expand Down Expand Up @@ -156,7 +161,6 @@ func (c *RetryClient) Disconnect(ctx context.Context) error {
err := c.pushTask(ctx, func(ctx context.Context, cli Client) {
cli.Disconnect(ctx)
})
close(c.chTask)
return err
}

Expand All @@ -179,26 +183,45 @@ func (c *RetryClient) SetClient(ctx context.Context, cli ClientCloser) {
return
}

c.chTask = make(chan func(ctx context.Context, cli Client))
c.chTask = make(chan struct{}, 1)
go func() {
ctx := context.Background()
for task := range c.chTask {
for {
c.mu.Lock()
if len(c.taskQueue) == 0 {
c.mu.Unlock()
_, ok := <-c.chTask
if !ok {
return
}
continue
}
task := c.taskQueue[0]
c.taskQueue = c.taskQueue[1:]
cli := c.cli
c.mu.Unlock()

task(ctx, cli)
}
}()
}

func (c *RetryClient) pushTask(ctx context.Context, task func(ctx context.Context, cli Client)) error {
c.mu.Lock()
chTask := c.chTask
c.mu.Unlock()
defer c.mu.Unlock()

select {
case _, ok := <-c.chTask:
if !ok {
return ErrClosedClient
}
default:
}

c.taskQueue = append(c.taskQueue, task)
select {
case <-ctx.Done():
return ctx.Err()
case chTask <- task:
case c.chTask <- struct{}{}:
default:
}
return nil
}
Expand Down
68 changes: 68 additions & 0 deletions retryclient_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,71 @@ func TestIntegration_RetryClient_Cancel(t *testing.T) {
t.Fatalf("Unexpected error: '%v'", err)
}
}

func TestIntegration_RetryClient_TaskQueue(t *testing.T) {
cliBase, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

var cli RetryClient
cli.SetClient(ctx, cliBase)

if _, err := cli.Connect(ctx, "RetryClientQueue"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

ctxDone, done := context.WithCancel(context.Background())
defer done()

var cnt int
cli.Handle(HandlerFunc(func(msg *Message) {
if err := cli.Publish(ctx, &Message{
Topic: "test/queue_response",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Errorf("Unexpected error: '%v'", err)
return
}
cnt++
if cnt == 100 {
done()
}
}))
if err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond)

func() {
for i := 0; i < 100; i++ {
if err := cli.Publish(ctx, &Message{
Topic: "test/queue",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Errorf("Unexpected error: '%v' (cnt=%d)", err, cnt)
return
}
select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
default:
}
}
}()

select {
case <-ctx.Done():
t.Errorf("Timeout (cnt=%d)", cnt)
case <-ctxDone.Done():
}

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
}

0 comments on commit 218a69c

Please sign in to comment.