From d0522ad9ef60d026aef78573b7bfd59288477395 Mon Sep 17 00:00:00 2001 From: z80 Date: Fri, 27 Oct 2023 14:52:03 -0400 Subject: [PATCH] fix: fallback through ASTs also count underscore as part of variable names --- vyper_lsp/analyzer/AstAnalyzer.py | 2 + vyper_lsp/ast.py | 128 +++++++++++------------------- vyper_lsp/utils.py | 9 ++- 3 files changed, 56 insertions(+), 83 deletions(-) diff --git a/vyper_lsp/analyzer/AstAnalyzer.py b/vyper_lsp/analyzer/AstAnalyzer.py index 24c456c..819f949 100644 --- a/vyper_lsp/analyzer/AstAnalyzer.py +++ b/vyper_lsp/analyzer/AstAnalyzer.py @@ -92,6 +92,8 @@ def get_completions( return CompletionList(is_incomplete=False, items=[]) def hover_info(self, document: Document, pos: Position) -> Optional[str]: + if len(document.lines) < pos.line: + return None og_line = document.lines[pos.line] word = get_word_at_cursor(og_line, pos.character) full_word = get_expression_at_cursor(og_line, pos.character) diff --git a/vyper_lsp/ast.py b/vyper_lsp/ast.py index 6272089..5ff0c04 100644 --- a/vyper_lsp/ast.py +++ b/vyper_lsp/ast.py @@ -45,43 +45,51 @@ def build_ast(self, src: str): print(f"Error generating folded AST, {e}") pass - def get_enums(self) -> List[str]: - if self.ast_data_unfolded is None: + def get_descendants_from_best_ast(self, *args, **kwargs): + if self.ast_data_unfolded: + return self.ast_data_unfolded.get_descendants(*args, **kwargs) + elif self.ast_data: + return self.ast_data.get_descendants(*args, **kwargs) + elif self.ast_data_folded: + return self.ast_data_folded.get_descendants(*args, **kwargs) + else: return [] - return [ - node.name for node in self.ast_data_unfolded.get_descendants(nodes.EnumDef) - ] + def get_children_from_best_ast(self, *args, **kwargs): + if self.ast_data_unfolded: + return self.ast_data_unfolded.get_children(*args, **kwargs) + elif self.ast_data: + return self.ast_data.get_children(*args, **kwargs) + elif self.ast_data_folded: + return self.ast_data_folded.get_children(*args, **kwargs) + else: + return [] + + def get_enums(self) -> List[str]: + return [node.name for node in self.get_descendants_from_best_ast(nodes.EnumDef)] def get_structs(self) -> List[str]: if self.ast_data_unfolded is None: return [] return [ - node.name - for node in self.ast_data_unfolded.get_descendants(nodes.StructDef) + node.name for node in self.get_descendants_from_best_ast(nodes.StructDef) ] def get_events(self) -> List[str]: - if self.ast_data_unfolded is None: - return [] - return [ - node.name for node in self.ast_data_unfolded.get_descendants(nodes.EventDef) + node.name for node in self.get_descendants_from_best_ast(nodes.EventDef) ] def get_user_defined_types(self): - if self.ast_data_unfolded is None: - return [] - return [ node.name - for node in self.ast_data_unfolded.get_descendants( - self.custom_type_node_types - ) + for node in self.get_descendants_from_best_ast(self.custom_type_node_types) ] def get_constants(self): + # NOTE: Constants should be fetched from self.ast_data, they are missing + # from self.ast_data_unfolded and self.ast_data_folded if self.ast_data is None: return [] @@ -92,9 +100,6 @@ def get_constants(self): ] def get_enum_variants(self, enum: str): - if self.ast_data_unfolded is None: - return [] - enum_node = self.find_type_declaration_node_for_name(enum) if enum_node is None: return [] @@ -102,9 +107,6 @@ def get_enum_variants(self, enum: str): return [node.value.id for node in enum_node.get_children(nodes.Expr)] def get_struct_fields(self, struct: str): - if self.ast_data_unfolded is None: - return [] - struct_node = self.find_type_declaration_node_for_name(struct) if struct_node is None: return [] @@ -112,18 +114,16 @@ def get_struct_fields(self, struct: str): return [node.target.id for node in struct_node.get_children(nodes.AnnAssign)] def get_state_variables(self): + # NOTE: The state variables should be fetched from self.ast_data, they are + # missing from self.ast_data_unfolded and self.ast_data_folded when constants if self.ast_data is None: return [] - return [ node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl) ] def get_internal_function_nodes(self): - if self.ast_data_unfolded is None: - return [] - - function_nodes = self.ast_data_unfolded.get_descendants(nodes.FunctionDef) + function_nodes = self.get_descendants_from_best_ast(nodes.FunctionDef) inernal_nodes = [] for node in function_nodes: @@ -134,28 +134,21 @@ def get_internal_function_nodes(self): return inernal_nodes def get_internal_functions(self): - if self.ast_data_unfolded is None: - return [] - return [node.name for node in self.get_internal_function_nodes()] def find_nodes_referencing_internal_function(self, function: str): - if self.ast_data_unfolded is None: - return [] - - return self.ast_data_unfolded.get_descendants( + return self.get_descendants_from_best_ast( nodes.Call, {"func.attr": function, "func.value.id": "self"} ) def find_nodes_referencing_state_variable(self, variable: str): - if self.ast_data_unfolded is None: - return [] - - return self.ast_data_unfolded.get_descendants( + return self.get_descendants_from_best_ast( nodes.Attribute, {"value.id": "self", "attr": variable} ) def find_nodes_referencing_constant(self, constant: str): + # NOTE: Constants should be fetched from self.ast_data, they are missing + # from self.ast_data_unfolded and self.ast_data_folded if self.ast_data_unfolded is None: return [] @@ -169,9 +162,6 @@ def find_nodes_referencing_constant(self, constant: str): ] def get_attributes_for_symbol(self, symbol: str): - if self.ast_data_unfolded is None: - return [] - node = self.find_type_declaration_node_for_name(symbol) if node is None: return [] @@ -184,10 +174,7 @@ def get_attributes_for_symbol(self, symbol: str): return [] def find_function_declaration_node_for_name(self, function: str): - if self.ast_data_unfolded is None: - return None - - for node in self.ast_data_unfolded.get_descendants(nodes.FunctionDef): + for node in self.get_descendants_from_best_ast(nodes.FunctionDef): name_match = node.name == function not_interface_declaration = not isinstance( node.get_ancestor(), nodes.InterfaceDef @@ -198,6 +185,8 @@ def find_function_declaration_node_for_name(self, function: str): return None def find_state_variable_declaration_node_for_name(self, variable: str): + # NOTE: The state variables should be fetched from self.ast_data, they are + # missing from self.ast_data_unfolded and self.ast_data_folded when constants if self.ast_data is None: return None @@ -208,10 +197,7 @@ def find_state_variable_declaration_node_for_name(self, variable: str): return None def find_type_declaration_node_for_name(self, symbol: str): - if self.ast_data_unfolded is None: - return None - - for node in self.ast_data_unfolded.get_descendants(self.custom_type_node_types): + for node in self.get_descendants_from_best_ast(self.custom_type_node_types): if node.name == symbol: return node if isinstance(node, nodes.EnumDef): @@ -222,20 +208,17 @@ def find_type_declaration_node_for_name(self, symbol: str): return None def find_nodes_referencing_enum(self, enum: str): - if self.ast_data_unfolded is None: - return [] - return_nodes = [] - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.AnnAssign, {"annotation.id": enum} ): return_nodes.append(node) - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.Attribute, {"value.id": enum} ): return_nodes.append(node) - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.VariableDecl, {"annotation.id": enum} ): return_nodes.append(node) @@ -243,32 +226,24 @@ def find_nodes_referencing_enum(self, enum: str): return return_nodes def find_nodes_referencing_enum_variant(self, enum: str, variant: str): - if self.ast_data_unfolded is None: - return [] - - return self.ast_data_unfolded.get_descendants( + return self.get_descendants_from_best_ast( nodes.Attribute, {"attr": variant, "value.id": enum} ) def find_nodes_referencing_struct(self, struct: str): - if self.ast_data_unfolded is None: - return [] - return_nodes = [] - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.AnnAssign, {"annotation.id": struct} ): return_nodes.append(node) - for node in self.ast_data_unfolded.get_descendants( - nodes.Call, {"func.id": struct} - ): + for node in self.get_descendants_from_best_ast(nodes.Call, {"func.id": struct}): return_nodes.append(node) - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.VariableDecl, {"annotation.id": struct} ): return_nodes.append(node) - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( nodes.FunctionDef, {"returns.id": struct} ): return_nodes.append(node) @@ -276,20 +251,14 @@ def find_nodes_referencing_struct(self, struct: str): return return_nodes def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]: - if self.ast_data_unfolded is None: - return None - - for node in self.ast_data_unfolded.get_children(): + for node in self.get_children_from_best_ast(): if node.lineno <= pos.line and node.end_lineno >= pos.line: return node def find_nodes_referencing_symbol(self, symbol: str): - if self.ast_data is None: - return [] - return_nodes = [] - for node in self.ast_data.get_descendants(nodes.Name, {"id": symbol}): + for node in self.get_descendants_from_best_ast(nodes.Name, {"id": symbol}): parent = node.get_ancestor() if isinstance(parent, nodes.Dict): if symbol not in [key.id for key in parent.keys]: @@ -305,10 +274,7 @@ def find_nodes_referencing_symbol(self, symbol: str): return return_nodes def find_node_declaring_symbol(self, symbol: str): - if self.ast_data_unfolded is None: - return None - - for node in self.ast_data_unfolded.get_descendants( + for node in self.get_descendants_from_best_ast( (nodes.AnnAssign, nodes.VariableDecl) ): if node.target.id == symbol: diff --git a/vyper_lsp/utils.py b/vyper_lsp/utils.py index 8a01ec5..d6aa12e 100644 --- a/vyper_lsp/utils.py +++ b/vyper_lsp/utils.py @@ -38,16 +38,21 @@ def is_attribute_access(line): return bool(re.match(reg, line.strip())) +def is_word_char(char): + # true for alnum and underscore + return char.isalnum() or char == "_" + + def get_word_at_cursor(sentence: str, cursor_index: int) -> str: start = cursor_index end = cursor_index # Find the start of the word - while start > 0 and sentence[start - 1].isalnum(): + while start > 0 and is_word_char(sentence[start - 1]): start -= 1 # Find the end of the word - while end < len(sentence) and sentence[end].isalnum(): + while end < len(sentence) and is_word_char(sentence[end]): end += 1 # Extract the word