diff --git a/server/__init__.py b/server/__init__.py index aeefb87b0..0f127a35e 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -106,13 +106,13 @@ async def do_report_dirties(): games.clear_dirty() player_service.clear_dirty() - if len(dirty_queues) > 0: + if dirty_queues: await ctx.broadcast_raw( encode_queues(dirty_queues), lambda lobby_conn: lobby_conn.authenticated ) - if len(dirty_players) > 0: + if dirty_players: await ctx.broadcast_raw( encode_players(dirty_players), lambda lobby_conn: lobby_conn.authenticated diff --git a/server/async_functions.py b/server/async_functions.py new file mode 100644 index 000000000..3555db6a6 --- /dev/null +++ b/server/async_functions.py @@ -0,0 +1,24 @@ +""" +Some helper functions for common async tasks. +""" +import asyncio +from typing import Any, List + + +async def gather_without_exceptions( + tasks: List[asyncio.Task], + *exceptions: type, +) -> List[Any]: + """ + Call gather on a list of tasks, raising the first exception that dosen't + match any of the specified exception classes. + """ + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + # Check if this exception is an instance (maybe subclass) that + # should be ignored + for exc_type in exceptions: + if not isinstance(result, exc_type): + raise result + return results diff --git a/server/gameconnection.py b/server/gameconnection.py index 5dc5ae99c..4f38fa014 100644 --- a/server/gameconnection.py +++ b/server/gameconnection.py @@ -538,11 +538,14 @@ async def disconnect_all_peers(self): tasks.append(peer.send_DisconnectFromPeer(self.player.id)) - for result in await asyncio.gather(*tasks, return_exceptions=True): - if isinstance(result, Exception): + for fut in asyncio.as_completed(tasks): + try: + await fut + except Exception: self._logger.exception( "peer_sendDisconnectFromPeer failed for player %i", - self.player.id) + self.player.id + ) async def on_connection_lost(self): try: diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index 39a4adb75..64b44bfbb 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -18,6 +18,7 @@ from . import config from .abc.base_game import GameConnectionState +from .async_functions import gather_without_exceptions from .config import FAF_POLICY_SERVER_BASE_URL, TRACE, TWILIO_TTL from .db.models import ban, friends_and_foes, lobby_ban from .db.models import login as t_login @@ -334,16 +335,14 @@ async def command_admin(self, message): tasks = [] for player in self.player_service: - try: + # Check if object still exists: + # https://docs.python.org/3/library/weakref.html#weak-reference-objects + if player.lobby_connection is not None: tasks.append( player.lobby_connection.send_warning(message_text) ) - except AttributeError: - self._logger.debug("Failed to send broadcast to %s", player) - except Exception: - self._logger.exception("Failed to send broadcast to %s", player) - await asyncio.gather(*tasks, return_exceptions=True) + await gather_without_exceptions(tasks, ConnectionError) if self.player.mod: if action == "join_channel": @@ -353,18 +352,13 @@ async def command_admin(self, message): tasks = [] for user_id in user_ids: player = self.player_service[user_id] - if player: - try: - tasks.append(player.lobby_connection.send({ - "command": "social", - "autojoin": [channel] - })) - except AttributeError: - self._logger.debug("Failed to send join_channel to %s", player) - except Exception: - self._logger.exception("Failed to send join_channel to %s", player) - - await asyncio.gather(*tasks, return_exceptions=True) + if player and player.lobby_connection is not None: + tasks.append(player.lobby_connection.send({ + "command": "social", + "autojoin": [channel] + })) + + await gather_without_exceptions(tasks, ConnectionError) async def check_user_login(self, conn, username, password): # TODO: Hash passwords server-side so the hashing actually *does* something. diff --git a/server/player_service.py b/server/player_service.py index c23161dde..41e9beccc 100644 --- a/server/player_service.py +++ b/server/player_service.py @@ -149,7 +149,7 @@ async def update_data(self): async def broadcast_shutdown(self): tasks = [] for player in self: - try: + if player.lobby_connection is not None: tasks.append( player.lobby_connection.send_warning( "The server has been shut down for maintenance, " @@ -158,8 +158,11 @@ async def broadcast_shutdown(self): "We apologize for this interruption." ) ) + + for fut in asyncio.as_completed(tasks): + try: + await fut except Exception as ex: self._logger.debug( "Could not send shutdown message to %s: %s", player, ex ) - await asyncio.gather(*tasks, return_exceptions=True) diff --git a/server/servercontext.py b/server/servercontext.py index 524d50135..751434ccd 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -1,9 +1,11 @@ import asyncio import server -from server.decorators import with_logger -from server.protocol import QDataStreamProtocol -from server.types import Address + +from .async_functions import gather_without_exceptions +from .decorators import with_logger +from .protocol import QDataStreamProtocol +from .types import Address @with_logger @@ -58,7 +60,7 @@ async def broadcast_raw(self, message, validate_fn=lambda a: True): if validate_fn(conn): tasks.append(proto.send_raw(message)) - await asyncio.gather(*tasks) + await gather_without_exceptions(tasks, ConnectionError) async def client_connected(self, stream_reader, stream_writer): self._logger.debug("%s: Client connected", self) @@ -72,9 +74,8 @@ async def client_connected(self, stream_reader, stream_writer): message = await protocol.read_message() with server.stats.timer('connection.on_message_received'): await connection.on_message_received(message) - except ConnectionResetError: - pass - except ConnectionAbortedError: + except ConnectionError: + # User disconnected. Proceed to finally block for cleanup. pass except TimeoutError: pass diff --git a/tests/conftest.py b/tests/conftest.py index deb371192..7a5bd00b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import pytest from asynctest import CoroutineMock from server.api.api_accessor import ApiAccessor -from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER +from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER, TRACE from server.db import FAFDatabase from server.game_service import GameService from server.geoip_service import GeoIpService @@ -25,7 +25,7 @@ from server.rating import RatingType from tests.utils import MockDatabase -logging.getLogger().setLevel(logging.DEBUG) +logging.getLogger().setLevel(TRACE) def pytest_addoption(parser): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index b379874f1..ca48a8737 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -1,12 +1,14 @@ import asyncio import hashlib import logging +from collections import defaultdict from typing import Any, Callable, Dict, Tuple from unittest import mock import pytest from aiohttp import web from server import GameService, PlayerService, run_lobby_server +from server.db.models import login from server.ladder_service import LadderService from server.protocol import QDataStreamProtocol @@ -99,6 +101,27 @@ async def start_app(): event_loop.run_until_complete(runner.cleanup()) +@pytest.fixture +async def tmp_user(database): + user_ids = defaultdict(lambda: 1) + password_plain = "foo" + password = hashlib.sha256(password_plain.encode()).hexdigest() + + async def make_user(name="TempUser"): + user_id = user_ids[name] + login_name = f"{name}{user_id}" + async with database.acquire() as conn: + await conn.execute(login.insert().values( + login=login_name, + email=f"{login_name}@example.com", + password=password, + )) + user_ids[name] += 1 + return login_name, password_plain + + return make_user + + async def connect_client(server) -> QDataStreamProtocol: return QDataStreamProtocol( *(await asyncio.open_connection(*server.sockets[0].getsockname())) diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py index 424ae2956..e94a87c00 100644 --- a/tests/integration_tests/test_server.py +++ b/tests/integration_tests/test_server.py @@ -1,4 +1,5 @@ import asyncio +import logging import pytest from server.db.models import ban @@ -123,11 +124,6 @@ async def test_server_double_login(lobby_server): 'text': 'You have been signed out because you signed in elsewhere.' } - lobby_server.close() - proto.close() - proto2.close() - await lobby_server.wait_closed() - @fast_forward(50) async def test_ping_message(lobby_server): @@ -148,8 +144,6 @@ async def test_player_info_broadcast(lobby_server): p2, lambda m: 'player_info' in m.values() and any(map(lambda d: ('login', 'test') in d.items(), m['players'])) ) - p1.close() - p2.close() @pytest.mark.slow @@ -238,6 +232,105 @@ async def test_game_info_broadcast_to_friends(lobby_server): await asyncio.wait_for(read_until_command(proto3, "game_info"), 0.2) +@fast_forward(300) +async def test_game_info_broadcast_on_connection_error( + event_loop, lobby_server, tmp_user, ladder_service, game_service, caplog +): + """ + Causes connection errors in `do_report_dirties` which in turn will cause + closed games not to be cleaned up if the errors aren't handled properly. + """ + # This test causes way to much logging output otherwise + caplog.set_level(logging.WARNING) + + NUM_HOSTS = 10 + NUM_PLAYERS_DC = 20 + NUM_TIMES_DC = 10 + + # Number of times that games will be rehosted + NUM_GAME_REHOSTS = 20 + + # Set up our game hosts + host_protos = [] + for _ in range(NUM_HOSTS): + _, _, proto = await connect_and_sign_in( + await tmp_user("Host"), lobby_server + ) + host_protos.append(proto) + await asyncio.gather(*( + read_until_command(proto, "game_info") + for proto in host_protos + )) + + # Set up our players that will disconnect + dc_players = [await tmp_user("Disconnecter") for _ in range(NUM_PLAYERS_DC)] + + # Host the games + async def host(proto): + await proto.send_message({ + "command": "game_host", + "title": "A dirty game", + "mod": "faf", + "visibility": "public" + }) + msg = await read_until_command(proto, "game_launch") + + # Pretend like ForgedAlliance.exe opened + await proto.send_message({ + "target": "game", + "command": "GameState", + "args": ["Idle"] + }) + return msg + + async def spam_game_changes(proto): + for _ in range(NUM_GAME_REHOSTS): + # Host + await host(proto) + await asyncio.sleep(0.1) + # Leave the game + await proto.send_message({ + "target": "game", + "command": "GameState", + "args": ["Ended"] + }) + + tasks = [] + for proto in host_protos: + tasks.append(spam_game_changes(proto)) + + async def do_dc_player(player): + for _ in range(NUM_TIMES_DC): + _, _, proto = await connect_and_sign_in(player, lobby_server) + await read_until_command(proto, "game_info") + await asyncio.sleep(0.1) + proto.close() + + async def do_dc_players(): + await asyncio.gather(*( + do_dc_player(player) + for player in dc_players + )) + + tasks.append(do_dc_players()) + + # Let the guests cause a bunch of broadcasts to happen while the other + # players are disconnecting + await asyncio.gather(*tasks) + + # Wait for games to be cleaned up + for proto in host_protos: + proto.close() + ladder_service.shutdown_queues() + + # Wait for games to time out if they need to + await asyncio.sleep(35) + + # Ensure that the connection errors haven't prevented games from being + # cleaned up. + assert len(game_service.all_games) == 0 + + @pytest.mark.parametrize("user", [ ("test", "test_password"), ("ban_revoked", "ban_revoked"), diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py index 5c5f787a3..fa1f92425 100644 --- a/tests/integration_tests/test_servercontext.py +++ b/tests/integration_tests/test_servercontext.py @@ -2,8 +2,9 @@ from unittest import mock import pytest -from asynctest import exhaust_callbacks +from asynctest import CoroutineMock, exhaust_callbacks from server import ServerContext, fake_statsd +from server.lobbyconnection import LobbyConnection from server.protocol import QDataStreamProtocol pytestmark = pytest.mark.asyncio @@ -14,20 +15,18 @@ def mock_server(event_loop): class MockServer: def __init__(self): self.protocol, self.peername, self.user_agent = None, None, None + self.on_connection_lost = CoroutineMock() - @asyncio.coroutine - def on_connection_made(self, protocol, peername): + async def on_connection_made(self, protocol, peername): self.protocol = protocol self.peername = peername self.protocol.writer.write_eof() self.protocol.reader.feed_eof() - @asyncio.coroutine - def on_message_received(self, msg): + async def on_message_received(self, msg): pass - mock_server = MockServer() - mock_server.on_connection_lost = mock.Mock() - return mock_server + + return MockServer() @pytest.fixture @@ -37,11 +36,32 @@ def mock_context(event_loop, request, mock_server): def fin(): ctx.close() request.addfinalizer(fin) - return event_loop.run_until_complete(ctx.listen('127.0.0.1', None)) + return event_loop.run_until_complete(ctx.listen('127.0.0.1', None)), ctx + + +@pytest.fixture +def context(event_loop, request): + def make_connection() -> LobbyConnection: + return LobbyConnection( + database=mock.Mock(), + geoip=mock.Mock(), + games=mock.Mock(), + nts_client=mock.Mock(), + players=mock.Mock(), + ladder_service=mock.Mock() + ) + + ctx = ServerContext(make_connection, name='TestServer') + + def fin(): + ctx.close() + request.addfinalizer(fin) + return event_loop.run_until_complete(ctx.listen('127.0.0.1', None)), ctx async def test_serverside_abort(event_loop, mock_context, mock_server): - (reader, writer) = await asyncio.open_connection(*mock_context.sockets[0].getsockname()) + srv, ctx = mock_context + (reader, writer) = await asyncio.open_connection(*srv.sockets[0].getsockname()) proto = QDataStreamProtocol(reader, writer) await proto.send_message({"some_junk": True}) await exhaust_callbacks(event_loop) @@ -49,6 +69,19 @@ async def test_serverside_abort(event_loop, mock_context, mock_server): mock_server.on_connection_lost.assert_any_call() +async def test_broadcast_raw(context, mock_server): + srv, ctx = context + (reader, writer) = await asyncio.open_connection( + *srv.sockets[0].getsockname() + ) + writer.close() + + # If connection errors aren't handled, this should fail due to a + # ConnectionError + for _ in range(20): + await ctx.broadcast_raw(b"Some bytes") + + async def test_server_fake_statsd(): dummy = fake_statsd.DummyConnection() # Verify that no exceptions are raised diff --git a/tests/unit_tests/test_async_functions.py b/tests/unit_tests/test_async_functions.py new file mode 100644 index 000000000..e68d60760 --- /dev/null +++ b/tests/unit_tests/test_async_functions.py @@ -0,0 +1,46 @@ +import pytest +from asynctest import CoroutineMock +from server.async_functions import gather_without_exceptions + +pytestmark = pytest.mark.asyncio + + +class CustomError(Exception): + pass + + +async def raises_connection_error(): + raise ConnectionError("Test ConnectionError") + + +async def raises_connection_reset_error(): + raise ConnectionResetError("Test ConnectionResetError") + + +async def raises_custom_error(): + raise CustomError("Test Exception") + + +async def test_gather_without_exceptions(): + completes_correctly = CoroutineMock() + + with pytest.raises(CustomError): + await gather_without_exceptions([ + raises_connection_error(), + raises_custom_error(), + completes_correctly() + ], ConnectionError) + + completes_correctly.assert_called_once() + + +async def test_gather_without_exceptions_subclass(): + completes_correctly = CoroutineMock() + + await gather_without_exceptions([ + raises_connection_error(), + raises_connection_reset_error(), + completes_correctly() + ], ConnectionError) + + completes_correctly.assert_called_once() diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index 165ac284f..70ee1ae2c 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -782,7 +782,7 @@ async def test_broadcast_during_disconnect(lobbyconnection: LobbyConnection, moc player.lobby_connection.send_warning.assert_called_with("This is a test message") -async def test_broadcast_error(lobbyconnection: LobbyConnection, mocker): +async def test_broadcast_connection_error(lobbyconnection: LobbyConnection, mocker): player = mocker.patch.object(lobbyconnection, 'player') player.login = 'Sheeo' player.admin = True @@ -790,7 +790,7 @@ async def test_broadcast_error(lobbyconnection: LobbyConnection, mocker): tuna = mock.Mock() tuna.id = 55 tuna.lobby_connection = asynctest.create_autospec(LobbyConnection) - tuna.lobby_connection.send_warning = Mock(side_effect=Exception("Some error")) + tuna.lobby_connection.send_warning.side_effect = ConnectionError("Some error") lobbyconnection.player_service = [player, tuna] # This should not leak any exceptions