Skip to content

Commit

Permalink
update retries to use defaults and ensure that a timeout or deadline …
Browse files Browse the repository at this point in the history
…is set
  • Loading branch information
mikealfare committed Nov 20, 2024
1 parent 89e568b commit 9aecdc7
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 55 deletions.
22 changes: 11 additions & 11 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def cancel_open(self):
with self.exception_handler(f"Cancel job: {job_id}"):
client.cancel_job(
job_id,
retry=self._retry.create_reopen_with_deadline(connection),
retry=self._retry.create_job_execution_retry_with_reopen(
connection
),
)
self.close(connection)

Expand Down Expand Up @@ -441,7 +443,7 @@ def copy_bq_table(self, source, destination, write_disposition) -> None:
source_ref_array,
destination_ref,
job_config=CopyJobConfig(write_disposition=write_disposition),
retry=self._retry.create_reopen_with_deadline(conn),
retry=self._retry.create_job_execution_retry_with_reopen(conn),
)
copy_job.result(timeout=self._retry.create_job_execution_timeout(fallback=300))

Expand All @@ -454,15 +456,14 @@ def write_dataframe_to_table(
identifier: str,
table_schema: List[SchemaField],
field_delimiter: str,
fallback_timeout: Optional[float] = None,
) -> None:
load_config = LoadJobConfig(
skip_leading_rows=1,
schema=table_schema,
field_delimiter=field_delimiter,
)
table = self.table_ref(database, schema, identifier)
self._write_file_to_table(client, file_path, table, load_config, fallback_timeout)
self._write_file_to_table(client, file_path, table, load_config)

