From 41f1d05ab82f7d7d613cc6ae445117ac45f2c5c8 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Tue, 8 Jun 2021 18:02:33 +0900 Subject: [PATCH] Avoid RetryClient task processing before base client connect (#174) * Add test cases * Pause RetryClient task runner loop until the base client is connected * Stop task runner on Disconnect * Fix to use RLock as possible --- retryclient.go | 74 +++++++++-- retryclient_integration_test.go | 215 ++++++++++++++++++++++++-------- 2 files changed, 224 insertions(+), 65 deletions(-) diff --git a/retryclient.go b/retryclient.go index 50e99cd..e9b23bc 100644 --- a/retryclient.go +++ b/retryclient.go @@ -25,11 +25,13 @@ var ErrClosedClient = errors.New("operation on closed client") // RetryClient queues unacknowledged messages and retry on reconnect. type RetryClient struct { - cli *BaseClient + cli *BaseClient + chConnectErr chan error + chConnSwitch chan struct{} retryQueue []retryFn subEstablished subscriptions // acknoledged subscriptions - mu sync.Mutex + mu sync.RWMutex handler Handler chTask chan struct{} taskQueue []func(ctx context.Context, cli *BaseClient) @@ -48,9 +50,9 @@ func (c *RetryClient) Handle(handler Handler) { // Publish tries to publish the message and immediately returns. // If it is not acknowledged to be published, the message will be queued. func (c *RetryClient) Publish(ctx context.Context, message *Message) error { - c.mu.Lock() + c.mu.RLock() cli := c.cli - c.mu.Unlock() + c.mu.RUnlock() if cli != nil { if err := cli.ValidateMessage(message); err != nil { @@ -170,21 +172,24 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics . func (c *RetryClient) Disconnect(ctx context.Context) error { return wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) { cli.Disconnect(ctx) + c.mu.Lock() + close(c.chTask) + c.mu.Unlock() }), "retryclient: disconnecting") } // Ping to the broker. func (c *RetryClient) Ping(ctx context.Context) error { - c.mu.Lock() + c.mu.RLock() cli := c.cli - c.mu.Unlock() + c.mu.RUnlock() return wrapError(cli.Ping(ctx), "retryclient: pinging") } // Client returns the base client. func (c *RetryClient) Client() *BaseClient { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.cli } @@ -194,6 +199,11 @@ func (c *RetryClient) Client() *BaseClient { func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) { c.mu.Lock() c.cli = cli + c.chConnectErr = make(chan error, 1) + if c.chConnSwitch != nil { + close(c.chConnSwitch) + } + c.chConnSwitch = make(chan struct{}) c.mu.Unlock() if c.chTask != nil { @@ -202,20 +212,55 @@ func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) { c.chTask = make(chan struct{}, 1) go func() { + connected := false ctx := context.Background() + + L_TASK: for { + if !connected { + // Wait Connect if Client was replaced by SetClient. + for { + c.mu.RLock() + chConnectErr := c.chConnectErr + chConnSwitch := c.chConnSwitch + c.mu.RUnlock() + select { + case _, ok := <-chConnectErr: + if !ok { + connected = true + continue L_TASK + } + case <-chConnSwitch: + } + } + } + c.mu.Lock() + chConnSwitch := c.chConnSwitch + select { + case <-chConnSwitch: + c.mu.Unlock() + connected = false + continue + default: + } + if len(c.taskQueue) == 0 { c.mu.Unlock() - _, ok := <-c.chTask - if !ok { - return + + select { + case _, ok := <-c.chTask: + if !ok { + return + } + case <-chConnSwitch: + connected = false } continue } + cli := c.cli task := c.taskQueue[0] c.taskQueue = c.taskQueue[1:] - cli := c.cli c.mu.Unlock() task(ctx, cli) @@ -248,9 +293,14 @@ func (c *RetryClient) Connect(ctx context.Context, clientID string, opts ...Conn c.mu.Lock() cli := c.cli cli.Handle(c.handler) + chConnectErr := c.chConnectErr c.mu.Unlock() present, err := cli.Connect(ctx, clientID, opts...) + if err != nil { + chConnectErr <- err + } + close(chConnectErr) return present, wrapError(err, "retryclient: connecting") } diff --git a/retryclient_integration_test.go b/retryclient_integration_test.go index f4e0098..adeae9e 100644 --- a/retryclient_integration_test.go +++ b/retryclient_integration_test.go @@ -19,8 +19,11 @@ package mqtt import ( "context" "crypto/tls" + "sync/atomic" "testing" "time" + + "github.com/at-wat/mqtt-go/internal/filteredpipe" ) func TestIntegration_RetryClient(t *testing.T) { @@ -118,69 +121,175 @@ func TestIntegration_RetryClient_Cancel(t *testing.T) { } func TestIntegration_RetryClient_TaskQueue(t *testing.T) { - cliBase, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) - if err != nil { - t.Fatalf("Unexpected error: '%v'", err) + type pubTiming string + const ( + pubBeforeSetClient pubTiming = "BeforeSetClient" + pubBeforeConnect pubTiming = "BeforeConnect" + pubAfterConnect pubTiming = "AfterConnect" + ) + pubTimings := []pubTiming{ + pubBeforeSetClient, pubBeforeConnect, pubAfterConnect, } - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() + for _, withWait := range []bool{true, false} { + name := "WithoutWait" + if withWait { + name = "WithWait" + } + withWait := withWait + t.Run(name, func(t *testing.T) { + for _, pubAt := range pubTimings { + pubAt := pubAt + t.Run(string(pubAt), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ctxDone, done := context.WithCancel(context.Background()) + defer done() - var cli RetryClient - cli.SetClient(ctx, cliBase) + var cnt int + const expectedCount = 100 - if _, err := cli.Connect(ctx, "RetryClientQueue"); err != nil { - t.Fatalf("Unexpected error: '%v'", err) - } + cliRecv, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + if _, err := cliRecv.Connect(ctx, "RetryClientQueueRecv"); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } - ctxDone, done := context.WithCancel(context.Background()) - defer done() - - var cnt int - cli.Handle(HandlerFunc(func(msg *Message) { - if err := cli.Publish(ctx, &Message{ - Topic: "test/queue_response", - QoS: QoS1, - Payload: []byte("message"), - }); err != nil { - t.Errorf("Unexpected error: '%v'", err) - return - } - cnt++ - if cnt == 100 { - done() - } - })) - if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil { - t.Fatal(err) + if _, err := cliRecv.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil { + t.Fatal(err) + } + cliRecv.Handle(HandlerFunc(func(*Message) { + cnt++ + if cnt == expectedCount { + done() + } + })) + + cliBase, err := Dial(urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + var cli RetryClient + publish := func() { + for i := 0; i < expectedCount; i++ { + if err := cli.Publish(ctx, &Message{ + Topic: "test/queue", + QoS: QoS1, + Payload: []byte("message"), + }); err != nil { + t.Errorf("Unexpected error: '%v' (cnt=%d)", err, cnt) + return + } + select { + case <-ctx.Done(): + t.Errorf("Timeout (cnt=%d)", cnt) + default: + } + } + } + + if pubAt == pubBeforeSetClient { + publish() + } + if withWait { + time.Sleep(50 * time.Millisecond) + } + cli.SetClient(ctx, cliBase) + + if withWait { + time.Sleep(50 * time.Millisecond) + } + // Ensure there is no deadlock when SetClient before Connect. + cli.SetClient(ctx, cliBase) + + if pubAt == pubBeforeConnect { + publish() + } + if withWait { + time.Sleep(50 * time.Millisecond) + } + + if _, err := cli.Connect(ctx, "RetryClientQueue"); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + if pubAt == pubAfterConnect { + publish() + } + + select { + case <-ctx.Done(): + t.Errorf("Timeout (cnt=%d)", cnt) + case <-ctxDone.Done(): + } + + if err := cli.Disconnect(ctx); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + }) + } + }) } - time.Sleep(10 * time.Millisecond) +} - func() { - for i := 0; i < 100; i++ { - if err := cli.Publish(ctx, &Message{ - Topic: "test/queue", - QoS: QoS1, - Payload: []byte("message"), - }); err != nil { - t.Errorf("Unexpected error: '%v' (cnt=%d)", err, cnt) - return +func TestIntegration_RetryClient_RetryInitialRequest(t *testing.T) { + for name, url := range urls { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + topic := "test/RetryInitialReq" + name + var sw int32 + + cli, err := NewReconnectClient( + DialerFunc(func() (*BaseClient, error) { + cli, err := Dial(url, + WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + ) + if err != nil { + return nil, err + } + ca, cb := filteredpipe.DetectAndClosePipe( + newOnOffFilter(&sw), + newOnOffFilter(&sw), + ) + filteredpipe.Connect(ca, cli.Transport) + cli.Transport = cb + return cli, nil + }), + WithReconnectWait(50*time.Millisecond, 200*time.Millisecond), + WithPingInterval(250*time.Millisecond), + WithTimeout(250*time.Millisecond), + ) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) } - select { - case <-ctx.Done(): - t.Errorf("Timeout (cnt=%d)", cnt) - default: + + if _, err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: QoS1}); err != nil { + t.Fatal(err) } - } - }() + time.Sleep(100 * time.Millisecond) - select { - case <-ctx.Done(): - t.Errorf("Timeout (cnt=%d)", cnt) - case <-ctxDone.Done(): - } + // Disconnect + atomic.StoreInt32(&sw, 1) + go func() { + time.Sleep(300 * time.Millisecond) + // Connect + atomic.StoreInt32(&sw, 0) + }() - if err := cli.Disconnect(ctx); err != nil { - t.Fatalf("Unexpected error: '%v'", err) + if _, err := cli.Connect(ctx, "RetryInitialReq"+name); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + if err := ctx.Err(); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + cli.Disconnect(ctx) + }) } }