Skip to content

Commit

Permalink
Fix concurrency by creating new http transport for each query. Remove…
Browse files Browse the repository at this point in the history
…d unnecessary code.
  • Loading branch information
andreachild committed Jan 10, 2025
1 parent 82bc36c commit 8f1e2a8
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 123 deletions.
32 changes: 6 additions & 26 deletions gremlin-go/driver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package gremlingo
import (
"crypto/tls"
"runtime"
"sync"
"time"

"golang.org/x/text/language"
Expand Down Expand Up @@ -62,7 +61,7 @@ type Client struct {
connections connectionPool
session string
connectionSettings *connectionSettings
protocol protocol
httpProtocol *httpProtocol
}

// NewClient creates a Client and configures it with the given parameters. During creation of the Client, a connection
Expand All @@ -71,7 +70,6 @@ type Client struct {
func NewClient(url string, configurations ...func(settings *ClientSettings)) (*Client, error) {
settings := &ClientSettings{
TraversalSource: "g",
TransporterType: Http,
LogVerbosity: Info,
Logger: &defaultLogger{},
Language: language.English,
Expand Down Expand Up @@ -115,16 +113,8 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C
settings.MaximumConcurrentConnections, settings.MaximumConcurrentConnections)
settings.InitialConcurrentConnections = settings.MaximumConcurrentConnections
}
pool, err := newLoadBalancingPool(url, logHandler, connSettings, settings.NewConnectionThreshold,
settings.MaximumConcurrentConnections, settings.InitialConcurrentConnections)
if err != nil {
if err != nil {
logHandler.logf(Error, logErrorGeneric, "NewClient", err.Error())
}
return nil, err
}

prot, err := newHttpProtocol(logHandler, url, connSettings)
httpProt, err := newHttpProtocol(logHandler, url, connSettings)
if err != nil {
return nil, err
}
Expand All @@ -134,10 +124,10 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C
traversalSource: settings.TraversalSource,
logHandler: logHandler,
transporterType: settings.TransporterType,
connections: pool,
connections: nil,
session: "",
connectionSettings: connSettings,
protocol: prot,
httpProtocol: httpProt,
}

return client, nil
Expand All @@ -155,7 +145,7 @@ func (client *Client) Close() {
client.session = ""
}
client.logHandler.logf(Info, closeClient, client.url)
client.connections.close()
//client.connections.close()
}

func (client *Client) errorCallback() {
Expand All @@ -167,17 +157,7 @@ func (client *Client) SubmitWithOptions(traversalString string, requestOptions R
client.logHandler.logf(Debug, submitStartedString, traversalString)
request := makeStringRequest(traversalString, client.traversalSource, client.session, requestOptions)

// write and send request
err := client.protocol.write(&request)
if err != nil {
return nil, err
}
results := &synchronizedMap{map[string]ResultSet{}, sync.Mutex{}}
rs := newChannelResultSet(request.requestID.String(), results)
results.store(request.requestID.String(), rs)
// read and handle response
client.protocol.readLoop(results, client.errorCallback)
// return the ResultSet from which the caller can obtain result(s)
rs, err := client.httpProtocol.send(&request)
return rs, err
}

Expand Down
2 changes: 1 addition & 1 deletion gremlin-go/driver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func createConnection(url string, logHandler *logHandler, connSettings *connecti
initialized,
}
logHandler.log(Info, connectConnection)
protocol, err := newHttpProtocol(logHandler, url, connSettings)
protocol, err := newGremlinServerWSProtocol(logHandler, Gorilla, url, connSettings, conn.results, conn.errorCallback)
if err != nil {
logHandler.logf(Warning, failedConnection)
conn.state = closedDueToError
Expand Down
37 changes: 20 additions & 17 deletions gremlin-go/driver/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,24 +608,27 @@ func TestConnection(t *testing.T) {
assert.NotNil(t, client)
defer client.Close()

resultSet, err := client.Submit("g.V().count()")
assert.Nil(t, err)
assert.NotNil(t, resultSet)
result, ok, err := resultSet.One()
assert.Nil(t, err)
assert.True(t, ok)
assert.NotNil(t, result)
_, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result)
var wg sync.WaitGroup

for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
resultSet, err := client.Submit("g.V().count().as('c').math('c + " + strconv.Itoa(i) + "')")
assert.Nil(t, err)
assert.NotNil(t, resultSet)
result, ok, err := resultSet.One()
assert.Nil(t, err)
assert.True(t, ok)
assert.NotNil(t, result)
c, err := result.GetInt()
assert.Equal(t, 6+i, c)
_, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result)
}(i)
}

