Skip to content

Commit

Permalink
ssh health check refresh server on fail
Browse files Browse the repository at this point in the history
  • Loading branch information
yanksyoon committed Jun 12, 2024
1 parent 0a71699 commit 7d307d6
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 21 deletions.
55 changes: 36 additions & 19 deletions src/openstack_cloud/openstack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,12 @@ def _health_check(conn: OpenstackConnection, server_name: str, startup: bool = F
return False
if server.status not in (_INSTANCE_STATUS_ACTIVE, _INSTANCE_STATUS_BUILDING):
return False
return OpenstackRunnerManager._ssh_health_check(server=server, startup=startup)
return OpenstackRunnerManager._ssh_health_check(
conn=conn, server_name=server_name, startup=startup
)

@staticmethod
def _ssh_health_check(server: Server, startup: bool) -> bool:
def _ssh_health_check(conn: OpenstackConnection, server_name: str, startup: bool) -> bool:
"""Use SSH to check whether runner application is running.
A healthy runner is defined as:
Expand All @@ -703,43 +705,49 @@ def _ssh_health_check(server: Server, startup: bool) -> bool:
3. Runner.Listener exists (waiting for job).
Args:
server: The openstack server instance to check connections.
conn: The Openstack connection instance.
server_name: The openstack server instance to check connections.
startup: Check only whether the startup is successful.
Returns:
Whether the runner application is running.
"""
try:
ssh_conn = OpenstackRunnerManager._get_ssh_connection(server=server)
ssh_conn = OpenstackRunnerManager._get_ssh_connection(
conn=conn, server_name=server_name
)
except _SSHError as exc:
logger.error("[ALERT]: Unable to SSH to server: %s, reason: %s", server.name, str(exc))
logger.error("[ALERT]: Unable to SSH to server: %s, reason: %s", server_name, str(exc))
return True

result: invoke.runners.Result = ssh_conn.run("ps aux", warn=True)
logger.debug("Output of `ps aux` on %s stderr: %s", server.name, result.stderr)
logger.debug("Output of `ps aux` on %s stdout: %s", server.name, result.stdout)
logger.debug("Output of `ps aux` on %s stderr: %s", server_name, result.stderr)
logger.debug("Output of `ps aux` on %s stdout: %s", server_name, result.stdout)
if not result.ok or RUNNER_STARTUP_PROCESS not in result.stdout:
logger.warning("List all process command failed on %s ", server.name)
logger.warning("List all process command failed on %s ", server_name)
return False
logger.info("Runner process found to be healthy on %s", server.name)
logger.info("Runner process found to be healthy on %s", server_name)
if startup:
return True

if RUNNER_WORKER_PROCESS in result.stdout or RUNNER_LISTENER_PROCESS in result.stdout:
return True

logger.error("[ALERT] Health check failed for server: %s", server.name)
logger.error("[ALERT] Health check failed for server: %s", server_name)
return True

@staticmethod
@retry(tries=3, delay=5, max_delay=60, backoff=2, local_logger=logger)
def _get_ssh_connection(server: Server, timeout: int = 30) -> SshConnection:
def _get_ssh_connection(
conn: OpenstackConnection, server_name: str, timeout: int = 30
) -> SshConnection:
"""Get a valid ssh connection within a network for a given openstack instance.
The SSH connection will attempt to establish connection until the timeout configured.
Args:
server: The Openstack server instance.
conn: The Openstack connection instance.
server_name: The Openstack server instance name.
timeout: Timeout in seconds to attempt connection to each available server address.
Raises:
Expand All @@ -748,6 +756,9 @@ def _get_ssh_connection(server: Server, timeout: int = 30) -> SshConnection:
Returns:
An SSH connection to OpenStack server instance.
"""
server: Server | None = conn.get_server(name_or_id=server_name)
if not server:
raise _SSHError(f"Server gone while trying to get SSH connection: {server_name}.")
if not server.key_name:
raise _SSHError(
f"Unable to create SSH connection, no valid keypair found for {server.name}"
Expand All @@ -756,6 +767,7 @@ def _get_ssh_connection(server: Server, timeout: int = 30) -> SshConnection:
if not key_path.exists():
raise _SSHError(f"Missing keyfile for server: {server.name}, key path: {key_path}")
network_address_list = server.addresses.values()
print(network_address_list)
if not network_address_list:
raise _SSHError(f"No addresses found for OpenStack server {server.name}")

Expand Down Expand Up @@ -1185,7 +1197,7 @@ def _remove_one_runner(
logger.info(
"Pulling metrics and deleting server for OpenStack runner %s", instance_name
)
self._pull_metrics(server, instance_name)
self._pull_metrics(conn=conn, instance_name=instance_name)
self._remove_openstack_runner(conn, server, remove_token)
else:
logger.info(
Expand All @@ -1204,11 +1216,11 @@ def _remove_one_runner(
"Found unexpected exception, please contact the developers", exc_info=True
)

def _pull_metrics(self, server: Server, instance_name: str) -> None:
def _pull_metrics(self, conn: OpenstackConnection, instance_name: str) -> None:
"""Pull metrics from the runner into the respective storage for the runner.
Args:
server: The Openstack server instance.
conn: The Openstack connection instance.
instance_name: The Openstack server name.
"""
try:
Expand All @@ -1222,7 +1234,7 @@ def _pull_metrics(self, server: Server, instance_name: str) -> None:
return

try:
ssh_conn = self._get_ssh_connection(server=server)
ssh_conn = self._get_ssh_connection(conn=conn, server_name=instance_name)
except _SSHError as exc:
logger.info("Failed to pull metrics for %s: %s", instance_name, exc)
return
Expand Down Expand Up @@ -1314,7 +1326,7 @@ def _remove_openstack_runner(
remove_token: The GitHub runner remove token.
"""
try:
self._run_github_removal_script(server=server, remove_token=remove_token)
self._run_github_removal_script(conn=conn, server=server, remove_token=remove_token)
except (TimeoutError, invoke.exceptions.UnexpectedExit, GithubRunnerRemoveError):
logger.warning(
"Failed to run runner removal script for %s", server.name, exc_info=True
Expand All @@ -1337,10 +1349,13 @@ def _remove_openstack_runner(
"Found unexpected exception, please contact the developers", exc_info=True
)

def _run_github_removal_script(self, server: Server, remove_token: str | None) -> None:
def _run_github_removal_script(
self, conn: OpenstackConnection, server: Server, remove_token: str | None
) -> None:
"""Run Github runner removal script.
Args:
conn: The Openstack connection instance.
server: The Openstack server instance.
remove_token: The GitHub instance removal token.
Expand All @@ -1350,7 +1365,9 @@ def _run_github_removal_script(self, server: Server, remove_token: str | None) -
if not remove_token:
return
try:
ssh_conn = OpenstackRunnerManager._get_ssh_connection(server=server)
ssh_conn = OpenstackRunnerManager._get_ssh_connection(
conn=conn, server_name=server.name
)
except _SSHError as exc:
logger.error(
"Unable to run GitHub removal script, server: %s, reason: %s",
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import copy
import secrets
import typing
import unittest.mock
from pathlib import Path

import pytest

import utilities
from openstack_cloud import openstack_manager
from tests.unit.mock import MockGhapiClient, MockLxdClient, MockRepoPolicyComplianceClient

Expand Down Expand Up @@ -148,3 +150,34 @@ def multi_clouds_yaml_fixture(clouds_yaml: dict) -> dict:
}
}
return multi_clouds_yaml


@pytest.fixture(name="skip_retry")
def skip_retry_fixture(monkeypatch: pytest.MonkeyPatch):
"""Fixture for skipping retry for functions with retry decorator."""

def patched_retry(*args, **kwargs):
"""A fallthrough decorator.
Args:
args: Positional arguments placeholder.
kwargs: Keyword arguments placeholder.
Returns:
The fallthrough decorator.
"""

def patched_retry_decorator(func: typing.Callable):
"""The fallthrough decorator.
Args:
func: The function to decorate.
Returns:
the function without any additional features.
"""
return func

return patched_retry_decorator

monkeypatch.setattr(utilities, "retry", patched_retry)
166 changes: 164 additions & 2 deletions tests/unit/test_openstack_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def test_repo_policy_config(
[
pytest.param(None, id="no server"),
pytest.param(factories.MockOpenstackServer(status="SHUTOFF"), id="shutoff"),
pytest.param(factories.MockOpenstackServer(status="REBUILD"), id="not active/buliding"),
pytest.param(factories.MockOpenstackServer(status="REBUILD"), id="not active/building"),
],
)
def test__health_check(server: factories.MockOpenstackServer | None):
Expand Down Expand Up @@ -981,5 +981,167 @@ def test__ssh_health_check_healthy(
)

assert openstack_manager.OpenstackRunnerManager._ssh_health_check(
server=MagicMock(), startup=startup
conn=MagicMock(), server_name=MagicMock(), startup=startup
)


@pytest.mark.usefixtures("skip_retry")
def test__get_ssh_connection_server_gone():
"""
arrange: given a mock Openstack get_server function that returns None.
act: when _get_ssh_connection is called.
assert: _SSHError is raised.
"""
mock_connection = MagicMock()
mock_connection.get_server.return_value = None

with pytest.raises(openstack_manager._SSHError) as exc:
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)

assert "Server gone while trying to get SSH connection" in str(exc.getrepr())


@pytest.mark.usefixtures("skip_retry")
def test__get_ssh_connection_no_server_key():
"""
arrange: given a mock server instance with no key attached.
act: when _get_ssh_connection is called.
assert: _SSHError is raised.
"""
mock_server = MagicMock()
mock_server.key_name = None
mock_connection = MagicMock()
mock_connection.get_server.return_value = mock_server

with pytest.raises(openstack_manager._SSHError) as exc:
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)

assert "Unable to create SSH connection, no valid keypair found" in str(exc.getrepr())


@pytest.mark.usefixtures("skip_retry")
def test__get_ssh_connection_key_not_exists(monkeypatch: pytest.MonkeyPatch):
"""
arrange: given a monkeypatched _get_key_path function that returns a non-existent path.
act: when _get_ssh_connection is called.
assert: _SSHError is raised.
"""
monkeypatch.setattr(
openstack_manager.OpenstackRunnerManager,
"_get_key_path",
MagicMock(return_value=Path("does-not-exist")),
)
mock_connection = MagicMock()

with pytest.raises(openstack_manager._SSHError) as exc:
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)

assert "Missing keyfile for server" in str(exc.getrepr())


@pytest.mark.usefixtures("skip_retry")
def test__get_ssh_connection_server_no_addresses(monkeypatch: pytest.MonkeyPatch):
"""
arrange: given a mock server instance with no server addresses.
act: when _get_ssh_connection is called.
assert: _SSHError is raised.
"""
monkeypatch.setattr(
openstack_manager.OpenstackRunnerManager,
"_get_key_path",
MagicMock(return_value=Path(".")),
)
mock_server = MagicMock()
mock_server.addresses = {}
mock_connection = MagicMock()
mock_connection.get_server.return_value = mock_server

with pytest.raises(openstack_manager._SSHError) as exc:
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)

