Skip to content

Commit

Permalink
Merge pull request dtm-labs#145 from dtm-labs/alpha
Browse files Browse the repository at this point in the history
support headers
  • Loading branch information
yedf2 authored Jan 1, 2022
2 parents 70fba0c + aae384e commit 25b7317
Show file tree
Hide file tree
Showing 54 changed files with 781 additions and 263 deletions.
8 changes: 4 additions & 4 deletions bench/svr/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ const total = 200000
var benchPort = dtmimp.If(os.Getenv("BENCH_PORT") == "", "8083", os.Getenv("BENCH_PORT")).(string)
var benchBusi = fmt.Sprintf("http://localhost:%s%s", benchPort, benchAPI)

func sdbGet() *sql.DB {
func pdbGet() *sql.DB {
db, err := dtmimp.PooledDB(busi.BusiConf)
logger.FatalIfError(err)
return db
}

func txGet() *sql.Tx {
db := sdbGet()
db := pdbGet()
tx, err := db.Begin()
logger.FatalIfError(err)
return tx
Expand All @@ -49,7 +49,7 @@ func txGet() *sql.Tx {
func reloadData() {
time.Sleep(dtmsvr.UpdateBranchAsyncInterval * 2)
began := time.Now()
db := sdbGet()
db := pdbGet()
tables := []string{"dtm_busi.user_account", "dtm_busi.user_account_log", "dtm.trans_global", "dtm.trans_branch_op", "dtm_barrier.barrier"}
for _, t := range tables {
_, err := dtmimp.DBExec(db, fmt.Sprintf("truncate %s", t))
Expand All @@ -70,7 +70,7 @@ var mode string = ""
var sqls int = 1

func PrepareBenchDB() {
db := sdbGet()
db := pdbGet()
_, err := dtmimp.DBExec(db, "drop table if exists dtm_busi.user_account_log")
logger.FatalIfError(err)
_, err = dtmimp.DBExec(db, `create table if not exists dtm_busi.user_account_log (
Expand Down
11 changes: 10 additions & 1 deletion conf.sample.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#####################################################################
### dtm can be run without any config.
### all config in this file is optional. the default value is as specified in each line
### all configs can be specified from env. for example:
### Store.MaxOpenConns can also specified from env: STORE_MAX_OPEN_CONNS
#####################################################################

# Store: # specify which engine to store trans status
# Driver: 'boltdb' # default store engine

Expand All @@ -19,10 +26,12 @@
# Password: 'mysecretpassword'
# Port: '5432'

### following connection config is for only Driver postgres/mysql
### following config is for only Driver postgres/mysql
# MaxOpenConns: 500
# MaxIdleConns: 500
# ConnMaxLifeTime 5 # default value is 5 (minutes)
# TransGlobalTable: 'dtm.trans_global'
# TransBranchOp: 'dtm.trans_branch_op'

### flollowing config is only for some Driver
# DataExpire: 604800 # Trans data will expire in 7 days. only for redis/boltdb.
Expand Down
2 changes: 1 addition & 1 deletion dtmcli/barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func insertBarrier(tx DB, transType string, gid string, branchID string, op stri
if op == "" {
return 0, nil
}
sql := dtmimp.GetDBSpecial().GetInsertIgnoreTemplate("dtm_barrier.barrier(trans_type, gid, branch_id, op, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier")
sql := dtmimp.GetDBSpecial().GetInsertIgnoreTemplate(dtmimp.BarrierTableName+"(trans_type, gid, branch_id, op, barrier_id, reason) values(?,?,?,?,?,?)", "uniq_barrier")
return dtmimp.DBExec(tx, sql, transType, gid, branchID, op, barrierID, reason)
}

Expand Down
23 changes: 11 additions & 12 deletions dtmcli/dtmimp/trans_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ func (g *BranchIDGen) CurrentSubBranchID() string {

// TransOptions transaction options
type TransOptions struct {
WaitResult bool `json:"wait_result,omitempty" gorm:"-"`
TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc
RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc
WaitResult bool `json:"wait_result,omitempty" gorm:"-"`
TimeoutToFail int64 `json:"timeout_to_fail,omitempty" gorm:"-"` // for trans type: xa, tcc
RetryInterval int64 `json:"retry_interval,omitempty" gorm:"-"` // for trans type: msg saga xa tcc
PassthroughHeaders []string `json:"passthrough_headers,omitempty" gorm:"-"`
BranchHeaders map[string]string `json:"branch_headers,omitempty" gorm:"-"`
}

// TransBase base for all trans
Expand All @@ -62,18 +64,14 @@ type TransBase struct {
QueryPrepared string `json:"query_prepared,omitempty"` // used in MSG
}

// SetOptions set options
func (tb *TransBase) SetOptions(options *TransOptions) {
tb.TransOptions = *options
}

// NewTransBase new a TransBase
func NewTransBase(gid string, transType string, dtm string, branchID string) *TransBase {
return &TransBase{
Gid: gid,
TransType: transType,
BranchIDGen: BranchIDGen{BranchID: branchID},
Dtm: dtm,
Gid: gid,
TransType: transType,
BranchIDGen: BranchIDGen{BranchID: branchID},
Dtm: dtm,
TransOptions: TransOptions{PassthroughHeaders: PassthroughHeaders},
}
}

Expand Down Expand Up @@ -118,6 +116,7 @@ func TransRequestBranch(t *TransBase, body interface{}, branchID string, op stri
"trans_type": t.TransType,
"op": op,
}).
SetHeaders(t.BranchHeaders).
Post(url)
return resp, CheckResponse(resp, err)
}
8 changes: 7 additions & 1 deletion dtmcli/dtmimp/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@ var MapFailure = map[string]interface{}{"dtm_result": ResultFailure}
// RestyClient the resty object
var RestyClient = resty.New()

// PassthroughHeaders will be passed to every sub-trans call
var PassthroughHeaders = []string{}

// BarrierTableName the table name of barrier table
var BarrierTableName = "dtm_barrier.barrier"

func init() {
// RestyClient.SetTimeout(3 * time.Second)
// RestyClient.SetRetryCount(2)
// RestyClient.SetRetryWaitTime(1 * time.Second)
RestyClient.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error {
r.URL = MayReplaceLocalhost(r.URL)
logger.Debugf("requesting: %s %s %v %v", r.Method, r.URL, r.Body, r.QueryParam)
logger.Debugf("requesting: %s %s %s", r.Method, r.URL, MustMarshalString(r.Body))
return nil
})
RestyClient.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error {
Expand Down
6 changes: 6 additions & 0 deletions dtmcli/tcc.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ type TccGlobalFunc func(tcc *Tcc) (*resty.Response, error)
// gid global transaction ID
// tccFunc tcc事务函数,里面会定义全局事务的分支
func TccGlobalTransaction(dtm string, gid string, tccFunc TccGlobalFunc) (rerr error) {
return TccGlobalTransaction2(dtm, gid, func(t *Tcc) {}, tccFunc)
}

// TccGlobalTransaction2 new version of TccGlobalTransaction, add custom param
func TccGlobalTransaction2(dtm string, gid string, custom func(*Tcc), tccFunc TccGlobalFunc) (rerr error) {
tcc := &Tcc{TransBase: *dtmimp.NewTransBase(gid, "tcc", dtm, "")}
custom(tcc)
rerr = dtmimp.TransCallDtm(&tcc.TransBase, tcc, "prepare")
if rerr != nil {
return rerr
Expand Down
23 changes: 23 additions & 0 deletions dtmcli/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"

"github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/go-resty/resty/v2"
)

// MustGenGid generate a new gid
Expand Down Expand Up @@ -49,3 +50,25 @@ func SetXaSqlTimeoutMs(ms int) {
func GetXaSqlTimeoutMs() int {
return dtmimp.XaSqlTimeoutMs
}

func SetBarrierTableName(tablename string) {
dtmimp.BarrierTableName = tablename
}

// OnBeforeRequest add before request middleware
func OnBeforeRequest(middleware func(c *resty.Client, r *resty.Request) error) {
dtmimp.RestyClient.OnBeforeRequest(middleware)
}

// OnAfterResponse add after request middleware
func OnAfterResponse(middleware func(c *resty.Client, resp *resty.Response) error) {
dtmimp.RestyClient.OnAfterResponse(middleware)
}

// SetPassthroughHeaders experimental.
// apply to http header and grpc metadata
// dtm server will save these headers in trans creating request.
// and then passthrough them to sub-trans
func SetPassthroughHeaders(headers []string) {
dtmimp.PassthroughHeaders = headers
}
1 change: 1 addition & 0 deletions dtmcli/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ func TestTypes(t *testing.T) {
func TestXaSqlTimeout(t *testing.T) {
old := GetXaSqlTimeoutMs()
SetXaSqlTimeoutMs(old)
SetBarrierTableName(dtmimp.BarrierTableName) // just cover this func
}
12 changes: 9 additions & 3 deletions dtmcli/xa.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,17 @@ func (xc *XaClient) XaLocalTransaction(qs url.Values, xaFunc XaLocalFunc) error

// XaGlobalTransaction start a xa global transaction
func (xc *XaClient) XaGlobalTransaction(gid string, xaFunc XaGlobalFunc) (rerr error) {
xa := Xa{TransBase: *dtmimp.NewTransBase(gid, "xa", xc.XaClientBase.Server, "")}
return xc.XaGlobalTransaction2(gid, func(x *Xa) {}, xaFunc)
}

// XaGlobalTransaction start a xa global transaction
func (xc *XaClient) XaGlobalTransaction2(gid string, custom func(*Xa), xaFunc XaGlobalFunc) (rerr error) {
xa := &Xa{TransBase: *dtmimp.NewTransBase(gid, "xa", xc.XaClientBase.Server, "")}
custom(xa)
return xc.HandleGlobalTrans(&xa.TransBase, func(action string) error {
return dtmimp.TransCallDtm(&xa.TransBase, &xa, action)
return dtmimp.TransCallDtm(&xa.TransBase, xa, action)
}, func() error {
_, rerr := xaFunc(&xa)
_, rerr := xaFunc(xa)
return rerr
})
}
Expand Down
7 changes: 6 additions & 1 deletion dtmgrpc/dtmgimp/grpc_clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/dtm-labs/dtm/dtmcli/dtmimp"
"github.com/dtm-labs/dtm/dtmcli/logger"
"github.com/dtm-labs/dtm/dtmgrpc/dtmgpb"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc "google.golang.org/grpc"
)

Expand All @@ -35,6 +36,8 @@ func (cb rawCodec) Name() string { return "dtm_raw" }

var normalClients, rawClients sync.Map

var ClientInterceptors = []grpc.UnaryClientInterceptor{}

// MustGetDtmClient 1
func MustGetDtmClient(grpcServer string) dtmgpb.DtmClient {
return dtmgpb.NewDtmClient(MustGetGrpcConn(grpcServer, false))
Expand All @@ -59,7 +62,9 @@ func GetGrpcConn(grpcServer string, isRaw bool) (conn *grpc.ClientConn, rerr err
opts = grpc.WithDefaultCallOptions(grpc.ForceCodec(rawCodec{}))
}
logger.Debugf("grpc client connecting %s", grpcServer)
conn, rerr := grpc.Dial(grpcServer, grpc.WithInsecure(), grpc.WithUnaryInterceptor(GrpcClientLog), opts)
interceptors := append(ClientInterceptors, GrpcClientLog)
inOpt := grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(interceptors...))
conn, rerr := grpc.Dial(grpcServer, inOpt, grpc.WithInsecure(), opts)
if rerr == nil {
clients.Store(grpcServer, conn)
v = conn
Expand Down
45 changes: 33 additions & 12 deletions dtmgrpc/dtmgimp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ func DtmGrpcCall(s *dtmimp.TransBase, operation string) error {
Gid: s.Gid,
TransType: s.TransType,
TransOptions: &dtmgpb.DtmTransOptions{
WaitResult: s.WaitResult,
TimeoutToFail: s.TimeoutToFail,
RetryInterval: s.RetryInterval,
WaitResult: s.WaitResult,
TimeoutToFail: s.TimeoutToFail,
RetryInterval: s.RetryInterval,
PassthroughHeaders: s.PassthroughHeaders,
BranchHeaders: s.BranchHeaders,
},
QueryPrepared: s.QueryPrepared,
CustomedData: s.CustomData,
Expand All @@ -42,20 +44,29 @@ func DtmGrpcCall(s *dtmimp.TransBase, operation string) error {
}, &reply)
}

const mdpre string = "dtm-"
const dtmpre string = "dtm-"

// TransInfo2Ctx add trans info to grpc context
func TransInfo2Ctx(gid, transType, branchID, op, dtm string) context.Context {
md := metadata.Pairs(
mdpre+"gid", gid,
mdpre+"trans_type", transType,
mdpre+"branch_id", branchID,
mdpre+"op", op,
mdpre+"dtm", dtm,
dtmpre+"gid", gid,
dtmpre+"trans_type", transType,
dtmpre+"branch_id", branchID,
dtmpre+"op", op,
dtmpre+"dtm", dtm,
)
return metadata.NewOutgoingContext(context.Background(), md)
}

// Map2Kvs map to metadata kv
func Map2Kvs(m map[string]string) []string {
kvs := []string{}
for k, v := range m {
kvs = append(kvs, k, v)
}
return kvs
}

// LogDtmCtx logout dtm info in context metadata
func LogDtmCtx(ctx context.Context) {
tb := TransBaseFromGrpc(ctx)
Expand All @@ -64,8 +75,12 @@ func LogDtmCtx(ctx context.Context) {
}
}

func dtmGet(md metadata.MD, key string) string {
return mdGet(md, dtmpre+key)
}

func mdGet(md metadata.MD, key string) string {
v := md.Get(mdpre + key)
v := md.Get(key)
if len(v) == 0 {
return ""
}
Expand All @@ -75,7 +90,13 @@ func mdGet(md metadata.MD, key string) string {
// TransBaseFromGrpc get trans base info from a context metadata
func TransBaseFromGrpc(ctx context.Context) *dtmimp.TransBase {
md, _ := metadata.FromIncomingContext(ctx)
tb := dtmimp.NewTransBase(mdGet(md, "gid"), mdGet(md, "trans_type"), mdGet(md, "dtm"), mdGet(md, "branch_id"))
tb.Op = mdGet(md, "op")
tb := dtmimp.NewTransBase(dtmGet(md, "gid"), dtmGet(md, "trans_type"), dtmGet(md, "dtm"), dtmGet(md, "branch_id"))
tb.Op = dtmGet(md, "op")
return tb
}

// GetMetaFromContext get header from context
func GetMetaFromContext(ctx context.Context, name string) string {
md, _ := metadata.FromIncomingContext(ctx)
return mdGet(md, name)
}
Loading

0 comments on commit 25b7317

Please sign in to comment.