wg.Wait()

// submit 2nd request
resultSet, err = client.Submit("g.V().count()")
assert.Nil(t, err)
assert.NotNil(t, resultSet)
result, ok, err = resultSet.One()
assert.Nil(t, err)
assert.True(t, ok)
assert.NotNil(t, result)
_, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result)
//
//g := cloneGraphTraversalSource(&Graph{}, NewBytecode(nil), nil)
//b := g.V().Count().Bytecode
Expand Down
140 changes: 83 additions & 57 deletions gremlin-go/driver/httpProtocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,88 +23,119 @@ import (
"encoding/base64"
"fmt"
"net/http"
"os"
"sync"
)

type httpProtocol struct {
*protocolBase
serializer serializer
logHandler *logHandler
url string
connSettings *connectionSettings
}

serializer serializer
logHandler *logHandler
func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) (*httpProtocol, error) {
httpProt := &httpProtocol{
serializer: newGraphBinarySerializer(handler),
logHandler: handler,
url: url,
connSettings: connSettings,
}
return httpProt, nil
}

// waits for response, deserializes and processes results
// function name is readLoop but is not actually a loop - just keeping the name due to the protocol interface for now
func (protocol *httpProtocol) readLoop(resultSets *synchronizedMap, errorCallback func()) {
msg, err := protocol.transporter.Read()
func (protocol *httpProtocol) send(request *request) (ResultSet, error) {
// TODO remove need for result set container map
results := &synchronizedMap{map[string]ResultSet{}, sync.Mutex{}}
rs := newChannelResultSet(request.requestID.String(), results)
results.store(request.requestID.String(), rs)

// Deserialize response message
fmt.Println("Deserializing response")
resp, err := protocol.serializer.deserializeMessage(msg)
fmt.Println("Serializing request")
bytes, err := protocol.serializer.serializeMessage(request)
if err != nil {
protocol.logHandler.logf(Error, logErrorGeneric, "httpReadLoop()", err.Error())
readErrorHandler(resultSets, errorCallback, err, protocol.logHandler)
return
return nil, err
}

fmt.Println("Deserialized response")
// TODO we should not need to use response/request ids to correlate responses with requests anymore after moving from ws to http
// but for simplicity of http POC we are just setting the responseId to the requestId here
resp.responseID = protocol.request.requestID
err = protocol.responseHandler(resultSets, resp)
if err != nil {
readErrorHandler(resultSets, errorCallback, err, protocol.logHandler)
return
}
transport := NewHttpTransporter(protocol.url, protocol.connSettings)

// async send request and wait for response
transport.wg.Add(1)
go func() {
defer transport.wg.Done()
err := transport.Write(bytes)
if err != nil {
protocol.errorCallback()
}
}()

// async receive response msg data
transport.wg.Add(1)
go func() {
defer transport.wg.Done()
msg, err := transport.Read()
if err != nil {
protocol.errorCallback()
} else {
protocol.receive(rs, msg, protocol.errorCallback)
}
err = transport.Close()
}()

return rs, err
}

func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) (protocol, error) {
transport, err := getTransportLayer(Http, url, connSettings, handler)
func (protocol *httpProtocol) receive(rs ResultSet, msg []byte, errorCallback func()) {
fmt.Println("Deserializing response")
resp, err := protocol.serializer.deserializeMessage(msg)
if err != nil {
return nil, err
protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error())
rs.Close()
return
}

gremlinProtocol := &httpProtocol{
protocolBase: &protocolBase{transporter: transport},
serializer: newGraphBinarySerializer(handler),
logHandler: handler,
fmt.Println("Handling response")
err = protocol.handleResponse(rs, resp)
if err != nil {
protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error())
rs.Close()
errorCallback()
return
}
return gremlinProtocol, nil
}

