diff --git a/dtmcli/barrier.go b/dtmcli/barrier.go index 48e5ae393..c264bd66c 100644 --- a/dtmcli/barrier.go +++ b/dtmcli/barrier.go @@ -95,10 +95,10 @@ func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error) // CallWithDB the same as Call, but with *sql.DB func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error { tx, err := db.Begin() - if err != nil { - return err + if err == nil { + err = bb.Call(tx, busiCall) } - return bb.Call(tx, busiCall) + return err } // QueryPrepared queries prepared data diff --git a/dtmcli/msg.go b/dtmcli/msg.go index 0998ed78d..66ac5ebdf 100644 --- a/dtmcli/msg.go +++ b/dtmcli/msg.go @@ -40,7 +40,7 @@ func (s *Msg) Submit() error { return dtmimp.TransCallDtm(&s.TransBase, s, "submit") } -// PrepareAndSubmit execs prepare and submit operation +// PrepareAndSubmit one method for the entire busi->prepare->submit func (s *Msg) PrepareAndSubmit(queryPrepared string, db *sql.DB, busiCall BarrierBusiFunc) error { bb, err := BarrierFrom(s.TransType, s.Gid, "00", "msg") // a special barrier for msg QueryPrepared if err == nil { diff --git a/dtmcli/types.go b/dtmcli/types.go index 46d399afd..dbe212921 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -57,14 +57,9 @@ func SetBarrierTableName(tablename string) { dtmimp.BarrierTableName = tablename } -// OnBeforeRequest add before request middleware -func OnBeforeRequest(middleware func(c *resty.Client, r *resty.Request) error) { - dtmimp.RestyClient.OnBeforeRequest(middleware) -} - -// OnAfterResponse add after request middleware -func OnAfterResponse(middleware func(c *resty.Client, resp *resty.Response) error) { - dtmimp.RestyClient.OnAfterResponse(middleware) +// GetRestyClient get the resty.Client for http request +func GetRestyClient() *resty.Client { + return dtmimp.RestyClient } // SetPassthroughHeaders experimental. diff --git a/dtmgrpc/msg.go b/dtmgrpc/msg.go index e11271309..cbb88425f 100644 --- a/dtmgrpc/msg.go +++ b/dtmgrpc/msg.go @@ -7,6 +7,8 @@ package dtmgrpc import ( + "database/sql" + "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" @@ -40,3 +42,21 @@ func (s *MsgGrpc) Prepare(queryPrepared string) error { func (s *MsgGrpc) Submit() error { return dtmgimp.DtmGrpcCall(&s.TransBase, "Submit") } + +// PrepareAndSubmit one method for the entire busi->prepare->submit +func (s *MsgGrpc) PrepareAndSubmit(queryPrepared string, db *sql.DB, busiCall dtmcli.BarrierBusiFunc) error { + bb, err := dtmcli.BarrierFrom(s.TransType, s.Gid, "00", "msg") // a special barrier for msg QueryPrepared + if err == nil { + err = bb.CallWithDB(db, func(tx *sql.Tx) error { + err := busiCall(tx) + if err == nil { + err = s.Prepare(queryPrepared) + } + return err + }) + } + if err == nil { + err = s.Submit() + } + return err +} diff --git a/test/busi/barrier.go b/test/busi/barrier.go index 8ca19de39..ce1d7cef3 100644 --- a/test/busi/barrier.go +++ b/test/busi/barrier.go @@ -122,3 +122,8 @@ func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emp return sagaGrpcAdjustBalance(tx, TransOutUID, in.Amount, "") }) } + +func (s *busiServer) QueryPreparedB(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { + barrier := MustBarrierFromGrpc(ctx) + return &emptypb.Empty{}, barrier.QueryPrepared(dbGet().ToSQLDB()) +} diff --git a/test/dtmsvr_test.go b/test/dtmsvr_test.go index ae4cf1794..6078e5b8e 100644 --- a/test/dtmsvr_test.go +++ b/test/dtmsvr_test.go @@ -18,6 +18,7 @@ import ( ) var DtmServer = dtmutil.DefaultHTTPServer +var DtmGrpcServer = dtmutil.DefaultGrpcServer var Busi = busi.Busi func getTransStatus(gid string) string { diff --git a/test/main_test.go b/test/main_test.go index 0f4add2a2..69d29cf7a 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -37,8 +37,8 @@ func TestMain(m *testing.M) { conf.UpdateBranchSync = 1 dtmgrpc.AddUnaryInterceptor(busi.SetGrpcHeaderForHeadersYes) - dtmcli.OnBeforeRequest(busi.SetHttpHeaderForHeadersYes) - dtmcli.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { return nil }) + dtmcli.GetRestyClient().OnBeforeRequest(busi.SetHttpHeaderForHeadersYes) + dtmcli.GetRestyClient().OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { return nil }) tenv := os.Getenv("TEST_STORE") if tenv == "boltdb" { diff --git a/test/msg_grpc_barrier_test.go b/test/msg_grpc_barrier_test.go new file mode 100644 index 000000000..ca1b93dbc --- /dev/null +++ b/test/msg_grpc_barrier_test.go @@ -0,0 +1,54 @@ +package test + +import ( + "database/sql" + "errors" + "reflect" + "testing" + + "bou.ke/monkey" + "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/dtmgrpc" + "github.com/dtm-labs/dtm/test/busi" + "github.com/stretchr/testify/assert" +) + +func TestMsgGrpcPrepareAndSubmit(t *testing.T) { + before := getBeforeBalances() + gid := dtmimp.GetFuncName() + req := busi.GenBusiReq(30, false, false) + msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid). + Add(busi.BusiGrpc+"/busi.Busi/TransInBSaga", req) + err := msg.PrepareAndSubmit(busi.BusiGrpc+"/busi.Busi/QueryPreparedB", dbGet().ToSQLDB(), func(tx *sql.Tx) error { + return busi.SagaAdjustBalance(tx, busi.TransOutUID, -int(req.Amount), "SUCCESS") + }) + assert.Nil(t, err) + waitTransProcessed(msg.Gid) + assert.Equal(t, []string{StatusSucceed}, getBranchesStatus(msg.Gid)) + assert.Equal(t, StatusSucceed, getTransStatus(msg.Gid)) + assertNotSameBalance(t, before) +} + +func TestMsgGrpcPrepareAndSubmitCommitAfterFailed(t *testing.T) { + if conf.Store.IsDB() { // cannot patch tx.Commit, because Prepare also do Commit + return + } + before := getBeforeBalances() + gid := dtmimp.GetFuncName() + req := busi.GenBusiReq(30, false, false) + msg := dtmgrpc.NewMsgGrpc(DtmGrpcServer, gid). + Add(busi.BusiGrpc+"/busi.Busi/TransInBSaga", req) + var guard *monkey.PatchGuard + err := msg.PrepareAndSubmit(busi.BusiGrpc+"/busi.Busi/QueryPreparedB", dbGet().ToSQLDB(), func(tx *sql.Tx) error { + err := busi.SagaAdjustBalance(tx, busi.TransOutUID, -int(req.Amount), "SUCCESS") + guard = monkey.PatchInstanceMethod(reflect.TypeOf(tx), "Commit", func(tx *sql.Tx) error { + guard.Unpatch() + _ = tx.Commit() + return errors.New("test error for patch") + }) + return err + }) + assert.Error(t, err) + cronTransOnceForwardNow(180) + assertNotSameBalance(t, before) +}