def write_file_to_table(
self,
Expand All @@ -479,22 +480,21 @@ def write_file_to_table(
config["schema"] = json.load(config["schema"])
load_config = LoadJobConfig(**config)
table = self.table_ref(database, schema, identifier)
self._write_file_to_table(client, file_path, table, load_config, fallback_timeout)
self._write_file_to_table(client, file_path, table, load_config)

def _write_file_to_table(
self,
client: Client,
file_path: str,
table: TableReference,
config: LoadJobConfig,
fallback_timeout: Optional[float] = None,
) -> None:

with self.exception_handler("LOAD TABLE"):
with open(file_path, "rb") as f:
job = client.load_table_from_file(f, table, rewind=True, job_config=config)

response = job.result(retry=self._retry.create_retry(fallback=fallback_timeout))
response = job.result(retry=self._retry.create_job_execution_retry())

if response.state != "DONE":
raise DbtRuntimeError("BigQuery Timeout Exceeded")
Expand All @@ -521,7 +521,7 @@ def get_bq_table(self, database, schema, identifier) -> Table:
schema = schema or conn.credentials.schema
return client.get_table(
table=self.table_ref(database, schema, identifier),
retry=self._retry.create_reopen_with_deadline(conn),
retry=self._retry.create_job_execution_retry_with_reopen(conn),
)

def drop_dataset(self, database, schema) -> None:
Expand All @@ -532,7 +532,7 @@ def drop_dataset(self, database, schema) -> None:
dataset=self.dataset_ref(database, schema),
delete_contents=True,
not_found_ok=True,
retry=self._retry.create_reopen_with_deadline(conn),
retry=self._retry.create_job_execution_retry_with_reopen(conn),
)

def create_dataset(self, database, schema) -> Dataset:
Expand All @@ -542,7 +542,7 @@ def create_dataset(self, database, schema) -> Dataset:
return client.create_dataset(
dataset=self.dataset_ref(database, schema),
exists_ok=True,
retry=self._retry.create_reopen_with_deadline(conn),
retry=self._retry.create_job_execution_retry_with_reopen(conn),
)

def list_dataset(self, database: str):
Expand All @@ -555,7 +555,7 @@ def list_dataset(self, database: str):
all_datasets = client.list_datasets(
project=database.strip("`"),
max_results=10000,
retry=self._retry.create_reopen_with_deadline(conn),
retry=self._retry.create_job_execution_retry_with_reopen(conn),
)
return [ds.dataset_id for ds in all_datasets]

Expand Down
2 changes: 0 additions & 2 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ def load_dataframe(
table_name,
table_schema,
field_delimiter,
fallback_timeout=300,
)

@available.parse_none
Expand All @@ -692,7 +691,6 @@ def upload_file(
database,
table_schema,
table_name,
fallback_timeout=300,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None

# set retry policy, default to timeout after 24 hours
retry = RetryFactory(credentials)
self._polling_retry = retry.create_polling(
self._polling_retry = retry.create_job_execution_polling(
model_timeout=parsed_model["config"].get("timeout")
)

Expand Down
78 changes: 38 additions & 40 deletions dbt/adapters/bigquery/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Callable, Optional

from google.api_core.exceptions import Forbidden
from google.api_core.future.polling import DEFAULT_POLLING
from google.api_core.retry import Retry
from google.cloud.bigquery.retry import DEFAULT_RETRY
from google.cloud.exceptions import BadGateway, BadRequest, ServerError
from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY
from google.cloud.exceptions import BadRequest
from requests.exceptions import ConnectionError

from dbt.adapters.contracts.connection import Connection, ConnectionState
Expand All @@ -17,14 +16,8 @@

_logger = AdapterLogger("BigQuery")


_SECOND = 1.0
_MINUTE = 60 * _SECOND
_HOUR = 60 * _MINUTE
_DAY = 24 * _HOUR
_DEFAULT_INITIAL_DELAY = _SECOND
_DEFAULT_MAXIMUM_DELAY = 3 * _SECOND
_DEFAULT_POLLING_MAXIMUM_DELAY = 10 * _SECOND
_MINUTE = 60.0
_DAY = 24 * 60 * 60.0


class RetryFactory:
Expand All @@ -35,34 +28,37 @@ def __init__(self, credentials: BigQueryCredentials) -> None:
self._job_execution_timeout = credentials.job_execution_timeout_seconds
self._job_deadline = credentials.job_retry_deadline_seconds

def create_job_creation_timeout(self, fallback: float = _MINUTE) -> float:
return (
self._job_creation_timeout or fallback
) # keep _MINUTE here so it's not overridden by passing fallback=None
def create_job_creation_timeout(self) -> float:
return self._job_creation_timeout or 1 * _MINUTE

def create_job_execution_timeout(self, fallback: float = _DAY) -> float:
return (
self._job_execution_timeout or fallback
) # keep _DAY here so it's not overridden by passing fallback=None
def create_job_execution_timeout(self, fallback: float = 1 * _DAY) -> float:
return self._job_execution_timeout or fallback

def create_retry(self, fallback: Optional[float] = None) -> Retry:
return DEFAULT_RETRY.with_timeout(self._job_execution_timeout or fallback or _DAY)
def create_job_execution_retry(self) -> Retry:
return DEFAULT_JOB_RETRY.with_timeout(self.create_job_execution_timeout(5 * _MINUTE))

def create_polling(self, model_timeout: Optional[float] = None) -> Retry:
return DEFAULT_POLLING.with_timeout(model_timeout or self._job_execution_timeout or _DAY)
def create_job_execution_polling(self, model_timeout: Optional[float] = None) -> Retry:
return DEFAULT_POLLING.with_timeout(model_timeout or self.create_job_execution_timeout())

def create_reopen_with_deadline(self, connection: Connection) -> Retry:
def create_job_execution_retry_with_reopen(self, connection: Connection) -> Retry:
"""
This strategy mimics what was accomplished with _retry_and_handle
"""
return Retry(
predicate=_DeferredException(self._retries),
initial=_DEFAULT_INITIAL_DELAY,
maximum=_DEFAULT_MAXIMUM_DELAY,
deadline=self._job_deadline,
on_error=_create_reopen_on_error(connection),

retry = DEFAULT_JOB_RETRY.with_delay(maximum=3.0).with_predicate(
_DeferredException(self._retries)
)

# there is no `with_on_error` method, but we want to retain the defaults on `DEFAULT_JOB_RETRY
retry._on_error = _create_reopen_on_error(connection)

# don't override the default deadline to None if the user did not provide one,
# the process will never end
if deadline := self._job_deadline:
return retry.with_deadline(deadline)

return retry


class _DeferredException:
"""
Expand Down Expand Up @@ -95,7 +91,7 @@ def __call__(self, error: Exception) -> bool:

def _create_reopen_on_error(connection: Connection) -> Callable[[Exception], None]:

def on_error(error: Exception):
def on_error(error: Exception) -> None:
if isinstance(error, (ConnectionResetError, ConnectionError)):
_logger.warning("Reopening connection after {!r}".format(error))
connection.handle.close()
Expand All @@ -116,13 +112,15 @@ def on_error(error: Exception):


def _is_retryable(error: Exception) -> bool:
"""Return true for errors that are unlikely to occur again if retried."""
if isinstance(
error, (BadGateway, BadRequest, ConnectionError, ConnectionResetError, ServerError)
):
return True
elif isinstance(error, Forbidden) and any(
e["reason"] == "rateLimitExceeded" for e in error.errors
):
"""
Extend the default predicate `_job_should_retry` to include BadRequest
Because `_job_should_retry` is private, take the predicate directly off of `DEFAULT_JOB_RETRY`.
This is expected to be more stable.
"""

# this is effectively an or, but it's more readable, especially if we add more in the future
if isinstance(error, BadRequest):
return True
return False

return DEFAULT_JOB_RETRY._predicate(error)
2 changes: 1 addition & 1 deletion tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):
def test_retry_connection_reset(self, mock_client_factory):
new_mock_client = mock_client_factory.return_value

@self.connections._retry.create_reopen_with_deadline(self.mock_connection)
@self.connections._retry.create_job_execution_retry_with_reopen(self.mock_connection)
def generate_connection_reset_error():
raise ConnectionResetError

Expand Down

0 comments on commit 9aecdc7

Please sign in to comment.