diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index bbf138496a..b014e09963 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -22,7 +22,6 @@ package gremlingo import ( "crypto/tls" "runtime" - "sync" "time" "golang.org/x/text/language" @@ -62,7 +61,7 @@ type Client struct { connections connectionPool session string connectionSettings *connectionSettings - protocol protocol + httpProtocol *httpProtocol } // NewClient creates a Client and configures it with the given parameters. During creation of the Client, a connection @@ -71,7 +70,6 @@ type Client struct { func NewClient(url string, configurations ...func(settings *ClientSettings)) (*Client, error) { settings := &ClientSettings{ TraversalSource: "g", - TransporterType: Http, LogVerbosity: Info, Logger: &defaultLogger{}, Language: language.English, @@ -115,16 +113,8 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C settings.MaximumConcurrentConnections, settings.MaximumConcurrentConnections) settings.InitialConcurrentConnections = settings.MaximumConcurrentConnections } - pool, err := newLoadBalancingPool(url, logHandler, connSettings, settings.NewConnectionThreshold, - settings.MaximumConcurrentConnections, settings.InitialConcurrentConnections) - if err != nil { - if err != nil { - logHandler.logf(Error, logErrorGeneric, "NewClient", err.Error()) - } - return nil, err - } - prot, err := newHttpProtocol(logHandler, url, connSettings) + httpProt, err := newHttpProtocol(logHandler, url, connSettings) if err != nil { return nil, err } @@ -134,10 +124,10 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C traversalSource: settings.TraversalSource, logHandler: logHandler, transporterType: settings.TransporterType, - connections: pool, + connections: nil, session: "", connectionSettings: connSettings, - protocol: prot, + httpProtocol: httpProt, } return client, nil @@ -155,7 +145,7 @@ func (client *Client) Close() { client.session = "" } client.logHandler.logf(Info, closeClient, client.url) - client.connections.close() + //client.connections.close() } func (client *Client) errorCallback() { @@ -167,17 +157,7 @@ func (client *Client) SubmitWithOptions(traversalString string, requestOptions R client.logHandler.logf(Debug, submitStartedString, traversalString) request := makeStringRequest(traversalString, client.traversalSource, client.session, requestOptions) - // write and send request - err := client.protocol.write(&request) - if err != nil { - return nil, err - } - results := &synchronizedMap{map[string]ResultSet{}, sync.Mutex{}} - rs := newChannelResultSet(request.requestID.String(), results) - results.store(request.requestID.String(), rs) - // read and handle response - client.protocol.readLoop(results, client.errorCallback) - // return the ResultSet from which the caller can obtain result(s) + rs, err := client.httpProtocol.send(&request) return rs, err } diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index cf08c750b6..a721044340 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -108,7 +108,7 @@ func createConnection(url string, logHandler *logHandler, connSettings *connecti initialized, } logHandler.log(Info, connectConnection) - protocol, err := newHttpProtocol(logHandler, url, connSettings) + protocol, err := newGremlinServerWSProtocol(logHandler, Gorilla, url, connSettings, conn.results, conn.errorCallback) if err != nil { logHandler.logf(Warning, failedConnection) conn.state = closedDueToError diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index bd14da458b..f41bfe8d4d 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -608,24 +608,27 @@ func TestConnection(t *testing.T) { assert.NotNil(t, client) defer client.Close() - resultSet, err := client.Submit("g.V().count()") - assert.Nil(t, err) - assert.NotNil(t, resultSet) - result, ok, err := resultSet.One() - assert.Nil(t, err) - assert.True(t, ok) - assert.NotNil(t, result) - _, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result) + var wg sync.WaitGroup + + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + resultSet, err := client.Submit("g.V().count().as('c').math('c + " + strconv.Itoa(i) + "')") + assert.Nil(t, err) + assert.NotNil(t, resultSet) + result, ok, err := resultSet.One() + assert.Nil(t, err) + assert.True(t, ok) + assert.NotNil(t, result) + c, err := result.GetInt() + assert.Equal(t, 6+i, c) + _, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result) + }(i) + } + + wg.Wait() - // submit 2nd request - resultSet, err = client.Submit("g.V().count()") - assert.Nil(t, err) - assert.NotNil(t, resultSet) - result, ok, err = resultSet.One() - assert.Nil(t, err) - assert.True(t, ok) - assert.NotNil(t, result) - _, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result) // //g := cloneGraphTraversalSource(&Graph{}, NewBytecode(nil), nil) //b := g.V().Count().Bytecode diff --git a/gremlin-go/driver/httpProtocol.go b/gremlin-go/driver/httpProtocol.go index c6d50cf667..4f3253d995 100644 --- a/gremlin-go/driver/httpProtocol.go +++ b/gremlin-go/driver/httpProtocol.go @@ -23,56 +23,87 @@ import ( "encoding/base64" "fmt" "net/http" + "os" + "sync" ) type httpProtocol struct { - *protocolBase + serializer serializer + logHandler *logHandler + url string + connSettings *connectionSettings +} - serializer serializer - logHandler *logHandler +func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) (*httpProtocol, error) { + httpProt := &httpProtocol{ + serializer: newGraphBinarySerializer(handler), + logHandler: handler, + url: url, + connSettings: connSettings, + } + return httpProt, nil } -// waits for response, deserializes and processes results -// function name is readLoop but is not actually a loop - just keeping the name due to the protocol interface for now -func (protocol *httpProtocol) readLoop(resultSets *synchronizedMap, errorCallback func()) { - msg, err := protocol.transporter.Read() +func (protocol *httpProtocol) send(request *request) (ResultSet, error) { + // TODO remove need for result set container map + results := &synchronizedMap{map[string]ResultSet{}, sync.Mutex{}} + rs := newChannelResultSet(request.requestID.String(), results) + results.store(request.requestID.String(), rs) - // Deserialize response message - fmt.Println("Deserializing response") - resp, err := protocol.serializer.deserializeMessage(msg) + fmt.Println("Serializing request") + bytes, err := protocol.serializer.serializeMessage(request) if err != nil { - protocol.logHandler.logf(Error, logErrorGeneric, "httpReadLoop()", err.Error()) - readErrorHandler(resultSets, errorCallback, err, protocol.logHandler) - return + return nil, err } - fmt.Println("Deserialized response") - // TODO we should not need to use response/request ids to correlate responses with requests anymore after moving from ws to http - // but for simplicity of http POC we are just setting the responseId to the requestId here - resp.responseID = protocol.request.requestID - err = protocol.responseHandler(resultSets, resp) - if err != nil { - readErrorHandler(resultSets, errorCallback, err, protocol.logHandler) - return - } + transport := NewHttpTransporter(protocol.url, protocol.connSettings) + + // async send request and wait for response + transport.wg.Add(1) + go func() { + defer transport.wg.Done() + err := transport.Write(bytes) + if err != nil { + protocol.errorCallback() + } + }() + + // async receive response msg data + transport.wg.Add(1) + go func() { + defer transport.wg.Done() + msg, err := transport.Read() + if err != nil { + protocol.errorCallback() + } else { + protocol.receive(rs, msg, protocol.errorCallback) + } + err = transport.Close() + }() + + return rs, err } -func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) (protocol, error) { - transport, err := getTransportLayer(Http, url, connSettings, handler) +func (protocol *httpProtocol) receive(rs ResultSet, msg []byte, errorCallback func()) { + fmt.Println("Deserializing response") + resp, err := protocol.serializer.deserializeMessage(msg) if err != nil { - return nil, err + protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error()) + rs.Close() + return } - gremlinProtocol := &httpProtocol{ - protocolBase: &protocolBase{transporter: transport}, - serializer: newGraphBinarySerializer(handler), - logHandler: handler, + fmt.Println("Handling response") + err = protocol.handleResponse(rs, resp) + if err != nil { + protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error()) + rs.Close() + errorCallback() + return } - return gremlinProtocol, nil } -// loads results into the response set from the response -func (protocol *httpProtocol) responseHandler(resultSets *synchronizedMap, response response) error { +func (protocol *httpProtocol) handleResponse(rs ResultSet, response response) error { fmt.Println("Handling response") // TODO http specific response handling - below is just copy-pasted from web socket implementation for now @@ -80,31 +111,31 @@ func (protocol *httpProtocol) responseHandler(resultSets *synchronizedMap, respo responseID, statusCode, metadata, data := response.responseID, response.responseStatus.code, response.responseResult.meta, response.responseResult.data responseIDString := responseID.String() - if resultSets.load(responseIDString) == nil { + if rs == nil { return newError(err0501ResponseHandlerResultSetNotCreatedError) } if aggregateTo, ok := metadata["aggregateTo"]; ok { - resultSets.load(responseIDString).setAggregateTo(aggregateTo.(string)) + rs.setAggregateTo(aggregateTo.(string)) } // Handle status codes appropriately. If status code is http.StatusPartialContent, we need to re-read data. if statusCode == http.StatusNoContent { - resultSets.load(responseIDString).addResult(&Result{make([]interface{}, 0)}) - resultSets.load(responseIDString).Close() + rs.addResult(&Result{make([]interface{}, 0)}) + rs.Close() protocol.logHandler.logf(Debug, readComplete, responseIDString) } else if statusCode == http.StatusOK { // Add data and status attributes to the ResultSet. - resultSets.load(responseIDString).addResult(&Result{data}) - resultSets.load(responseIDString).setStatusAttributes(response.responseStatus.attributes) - resultSets.load(responseIDString).Close() + rs.addResult(&Result{data}) + rs.setStatusAttributes(response.responseStatus.attributes) + rs.Close() protocol.logHandler.logf(Debug, readComplete, responseIDString) } else if statusCode == http.StatusPartialContent { // Add data to the ResultSet. - resultSets.load(responseIDString).addResult(&Result{data}) + rs.addResult(&Result{data}) } else if statusCode == http.StatusProxyAuthRequired || statusCode == authenticationFailed { // http status code 151 is not defined here, but corresponds with 403, i.e. authentication has failed. // Server has requested basic auth. - authInfo := protocol.transporter.getAuthInfo() + authInfo := protocol.getAuthInfo() if ok, username, password := authInfo.GetBasicAuth(); ok { authBytes := make([]byte, 0) authBytes = append(authBytes, 0) @@ -113,36 +144,31 @@ func (protocol *httpProtocol) responseHandler(resultSets *synchronizedMap, respo authBytes = append(authBytes, []byte(password)...) encoded := base64.StdEncoding.EncodeToString(authBytes) request := makeBasicAuthRequest(encoded) - err := protocol.write(&request) + // TODO retry + _, err := fmt.Fprintf(os.Stdout, "Skipping retry of failed request : %s\n", request.requestID) if err != nil { return err } } else { - resultSets.load(responseIDString).Close() + rs.Close() return newError(err0503ResponseHandlerAuthError, response.responseStatus, response.responseResult) } } else { newError := newError(err0502ResponseHandlerReadLoopError, response.responseStatus, statusCode) - resultSets.load(responseIDString).setError(newError) - resultSets.load(responseIDString).Close() + rs.setError(newError) + rs.Close() protocol.logHandler.logf(Error, logErrorGeneric, "httpProtocol.responseHandler()", newError.Error()) } return nil } -// serializes and sends the request -func (protocol *httpProtocol) write(request *request) error { - protocol.request = request - // TODO interceptors - fmt.Println("Serializing request") - bytes, err := protocol.serializer.serializeMessage(request) - if err != nil { - return err +func (protocol *httpProtocol) getAuthInfo() AuthInfoProvider { + if protocol.connSettings.authInfo == nil { + return NoopAuthInfo } - return protocol.transporter.Write(bytes) - + return protocol.connSettings.authInfo } -func (protocol *httpProtocol) close(wait bool) error { - return nil +func (protocol *httpProtocol) errorCallback() { + protocol.logHandler.log(Error, errorCallback) } diff --git a/gremlin-go/driver/httpTransporter.go b/gremlin-go/driver/httpTransporter.go index 2ebe089765..767c441f42 100644 --- a/gremlin-go/driver/httpTransporter.go +++ b/gremlin-go/driver/httpTransporter.go @@ -30,14 +30,16 @@ import ( "net/http" "net/url" "os" + "sync" ) type HttpTransporter struct { url string isClosed bool connSettings *connectionSettings - responseChannel chan []byte + responseChannel chan []byte // response channel needs to be per request, not per client client http.Client + wg *sync.WaitGroup } func NewHttpTransporter(url string, connSettings *connectionSettings) *HttpTransporter { @@ -53,20 +55,17 @@ func NewHttpTransporter(url string, connSettings *connectionSettings) *HttpTrans Timeout: connSettings.connectionTimeout, } + wg := &sync.WaitGroup{} + return &HttpTransporter{ url: url, connSettings: connSettings, responseChannel: make(chan []byte, writeChannelSizeDefault), client: c, + wg: wg, } } -func (transporter *HttpTransporter) Connect() (err error) { - // http transporter delegates connection management to the http client - // TODO verify that connections are being reused and cleaned up when appropriate - return -} - func (transporter *HttpTransporter) Write(data []byte) error { fmt.Println("Sending request message") u, err := url.Parse(transporter.url) @@ -127,13 +126,6 @@ func (transporter *HttpTransporter) Write(data []byte) error { return nil } -func (transporter *HttpTransporter) getAuthInfo() AuthInfoProvider { - if transporter.connSettings.authInfo == nil { - return NoopAuthInfo - } - return transporter.connSettings.authInfo -} - func (transporter *HttpTransporter) Read() ([]byte, error) { fmt.Println("Reading from responseChannel") msg, ok := <-transporter.responseChannel @@ -153,7 +145,3 @@ func (transporter *HttpTransporter) Close() (err error) { } return } - -func (transporter *HttpTransporter) IsClosed() bool { - return transporter.isClosed -} diff --git a/gremlin-go/driver/protocol.go b/gremlin-go/driver/protocol.go index 5ebb194aba..2f29051831 100644 --- a/gremlin-go/driver/protocol.go +++ b/gremlin-go/driver/protocol.go @@ -37,7 +37,6 @@ const authenticationFailed = uint16(151) type protocolBase struct { protocol - request *request transporter transporter } diff --git a/gremlin-go/driver/transporterFactory.go b/gremlin-go/driver/transporterFactory.go index 27f60cc9a2..6177120ec4 100644 --- a/gremlin-go/driver/transporterFactory.go +++ b/gremlin-go/driver/transporterFactory.go @@ -29,7 +29,6 @@ type TransporterType int const ( // Gorilla transport layer: github.com/gorilla/websocket Gorilla TransporterType = iota + 1 - Http TransporterType = iota + 1 ) func getTransportLayer(transporterType TransporterType, url string, connSettings *connectionSettings, logHandler *logHandler) (transporter, error) { @@ -43,10 +42,12 @@ func getTransportLayer(transporterType TransporterType, url string, connSettings writeChannel: make(chan []byte, writeChannelSizeDefault), wg: &sync.WaitGroup{}, } - case Http: - transporter = NewHttpTransporter(url, connSettings) default: return nil, newError(err0801GetTransportLayerNoTypeError) } + err := transporter.Connect() + if err != nil { + return nil, err + } return transporter, nil }