Skip to content

Commit

Permalink
Bugfix: Handle connection errors (#524)
Browse files Browse the repository at this point in the history
* Add some exception handling for gathered ConnectionErrors

* Refactor length checks to be more pythonic

* Add test for checking that broadcast_raw ignores ConnectionErrors

* Use asyncio.as_completed to handle exceptions per future

* Check weak references explicitly

* Add test for gather_without_exceptions

* Adjust test to check for ConnectionError's specifically

* Fix gather_without_exception subclass behavior

* Log at TRACE level during tests

* Added a fixture for creating temporary users on demand

* Test for generating connection errors

* Handle all ConnectionErrors the same

* Add weak reference check to broadcast_shutdown
  • Loading branch information
Askaholic authored Feb 2, 2020
1 parent 0ea040f commit 273ff76
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 53 deletions.
4 changes: 2 additions & 2 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ async def do_report_dirties():
games.clear_dirty()
player_service.clear_dirty()

if len(dirty_queues) > 0:
if dirty_queues:
await ctx.broadcast_raw(
encode_queues(dirty_queues),
lambda lobby_conn: lobby_conn.authenticated
)

if len(dirty_players) > 0:
if dirty_players:
await ctx.broadcast_raw(
encode_players(dirty_players),
lambda lobby_conn: lobby_conn.authenticated
Expand Down
24 changes: 24 additions & 0 deletions server/async_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Some helper functions for common async tasks.
"""
import asyncio
from typing import Any, List


async def gather_without_exceptions(
tasks: List[asyncio.Task],
*exceptions: type,
) -> List[Any]:
"""
Call gather on a list of tasks, raising the first exception that dosen't
match any of the specified exception classes.
"""
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
# Check if this exception is an instance (maybe subclass) that
# should be ignored
for exc_type in exceptions:
if not isinstance(result, exc_type):
raise result
return results
9 changes: 6 additions & 3 deletions server/gameconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,14 @@ async def disconnect_all_peers(self):

tasks.append(peer.send_DisconnectFromPeer(self.player.id))

for result in await asyncio.gather(*tasks, return_exceptions=True):
if isinstance(result, Exception):
for fut in asyncio.as_completed(tasks):
try:
await fut
except Exception:
self._logger.exception(
"peer_sendDisconnectFromPeer failed for player %i",
self.player.id)
self.player.id
)

async def on_connection_lost(self):
try:
Expand Down
30 changes: 12 additions & 18 deletions server/lobbyconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from . import config
from .abc.base_game import GameConnectionState
from .async_functions import gather_without_exceptions
from .config import FAF_POLICY_SERVER_BASE_URL, TRACE, TWILIO_TTL
from .db.models import ban, friends_and_foes, lobby_ban
from .db.models import login as t_login
Expand Down Expand Up @@ -334,16 +335,14 @@ async def command_admin(self, message):

tasks = []
for player in self.player_service:
try:
# Check if object still exists:
# https://docs.python.org/3/library/weakref.html#weak-reference-objects
if player.lobby_connection is not None:
tasks.append(
player.lobby_connection.send_warning(message_text)
)
except AttributeError:
self._logger.debug("Failed to send broadcast to %s", player)
except Exception:
self._logger.exception("Failed to send broadcast to %s", player)

await asyncio.gather(*tasks, return_exceptions=True)
await gather_without_exceptions(tasks, ConnectionError)

if self.player.mod:
if action == "join_channel":
Expand All @@ -353,18 +352,13 @@ async def command_admin(self, message):
tasks = []
for user_id in user_ids:
player = self.player_service[user_id]
if player:
try:
tasks.append(player.lobby_connection.send({
"command": "social",
"autojoin": [channel]
}))
except AttributeError:
self._logger.debug("Failed to send join_channel to %s", player)
except Exception:
self._logger.exception("Failed to send join_channel to %s", player)

await asyncio.gather(*tasks, return_exceptions=True)
if player and player.lobby_connection is not None:
tasks.append(player.lobby_connection.send({
"command": "social",
"autojoin": [channel]
}))

await gather_without_exceptions(tasks, ConnectionError)

async def check_user_login(self, conn, username, password):
# TODO: Hash passwords server-side so the hashing actually *does* something.
Expand Down
7 changes: 5 additions & 2 deletions server/player_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def update_data(self):
async def broadcast_shutdown(self):
tasks = []
for player in self:
try:
if player.lobby_connection is not None:
tasks.append(
player.lobby_connection.send_warning(
"The server has been shut down for maintenance, "
Expand All @@ -158,8 +158,11 @@ async def broadcast_shutdown(self):
"We apologize for this interruption."
)
)

for fut in asyncio.as_completed(tasks):
try:
await fut
except Exception as ex:
self._logger.debug(
"Could not send shutdown message to %s: %s", player, ex
)
await asyncio.gather(*tasks, return_exceptions=True)
15 changes: 8 additions & 7 deletions server/servercontext.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio

import server
from server.decorators import with_logger
from server.protocol import QDataStreamProtocol
from server.types import Address

from .async_functions import gather_without_exceptions
from .decorators import with_logger
from .protocol import QDataStreamProtocol
from .types import Address


@with_logger
Expand Down Expand Up @@ -58,7 +60,7 @@ async def broadcast_raw(self, message, validate_fn=lambda a: True):
if validate_fn(conn):
tasks.append(proto.send_raw(message))

await asyncio.gather(*tasks)
await gather_without_exceptions(tasks, ConnectionError)

async def client_connected(self, stream_reader, stream_writer):
self._logger.debug("%s: Client connected", self)
Expand All @@ -72,9 +74,8 @@ async def client_connected(self, stream_reader, stream_writer):
message = await protocol.read_message()
with server.stats.timer('connection.on_message_received'):
await connection.on_message_received(message)
except ConnectionResetError:
pass
except ConnectionAbortedError:
except ConnectionError:
# User disconnected. Proceed to finally block for cleanup.
pass
except TimeoutError:
pass
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
from asynctest import CoroutineMock
from server.api.api_accessor import ApiAccessor
from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER
from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER, TRACE
from server.db import FAFDatabase
from server.game_service import GameService
from server.geoip_service import GeoIpService
Expand All @@ -25,7 +25,7 @@
from server.rating import RatingType
from tests.utils import MockDatabase

logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger().setLevel(TRACE)


def pytest_addoption(parser):
Expand Down
23 changes: 23 additions & 0 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import hashlib
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Tuple
from unittest import mock

import pytest
from aiohttp import web
from server import GameService, PlayerService, run_lobby_server
from server.db.models import login
from server.ladder_service import LadderService
from server.protocol import QDataStreamProtocol

Expand Down Expand Up @@ -99,6 +101,27 @@ async def start_app():
event_loop.run_until_complete(runner.cleanup())


@pytest.fixture
async def tmp_user(database):
user_ids = defaultdict(lambda: 1)
password_plain = "foo"
password = hashlib.sha256(password_plain.encode()).hexdigest()

async def make_user(name="TempUser"):
user_id = user_ids[name]
login_name = f"{name}{user_id}"
async with database.acquire() as conn:
await conn.execute(login.insert().values(
login=login_name,
email=f"{login_name}@example.com",
password=password,
))
user_ids[name] += 1
return login_name, password_plain

return make_user


async def connect_client(server) -> QDataStreamProtocol:
return QDataStreamProtocol(
*(await asyncio.open_connection(*server.sockets[0].getsockname()))
Expand Down
107 changes: 100 additions & 7 deletions tests/integration_tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging

import pytest
from server.db.models import ban
Expand Down Expand Up @@ -123,11 +124,6 @@ async def test_server_double_login(lobby_server):
'text': 'You have been signed out because you signed in elsewhere.'
}

lobby_server.close()
proto.close()
proto2.close()
await lobby_server.wait_closed()


@fast_forward(50)
async def test_ping_message(lobby_server):
Expand All @@ -148,8 +144,6 @@ async def test_player_info_broadcast(lobby_server):
p2, lambda m: 'player_info' in m.values()
and any(map(lambda d: ('login', 'test') in d.items(), m['players']))
)
p1.close()
p2.close()


@pytest.mark.slow
Expand Down Expand Up @@ -238,6 +232,105 @@ async def test_game_info_broadcast_to_friends(lobby_server):
await asyncio.wait_for(read_until_command(proto3, "game_info"), 0.2)


@fast_forward(300)
async def test_game_info_broadcast_on_connection_error(
event_loop, lobby_server, tmp_user, ladder_service, game_service, caplog
):
"""
Causes connection errors in `do_report_dirties` which in turn will cause
closed games not to be cleaned up if the errors aren't handled properly.
"""
# This test causes way to much logging output otherwise
caplog.set_level(logging.WARNING)

NUM_HOSTS = 10
NUM_PLAYERS_DC = 20
NUM_TIMES_DC = 10

# Number of times that games will be rehosted
NUM_GAME_REHOSTS = 20

# Set up our game hosts
host_protos = []
for _ in range(NUM_HOSTS):
_, _, proto = await connect_and_sign_in(
await tmp_user("Host"), lobby_server
)
host_protos.append(proto)
await asyncio.gather(*(
read_until_command(proto, "game_info")
for proto in host_protos
))

# Set up our players that will disconnect
dc_players = [await tmp_user("Disconnecter") for _ in range(NUM_PLAYERS_DC)]

# Host the games
async def host(proto):
await proto.send_message({
"command": "game_host",
"title": "A dirty game",
"mod": "faf",
"visibility": "public"
})
msg = await read_until_command(proto, "game_launch")

# Pretend like ForgedAlliance.exe opened
await proto.send_message({
"target": "game",
"command": "GameState",
"args": ["Idle"]
})
return msg

async def spam_game_changes(proto):
for _ in range(NUM_GAME_REHOSTS):
# Host
await host(proto)
await asyncio.sleep(0.1)
# Leave the game
await proto.send_message({
"target": "game",
"command": "GameState",
"args": ["Ended"]
})

tasks = []
for proto in host_protos:
tasks.append(spam_game_changes(proto))

async def do_dc_player(player):
for _ in range(NUM_TIMES_DC):
_, _, proto = await connect_and_sign_in(player, lobby_server)
await read_until_command(proto, "game_info")
await asyncio.sleep(0.1)
proto.close()

async def do_dc_players():
await asyncio.gather(*(
do_dc_player(player)
for player in dc_players
))

tasks.append(do_dc_players())

# Let the guests cause a bunch of broadcasts to happen while the other
# players are disconnecting
await asyncio.gather(*tasks)

# Wait for games to be cleaned up
for proto in host_protos:
proto.close()
ladder_service.shutdown_queues()

# Wait for games to time out if they need to
await asyncio.sleep(35)

# Ensure that the connection errors haven't prevented games from being
# cleaned up.
assert len(game_service.all_games) == 0


@pytest.mark.parametrize("user", [
("test", "test_password"),
("ban_revoked", "ban_revoked"),
Expand Down
Loading

0 comments on commit 273ff76

Please sign in to comment.