diff --git a/tests/test_ast.py b/tests/test_ast.py index ad80263..388b3ae 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -41,6 +41,18 @@ def test_get_enum_variants(ast): """ ast.build_ast(src) assert ast.get_enum_variants("Foo") == ["Bar", "Baz"] + assert ast.get_enum_variants("Bar") == [] + + +def test_get_struct_fields(ast): + src = """ +struct Foo: + bar: uint256 + baz: address + """ + ast.build_ast(src) + assert ast.get_struct_fields("Foo") == ["bar", "baz"] + assert ast.get_struct_fields("Bar") == [] def test_get_events(ast): @@ -272,6 +284,7 @@ def foo(): ast.build_ast(src) assert ast.get_attributes_for_symbol("Foo") == ["bar", "baz"] assert ast.get_attributes_for_symbol("Bar") == ["Baz"] + assert ast.get_attributes_for_symbol("Baz") == [] def test_find_function_declaration_node_for_name(ast): @@ -290,6 +303,7 @@ def bar(): assert ( ast.find_function_declaration_node_for_name("foo").lineno == 3 ) # line number of def foo(), counting first newline + assert ast.find_function_declaration_node_for_name("baz") is None def test_find_state_variable_declaration_node_for_name(ast): @@ -314,6 +328,7 @@ def foo(): assert ( ast.find_state_variable_declaration_node_for_name("z").lineno == 4 ) # line number of z: bool, counting first newline + assert ast.find_state_variable_declaration_node_for_name("baz") is None def test_find_type_declaration_node_for_name(ast): @@ -336,6 +351,10 @@ def foo(): assert ( ast.find_type_declaration_node_for_name("Bar").lineno == 6 ) # line number of enum Bar, counting first newline + assert ( + ast.find_type_declaration_node_for_name("Baz").lineno == 7 + ) # line number of Baz, counting first newline + assert ast.find_type_declaration_node_for_name("baz") is None def test_find_top_level_node_at_position(ast): @@ -383,3 +402,33 @@ def foo(): assert ( ast.find_node_declaring_symbol("y").lineno == 3 ) # line number of y: address, counting first newline + + +def test_ast_no_data_returns_empty_and_none(ast: AST): + ast.ast_data = None + ast.ast_data_folded = None + ast.ast_data_unfolded = None + + assert ast.get_constants() == [] + assert ast.get_enums() == [] + assert ast.get_enum_variants("Foo") == [] + assert ast.get_events() == [] + assert ast.get_structs() == [] + assert ast.get_user_defined_types() == [] + assert ast.get_state_variables() == [] + assert ast.get_internal_functions() == [] + assert ast.get_struct_fields("Foo") == [] + assert ast.get_internal_function_nodes() == [] + assert ast.find_nodes_referencing_internal_function("foo") == [] + assert ast.find_nodes_referencing_state_variable("x") == [] + assert ast.find_nodes_referencing_constant("x") == [] + assert ast.find_nodes_referencing_enum("Foo") == [] + assert ast.find_nodes_referencing_enum_variant("Foo", "Bar") == [] + assert ast.find_nodes_referencing_struct("Foo") == [] + assert ast.find_nodes_referencing_symbol("x") == [] + assert ast.get_attributes_for_symbol("Foo") == [] + assert ast.find_function_declaration_node_for_name("foo") is None + assert ast.find_state_variable_declaration_node_for_name("x") is None + assert ast.find_type_declaration_node_for_name("Foo") is None + assert ast.find_top_level_node_at_pos(Position(line=0, character=0)) is None + assert ast.find_node_declaring_symbol("x") is None diff --git a/tests/test_navigation.py b/tests/test_navigation.py index 30c69c6..a7ffd9b 100644 --- a/tests/test_navigation.py +++ b/tests/test_navigation.py @@ -60,6 +60,30 @@ def test_find_references_storage_var(doc, navigator): assert len(references) == 3 +def test_find_references_constant(doc, navigator): + pos = Position(line=16, character=0) + references = navigator.find_references(doc, pos) + assert len(references) == 2 + + +def test_find_references_function_local_var(doc, navigator): + pos = Position(line=20, character=5) + references = navigator.find_references(doc, pos) + assert len(references) == 1 + + +def test_find_internal_fn_implementation(doc, navigator: ASTNavigator): + pos = Position(line=35, character=17) + implementation = navigator.find_implementation(doc, pos) + assert implementation and implementation.start.line == 40 + + +def test_find_interface_fn_implementation(doc, navigator: ASTNavigator): + pos = Position(line=51, character=10) + implementation = navigator.find_implementation(doc, pos) + assert implementation and implementation.start.line == 57 + + def test_find_declaration_constant(doc, navigator: ASTNavigator): pos = Position(line=20, character=19) declaration = navigator.find_declaration(doc, pos) @@ -105,3 +129,15 @@ def test_find_declaration_storage_var(doc, navigator: ASTNavigator): pos = Position(line=26, character=9) declaration = navigator.find_declaration(doc, pos) assert declaration and declaration.start.line == 13 + + +def test_find_declaration_function_local_var(doc, navigator: ASTNavigator): + pos = Position(line=26, character=13) + declaration = navigator.find_declaration(doc, pos) + assert declaration and declaration.start.line == 22 + + +def test_find_implementation_variable_returns_none(doc, navigator: ASTNavigator): + pos = Position(line=26, character=13) + implementation = navigator.find_implementation(doc, pos) + assert implementation is None diff --git a/vyper_lsp/ast.py b/vyper_lsp/ast.py index 933182b..6272089 100644 --- a/vyper_lsp/ast.py +++ b/vyper_lsp/ast.py @@ -99,7 +99,7 @@ def get_enum_variants(self, enum: str): if enum_node is None: return [] - return [node.value.id for node in enum_node.get_children()] + return [node.value.id for node in enum_node.get_children(nodes.Expr)] def get_struct_fields(self, struct: str): if self.ast_data_unfolded is None: diff --git a/vyper_lsp/navigation.py b/vyper_lsp/navigation.py index a669f5f..625630c 100644 --- a/vyper_lsp/navigation.py +++ b/vyper_lsp/navigation.py @@ -7,6 +7,8 @@ from vyper_lsp.ast import AST from vyper_lsp.utils import get_expression_at_cursor, get_word_at_cursor +ENUM_VARIANT_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)") + # this class should abstract away all the AST stuff # and just provide a simple interface for navigation @@ -137,7 +139,6 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range] word = get_word_at_cursor(line_content, pos.character) full_word = get_expression_at_cursor(line_content, pos.character) top_level_node = self.ast.find_top_level_node_at_pos(pos) - node = None # Determine the type of declaration and find it if full_word.startswith("self."): @@ -150,8 +151,10 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range] elif word in self.ast.get_constants(): return self.find_state_variable_declaration(word) elif isinstance(top_level_node, FunctionDef): - node = self.find_variable_declaration_under_node(top_level_node, word) - if not node: + range = self.find_variable_declaration_under_node(top_level_node, word) + if range: + return range + else: match = ENUM_VARIANT_PATTERN.match(full_word) if ( match @@ -160,9 +163,6 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range] ): return self.find_type_declaration(match.group(1)) - if node: - return _create_range(node) - def find_implementation(self, document: Document, pos: Position) -> Optional[Range]: og_line = document.lines[pos.line] word = get_word_at_cursor(og_line, pos.character) @@ -178,13 +178,3 @@ def find_implementation(self, document: Document, pos: Position) -> Optional[Ran return self.find_function_declaration(word) else: return None - - -ENUM_VARIANT_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)") - - -def _create_range(node) -> Range: - return Range( - start=Position(line=node.lineno - 1, character=node.col_offset), - end=Position(line=node.end_lineno - 1, character=node.end_col_offset), - )