diff --git a/src/backend/txn_test.go b/src/backend/txn_test.go index ba31fc52..db9096e1 100644 --- a/src/backend/txn_test.go +++ b/src/backend/txn_test.go @@ -284,13 +284,13 @@ func TestTxnExecuteReplicaError(t *testing.T) { func TestTxnExecuteStreamFetch(t *testing.T) { defer leaktest.Check(t)() log := xlog.NewStdLog(xlog.Level(xlog.PANIC)) - fakedb, txnMgr, backends, addrs, cleanup := MockTxnMgr(log, 2) + fakedb, txnMgr, backends, addrs, cleanup := MockTxnMgrWithReplica(log, 2) defer cleanup() querys := []xcontext.QueryTuple{ - xcontext.QueryTuple{Query: "select * from node1", Backend: addrs[0]}, - xcontext.QueryTuple{Query: "select * from node2", Backend: addrs[1]}, - xcontext.QueryTuple{Query: "select * from node3", Backend: addrs[1]}, + {Query: "select * from node1", Backend: addrs[0]}, + {Query: "select * from node2", Backend: addrs[1]}, + {Query: "select * from node3", Backend: addrs[1]}, } result11 := &sqltypes.Result{ @@ -360,6 +360,33 @@ func TestTxnExecuteStreamFetch(t *testing.T) { assert.Equal(t, want, got) } + // loadbalance=1. + { + fakedb.AddQueryStream(querys[0].Query, result11) + fakedb.AddQueryStream(querys[1].Query, result12) + fakedb.AddQueryStream(querys[2].Query, result12) + + txn, err := txnMgr.CreateTxn(backends) + assert.Nil(t, err) + defer txn.Finish() + txn.SetIsExecOnRep(true) + + rctx := &xcontext.RequestContext{ + Querys: querys, + } + + callbackQr := &sqltypes.Result{} + err = txn.ExecuteStreamFetch(rctx, func(qr *sqltypes.Result) error { + callbackQr.AppendResult(qr) + return nil + }, 1024*1024) + assert.Nil(t, err) + + want := len(result11.Rows) + 2*len(result12.Rows) + got := len(callbackQr.Rows) + assert.Equal(t, want, got) + } + // execute error. { fakedb.AddQueryError(querys[0].Query, errors.New("mock.stream.query.error")) diff --git a/src/proxy/execute.go b/src/proxy/execute.go index e5fda2ab..eb1c96a9 100644 --- a/src/proxy/execute.go +++ b/src/proxy/execute.go @@ -166,6 +166,7 @@ func (spanner *Spanner) executeWithTimeout(session *driver.Session, database str // ExecuteStreamFetch used to execute a stream fetch query. func (spanner *Spanner) ExecuteStreamFetch(session *driver.Session, database string, query string, node sqlparser.Statement, callback func(qr *sqltypes.Result) error) error { log := spanner.log + conf := spanner.conf router := spanner.router scatter := spanner.scatter sessions := spanner.sessions @@ -178,6 +179,8 @@ func (spanner *Spanner) ExecuteStreamFetch(session *driver.Session, database str } defer txn.Finish() + txn.SetIsExecOnRep(conf.Proxy.LoadBalance != 0) + // binding. sessions.TxnBinding(session, txn, node, query) defer sessions.TxnUnBinding(session)