diff --git a/tests/test_symbol.py b/tests/test_symbol.py index a8a55d21a..038317143 100644 --- a/tests/test_symbol.py +++ b/tests/test_symbol.py @@ -263,3 +263,72 @@ def test_all_symbols(self): ] prog = elf_symbol_program(*elf_syms) self.assert_symbols_equal_unordered(prog.symbols(), syms) + + +class TestSymbolFinder(TestCase): + TEST_SYMS = [ + Symbol("one", 0xFFFF1000, 16, SymbolBinding.LOCAL, SymbolKind.FUNC), + Symbol("two", 0xFFFF2000, 16, SymbolBinding.GLOBAL, SymbolKind.FUNC), + ] + + def finder(self, arg_name, arg_address, arg_one): + self.called = True + res = [] + self.assertEqual(self.expected_name, arg_name) + self.assertEqual(self.expected_address, arg_address) + self.assertEqual(self.expected_one, arg_one) + for sym in self.TEST_SYMS: + if arg_name and sym.name == arg_name: + res.append(sym) + elif arg_address and sym.address <= arg_address < sym.address + sym.size: + res.append(sym) + elif not arg_name and not arg_address: + res.append(sym) + return res + + def setUp(self): + self.prog = Program() + self.prog.add_symbol_finder(self.finder) + self.called = False + + def expect_args(self, name, address, one): + self.expected_name = name + self.expected_address = address + self.expected_one = one + + def test_args_single_string(self): + self.expect_args("search_symbol", None, True) + with self.assertRaises(LookupError): + self.prog.symbol("search_symbol") + self.assertTrue(self.called) + + def test_args_single_int(self): + self.expect_args(None, 0xFF00, True) + with self.assertRaises(LookupError): + self.prog.symbol(0xFF00) + self.assertTrue(self.called) + + def test_single_with_result(self): + self.expect_args("one", None, True) + self.assertEqual(self.prog.symbol("one"), self.TEST_SYMS[0]) + self.assertTrue(self.called) + + def test_args_many_string(self): + self.expect_args("search_symbol", None, False) + self.assertEqual(self.prog.symbols("search_symbol"), []) + self.assertTrue(self.called) + + def test_args_many_int(self): + self.expect_args(None, 0xFF00, False) + self.assertEqual(self.prog.symbols(0xFF00), []) + self.assertTrue(self.called) + + def test_many_with_result(self): + self.expect_args(None, 0xFFFF2008, False) + self.assertEqual(self.prog.symbols(0xFFFF2008), [self.TEST_SYMS[1]]) + self.assertTrue(self.called) + + def test_many_without_filter(self): + self.expect_args(None, None, False) + self.assertEqual(self.prog.symbols(), self.TEST_SYMS) + self.assertTrue(self.called)