Skip to content

Commit

Permalink
feat(sdk): use a persistent session object for GraphQL requests (wand…
Browse files Browse the repository at this point in the history
…b#5075)

* Make local (altered) copy of RequestHTTPTransport as GraphQLSession

* Use GraphQLSession instead of RequestHTTPTransport

* Add mypy ignores for vendored gql, graphql-core

* Remove unused '# type: ignore' declarations for wandb_{gql,graphql}

* Use new yea-wandb branch in tox.ini
  • Loading branch information
moredatarequired authored Mar 17, 2023
1 parent a4aa143 commit 6d3afec
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 11 deletions.
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,9 @@ ignore_missing_imports = True

[mypy-wandb.vendor.*]
ignore_missing_imports = True

[mypy-wandb_gql.*]
ignore_missing_imports = True

[mypy-wandb_graphql.*]
ignore_missing_imports = True
6 changes: 3 additions & 3 deletions tests/pytest_tests/unit_tests_old/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def mock_server(mocker):
mock = RequestsMock(app, ctx)
# We mock out all requests libraries, couldn't find a way to mock the core lib
sdk_path = "wandb.sdk"
# From previous wandb_gql transport library.
mocker.patch("wandb_gql.transport.requests.requests", mock)

mocker.patch("wandb.wandb_sdk.lib.gql_request.requests", mock)
mocker.patch("wandb.wandb_sdk.internal.file_stream.requests", mock)
mocker.patch("wandb.wandb_sdk.internal.internal_api.requests", mock)
mocker.patch("wandb.wandb_sdk.internal.update.requests", mock)
Expand Down Expand Up @@ -372,7 +375,6 @@ def to_dict(self):


class SnoopRelay:

_inject_count: int
_inject_time: float

Expand All @@ -383,7 +385,6 @@ def __init__(self):
def relay(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):

# Normal mockserver mode, disable live relay and call next function
if not os.environ.get("MOCKSERVER_RELAY"):
return func(*args, **kwargs)
Expand Down Expand Up @@ -1659,7 +1660,6 @@ def graphql():
c["alerts"].append(adict)
return {"data": {"notifyScriptableRunAlert": {"success": True}}}
if "query SearchUsers" in body["query"]:

return {
"data": {
"users": {
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ envlist=

[base]
setenv =
YEA_WANDB_VERSION = 0.9.5
YEA_WANDB_VERSION = 0.9.6
; Setting low network buffer so that we exercise flow control logic
WANDB__NETWORK_BUFFER = 1000

Expand Down
4 changes: 2 additions & 2 deletions wandb/apis/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import requests
from wandb_gql import Client, gql
from wandb_gql.client import RetryError
from wandb_gql.transport.requests import RequestsHTTPTransport

import wandb
from wandb import __version__, env, util
Expand All @@ -58,6 +57,7 @@
apply_patch,
)
from wandb.sdk.lib import filesystem, ipython, retry, runid
from wandb.sdk.lib.gql_request import GraphQLSession
from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id, md5_file_b64

if TYPE_CHECKING:
Expand Down Expand Up @@ -424,7 +424,7 @@ def __init__(
self._default_entity = None
self._timeout = timeout if timeout is not None else self._HTTP_TIMEOUT
self._base_client = Client(
transport=RequestsHTTPTransport(
transport=GraphQLSession(
headers={"User-Agent": self.user_agent, "Use-Admin-Privileges": "true"},
use_json=True,
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
Expand Down
8 changes: 4 additions & 4 deletions wandb/sdk/internal/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
import click
import requests
import yaml
from wandb_gql import Client, gql # type: ignore
from wandb_gql.client import RetryError # type: ignore
from wandb_gql.transport.requests import RequestsHTTPTransport # type: ignore
from wandb_gql import Client, gql
from wandb_gql.client import RetryError

import wandb
from wandb import __version__, env, util
from wandb.apis.normalize import normalize_exceptions
from wandb.errors import CommError, UsageError
from wandb.integration.sagemaker import parse_sm_secrets
from wandb.old.settings import Settings
from wandb.sdk.lib.gql_request import GraphQLSession
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64

from ..lib import retry
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(
"heartbeat_seconds": 30,
}
self.client = Client(
transport=RequestsHTTPTransport(
transport=GraphQLSession(
headers={
"User-Agent": self.user_agent,
"X-WANDB-USERNAME": env.get_username(env=self._environ),
Expand Down
62 changes: 62 additions & 0 deletions wandb/sdk/lib/gql_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""A simple GraphQL client for sending queries and mutations.
Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
The only substantial change is to re-use a requests.Session object.
"""

from typing import Any, Callable, Dict, Optional, Tuple, Union

import requests
from wandb_gql.transport.http import HTTPTransport
from wandb_graphql.execution import ExecutionResult
from wandb_graphql.language import ast
from wandb_graphql.language.printer import print_ast


class GraphQLSession(HTTPTransport):
def __init__(
self,
url: str,
auth: Optional[Union[Tuple[str, str], Callable]] = None,
use_json: bool = False,
timeout: Optional[Union[int, float]] = None,
**kwargs: Any,
) -> None:
"""Setup a session for sending GraphQL queries and mutations.
Args:
url (str): The GraphQL URL
auth (tuple or callable): Auth tuple or callable for Basic/Digest/Custom HTTP Auth
use_json (bool): Send request body as JSON instead of form-urlencoded
timeout (int, float): Specifies a default timeout for requests (Default: None)
"""
super().__init__(url, **kwargs)
self.session = requests.Session()
self.session.auth = auth
self.default_timeout = timeout
self.use_json = use_json

def execute(
self,
document: ast.Node,
variable_values: Optional[Dict] = None,
timeout: Optional[Union[int, float]] = None,
) -> ExecutionResult:
query_str = print_ast(document)
payload = {"query": query_str, "variables": variable_values or {}}

data_key = "json" if self.use_json else "data"
post_args = {
"headers": self.headers,
"cookies": self.cookies,
"timeout": timeout or self.default_timeout,
data_key: payload,
}
request = self.session.post(self.url, **post_args)
request.raise_for_status()

result = request.json()
data, errors = result.get("data"), result.get("errors")
if data is None and errors is None:
raise RuntimeError(f"Received non-compatible response: {result}")
return ExecutionResult(data=data, errors=errors)
2 changes: 1 addition & 1 deletion wandb/sdk/verify/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import click
import requests
from pkg_resources import parse_version
from wandb_gql import gql # type: ignore
from wandb_gql import gql

import wandb
from wandb.sdk.lib import runid
Expand Down

0 comments on commit 6d3afec

Please sign in to comment.