Skip to content

Commit

Permalink
Resolve occasional exceptions in user logs by catching CancelledError (
Browse files Browse the repository at this point in the history
…#167)

* Add to resolve occasional exceptions in user logs by catching CancelledError
* Define some LAN test cases to verify exception handling by mocking underlying methods
  • Loading branch information
mill1000 authored Aug 27, 2024
1 parent e614ee5 commit 5790e9a
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 5 deletions.
4 changes: 4 additions & 0 deletions msmart/lan.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,10 @@ async def send(self, data: bytes, retries: int = RETRIES) -> List[bytes]:
# TODO could add a fatal flag to exception to trigger disconnect
self._disconnect()
raise e
except asyncio.CancelledError as e:
_LOGGER.warning("Read cancelled. Disconnecting.")
self._disconnect()
raise TimeoutError("Read cancelled.") from e

# Read any additional responses without blocking
async for resp in self._read_available():
Expand Down
121 changes: 116 additions & 5 deletions msmart/tests/test_lan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import logging
import unittest
import unittest.mock as mock

from msmart.lan import _LanProtocolV3, _Packet
from msmart.lan import (LAN, AuthenticationError, ProtocolError, _LanProtocol,
_LanProtocolV3, _Packet)


class TestEncodeDecode(unittest.IsolatedAsyncioTestCase):
# pylint: disable=protected-access

async def test_encode_packet_roundtrip(self) -> None:
def test_encode_packet_roundtrip(self) -> None:
"""Test that we can encode and decode a frame."""
FRAME = bytes.fromhex(
"aa21ac8d000000000003418100ff03ff000200000000000000000000000003016971")
Expand All @@ -17,7 +21,7 @@ async def test_encode_packet_roundtrip(self) -> None:
rx_frame = _Packet.decode(packet)
self.assertEqual(rx_frame, FRAME)

async def test_decode_packet(self) -> None:
def test_decode_packet(self) -> None:
"""Test that we can decode a packet to a frame."""
PACKET = bytes.fromhex(
"5a5a01116800208000000000000000000000000060ca0000000e0000000000000000000001000000c6a90377a364cb55af337259514c6f96bf084e8c7a899b50b68920cdea36cecf11c882a88861d1f46cd87912f201218c66151f0c9fbe5941c5384e707c36ff76")
Expand All @@ -28,7 +32,7 @@ async def test_decode_packet(self) -> None:
self.assertIsNotNone(frame)
self.assertEqual(frame, EXPECTED_FRAME)

async def test_decode_v3_packet(self) -> None:
def test_decode_v3_packet(self) -> None:
"""Test that we can decode a V3 packet to payload to a frame."""
PACKET = bytes.fromhex("8370008e2063ec2b8aeb17d4e3aff77094dde7fa65cf22671adf807f490a97b927347943626e9b4f58362cf34b97a0d641f8bf0c8fcbf69ad8cca131d2d7baa70ef048c5e3f3dc78da8af4598ff47aee762a0345c18815d91b50a24dedcacde0663c4ec5e73a963dc8bbbea9a593859996eb79dcfcc6a29b96262fcaa8ea6346366efea214e4a2e48caf83489475246b6fef90192b00")
LOCAL_KEY = bytes.fromhex(
Expand All @@ -52,7 +56,7 @@ async def test_decode_v3_packet(self) -> None:
self.assertIsNotNone(frame)
self.assertEqual(frame, EXPECTED_FRAME)

async def test_encode_packet_v3_roundtrip(self) -> None:
def test_encode_packet_v3_roundtrip(self) -> None:
"""Test that we can encode a frame to V3 packet and back to the same frame."""
FRAME = bytes.fromhex(
"aa23ac00000000000303c00145660000003c0010045c6800000000000000000000018426")
Expand Down Expand Up @@ -86,5 +90,112 @@ async def test_encode_packet_v3_roundtrip(self) -> None:
self.assertEqual(rx_frame, FRAME)


class TestProtocol(unittest.IsolatedAsyncioTestCase):
# pylint: disable=protected-access

async def test_send_exceptions(self) -> None:
"""Test exception handling for send method."""
# Create a dummy LAN object to test
lan = LAN("0.0.0.0", 0, 0)

# Mock the protocol object
lan._protocol = mock.MagicMock(spec=_LanProtocol)

# Mock the read_available method so call to send() will be reached
lan._read_available = mock.MagicMock()
lan._read_available.__aiter__.return_value = None

# Mock the disconnect method to ensure it's called
lan._disconnect = mock.MagicMock()

# Test that both types of timeouts bubble up as TimeoutError
# Test asyncio.TimeoutError
lan._protocol.read.side_effect = asyncio.TimeoutError
lan._disconnect.reset_mock()
with self.assertRaisesRegex(TimeoutError, "No response from host."):
await lan.send(bytes(0))

# Assert disconnect was called
lan._disconnect.assert_called_once()

# Test TimeoutError
lan._protocol.read.side_effect = TimeoutError
lan._disconnect.reset_mock()
with self.assertRaisesRegex(TimeoutError, "No response from host."):
await lan.send(bytes(0))

lan._disconnect.assert_called_once()

# Test cancelled exceptions log a warning and bubble up as TimeoutError
with self.assertLogs("msmart", logging.WARNING) as log:

lan._protocol.read.side_effect = asyncio.CancelledError
lan._disconnect.reset_mock()
with self.assertRaisesRegex(TimeoutError, "Read cancelled."):
await lan.send(bytes(0))

# Assert disconnect was called
lan._disconnect.assert_called_once()

# Assert timeouts were logged
self.assertRegex(" ".join(log.output),
".*Read cancelled. Disconnecting.*")

# Test ProtocolErrors bubbled up with a disconnect
lan._protocol.read.side_effect = ProtocolError
lan._disconnect.reset_mock()
with self.assertRaises(ProtocolError):
await lan.send(bytes(0))

# Assert disconnect was called
lan._disconnect.assert_called_once()

async def test_authenticate_exceptions(self) -> None:
"""Test exception handling for authenticate method."""
# Create a dummy LAN object to test
lan = LAN("0.0.0.0", 0, 0)

# Mock connect method to create a protocol
def _mock_connect() -> None:
lan._protocol = _LanProtocolV3()

# Mock connect/disconnect methods to check that they're called
lan._connect = mock.AsyncMock(side_effect=_mock_connect)
lan._disconnect = mock.MagicMock()

# Assert that exception is thrown is token and key are invalid
with self.assertRaisesRegex(AuthenticationError, "Token and key must be supplied."):
await lan.authenticate(key=None, token=None)

# Assert a disconnect->connect cycle occurred
lan._disconnect.assert_called_once()
lan._connect.assert_awaited_once()

# Assert that the expected protocol class was created
self.assertEqual(lan._protocol_version, 3)
self.assertIsInstance(lan._protocol, _LanProtocolV3)

# Mock connect method to create a protocol that throws
def _mock_connect_write_error() -> None:
lan._protocol = _LanProtocolV3()
lan._protocol.write = mock.MagicMock(side_effect=ProtocolError)

# Assert that a protocol error bubbles up as AuthenticationError
lan._connect.side_effect = _mock_connect_write_error
with self.assertRaises(AuthenticationError):
await lan.authenticate(key=bytes(10), token=bytes(10))

# Mock connect method to create a protocol that timeouts
def _mock_connect_timeout() -> None:
lan._protocol = _LanProtocolV3()
lan._protocol.authenticate = mock.MagicMock(
side_effect=TimeoutError)

# Assert that timeouts bubble up
lan._connect.side_effect = _mock_connect_timeout
with self.assertRaisesRegex(TimeoutError, "No response from host."):
await lan.authenticate(key=bytes(10), token=bytes(10))


if __name__ == "__main__":
unittest.main()

0 comments on commit 5790e9a

Please sign in to comment.