diff --git a/test/test_nmt.py b/test/test_nmt.py index e4bb607b..5acda27d 100644 --- a/test/test_nmt.py +++ b/test/test_nmt.py @@ -1,6 +1,7 @@ import time import unittest +import can import canopen from canopen.nmt import NMT_STATES, NMT_COMMANDS from .util import SAMPLE_EDS @@ -36,23 +37,101 @@ def test_state_set_invalid(self): with self.assertRaisesRegex(ValueError, "INVALID"): self.nmt.state = "INVALID" - def test_state_get_invalid(self): - # This is a known bug; it will be changed in gh-500. - self.nmt._state = 255 - self.assertEqual(self.nmt.state, 255) + +class TestNmtMaster(unittest.TestCase): + NODE_ID = 2 + COB_ID = 0x700 + NODE_ID + PERIOD = 0.01 + TIMEOUT = PERIOD * 2 + + def setUp(self): + bus = can.ThreadSafeBus( + interface="virtual", + channel="test", + receive_own_messages=True, + ) + net = canopen.Network(bus) + net.connect() + with self.assertLogs(): + node = net.add_node(self.NODE_ID, SAMPLE_EDS) + + self.bus = bus + self.net = net + self.node = node + + def tearDown(self): + self.net.disconnect() + + def test_nmt_master_no_heartbeat(self): + with self.assertRaisesRegex(canopen.nmt.NmtError, "heartbeat"): + self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + with self.assertRaisesRegex(canopen.nmt.NmtError, "boot-up"): + self.node.nmt.wait_for_bootup(self.TIMEOUT) + + def test_nmt_master_on_heartbeat(self): + # Skip the special INITIALISING case. + for code in [st for st in NMT_STATES if st != 0]: + with self.subTest(code=code): + data = bytes([code]) + task = self.net.send_periodic(self.COB_ID, data, self.PERIOD) + self.addCleanup(task.stop) + actual = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + task.stop() + expected = NMT_STATES[code] + self.assertEqual(actual, expected) + + def test_nmt_master_on_heartbeat_initialising(self): + task = self.net.send_periodic(self.COB_ID, b"\x00", self.PERIOD) + self.addCleanup(task.stop) + self.node.nmt.wait_for_bootup(self.TIMEOUT) + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + self.assertEqual(state, "PRE-OPERATIONAL") + + def test_nmt_master_on_heartbeat_unknown_state(self): + task = self.net.send_periodic(self.COB_ID, b"\xcb", self.PERIOD) + self.addCleanup(task.stop) + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + # Expect the high bit to be masked out, and the resulting integer + # returned as it is. See gh-500 for the data type inconsistency. + self.assertEqual(state, 0x4b) + + def test_nmt_master_add_heartbeat_callback(self): + from threading import Event + event = Event() + state = None + def hook(st): + nonlocal state + state = st + event.set() + self.node.nmt.add_heartbeat_callback(hook) + self.net.send_message(self.COB_ID, bytes([127])) + self.assertTrue(event.wait(self.TIMEOUT)) + self.assertEqual(state, 127) + + def test_nmt_master_node_guarding(self): + self.node.nmt.start_node_guarding(self.PERIOD) + msg = self.bus.recv(self.TIMEOUT) + self.assertIsNotNone(msg) + self.assertEqual(msg.arbitration_id, self.COB_ID) + self.assertEqual(msg.dlc, 0) + + self.node.nmt.stop_node_guarding() + self.assertIsNone(self.bus.recv(self.TIMEOUT)) class TestNmtSlave(unittest.TestCase): def setUp(self): self.network1 = canopen.Network() self.network1.connect("test", interface="virtual") - self.remote_node = self.network1.add_node(2, SAMPLE_EDS) + with self.assertLogs(): + self.remote_node = self.network1.add_node(2, SAMPLE_EDS) self.network2 = canopen.Network() self.network2.connect("test", interface="virtual") - self.local_node = self.network2.create_node(2, SAMPLE_EDS) - self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) - self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) + with self.assertLogs(): + self.local_node = self.network2.create_node(2, SAMPLE_EDS) + self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) + self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) def tearDown(self): self.network1.disconnect()