// loads results into the response set from the response
func (protocol *httpProtocol) responseHandler(resultSets *synchronizedMap, response response) error {
func (protocol *httpProtocol) handleResponse(rs ResultSet, response response) error {
fmt.Println("Handling response")

// TODO http specific response handling - below is just copy-pasted from web socket implementation for now

responseID, statusCode, metadata, data := response.responseID, response.responseStatus.code,
response.responseResult.meta, response.responseResult.data
responseIDString := responseID.String()
if resultSets.load(responseIDString) == nil {
if rs == nil {
return newError(err0501ResponseHandlerResultSetNotCreatedError)
}
if aggregateTo, ok := metadata["aggregateTo"]; ok {
resultSets.load(responseIDString).setAggregateTo(aggregateTo.(string))
rs.setAggregateTo(aggregateTo.(string))
}

// Handle status codes appropriately. If status code is http.StatusPartialContent, we need to re-read data.
if statusCode == http.StatusNoContent {
resultSets.load(responseIDString).addResult(&Result{make([]interface{}, 0)})
resultSets.load(responseIDString).Close()
rs.addResult(&Result{make([]interface{}, 0)})
rs.Close()
protocol.logHandler.logf(Debug, readComplete, responseIDString)
} else if statusCode == http.StatusOK {
// Add data and status attributes to the ResultSet.
resultSets.load(responseIDString).addResult(&Result{data})
resultSets.load(responseIDString).setStatusAttributes(response.responseStatus.attributes)
resultSets.load(responseIDString).Close()
rs.addResult(&Result{data})
rs.setStatusAttributes(response.responseStatus.attributes)
rs.Close()
protocol.logHandler.logf(Debug, readComplete, responseIDString)
} else if statusCode == http.StatusPartialContent {
// Add data to the ResultSet.
resultSets.load(responseIDString).addResult(&Result{data})
rs.addResult(&Result{data})
} else if statusCode == http.StatusProxyAuthRequired || statusCode == authenticationFailed {
// http status code 151 is not defined here, but corresponds with 403, i.e. authentication has failed.
// Server has requested basic auth.
authInfo := protocol.transporter.getAuthInfo()
authInfo := protocol.getAuthInfo()
if ok, username, password := authInfo.GetBasicAuth(); ok {
authBytes := make([]byte, 0)
authBytes = append(authBytes, 0)
Expand All @@ -113,36 +144,31 @@ func (protocol *httpProtocol) responseHandler(resultSets *synchronizedMap, respo
authBytes = append(authBytes, []byte(password)...)
encoded := base64.StdEncoding.EncodeToString(authBytes)
request := makeBasicAuthRequest(encoded)
err := protocol.write(&request)
// TODO retry
_, err := fmt.Fprintf(os.Stdout, "Skipping retry of failed request : %s\n", request.requestID)
if err != nil {
return err
}
} else {
resultSets.load(responseIDString).Close()
rs.Close()
return newError(err0503ResponseHandlerAuthError, response.responseStatus, response.responseResult)
}
} else {
newError := newError(err0502ResponseHandlerReadLoopError, response.responseStatus, statusCode)
resultSets.load(responseIDString).setError(newError)
resultSets.load(responseIDString).Close()
rs.setError(newError)
rs.Close()
protocol.logHandler.logf(Error, logErrorGeneric, "httpProtocol.responseHandler()", newError.Error())
}
return nil
}

// serializes and sends the request
func (protocol *httpProtocol) write(request *request) error {
protocol.request = request
// TODO interceptors
fmt.Println("Serializing request")
bytes, err := protocol.serializer.serializeMessage(request)
if err != nil {
return err
func (protocol *httpProtocol) getAuthInfo() AuthInfoProvider {
if protocol.connSettings.authInfo == nil {
return NoopAuthInfo
}
return protocol.transporter.Write(bytes)

return protocol.connSettings.authInfo
}

func (protocol *httpProtocol) close(wait bool) error {
return nil
func (protocol *httpProtocol) errorCallback() {
protocol.logHandler.log(Error, errorCallback)
}
Loading

0 comments on commit 8f1e2a8

Please sign in to comment.