diff --git a/pkg/push/api/api.go b/pkg/push/api/api.go index 140ce78cf..d2f211bd1 100644 --- a/pkg/push/api/api.go +++ b/pkg/push/api/api.go @@ -62,6 +62,7 @@ func WithLogger(l *zap.Logger) Option { type PushService struct { mysqlClient mysql.Client + pushStorage v2ps.PushStorage featureClient featureclient.Client experimentClient experimentclient.Client accountClient accountclient.Client @@ -86,6 +87,7 @@ func NewPushService( } return &PushService{ mysqlClient: mysqlClient, + pushStorage: v2ps.NewPushStorage(mysqlClient), featureClient: featureClient, experimentClient: experimentClient, accountClient: accountClient, @@ -183,26 +185,8 @@ func (s *PushService) CreatePush( } return nil, dt.Err() } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - pushStorage := v2ps.NewPushStorage(tx) - if err := pushStorage.CreatePush(ctx, push, req.EnvironmentId); err != nil { + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + if err := s.pushStorage.CreatePush(contextWithTx, push, req.EnvironmentId); err != nil { return err } handler, err := command.NewPushCommandHandler(editor, push, s.publisher, req.EnvironmentId) @@ -330,26 +314,8 @@ func (s *PushService) createPushNoCommand( } var event *eventproto.Event - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - pushStorage := v2ps.NewPushStorage(tx) - if err := pushStorage.CreatePush(ctx, push, req.EnvironmentId); err != nil { + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + if err := s.pushStorage.CreatePush(contextWithTx, push, req.EnvironmentId); err != nil { return err } prev := &domain.Push{} @@ -506,26 +472,8 @@ func (s *PushService) UpdatePush( var updatedPushPb *pushproto.Push commands := s.createUpdatePushCommands(req) - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - pushStorage := v2ps.NewPushStorage(tx) - push, err := pushStorage.GetPush(ctx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + push, err := s.pushStorage.GetPush(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -539,7 +487,7 @@ func (s *PushService) UpdatePush( } } updatedPushPb = push.Push - return pushStorage.UpdatePush(ctx, push, req.EnvironmentId) + return s.pushStorage.UpdatePush(contextWithTx, push, req.EnvironmentId) }) if err != nil { switch { @@ -599,26 +547,8 @@ func (s *PushService) updatePushNoCommand( } var updatedPushPb *pushproto.Push var updatePushEvent *eventproto.Event - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - pushStorage := v2ps.NewPushStorage(tx) - push, err := pushStorage.GetPush(ctx, req.Id, req.EnvironmentId) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + push, err := s.pushStorage.GetPush(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -648,7 +578,7 @@ func (s *PushService) updatePushNoCommand( } updatedPushPb = updated.Push - return pushStorage.UpdatePush(ctx, updated, req.EnvironmentId) + return s.pushStorage.UpdatePush(contextWithTx, updated, req.EnvironmentId) }) if err != nil { switch { @@ -842,26 +772,8 @@ func (s *PushService) DeletePush( } var event *eventproto.Event - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - pushStorage := v2ps.NewPushStorage(tx) - push, err := pushStorage.GetPush(ctx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + push, err := s.pushStorage.GetPush(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -890,7 +802,7 @@ func (s *PushService) DeletePush( if err = s.publisher.Publish(ctx, event); err != nil { return err } - return pushStorage.UpdatePush(ctx, push, req.EnvironmentId) + return s.pushStorage.UpdatePush(contextWithTx, push, req.EnvironmentId) }) if err != nil { switch { @@ -935,8 +847,7 @@ func (s *PushService) GetPush( return nil, err } - pushStorage := v2ps.NewPushStorage(s.mysqlClient) - push, err := pushStorage.GetPush(ctx, req.Id, req.EnvironmentId) + push, err := s.pushStorage.GetPush(ctx, req.Id, req.EnvironmentId) if err != nil { if errors.Is(err, v2ps.ErrPushNotFound) { dt, err := statusNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -1266,8 +1177,7 @@ func (s *PushService) listPushes( } return nil, "", 0, dt.Err() } - pushStorage := v2ps.NewPushStorage(s.mysqlClient) - pushes, nextCursor, totalCount, err := pushStorage.ListPushes( + pushes, nextCursor, totalCount, err := s.pushStorage.ListPushes( ctx, whereParts, orders, diff --git a/pkg/push/api/api_test.go b/pkg/push/api/api_test.go index 0858fa7da..f8b12df61 100644 --- a/pkg/push/api/api_test.go +++ b/pkg/push/api/api_test.go @@ -29,6 +29,8 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/wrapperspb" + proto "github.com/bucketeer-io/bucketeer/proto/push" + accountproto "github.com/bucketeer-io/bucketeer/proto/account" accountclientmock "github.com/bucketeer-io/bucketeer/pkg/account/client/mock" @@ -36,8 +38,11 @@ import ( featureclientmock "github.com/bucketeer-io/bucketeer/pkg/feature/client/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" publishermock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" + "github.com/bucketeer-io/bucketeer/pkg/push/domain" v2ps "github.com/bucketeer-io/bucketeer/pkg/push/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/push/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/rpc" + "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" "github.com/bucketeer-io/bucketeer/pkg/token" pushproto "github.com/bucketeer-io/bucketeer/proto/push" @@ -147,20 +152,15 @@ func TestCreatePushMySQL(t *testing.T) { { desc: "err: ErrAlreadyExists", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2ps.ErrPushAlreadyExists) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().CreatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(v2ps.ErrPushAlreadyExists) }, @@ -177,22 +177,18 @@ func TestCreatePushMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().CreatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &pushproto.CreatePushRequest{ EnvironmentId: "ns0", @@ -270,20 +266,15 @@ func TestCreatePushNoCommandMySQL(t *testing.T) { { desc: "err: ErrAlreadyExists", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2ps.ErrPushAlreadyExists) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().CreatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(v2ps.ErrPushAlreadyExists) }, @@ -298,22 +289,18 @@ func TestCreatePushNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().CreatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &pushproto.CreatePushRequest{ EnvironmentId: "ns0", @@ -404,22 +391,17 @@ func TestUpdatePushMySQL(t *testing.T) { { desc: "err: ErrNotFound", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2ps.ErrPushNotFound) + ).Return(nil, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2ps.ErrPushNotFound) }, req: &pushproto.UpdatePushRequest{ Id: "key-1", @@ -430,8 +412,20 @@ func TestUpdatePushMySQL(t *testing.T) { { desc: "success: rename", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -445,10 +439,23 @@ func TestUpdatePushMySQL(t *testing.T) { { desc: "success: deletePushTags", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Tags: []string{"tag-0"}, + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &pushproto.UpdatePushRequest{ EnvironmentId: "ns0", @@ -460,20 +467,24 @@ func TestUpdatePushMySQL(t *testing.T) { { desc: "success: addPushTags", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Tags: []string{"tag-0"}, + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -487,22 +498,21 @@ func TestUpdatePushMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + }, + }, nil) }, req: &pushproto.UpdatePushRequest{ EnvironmentId: "ns0", @@ -559,8 +569,20 @@ func TestUpdatePushNoCommandMySQL(t *testing.T) { { desc: "err: ErrNotFound", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2ps.ErrPushNotFound) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(v2ps.ErrPushNotFound) }, @@ -574,8 +596,22 @@ func TestUpdatePushNoCommandMySQL(t *testing.T) { { desc: "success update name", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Name: "push-0", + Tags: []string{"tag-0"}, + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -588,8 +624,22 @@ func TestUpdatePushNoCommandMySQL(t *testing.T) { { desc: "success update tags", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Name: "push-0", + Tags: []string{"tag-0"}, + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -602,8 +652,22 @@ func TestUpdatePushNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Name: "push-0", + Tags: []string{"tag-0"}, + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -746,10 +810,14 @@ func TestDeletePushMySQL(t *testing.T) { { desc: "err: ErrNotFound", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2ps.ErrPushNotFound) + ).Return(nil, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2ps.ErrPushNotFound) }, req: &pushproto.DeletePushRequest{ EnvironmentId: "ns0", @@ -761,9 +829,26 @@ func TestDeletePushMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-0", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().UpdatePush( + gomock.Any(), &domain.Push{ + Push: &proto.Push{ + Id: "key-0", + Deleted: true, + }, + }, gomock.Any(), ).Return(nil) }, req: &pushproto.DeletePushRequest{ @@ -826,9 +911,9 @@ func TestListPushesMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *PushService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil, errors.New("error")) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, 0, int64(0), errors.New("error")) }, input: &pushproto.ListPushesRequest{EnvironmentId: "ns0"}, expected: nil, @@ -847,18 +932,9 @@ func TestListPushesMySQL(t *testing.T) { orgRole: toPtr(accountproto.AccountV2_Role_Organization_MEMBER), envRole: toPtr(accountproto.AccountV2_Role_Environment_VIEWER), setup: func(s *PushService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().ListPushes( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Push{}, 0, int64(0), nil) }, input: &pushproto.ListPushesRequest{PageSize: 2, Cursor: "", EnvironmentId: "ns0"}, expected: &pushproto.ListPushesResponse{Pushes: []*pushproto.Push{}, Cursor: "0"}, @@ -912,11 +988,9 @@ func TestGetPushMySQL(t *testing.T) { { desc: "err: ErrNotFound", setup: func(s *PushService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(v2ps.ErrPushNotFound) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + ).Return(nil, v2ps.ErrPushNotFound) }, req: &pushproto.GetPushRequest{ EnvironmentId: "ns0", @@ -927,11 +1001,13 @@ func TestGetPushMySQL(t *testing.T) { { desc: "success", setup: func(s *PushService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.pushStorage.(*storagemock.MockPushStorage).EXPECT().GetPush( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + ).Return(&domain.Push{ + Push: &proto.Push{ + Id: "key-1", + }, + }, nil) }, req: &pushproto.GetPushRequest{ EnvironmentId: "ns0", @@ -956,6 +1032,7 @@ func newPushServiceWithMock(t *testing.T, c *gomock.Controller) *PushService { t.Helper() return &PushService{ mysqlClient: mysqlmock.NewMockClient(c), + pushStorage: storagemock.NewMockPushStorage(c), featureClient: featureclientmock.NewMockClient(c), experimentClient: experimentclientmock.NewMockClient(c), accountClient: accountclientmock.NewMockClient(c), @@ -1004,6 +1081,7 @@ func newPushService(c *gomock.Controller, specifiedEnvironmentId *string, specif return &PushService{ mysqlClient: mysqlClient, featureClient: featureclientmock.NewMockClient(c), + pushStorage: storagemock.NewMockPushStorage(c), experimentClient: experimentclientmock.NewMockClient(c), accountClient: accountClientMock, publisher: publishermock.NewMockPublisher(c), diff --git a/pkg/push/storage/v2/push.go b/pkg/push/storage/v2/push.go index 740cbc8ab..a2e5f0243 100644 --- a/pkg/push/storage/v2/push.go +++ b/pkg/push/storage/v2/push.go @@ -56,15 +56,15 @@ type PushStorage interface { } type pushStorage struct { - qe mysql.QueryExecer + client mysql.Client } -func NewPushStorage(qe mysql.QueryExecer) PushStorage { - return &pushStorage{qe: qe} +func NewPushStorage(client mysql.Client) PushStorage { + return &pushStorage{client: client} } func (s *pushStorage) CreatePush(ctx context.Context, e *domain.Push, environmentId string) error { - _, err := s.qe.ExecContext( + _, err := s.client.Qe(ctx).ExecContext( ctx, insertPushSQL, e.Id, @@ -87,7 +87,7 @@ func (s *pushStorage) CreatePush(ctx context.Context, e *domain.Push, environmen } func (s *pushStorage) UpdatePush(ctx context.Context, e *domain.Push, environmentId string) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, updatePushSQL, e.FcmServiceAccount, @@ -115,7 +115,7 @@ func (s *pushStorage) UpdatePush(ctx context.Context, e *domain.Push, environmen func (s *pushStorage) GetPush(ctx context.Context, id, environmentId string) (*domain.Push, error) { push := proto.Push{} - err := s.qe.QueryRowContext( + err := s.client.Qe(ctx).QueryRowContext( ctx, selectPushSQL, id, @@ -151,7 +151,7 @@ func (s *pushStorage) ListPushes( orderBySQL := mysql.ConstructOrderBySQLString(orders) limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) query := fmt.Sprintf(listPushesSQL, whereSQL, orderBySQL, limitOffsetSQL) - rows, err := s.qe.QueryContext(ctx, query, whereArgs...) + rows, err := s.client.Qe(ctx).QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, 0, err } @@ -182,7 +182,7 @@ func (s *pushStorage) ListPushes( nextOffset := offset + len(pushes) var totalCount int64 countQuery := fmt.Sprintf(countPushesSQL, whereSQL) - err = s.qe.QueryRowContext(ctx, countQuery, whereArgs...).Scan(&totalCount) + err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, whereArgs...).Scan(&totalCount) if err != nil { return nil, 0, 0, err } diff --git a/pkg/push/storage/v2/push_test.go b/pkg/push/storage/v2/push_test.go index 5d13609da..ef8c68896 100644 --- a/pkg/push/storage/v2/push_test.go +++ b/pkg/push/storage/v2/push_test.go @@ -32,7 +32,7 @@ func TestNewPushStorage(t *testing.T) { t.Parallel() mockController := gomock.NewController(t) defer mockController.Finish() - storage := NewPushStorage(mock.NewMockQueryExecer(mockController)) + storage := NewPushStorage(mock.NewMockClient(mockController)) assert.IsType(t, &pushStorage{}, storage) } @@ -50,7 +50,11 @@ func TestCreatePush(t *testing.T) { { desc: "ErrPushAlreadyExists", setup: func(s *pushStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, mysql.ErrDuplicateEntry) }, @@ -63,7 +67,11 @@ func TestCreatePush(t *testing.T) { { desc: "Error", setup: func(s *pushStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) @@ -77,7 +85,11 @@ func TestCreatePush(t *testing.T) { { desc: "Success", setup: func(s *pushStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, nil) }, @@ -116,7 +128,11 @@ func TestUpdatePush(t *testing.T) { setup: func(s *pushStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(0), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -129,7 +145,11 @@ func TestUpdatePush(t *testing.T) { { desc: "Error", setup: func(s *pushStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) @@ -145,7 +165,11 @@ func TestUpdatePush(t *testing.T) { setup: func(s *pushStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(1), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -184,7 +208,11 @@ func TestGetPush(t *testing.T) { setup: func(s *pushStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -197,7 +225,11 @@ func TestGetPush(t *testing.T) { setup: func(s *pushStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) @@ -211,7 +243,11 @@ func TestGetPush(t *testing.T) { setup: func(s *pushStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -250,7 +286,11 @@ func TestListPushs(t *testing.T) { { desc: "Error", setup: func(s *pushStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) }, @@ -269,12 +309,16 @@ func TestListPushs(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe).AnyTimes() + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -313,5 +357,5 @@ func TestListPushs(t *testing.T) { func newpushStorageWithMock(t *testing.T, mockController *gomock.Controller) *pushStorage { t.Helper() - return &pushStorage{mock.NewMockQueryExecer(mockController)} + return &pushStorage{mock.NewMockClient(mockController)} }