Skip to content

Commit

Permalink
fix(anta): Add Semaphore to AsyncEOSDevice (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-baillargeon authored Feb 21, 2025
1 parent ecce5e5 commit 45f3b5b
Showing 1 changed file with 74 additions and 50 deletions.
124 changes: 74 additions & 50 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
# https://github.com/pyca/cryptography/issues/7236#issuecomment-1131908472
CLIENT_KEYS = asyncssh.public_key.load_default_keypairs()

# Limit concurrency to 100 requests (HTTPX default) to avoid high-concurrency performance issues
# See: https://github.com/encode/httpx/issues/3215
MAX_CONCURRENT_REQUESTS = 100


class AntaCache:
"""Class to be used as cache.
Expand Down Expand Up @@ -296,6 +300,7 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[
raise NotImplementedError(msg)


# pylint: disable=too-many-instance-attributes
class AsyncEOSDevice(AntaDevice):
"""Implementation of AntaDevice for EOS using aio-eapi.
Expand Down Expand Up @@ -388,6 +393,10 @@ def __init__( # noqa: PLR0913
host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params
)

# In Python 3.9, Semaphore must be created within a running event loop
# TODO: Once we drop Python 3.9 support, initialize the semaphore here
self._command_semaphore: asyncio.Semaphore | None = None

def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
"""Implement Rich Repr Protocol.
Expand Down Expand Up @@ -431,6 +440,15 @@ def _keys(self) -> tuple[Any, ...]:
"""
return (self._session.host, self._session.port)

async def _get_semaphore(self) -> asyncio.Semaphore:
"""Return the semaphore, initializing it if needed.
TODO: Remove this method once we drop Python 3.9 support.
"""
if self._command_semaphore is None:
self._command_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
return self._command_semaphore

async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect device command output from EOS using aio-eapi.
Expand All @@ -445,57 +463,63 @@ async def _collect(self, command: AntaCommand, *, collection_id: str | None = No
collection_id
An identifier used to build the eAPI request ID.
"""
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
self._log_eapi_command_error(command, e)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(os_error := e, OSError): # pylint: disable=no-member
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
semaphore = await self._get_semaphore()

async with semaphore:
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
"cmd": "enable",
"input": str(self._enable_password),
},
)
elif self.enable:
# No password
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
self._log_eapi_command_error(command, e)
except TimeoutException as e:
# This block catches Timeout exceptions.
command.errors = [exc_to_str(e)]
timeouts = self._session.timeout.as_dict()
logger.error(
"%s occurred while sending a command to %s. Consider increasing the timeout.\nCurrent timeouts: Connect: %s | Read: %s | Write: %s | Pool: %s",
exc_to_str(e),
self.name,
timeouts["connect"],
timeouts["read"],
timeouts["write"],
timeouts["pool"],
)
except (ConnectError, OSError) as e:
# This block catches OSError and socket issues related exceptions.
command.errors = [exc_to_str(e)]
# pylint: disable=no-member
if (isinstance(exc := e.__cause__, httpcore.ConnectError) and isinstance(os_error := exc.__context__, OSError)) or isinstance(
os_error := e, OSError
):
if isinstance(os_error.__cause__, OSError):
os_error = os_error.__cause__
logger.error("A local OS error occurred while connecting to %s: %s.", self.name, os_error)
else:
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
except HTTPError as e:
# This block catches most of the httpx Exceptions and logs a general message.
command.errors = [exc_to_str(e)]
anta_log_exception(e, f"An error occurred while issuing an eAPI request to {self.name}", logger)
logger.debug("%s: %s", self.name, command)
logger.debug("%s: %s", self.name, command)

def _log_eapi_command_error(self, command: AntaCommand, e: asynceapi.EapiCommandError) -> None:
"""Appropriately log the eapi command error."""
Expand Down

0 comments on commit 45f3b5b

Please sign in to comment.