assert "Missing keyfile for server" in str(exc.getrepr())


@pytest.mark.usefixtures("skip_retry")
@pytest.mark.parametrize(
"run",
[
pytest.param(MagicMock(side_effect=TimeoutError), id="timeout error"),
pytest.param(
MagicMock(return_value=factories.MockSSHRunResult(exited=1)), id="result not ok"
),
pytest.param(
MagicMock(return_value=factories.MockSSHRunResult(exited=0, stdout="")),
id="empty response",
),
],
)
def test__get_ssh_connection_server_no_valid_connections(
monkeypatch: pytest.MonkeyPatch, run: MagicMock
):
"""
arrange: given a monkeypatched Connection instance that returns invalid connections.
act: when _get_ssh_connection is called.
assert: _SSHError is raised.
"""
monkeypatch.setattr(
openstack_manager.OpenstackRunnerManager,
"_get_key_path",
MagicMock(return_value=Path(".")),
)
mock_server = MagicMock()
mock_server.addresses = {"test": [{"addr": "test-address"}]}
mock_connection = MagicMock()
mock_connection.get_server.return_value = mock_server
mock_ssh_connection = MagicMock()
mock_ssh_connection.run = run
monkeypatch.setattr(
openstack_manager, "SshConnection", MagicMock(return_value=mock_ssh_connection)
)

with pytest.raises(openstack_manager._SSHError) as exc:
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)

assert "Missing keyfile for server" in str(exc.getrepr())


@pytest.mark.usefixtures("skip_retry")
def test__get_ssh_connection_server(monkeypatch: pytest.MonkeyPatch):
"""
arrange: given monkeypatched SSH connection instance.
act: when _get_ssh_connection is called.
assert: the SSH connection instance is returned.
"""
monkeypatch.setattr(
openstack_manager.OpenstackRunnerManager,
"_get_key_path",
MagicMock(return_value=Path(".")),
)
mock_server = MagicMock()
mock_server.addresses = {"test": [{"addr": "test-address"}]}
mock_connection = MagicMock()
mock_connection.get_server.return_value = mock_server
mock_ssh_connection = MagicMock()
mock_ssh_connection.run = MagicMock(
return_value=factories.MockSSHRunResult(exited=0, stdout="hello world")
)
monkeypatch.setattr(
openstack_manager, "SshConnection", MagicMock(return_value=mock_ssh_connection)
)

assert (
openstack_manager.OpenstackRunnerManager._get_ssh_connection(
conn=mock_connection, server_name="test"
)
== mock_ssh_connection
)

0 comments on commit 7d307d6

Please sign in to comment.