diff --git a/README-cn.md b/README-cn.md index a5acf06ec..fd3f6022c 100644 --- a/README-cn.md +++ b/README-cn.md @@ -111,8 +111,8 @@ go run main.go 在实际的业务中,子事务可能出现失败,例如转入的子账号被冻结导致转账失败。我们对业务代码进行修改,让TransIn的正向操作失败,然后看看结果 ``` go - app.POST(qsBusiAPI+"/TransIn", common.WrapHandler(func(c *gin.Context) (interface{}, error) { - return M{"dtm_result": "FAILURE"}, nil + app.POST(qsBusiAPI+"/TransIn", common.WrapHandler2(func(c *gin.Context) interface{} { + return dtmcli.ErrFailure })) ``` diff --git a/README.md b/README.md index fc837042b..760cff95b 100644 --- a/README.md +++ b/README.md @@ -111,8 +111,8 @@ go run main.go 在实际的业务中,子事务可能出现失败,例如转入的子账号被冻结导致转账失败。我们对业务代码进行修改,让TransIn的正向操作失败,然后看看结果 ``` go - app.POST(qsBusiAPI+"/TransIn", common.WrapHandler(func(c *gin.Context) (interface{}, error) { - return M{"dtm_result": "FAILURE"}, nil + app.POST(qsBusiAPI+"/TransIn", common.WrapHandler2(func(c *gin.Context) interface{} { + return dtmcli.ErrFailure })) ``` diff --git a/bench/svr/http.go b/bench/svr/http.go index ed3f3cc53..6e8d7bc63 100644 --- a/bench/svr/http.go +++ b/bench/svr/http.go @@ -101,9 +101,9 @@ func StartSvr() { }() } -func qsAdjustBalance(uid int, amount int, c *gin.Context) (interface{}, error) { +func qsAdjustBalance(uid int, amount int, c *gin.Context) error { // nolint: unparam if strings.Contains(mode, "empty") || sqls == 0 { - return dtmcli.MapSuccess, nil + return nil } tb := dtmimp.TransBaseFromQuery(c.Request.URL.Query()) f := func(tx *sql.Tx) error { @@ -129,32 +129,32 @@ func qsAdjustBalance(uid int, amount int, c *gin.Context) (interface{}, error) { logger.FatalIfError(err) } - return dtmcli.MapSuccess, nil + return nil } func benchAddRoute(app *gin.Engine) { - app.POST(benchAPI+"/TransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(benchAPI+"/TransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return qsAdjustBalance(dtmimp.MustAtoi(c.Query("uid")), 1, c) })) - app.POST(benchAPI+"/TransInCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(benchAPI+"/TransInCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return qsAdjustBalance(dtmimp.MustAtoi(c.Query("uid")), -1, c) })) - app.POST(benchAPI+"/TransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(benchAPI+"/TransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return qsAdjustBalance(dtmimp.MustAtoi(c.Query("uid")), -1, c) })) - app.POST(benchAPI+"/TransOutCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(benchAPI+"/TransOutCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return qsAdjustBalance(dtmimp.MustAtoi(c.Query("uid")), 30, c) })) - app.Any(benchAPI+"/reloadData", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.Any(benchAPI+"/reloadData", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { reloadData() mode = c.Query("m") s := c.Query("sqls") if s != "" { sqls = dtmimp.MustAtoi(s) } - return nil, nil + return nil })) - app.Any(benchAPI+"/bench", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.Any(benchAPI+"/bench", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { uid := (atomic.AddInt32(&uidCounter, 1)-1)%total + 1 suid := fmt.Sprintf("%d", uid) suid2 := fmt.Sprintf("%d", total+1-uid) @@ -175,16 +175,15 @@ func benchAddRoute(app *gin.Engine) { _, err = dtmimp.RestyClient.R().SetBody(gin.H{}).SetQueryParam("uid", suid).Post(benchBusi + "/TransIn") dtmimp.E2P(err) } - return nil, nil + return nil })) - app.Any(benchAPI+"/benchEmptyUrl", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.Any(benchAPI+"/benchEmptyUrl", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { gid := shortuuid.New() req := gin.H{} saga := dtmcli.NewSaga(dtmutil.DefaultHTTPServer, gid). Add("", "", req). Add("", "", req) saga.WaitResult = true - err := saga.Submit() - return nil, err + return saga.Submit() })) } diff --git a/dtmcli/dtmimp/trans_base.go b/dtmcli/dtmimp/trans_base.go index 9c89a9fc0..ad0b9d1b7 100644 --- a/dtmcli/dtmimp/trans_base.go +++ b/dtmcli/dtmimp/trans_base.go @@ -9,6 +9,7 @@ package dtmimp import ( "errors" "fmt" + "net/http" "net/url" "strings" @@ -87,7 +88,7 @@ func TransCallDtm(tb *TransBase, body interface{}, operation string) error { if err != nil { return err } - if !strings.Contains(resp.String(), ResultSuccess) { + if resp.StatusCode() != http.StatusOK || strings.Contains(resp.String(), ResultFailure) { return errors.New(resp.String()) } return nil @@ -118,5 +119,8 @@ func TransRequestBranch(t *TransBase, body interface{}, branchID string, op stri }). SetHeaders(t.BranchHeaders). Post(url) - return resp, CheckResponse(resp, err) + if err == nil { + err = RespAsErrorCompatible(resp) + } + return resp, err } diff --git a/dtmcli/dtmimp/utils.go b/dtmcli/dtmimp/utils.go index e2576e208..8038b61ea 100644 --- a/dtmcli/dtmimp/utils.go +++ b/dtmcli/dtmimp/utils.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "os" "runtime" "strconv" @@ -204,37 +205,17 @@ func GetDsn(conf DBConf) string { return dsn } -// CheckResponse is check response, and return corresponding error by the condition of resp when err is nil. Otherwise, return err directly. -func CheckResponse(resp *resty.Response, err error) error { - if err == nil && resp != nil { - if resp.IsError() { - return errors.New(resp.String()) - } else if strings.Contains(resp.String(), ResultFailure) { - return ErrFailure - } else if strings.Contains(resp.String(), ResultOngoing) { - return ErrOngoing - } - } - return err -} - -// CheckResult is check result. Return err directly if err is not nil. And return corresponding error by calling CheckResponse if resp is the type of *resty.Response. -// Otherwise, return error by value of str, the string after marshal. -func CheckResult(res interface{}, err error) error { - if err != nil { - return err +// RespAsErrorCompatible translate a resty response to error +// compatible with version < v1.10 +func RespAsErrorCompatible(resp *resty.Response) error { + code := resp.StatusCode() + str := resp.String() + if code == http.StatusTooEarly || strings.Contains(str, ResultOngoing) { + return fmt.Errorf("%s. %w", str, ErrOngoing) + } else if code == http.StatusConflict || strings.Contains(str, ResultFailure) { + return fmt.Errorf("%s. %w", str, ErrFailure) + } else if code != http.StatusOK { + return errors.New(str) } - resp, ok := res.(*resty.Response) - if ok { - return CheckResponse(resp, err) - } - if res != nil { - str := MustMarshalString(res) - if strings.Contains(str, ResultFailure) { - return ErrFailure - } else if strings.Contains(str, ResultOngoing) { - return ErrOngoing - } - } - return err + return nil } diff --git a/dtmcli/types.go b/dtmcli/types.go index dbe212921..f48e30fec 100644 --- a/dtmcli/types.go +++ b/dtmcli/types.go @@ -32,6 +32,16 @@ type TransOptions = dtmimp.TransOptions // DBConf declares db configuration type DBConf = dtmimp.DBConf +// String2DtmError translate string to dtm error +func String2DtmError(str string) error { + return map[string]error{ + ResultFailure: ErrFailure, + ResultOngoing: ErrOngoing, + ResultSuccess: nil, + "": nil, + }[str] +} + // SetCurrentDBType set currentDBType func SetCurrentDBType(dbType string) { dtmimp.SetCurrentDBType(dbType) diff --git a/dtmcli/xa.go b/dtmcli/xa.go index 1f4137cbd..2b48ecbd9 100644 --- a/dtmcli/xa.go +++ b/dtmcli/xa.go @@ -59,8 +59,8 @@ func NewXaClient(server string, mysqlConf DBConf, notifyURL string, register XaR } // HandleCallback 处理commit/rollback的回调 -func (xc *XaClient) HandleCallback(gid string, branchID string, action string) (interface{}, error) { - return MapSuccess, xc.XaClientBase.HandleCallback(gid, branchID, action) +func (xc *XaClient) HandleCallback(gid string, branchID string, action string) interface{} { + return xc.XaClientBase.HandleCallback(gid, branchID, action) } // XaLocalTransaction start a xa local transaction diff --git a/dtmgrpc/dtmgimp/types.go b/dtmgrpc/dtmgimp/types.go index a11fe718c..535e9240e 100644 --- a/dtmgrpc/dtmgimp/types.go +++ b/dtmgrpc/dtmgimp/types.go @@ -11,12 +11,9 @@ import ( "fmt" "time" - "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) // GrpcServerLog 打印grpc服务端的日志 @@ -49,15 +46,3 @@ func GrpcClientLog(ctx context.Context, method string, req, reply interface{}, c } return err } - -// Result2Error 将通用的result转成grpc的error -func Result2Error(res interface{}, err error) error { - e := dtmimp.CheckResult(res, err) - if e == dtmimp.ErrFailure { - logger.Errorf("failure: res: %v, err: %v", res, e) - return status.New(codes.Aborted, dtmcli.ResultFailure).Err() - } else if e == dtmimp.ErrOngoing { - return status.New(codes.Aborted, dtmcli.ResultOngoing).Err() - } - return e -} diff --git a/dtmgrpc/type.go b/dtmgrpc/type.go index f900170ef..194ff5919 100644 --- a/dtmgrpc/type.go +++ b/dtmgrpc/type.go @@ -9,13 +9,27 @@ package dtmgrpc import ( context "context" + "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" "github.com/dtm-labs/dtmdriver" grpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" emptypb "google.golang.org/protobuf/types/known/emptypb" ) +// DtmError2GrpcError translate dtm error to grpc error +func DtmError2GrpcError(res interface{}) error { + e, ok := res.(error) + if ok && e == dtmimp.ErrFailure { + return status.New(codes.Aborted, dtmcli.ResultFailure).Err() + } else if ok && e == dtmimp.ErrOngoing { + return status.New(codes.FailedPrecondition, dtmcli.ResultOngoing).Err() + } + return e +} + // MustGenGid must gen a gid from grpcServer func MustGenGid(grpcServer string) string { dc := dtmgimp.MustGetDtmClient(grpcServer) diff --git a/dtmsvr/api.go b/dtmsvr/api.go index 7fb5673bf..83a10565a 100644 --- a/dtmsvr/api.go +++ b/dtmsvr/api.go @@ -15,7 +15,7 @@ import ( "github.com/dtm-labs/dtm/dtmsvr/storage" ) -func svcSubmit(t *TransGlobal) (interface{}, error) { +func svcSubmit(t *TransGlobal) interface{} { t.Status = dtmcli.StatusSubmitted branches, err := t.saveNew() @@ -25,35 +25,36 @@ func svcSubmit(t *TransGlobal) (interface{}, error) { dbt.changeStatus(t.Status) branches = GetStore().FindBranches(t.Gid) } else if dbt.Status != dtmcli.StatusSubmitted { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot sumbmit", dbt.Status)}, nil + return fmt.Errorf("current status '%s', cannot sumbmit. %w", dbt.Status, dtmcli.ErrFailure) } } - return t.Process(branches), nil + return t.Process(branches) } -func svcPrepare(t *TransGlobal) (interface{}, error) { +func svcPrepare(t *TransGlobal) interface{} { t.Status = dtmcli.StatusPrepared _, err := t.saveNew() if err == storage.ErrUniqueConflict { dbt := GetTransGlobal(t.Gid) if dbt.Status != dtmcli.StatusPrepared { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("current status '%s', cannot prepare", dbt.Status)}, nil + return fmt.Errorf("current status '%s', cannot prepare. %w", dbt.Status, dtmcli.ErrFailure) } + return nil } - return dtmcli.MapSuccess, nil + return err } -func svcAbort(t *TransGlobal) (interface{}, error) { +func svcAbort(t *TransGlobal) interface{} { dbt := GetTransGlobal(t.Gid) if t.TransType != "xa" && t.TransType != "tcc" || dbt.Status != dtmcli.StatusPrepared && dbt.Status != dtmcli.StatusAborting { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": fmt.Sprintf("trans type: '%s' current status '%s', cannot abort", dbt.TransType, dbt.Status)}, nil + return fmt.Errorf("trans type: '%s' current status '%s', cannot abort. %w", dbt.TransType, dbt.Status, dtmcli.ErrFailure) } dbt.changeStatus(dtmcli.StatusAborting) branches := GetStore().FindBranches(t.Gid) - return dbt.Process(branches), nil + return dbt.Process(branches) } -func svcRegisterBranch(transType string, branch *TransBranch, data map[string]string) (ret interface{}, rerr error) { +func svcRegisterBranch(transType string, branch *TransBranch, data map[string]string) error { branches := []TransBranch{*branch, *branch} if transType == "tcc" { for i, b := range []string{dtmcli.BranchCancel, dtmcli.BranchConfirm} { @@ -66,7 +67,7 @@ func svcRegisterBranch(transType string, branch *TransBranch, data map[string]st branches[1].Op = dtmcli.BranchCommit branches[1].URL = data["url"] } else { - return nil, fmt.Errorf("unknow trans type: %s", transType) + return fmt.Errorf("unknow trans type: %s", transType) } err := dtmimp.CatchP(func() { @@ -75,9 +76,9 @@ func svcRegisterBranch(transType string, branch *TransBranch, data map[string]st if err == storage.ErrNotFound { msg := fmt.Sprintf("no trans with gid: %s status: %s found", branch.Gid, dtmcli.StatusPrepared) logger.Errorf(msg) - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": msg}, nil + return fmt.Errorf("message: %s %w", msg, dtmcli.ErrFailure) } logger.Infof("LockGlobalSaveBranches result: %v: gid: %s old status: %s branches: %s", err, branch.Gid, dtmcli.StatusPrepared, dtmimp.MustMarshalString(branches)) - return dtmimp.If(err != nil, nil, dtmcli.MapSuccess), err + return err } diff --git a/dtmsvr/api_grpc.go b/dtmsvr/api_grpc.go index 8cd48d358..8bc840a83 100644 --- a/dtmsvr/api_grpc.go +++ b/dtmsvr/api_grpc.go @@ -10,7 +10,7 @@ import ( "context" "github.com/dtm-labs/dtm/dtmcli" - "github.com/dtm-labs/dtm/dtmgrpc/dtmgimp" + "github.com/dtm-labs/dtm/dtmgrpc" pb "github.com/dtm-labs/dtm/dtmgrpc/dtmgpb" "google.golang.org/protobuf/types/known/emptypb" ) @@ -25,26 +25,26 @@ func (s *dtmServer) NewGid(ctx context.Context, in *emptypb.Empty) (*pb.DtmGidRe } func (s *dtmServer) Submit(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcSubmit(TransFromDtmRequest(ctx, in)) - return &emptypb.Empty{}, dtmgimp.Result2Error(r, err) + r := svcSubmit(TransFromDtmRequest(ctx, in)) + return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(r) } func (s *dtmServer) Prepare(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcPrepare(TransFromDtmRequest(ctx, in)) - return &emptypb.Empty{}, dtmgimp.Result2Error(r, err) + r := svcPrepare(TransFromDtmRequest(ctx, in)) + return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(r) } func (s *dtmServer) Abort(ctx context.Context, in *pb.DtmRequest) (*emptypb.Empty, error) { - r, err := svcAbort(TransFromDtmRequest(ctx, in)) - return &emptypb.Empty{}, dtmgimp.Result2Error(r, err) + r := svcAbort(TransFromDtmRequest(ctx, in)) + return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(r) } func (s *dtmServer) RegisterBranch(ctx context.Context, in *pb.DtmBranchRequest) (*emptypb.Empty, error) { - r, err := svcRegisterBranch(in.TransType, &TransBranch{ + r := svcRegisterBranch(in.TransType, &TransBranch{ Gid: in.Gid, BranchID: in.BranchID, Status: dtmcli.StatusPrepared, BinData: in.BusiPayload, }, in.Data) - return &emptypb.Empty{}, dtmgimp.Result2Error(r, err) + return &emptypb.Empty{}, dtmgrpc.DtmError2GrpcError(r) } diff --git a/dtmsvr/api_http.go b/dtmsvr/api_http.go index fa98498b7..60271009c 100644 --- a/dtmsvr/api_http.go +++ b/dtmsvr/api_http.go @@ -17,15 +17,15 @@ import ( ) func addRoute(engine *gin.Engine) { - engine.GET("/api/dtmsvr/newGid", dtmutil.WrapHandler(newGid)) - engine.POST("/api/dtmsvr/prepare", dtmutil.WrapHandler(prepare)) - engine.POST("/api/dtmsvr/submit", dtmutil.WrapHandler(submit)) - engine.POST("/api/dtmsvr/abort", dtmutil.WrapHandler(abort)) - engine.POST("/api/dtmsvr/registerBranch", dtmutil.WrapHandler(registerBranch)) - engine.POST("/api/dtmsvr/registerXaBranch", dtmutil.WrapHandler(registerBranch)) // compatible for old sdk - engine.POST("/api/dtmsvr/registerTccBranch", dtmutil.WrapHandler(registerBranch)) // compatible for old sdk - engine.GET("/api/dtmsvr/query", dtmutil.WrapHandler(query)) - engine.GET("/api/dtmsvr/all", dtmutil.WrapHandler(all)) + engine.GET("/api/dtmsvr/newGid", dtmutil.WrapHandler2(newGid)) + engine.POST("/api/dtmsvr/prepare", dtmutil.WrapHandler2(prepare)) + engine.POST("/api/dtmsvr/submit", dtmutil.WrapHandler2(submit)) + engine.POST("/api/dtmsvr/abort", dtmutil.WrapHandler2(abort)) + engine.POST("/api/dtmsvr/registerBranch", dtmutil.WrapHandler2(registerBranch)) + engine.POST("/api/dtmsvr/registerXaBranch", dtmutil.WrapHandler2(registerBranch)) // compatible for old sdk + engine.POST("/api/dtmsvr/registerTccBranch", dtmutil.WrapHandler2(registerBranch)) // compatible for old sdk + engine.GET("/api/dtmsvr/query", dtmutil.WrapHandler2(query)) + engine.GET("/api/dtmsvr/all", dtmutil.WrapHandler2(all)) // add prometheus exporter h := promhttp.Handler() @@ -34,23 +34,23 @@ func addRoute(engine *gin.Engine) { }) } -func newGid(c *gin.Context) (interface{}, error) { - return map[string]interface{}{"gid": GenGid(), "dtm_result": dtmcli.ResultSuccess}, nil +func newGid(c *gin.Context) interface{} { + return map[string]interface{}{"gid": GenGid(), "dtm_result": dtmcli.ResultSuccess} } -func prepare(c *gin.Context) (interface{}, error) { +func prepare(c *gin.Context) interface{} { return svcPrepare(TransFromContext(c)) } -func submit(c *gin.Context) (interface{}, error) { +func submit(c *gin.Context) interface{} { return svcSubmit(TransFromContext(c)) } -func abort(c *gin.Context) (interface{}, error) { +func abort(c *gin.Context) interface{} { return svcAbort(TransFromContext(c)) } -func registerBranch(c *gin.Context) (interface{}, error) { +func registerBranch(c *gin.Context) interface{} { data := map[string]string{} err := c.BindJSON(&data) e2p(err) @@ -63,19 +63,19 @@ func registerBranch(c *gin.Context) (interface{}, error) { return svcRegisterBranch(data["trans_type"], &branch, data) } -func query(c *gin.Context) (interface{}, error) { +func query(c *gin.Context) interface{} { gid := c.Query("gid") if gid == "" { - return nil, errors.New("no gid specified") + return errors.New("no gid specified") } trans := GetStore().FindTransGlobalStore(gid) branches := GetStore().FindBranches(gid) - return map[string]interface{}{"transaction": trans, "branches": branches}, nil + return map[string]interface{}{"transaction": trans, "branches": branches} } -func all(c *gin.Context) (interface{}, error) { +func all(c *gin.Context) interface{} { position := c.Query("position") slimit := dtmimp.OrString(c.Query("limit"), "100") globals := GetStore().ScanTransGlobalStores(&position, int64(dtmimp.MustAtoi(slimit))) - return map[string]interface{}{"transactions": globals, "next_position": position}, nil + return map[string]interface{}{"transactions": globals, "next_position": position} } diff --git a/dtmsvr/cron.go b/dtmsvr/cron.go index 66ff0b5b9..8ffff263b 100644 --- a/dtmsvr/cron.go +++ b/dtmsvr/cron.go @@ -32,7 +32,8 @@ func CronTransOnce() (gid string) { gid = trans.Gid trans.WaitResult = true branches := GetStore().FindBranches(gid) - trans.Process(branches) + err := trans.Process(branches) + dtmimp.E2P(err) return } diff --git a/dtmsvr/trans_process.go b/dtmsvr/trans_process.go index fc4851363..7a8fe4698 100644 --- a/dtmsvr/trans_process.go +++ b/dtmsvr/trans_process.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "fmt" "time" "github.com/dtm-labs/dtm/dtmcli" @@ -16,13 +17,13 @@ import ( ) // Process process global transaction once -func (t *TransGlobal) Process(branches []TransBranch) map[string]interface{} { +func (t *TransGlobal) Process(branches []TransBranch) error { r := t.process(branches) - transactionMetrics(t, r["dtm_result"] == dtmcli.ResultSuccess) + transactionMetrics(t, r == nil) return r } -func (t *TransGlobal) process(branches []TransBranch) map[string]interface{} { +func (t *TransGlobal) process(branches []TransBranch) error { if t.Options != "" { dtmimp.MustUnmarshalString(t.Options, &t.TransOptions) } @@ -37,17 +38,17 @@ func (t *TransGlobal) process(branches []TransBranch) map[string]interface{} { logger.Errorf("processInner err: %v", err) } }() - return dtmcli.MapSuccess + return nil } submitting := t.Status == dtmcli.StatusSubmitted err := t.processInner(branches) if err != nil { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": err.Error()} + return err } if submitting && t.Status != dtmcli.StatusSucceed { - return map[string]interface{}{"dtm_result": dtmcli.ResultFailure, "message": "trans failed by user"} + return fmt.Errorf("wait result not return success: %w", dtmcli.ErrFailure) } - return dtmcli.MapSuccess + return nil } func (t *TransGlobal) processInner(branches []TransBranch) (rerr error) { diff --git a/dtmsvr/trans_status.go b/dtmsvr/trans_status.go index ec78533a6..c70e5f295 100644 --- a/dtmsvr/trans_status.go +++ b/dtmsvr/trans_status.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "errors" "fmt" "strings" "time" @@ -72,15 +73,15 @@ func (t *TransGlobal) needProcess() bool { return t.Status == dtmcli.StatusSubmitted || t.Status == dtmcli.StatusAborting || t.Status == dtmcli.StatusPrepared && t.isTimeout() } -func (t *TransGlobal) getURLResult(url string, branchID, op string, branchPayload []byte) (string, error) { +func (t *TransGlobal) getURLResult(url string, branchID, op string, branchPayload []byte) error { if url == "" { // empty url is success - return dtmcli.ResultSuccess, nil + return nil } if t.Protocol == "grpc" { dtmimp.PanicIf(strings.HasPrefix(url, "http"), fmt.Errorf("bad url for grpc: %s", url)) server, method, err := dtmdriver.GetDriver().ParseServerMethod(url) if err != nil { - return "", err + return err } conn := dtmgimp.MustGetGrpcConn(server, true) ctx := dtmgimp.TransInfo2Ctx(t.Gid, t.TransType, branchID, op, "") @@ -89,17 +90,19 @@ func (t *TransGlobal) getURLResult(url string, branchID, op string, branchPayloa ctx = metadata.AppendToOutgoingContext(ctx, kvs...) err = conn.Invoke(ctx, method, branchPayload, &[]byte{}) if err == nil { - return dtmcli.ResultSuccess, nil + return nil } st, ok := status.FromError(err) if ok && st.Code() == codes.Aborted { + // version lower then v1.10, will specify Ongoing in code Aborted if st.Message() == dtmcli.ResultOngoing { - return dtmcli.ResultOngoing, nil - } else if st.Message() == dtmcli.ResultFailure { - return dtmcli.ResultFailure, nil + return dtmcli.ErrOngoing } + return dtmcli.ErrFailure + } else if ok && st.Code() == codes.FailedPrecondition { + return dtmcli.ErrOngoing } - return "", err + return err } dtmimp.PanicIf(!strings.HasPrefix(url, "http"), fmt.Errorf("bad url for http: %s", url)) resp, err := dtmimp.RestyClient.R().SetBody(string(branchPayload)). @@ -114,24 +117,21 @@ func (t *TransGlobal) getURLResult(url string, branchID, op string, branchPayloa SetHeaders(t.TransOptions.BranchHeaders). Execute(dtmimp.If(branchPayload != nil || t.TransType == "xa", "POST", "GET").(string), url) if err != nil { - return "", err + return err } - return resp.String(), nil + return dtmimp.RespAsErrorCompatible(resp) } func (t *TransGlobal) getBranchResult(branch *TransBranch) (string, error) { - body, err := t.getURLResult(branch.URL, branch.BranchID, branch.Op, branch.BinData) - if err != nil { - return "", err - } - if strings.Contains(body, dtmcli.ResultSuccess) { + err := t.getURLResult(branch.URL, branch.BranchID, branch.Op, branch.BinData) + if err == nil { return dtmcli.StatusSucceed, nil - } else if strings.HasSuffix(t.TransType, "saga") && branch.Op == dtmcli.BranchAction && strings.Contains(body, dtmcli.ResultFailure) { + } else if t.TransType == "saga" && branch.Op == dtmcli.BranchAction && errors.Is(err, dtmcli.ErrFailure) { return dtmcli.StatusFailed, nil - } else if strings.Contains(body, dtmcli.ResultOngoing) { - return "", dtmimp.ErrOngoing + } else if errors.Is(err, dtmcli.ErrOngoing) { + return "", dtmcli.ErrOngoing } - return "", fmt.Errorf("http result should contains SUCCESS|FAILURE|ONGOING. grpc error should return nil|Aborted with message(FAILURE|ONGOING). \nrefer to: https://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", body) + return "", fmt.Errorf("http/grpc result should be specified as in:\nhttps://dtm.pub/summary/arch.html#http\nunkown result will be retried: %s", err) } func (t *TransGlobal) execBranch(branch *TransBranch, branchPos int) error { diff --git a/dtmsvr/trans_type_msg.go b/dtmsvr/trans_type_msg.go index 3435042cf..5fe3eb2be 100644 --- a/dtmsvr/trans_type_msg.go +++ b/dtmsvr/trans_type_msg.go @@ -7,8 +7,8 @@ package dtmsvr import ( + "errors" "fmt" - "strings" "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/logger" @@ -42,15 +42,15 @@ func (t *TransGlobal) mayQueryPrepared() { if !t.needProcess() || t.Status == dtmcli.StatusSubmitted { return } - body, err := t.getURLResult(t.QueryPrepared, "00", "msg", nil) - if strings.Contains(body, dtmcli.ResultSuccess) { + err := t.getURLResult(t.QueryPrepared, "00", "msg", nil) + if err == nil { t.changeStatus(dtmcli.StatusSubmitted) - } else if strings.Contains(body, dtmcli.ResultFailure) { + } else if errors.Is(err, dtmcli.ErrFailure) { t.changeStatus(dtmcli.StatusFailed) - } else if strings.Contains(body, dtmcli.ResultOngoing) { + } else if errors.Is(err, dtmcli.ErrOngoing) { t.touchCronTime(cronReset) } else { - logger.Errorf("getting result failed for %s. error: %v body %s", t.QueryPrepared, err, body) + logger.Errorf("getting result failed for %s. error: %v", t.QueryPrepared, err) t.touchCronTime(cronBackoff) } } diff --git a/dtmsvr/trans_type_saga.go b/dtmsvr/trans_type_saga.go index abaa80169..f2a88fd5b 100644 --- a/dtmsvr/trans_type_saga.go +++ b/dtmsvr/trans_type_saga.go @@ -7,6 +7,7 @@ package dtmsvr import ( + "errors" "fmt" "time" @@ -106,7 +107,7 @@ func (t *transSagaProcessor) ProcessOnce(branches []TransBranch) error { err = dtmimp.AsError(x) } resultChan <- branchResult{index: i, status: branches[i].Status, op: branches[i].Op} - if err != nil && err != dtmcli.ErrOngoing { + if err != nil && !errors.Is(err, dtmcli.ErrOngoing) { logger.Errorf("exec branch error: %v", err) } }() diff --git a/dtmutil/utils.go b/dtmutil/utils.go index 288c42425..849b53995 100644 --- a/dtmutil/utils.go +++ b/dtmutil/utils.go @@ -9,7 +9,9 @@ package dtmutil import ( "bytes" "encoding/json" + "errors" "io/ioutil" + "net/http" "os" "path/filepath" "strings" @@ -45,31 +47,61 @@ func GetGinApp() *gin.Engine { return app } -// WrapHandler name is clear -func WrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { +// WrapHandler2 wrap a function te bo the handler of gin request +func WrapHandler2(fn func(*gin.Context) interface{}) gin.HandlerFunc { return func(c *gin.Context) { began := time.Now() - r, err := func() (r interface{}, rerr error) { - defer dtmimp.P2E(&rerr) + var err error + r := func() interface{} { + defer dtmimp.P2E(&err) return fn(c) }() - var b = []byte{} - if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理 - b = resp.Body() - } else if err == nil { - b, err = json.Marshal(r) + + status := http.StatusOK + + // in dtm test/busi, there are some functions, which will return a resty response + // pass resty response as gin's response + if resp, ok := r.(*resty.Response); ok { + b := resp.Body() + status = resp.StatusCode() + r = nil + err = json.Unmarshal(b, &r) + } + + // error maybe returned in r, assign it to err + if ne, ok := r.(error); ok && err == nil { + err = ne } + // if err != nil || r == nil. then set the status and dtm_result + // dtm_result is for compatible with version lower than v1.10 + // when >= v1.10, result test should base on status, not dtm_result. + result := map[string]interface{}{} if err != nil { - logger.Errorf("%2dms 500 %s %s %s %s", time.Since(began).Milliseconds(), err.Error(), c.Request.Method, c.Request.RequestURI, string(b)) - c.JSON(500, map[string]interface{}{"code": 500, "message": err.Error()}) + if errors.Is(err, dtmcli.ErrFailure) { + status = http.StatusConflict + result["dtm_result"] = dtmcli.ResultFailure + } else if errors.Is(err, dtmcli.ErrOngoing) { + status = http.StatusTooEarly + result["dtm_result"] = dtmcli.ResultOngoing + } else if err != nil { + status = http.StatusInternalServerError + } + result["message"] = err.Error() + r = result + } else if r == nil { + result["dtm_result"] = dtmcli.ResultSuccess + r = result + } + + b, _ := json.Marshal(r) + cont := string(b) + if status == http.StatusOK || status == http.StatusTooEarly { + logger.Infof("%2dms %d %s %s %s", time.Since(began).Milliseconds(), status, c.Request.Method, c.Request.RequestURI, cont) } else { - logger.Infof("%2dms 200 %s %s %s", time.Since(began).Milliseconds(), c.Request.Method, c.Request.RequestURI, string(b)) - c.Status(200) - c.Writer.Header().Add("Content-Type", "application/json") - _, err = c.Writer.Write(b) - dtmimp.E2P(err) + logger.Errorf("%2dms %d %s %s %s", time.Since(began).Milliseconds(), status, c.Request.Method, c.Request.RequestURI, cont) } + c.JSON(status, r) } } diff --git a/dtmutil/utils_test.go b/dtmutil/utils_test.go index 714f2177b..4ce150744 100644 --- a/dtmutil/utils_test.go +++ b/dtmutil/utils_test.go @@ -21,11 +21,11 @@ import ( func TestGin(t *testing.T) { app := GetGinApp() - app.GET("/api/sample", WrapHandler(func(c *gin.Context) (interface{}, error) { - return 1, nil + app.GET("/api/sample", WrapHandler2(func(c *gin.Context) interface{} { + return 1 })) - app.GET("/api/error", WrapHandler(func(c *gin.Context) (interface{}, error) { - return nil, errors.New("err1") + app.GET("/api/error", WrapHandler2(func(c *gin.Context) interface{} { + return errors.New("err1") })) getResultString := func(api string, body io.Reader) string { req, _ := http.NewRequest("GET", api, body) @@ -35,7 +35,7 @@ func TestGin(t *testing.T) { } assert.Equal(t, "{\"msg\":\"pong\"}", getResultString("/api/ping", nil)) assert.Equal(t, "1", getResultString("/api/sample", nil)) - assert.Equal(t, "{\"code\":500,\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}"))) + assert.Equal(t, "{\"message\":\"err1\"}", getResultString("/api/error", strings.NewReader("{}"))) } func TestFuncs(t *testing.T) { diff --git a/test/busi/barrier.go b/test/busi/barrier.go index ce1d7cef3..46463f7bb 100644 --- a/test/busi/barrier.go +++ b/test/busi/barrier.go @@ -18,79 +18,79 @@ import ( func init() { setupFuncs["BarrierSetup"] = func(app *gin.Engine) { - app.POST(BusiAPI+"/SagaBTransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/SagaBTransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return barrier.Call(txGet(), func(tx *sql.Tx) error { return SagaAdjustBalance(tx, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) }) })) - app.POST(BusiAPI+"/SagaBTransInCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/SagaBTransInCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return barrier.Call(txGet(), func(tx *sql.Tx) error { return SagaAdjustBalance(tx, TransInUID, -reqFrom(c).Amount, "") }) })) - app.POST(BusiAPI+"/SagaBTransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/SagaBTransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return barrier.Call(txGet(), func(tx *sql.Tx) error { return SagaAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount, reqFrom(c).TransOutResult) }) })) - app.POST(BusiAPI+"/SagaBTransOutCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/SagaBTransOutCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { barrier := MustBarrierFromGin(c) - return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error { + return barrier.Call(txGet(), func(tx *sql.Tx) error { return SagaAdjustBalance(tx, TransOutUID, reqFrom(c).Amount, "") }) })) - app.POST(BusiAPI+"/SagaBTransOutGorm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/SagaBTransOutGorm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { req := reqFrom(c) barrier := MustBarrierFromGin(c) tx := dbGet().DB.Begin() - return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error { + return barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error { return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, TransOutUID).Error }) })) - app.POST(BusiAPI+"/TccBTransInTry", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - req := reqFrom(c) // 去重构一下,改成可以重复使用的输入 + app.POST(BusiAPI+"/TccBTransInTry", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + req := reqFrom(c) if req.TransInResult != "" { - return req.TransInResult, nil + return dtmcli.String2DtmError(req.TransInResult) } - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransInUID, req.Amount) }) })) - app.POST(BusiAPI+"/TccBTransInConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + app.POST(BusiAPI+"/TccBTransInConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustBalance(tx, TransInUID, reqFrom(c).Amount) }) })) - app.POST(BusiAPI+"/TccBTransInCancel", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + app.POST(BusiAPI+"/TccBTransInCancel", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransInUID, -reqFrom(c).Amount) }) })) - app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TccBTransOutTry", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { req := reqFrom(c) if req.TransOutResult != "" { - return req.TransOutResult, nil + return dtmcli.String2DtmError(req.TransOutResult) } - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransOutUID, -req.Amount) }) })) - app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { + app.POST(BusiAPI+"/TccBTransOutConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustBalance(tx, TransOutUID, -reqFrom(c).Amount) }) })) - app.POST(BusiAPI+"/TccBTransOutCancel", dtmutil.WrapHandler(TccBarrierTransOutCancel)) + app.POST(BusiAPI+"/TccBTransOutCancel", dtmutil.WrapHandler2(TccBarrierTransOutCancel)) } } // TccBarrierTransOutCancel will be use in test -func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) { - return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { +func TccBarrierTransOutCancel(c *gin.Context) interface{} { + return MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error { return tccAdjustTrading(tx, TransOutUID, reqFrom(c).Amount) }) } diff --git a/test/busi/base_grpc.go b/test/busi/base_grpc.go index 3c71a5079..63099e1b9 100644 --- a/test/busi/base_grpc.go +++ b/test/busi/base_grpc.go @@ -13,6 +13,7 @@ import ( "fmt" "net" + "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmcli/logger" "github.com/dtm-labs/dtm/dtmgrpc" @@ -65,7 +66,9 @@ type busiServer struct { func (s *busiServer) QueryPrepared(ctx context.Context, in *BusiReq) (*BusiReply, error) { res := MainSwitch.QueryPreparedResult.Fetch() - return &BusiReply{Message: "a sample data"}, dtmgimp.Result2Error(res, nil) + err := dtmcli.String2DtmError(res) + + return &BusiReply{Message: "a sample data"}, dtmgrpc.DtmError2GrpcError(err) } func (s *busiServer) TransIn(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) { diff --git a/test/busi/base_http.go b/test/busi/base_http.go index b39d4ae1c..0feee1e89 100644 --- a/test/busi/base_http.go +++ b/test/busi/base_http.go @@ -39,7 +39,7 @@ var Busi string = fmt.Sprintf("http://localhost:%d%s", BusiPort, BusiAPI) var XaClient *dtmcli.XaClient = nil -type SleepCancelHandler func(c *gin.Context) (interface{}, error) +type SleepCancelHandler func(c *gin.Context) interface{} var sleepCancelHandler SleepCancelHandler = nil @@ -62,7 +62,7 @@ func BaseAppStartup() *gin.Engine { }) var err error XaClient, err = dtmcli.NewXaClient(dtmutil.DefaultHTTPServer, BusiConf, Busi+"/xa", func(path string, xa *dtmcli.XaClient) { - app.POST(path, dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(path, dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return xa.HandleCallback(c.Query("gid"), c.Query("branch_id"), c.Query("op")) })) }) @@ -81,55 +81,76 @@ func BaseAppStartup() *gin.Engine { // BaseAddRoute add base route handler func BaseAddRoute(app *gin.Engine) { - app.POST(BusiAPI+"/TransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransInResult.Fetch(), reqFrom(c).TransInResult, "transIn") })) - app.POST(BusiAPI+"/TransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransOutResult.Fetch(), reqFrom(c).TransOutResult, "TransOut") })) - app.POST(BusiAPI+"/TransInConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransInConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransInConfirmResult.Fetch(), "", "TransInConfirm") })) - app.POST(BusiAPI+"/TransOutConfirm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransOutConfirm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransOutConfirmResult.Fetch(), "", "TransOutConfirm") })) - app.POST(BusiAPI+"/TransInRevert", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransInRevert", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransInRevertResult.Fetch(), "", "TransInRevert") })) - app.POST(BusiAPI+"/TransOutRevert", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransOutRevert", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return handleGeneralBusiness(c, MainSwitch.TransOutRevertResult.Fetch(), "", "TransOutRevert") })) - app.GET(BusiAPI+"/QueryPrepared", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransInOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransInResult.Fetch(), reqFrom(c).TransInResult, "transIn") + })) + app.POST(BusiAPI+"/TransOutOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransOutResult.Fetch(), reqFrom(c).TransOutResult, "TransOut") + })) + app.POST(BusiAPI+"/TransInConfirmOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransInConfirmResult.Fetch(), "", "TransInConfirm") + })) + app.POST(BusiAPI+"/TransOutConfirmOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransOutConfirmResult.Fetch(), "", "TransOutConfirm") + })) + app.POST(BusiAPI+"/TransInRevertOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransInRevertResult.Fetch(), "", "TransInRevert") + })) + app.POST(BusiAPI+"/TransOutRevertOld", oldWrapHandler(func(c *gin.Context) (interface{}, error) { + return handleGeneralBusinessCompatible(c, MainSwitch.TransOutRevertResult.Fetch(), "", "TransOutRevert") + })) + + app.GET(BusiAPI+"/QueryPrepared", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Debugf("%s QueryPrepared", c.Query("gid")) - return dtmimp.OrString(MainSwitch.QueryPreparedResult.Fetch(), dtmcli.ResultSuccess), nil + return dtmcli.String2DtmError(dtmimp.OrString(MainSwitch.QueryPreparedResult.Fetch(), dtmcli.ResultSuccess)) })) - app.GET(BusiAPI+"/QueryPreparedB", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.GET(BusiAPI+"/QueryPreparedB", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Debugf("%s QueryPreparedB", c.Query("gid")) bb := MustBarrierFromGin(c) db := dbGet().ToSQLDB() - return error2Resp(bb.QueryPrepared(db)) + return bb.QueryPrepared(db) })) - app.POST(BusiAPI+"/TransInXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + app.POST(BusiAPI+"/TransInXa", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { return SagaAdjustBalance(db, TransInUID, reqFrom(c).Amount, reqFrom(c).TransInResult) }) - return error2Resp(err) })) - app.POST(BusiAPI+"/TransOutXa", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + app.POST(BusiAPI+"/TransOutXa", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { return SagaAdjustBalance(db, TransOutUID, reqFrom(c).Amount, reqFrom(c).TransOutResult) }) - return error2Resp(err) })) - app.POST(BusiAPI+"/TransInTccParent", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransInTccNested", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { tcc, err := dtmcli.TccFromQuery(c.Request.URL.Query()) logger.FatalIfError(err) - logger.Debugf("TransInTccParent ") - return tcc.CallBranch(&TransReq{Amount: reqFrom(c).Amount}, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") + logger.Debugf("TransInTccNested ") + resp, err := tcc.CallBranch(&TransReq{Amount: reqFrom(c).Amount}, Busi+"/TransIn", Busi+"/TransInConfirm", Busi+"/TransInRevert") + if err != nil { + return err + } + return resp })) - app.POST(BusiAPI+"/TransOutXaGorm", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { - err := XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { + app.POST(BusiAPI+"/TransOutXaGorm", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { + return XaClient.XaLocalTransaction(c.Request.URL.Query(), func(db *sql.DB, xa *dtmcli.Xa) error { if reqFrom(c).TransOutResult == dtmcli.ResultFailure { return dtmcli.ErrFailure } @@ -146,32 +167,31 @@ func BaseAddRoute(app *gin.Engine) { dbr := gdb.Exec("update dtm_busi.user_account set balance=balance-? where user_id=?", reqFrom(c).Amount, TransOutUID) return dbr.Error }) - return error2Resp(err) })) - app.POST(BusiAPI+"/TestPanic", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TestPanic", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { if c.Query("panic_error") != "" { panic(errors.New("panic_error")) } else if c.Query("panic_string") != "" { panic("panic_string") } - return "SUCCESS", nil + return nil })) - app.POST(BusiAPI+"/TccBSleepCancel", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TccBSleepCancel", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { return sleepCancelHandler(c) })) - app.POST(BusiAPI+"/TransOutHeaderYes", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransOutHeaderYes", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { h := c.GetHeader("test_header") if h == "" { - return nil, errors.New("no test_header found in TransOutHeaderYes") + return errors.New("no test_header found in TransOutHeaderYes") } return handleGeneralBusiness(c, MainSwitch.TransOutResult.Fetch(), reqFrom(c).TransOutResult, "TransOut") })) - app.POST(BusiAPI+"/TransOutHeaderNo", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(BusiAPI+"/TransOutHeaderNo", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { h := c.GetHeader("test_header") if h != "" { - return nil, errors.New("test_header found in TransOutHeaderNo") + return errors.New("test_header found in TransOutHeaderNo") } - return dtmcli.MapSuccess, nil + return nil })) } diff --git a/test/busi/busi.go b/test/busi/busi.go index 7d6c0644b..5a4119296 100644 --- a/test/busi/busi.go +++ b/test/busi/busi.go @@ -29,25 +29,25 @@ func handleGrpcBusiness(in *BusiReq, result1 string, result2 string, busi string return status.New(codes.Internal, fmt.Sprintf("unknow result %s", res)).Err() } -func handleGeneralBusiness(c *gin.Context, result1 string, result2 string, busi string) (interface{}, error) { +func handleGeneralBusiness(c *gin.Context, result1 string, result2 string, busi string) interface{} { info := infoFromContext(c) res := dtmimp.OrString(result1, result2, dtmcli.ResultSuccess) logger.Debugf("%s %s result: %s", busi, info.String(), res) if res == "ERROR" { - return nil, errors.New("ERROR from user") + return errors.New("ERROR from user") } - return map[string]interface{}{"dtm_result": res}, nil + return dtmcli.String2DtmError(res) } -func error2Resp(err error) (interface{}, error) { - if err != nil { - s := err.Error() - if strings.Contains(s, dtmcli.ResultFailure) || strings.Contains(s, dtmcli.ResultOngoing) { - return gin.H{"dtm_result": s}, nil - } - return nil, err +// old business handler. for compatible usage +func handleGeneralBusinessCompatible(c *gin.Context, result1 string, result2 string, busi string) (interface{}, error) { + info := infoFromContext(c) + res := dtmimp.OrString(result1, result2, dtmcli.ResultSuccess) + logger.Debugf("%s %s result: %s", busi, info.String(), res) + if res == "ERROR" { + return nil, errors.New("ERROR from user") } - return gin.H{"dtm_result": dtmcli.ResultSuccess}, nil + return map[string]interface{}{"dtm_result": res}, nil } func sagaGrpcAdjustBalance(db dtmcli.DB, uid int, amount int64, result string) error { @@ -56,7 +56,6 @@ func sagaGrpcAdjustBalance(db dtmcli.DB, uid int, amount int64, result string) e } _, err := dtmimp.DBExec(db, "update dtm_busi.user_account set balance = balance + ? where user_id = ?", amount, uid) return err - } func SagaAdjustBalance(db dtmcli.DB, uid int, amount int, result string) error { diff --git a/test/busi/quick_start.go b/test/busi/quick_start.go index db3c1883a..da939ec34 100644 --- a/test/busi/quick_start.go +++ b/test/busi/quick_start.go @@ -49,20 +49,20 @@ func QsFireRequest() string { } func qsAddRoute(app *gin.Engine) { - app.POST(qsBusiAPI+"/TransIn", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(qsBusiAPI+"/TransIn", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Infof("TransIn") - return dtmcli.MapSuccess, nil + return nil })) - app.POST(qsBusiAPI+"/TransInCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(qsBusiAPI+"/TransInCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Infof("TransInCompensate") - return dtmcli.MapSuccess, nil + return nil })) - app.POST(qsBusiAPI+"/TransOut", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(qsBusiAPI+"/TransOut", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Infof("TransOut") - return dtmcli.MapSuccess, nil + return nil })) - app.POST(qsBusiAPI+"/TransOutCompensate", dtmutil.WrapHandler(func(c *gin.Context) (interface{}, error) { + app.POST(qsBusiAPI+"/TransOutCompensate", dtmutil.WrapHandler2(func(c *gin.Context) interface{} { logger.Infof("TransOutCompensate") - return dtmcli.MapSuccess, nil + return nil })) } diff --git a/test/busi/utils.go b/test/busi/utils.go index 5df72d1f6..bc5a77f2f 100644 --- a/test/busi/utils.go +++ b/test/busi/utils.go @@ -3,8 +3,10 @@ package busi import ( "context" "database/sql" + "encoding/json" "fmt" "strings" + "time" "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" @@ -83,3 +85,31 @@ func SetHttpHeaderForHeadersYes(c *resty.Client, r *resty.Request) error { } return nil } + +// oldWrapHandler old wrap handler for test use of dtm +func oldWrapHandler(fn func(*gin.Context) (interface{}, error)) gin.HandlerFunc { + return func(c *gin.Context) { + began := time.Now() + r, err := func() (r interface{}, rerr error) { + defer dtmimp.P2E(&rerr) + return fn(c) + }() + var b = []byte{} + if resp, ok := r.(*resty.Response); ok { // 如果是response,则取出body直接处理 + b = resp.Body() + } else if err == nil { + b, err = json.Marshal(r) + } + + if err != nil { + logger.Errorf("%2dms 500 %s %s %s %s", time.Since(began).Milliseconds(), err.Error(), c.Request.Method, c.Request.RequestURI, string(b)) + c.JSON(500, map[string]interface{}{"code": 500, "message": err.Error()}) + } else { + logger.Infof("%2dms 200 %s %s %s", time.Since(began).Milliseconds(), c.Request.Method, c.Request.RequestURI, string(b)) + c.Status(200) + c.Writer.Header().Add("Content-Type", "application/json") + _, err = c.Writer.Write(b) + dtmimp.E2P(err) + } + } +} diff --git a/test/msg_test.go b/test/msg_test.go index a9b115259..8f26ad879 100644 --- a/test/msg_test.go +++ b/test/msg_test.go @@ -45,9 +45,6 @@ func TestMsgTimeoutFailed(t *testing.T) { msg := genMsg(dtmimp.GetFuncName()) msg.Prepare("") assert.Equal(t, StatusPrepared, getTransStatus(msg.Gid)) - busi.MainSwitch.QueryPreparedResult.SetOnce("OTHER_ERROR") - cronTransOnceForwardNow(180) - assert.Equal(t, StatusPrepared, getTransStatus(msg.Gid)) busi.MainSwitch.QueryPreparedResult.SetOnce(dtmcli.ResultOngoing) cronTransOnceForwardNow(360) assert.Equal(t, StatusPrepared, getTransStatus(msg.Gid)) diff --git a/test/tcc_barrier_test.go b/test/tcc_barrier_test.go index 376b3ca02..2a4ec9409 100644 --- a/test/tcc_barrier_test.go +++ b/test/tcc_barrier_test.go @@ -63,14 +63,14 @@ func TestTccBarrierDisorder(t *testing.T) { cancelURL := Busi + "/TccBSleepCancel" // 请参见子事务屏障里的时序图,这里为了模拟该时序图,手动拆解了callbranch branchID := tcc.NewSubBranchID() - busi.SetSleepCancelHandler(func(c *gin.Context) (interface{}, error) { - res, err := busi.TccBarrierTransOutCancel(c) + busi.SetSleepCancelHandler(func(c *gin.Context) interface{} { + res := busi.TccBarrierTransOutCancel(c) logger.Debugf("disorderHandler before cancel finish write") cancelFinishedChan <- "1" logger.Debugf("disorderHandler before cancel return read") <-cancelCanReturnChan logger.Debugf("disorderHandler after cancel return read") - return res, err + return res }) // 注册子事务 resp, err := dtmimp.RestyClient.R(). diff --git a/test/tcc_cover_test.go b/test/tcc_cover_test.go index c8d93d898..00ba7115d 100644 --- a/test/tcc_cover_test.go +++ b/test/tcc_cover_test.go @@ -6,6 +6,7 @@ import ( "github.com/dtm-labs/dtm/dtmcli" "github.com/dtm-labs/dtm/dtmcli/dtmimp" "github.com/dtm-labs/dtm/dtmutil" + "github.com/dtm-labs/dtm/test/busi" "github.com/go-resty/resty/v2" "github.com/stretchr/testify/assert" ) @@ -29,3 +30,17 @@ func TestTccCoverPanic(t *testing.T) { assert.Contains(t, err.Error(), "user panic") waitTransProcessed(gid) } + +func TestTccNested(t *testing.T) { + req := busi.GenTransReq(30, false, false) + gid := dtmimp.GetFuncName() + err := dtmcli.TccGlobalTransaction(dtmutil.DefaultHTTPServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { + _, err := tcc.CallBranch(req, Busi+"/TransOut", Busi+"/TransOutConfirm", Busi+"/TransOutRevert") + assert.Nil(t, err) + return tcc.CallBranch(req, Busi+"/TransInTccNested", Busi+"/TransInConfirm", Busi+"/TransInRevert") + }) + assert.Nil(t, err) + waitTransProcessed(gid) + assert.Equal(t, StatusSucceed, getTransStatus(gid)) + assert.Equal(t, []string{StatusPrepared, StatusSucceed, StatusPrepared, StatusSucceed, StatusPrepared, StatusSucceed}, getBranchesStatus(gid)) +} diff --git a/test/tcc_old_test.go b/test/tcc_old_test.go new file mode 100644 index 000000000..9a55611bc --- /dev/null +++ b/test/tcc_old_test.go @@ -0,0 +1,66 @@ +package test + +import ( + "testing" + + "github.com/dtm-labs/dtm/dtmcli" + "github.com/dtm-labs/dtm/dtmcli/dtmimp" + "github.com/dtm-labs/dtm/dtmutil" + "github.com/dtm-labs/dtm/test/busi" + "github.com/go-resty/resty/v2" + "github.com/stretchr/testify/assert" +) + +func TestTccOldNormal(t *testing.T) { + req := busi.GenTransReq(30, false, false) + gid := dtmimp.GetFuncName() + err := dtmcli.TccGlobalTransaction(dtmutil.DefaultHTTPServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { + _, err := tcc.CallBranch(req, Busi+"/TransOutOld", Busi+"/TransOutConfirmOld", Busi+"/TransOutRevertOld") + assert.Nil(t, err) + return tcc.CallBranch(req, Busi+"/TransInOld", Busi+"/TransInConfirmOld", Busi+"/TransInRevertOld") + }) + assert.Nil(t, err) + waitTransProcessed(gid) + assert.Equal(t, StatusSucceed, getTransStatus(gid)) + assert.Equal(t, []string{StatusPrepared, StatusSucceed, StatusPrepared, StatusSucceed}, getBranchesStatus(gid)) +} + +func TestTccOldRollback(t *testing.T) { + gid := dtmimp.GetFuncName() + req := busi.GenTransReq(30, false, true) + err := dtmcli.TccGlobalTransaction(dtmutil.DefaultHTTPServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { + _, rerr := tcc.CallBranch(req, Busi+"/TransOutOld", Busi+"/TransOutConfirmOld", Busi+"/TransOutRevertOld") + assert.Nil(t, rerr) + busi.MainSwitch.TransOutRevertResult.SetOnce(dtmcli.ResultOngoing) + return tcc.CallBranch(req, Busi+"/TransInOld", Busi+"/TransInConfirmOld", Busi+"/TransInRevertOld") + }) + assert.Error(t, err) + waitTransProcessed(gid) + assert.Equal(t, StatusAborting, getTransStatus(gid)) + g := cronTransOnce() + assert.Equal(t, gid, g) + assert.Equal(t, StatusFailed, getTransStatus(gid)) + assert.Equal(t, []string{StatusSucceed, StatusPrepared, StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) +} + +func TestTccOldTimeout(t *testing.T) { + req := busi.GenTransReq(30, false, false) + gid := dtmimp.GetFuncName() + timeoutChan := make(chan int, 1) + + err := dtmcli.TccGlobalTransaction(dtmutil.DefaultHTTPServer, gid, func(tcc *dtmcli.Tcc) (*resty.Response, error) { + _, err := tcc.CallBranch(req, Busi+"/TransOutOld", Busi+"/TransOutConfirmOld", Busi+"/TransOutRevertOld") + assert.Nil(t, err) + go func() { + cronTransOnceForwardNow(300) + timeoutChan <- 0 + }() + <-timeoutChan + _, err = tcc.CallBranch(req, Busi+"/TransInOld", Busi+"/TransInConfirmOld", Busi+"/TransInRevertOld") + assert.Error(t, err) + return nil, err + }) + assert.Error(t, err) + assert.Equal(t, StatusFailed, getTransStatus(gid)) + assert.Equal(t, []string{StatusSucceed, StatusPrepared}, getBranchesStatus(gid)) +}