Skip to content

Commit

Permalink
fix: fallback through ASTs
Browse files Browse the repository at this point in the history
also count underscore as part of variable names
  • Loading branch information
z80dev committed Oct 27, 2023
1 parent f19aaec commit d0522ad
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 83 deletions.
2 changes: 2 additions & 0 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def get_completions(
return CompletionList(is_incomplete=False, items=[])

def hover_info(self, document: Document, pos: Position) -> Optional[str]:
if len(document.lines) < pos.line:
return None
og_line = document.lines[pos.line]
word = get_word_at_cursor(og_line, pos.character)
full_word = get_expression_at_cursor(og_line, pos.character)
Expand Down
128 changes: 47 additions & 81 deletions vyper_lsp/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,43 +45,51 @@ def build_ast(self, src: str):
print(f"Error generating folded AST, {e}")
pass

def get_enums(self) -> List[str]:
if self.ast_data_unfolded is None:
def get_descendants_from_best_ast(self, *args, **kwargs):
if self.ast_data_unfolded:
return self.ast_data_unfolded.get_descendants(*args, **kwargs)
elif self.ast_data:
return self.ast_data.get_descendants(*args, **kwargs)
elif self.ast_data_folded:
return self.ast_data_folded.get_descendants(*args, **kwargs)
else:
return []

return [
node.name for node in self.ast_data_unfolded.get_descendants(nodes.EnumDef)
]
def get_children_from_best_ast(self, *args, **kwargs):
if self.ast_data_unfolded:
return self.ast_data_unfolded.get_children(*args, **kwargs)
elif self.ast_data:
return self.ast_data.get_children(*args, **kwargs)
elif self.ast_data_folded:
return self.ast_data_folded.get_children(*args, **kwargs)
else:
return []

def get_enums(self) -> List[str]:
return [node.name for node in self.get_descendants_from_best_ast(nodes.EnumDef)]

def get_structs(self) -> List[str]:
if self.ast_data_unfolded is None:
return []

return [
node.name
for node in self.ast_data_unfolded.get_descendants(nodes.StructDef)
node.name for node in self.get_descendants_from_best_ast(nodes.StructDef)
]

def get_events(self) -> List[str]:
if self.ast_data_unfolded is None:
return []

return [
node.name for node in self.ast_data_unfolded.get_descendants(nodes.EventDef)
node.name for node in self.get_descendants_from_best_ast(nodes.EventDef)
]

def get_user_defined_types(self):
if self.ast_data_unfolded is None:
return []

return [
node.name
for node in self.ast_data_unfolded.get_descendants(
self.custom_type_node_types
)
for node in self.get_descendants_from_best_ast(self.custom_type_node_types)
]

def get_constants(self):
# NOTE: Constants should be fetched from self.ast_data, they are missing
# from self.ast_data_unfolded and self.ast_data_folded
if self.ast_data is None:
return []

Expand All @@ -92,38 +100,30 @@ def get_constants(self):
]

def get_enum_variants(self, enum: str):
if self.ast_data_unfolded is None:
return []

enum_node = self.find_type_declaration_node_for_name(enum)
if enum_node is None:
return []

return [node.value.id for node in enum_node.get_children(nodes.Expr)]

def get_struct_fields(self, struct: str):
if self.ast_data_unfolded is None:
return []

struct_node = self.find_type_declaration_node_for_name(struct)
if struct_node is None:
return []

return [node.target.id for node in struct_node.get_children(nodes.AnnAssign)]

def get_state_variables(self):
# NOTE: The state variables should be fetched from self.ast_data, they are
# missing from self.ast_data_unfolded and self.ast_data_folded when constants
if self.ast_data is None:
return []

return [
node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl)
]

def get_internal_function_nodes(self):
if self.ast_data_unfolded is None:
return []

function_nodes = self.ast_data_unfolded.get_descendants(nodes.FunctionDef)
function_nodes = self.get_descendants_from_best_ast(nodes.FunctionDef)
inernal_nodes = []

for node in function_nodes:
Expand All @@ -134,28 +134,21 @@ def get_internal_function_nodes(self):
return inernal_nodes

def get_internal_functions(self):
if self.ast_data_unfolded is None:
return []

return [node.name for node in self.get_internal_function_nodes()]

def find_nodes_referencing_internal_function(self, function: str):
if self.ast_data_unfolded is None:
return []

return self.ast_data_unfolded.get_descendants(
return self.get_descendants_from_best_ast(
nodes.Call, {"func.attr": function, "func.value.id": "self"}
)

def find_nodes_referencing_state_variable(self, variable: str):
if self.ast_data_unfolded is None:
return []

return self.ast_data_unfolded.get_descendants(
return self.get_descendants_from_best_ast(
nodes.Attribute, {"value.id": "self", "attr": variable}
)

def find_nodes_referencing_constant(self, constant: str):
# NOTE: Constants should be fetched from self.ast_data, they are missing
# from self.ast_data_unfolded and self.ast_data_folded
if self.ast_data_unfolded is None:
return []

Expand All @@ -169,9 +162,6 @@ def find_nodes_referencing_constant(self, constant: str):
]

def get_attributes_for_symbol(self, symbol: str):
if self.ast_data_unfolded is None:
return []

node = self.find_type_declaration_node_for_name(symbol)
if node is None:
return []
Expand All @@ -184,10 +174,7 @@ def get_attributes_for_symbol(self, symbol: str):
return []

def find_function_declaration_node_for_name(self, function: str):
if self.ast_data_unfolded is None:
return None

for node in self.ast_data_unfolded.get_descendants(nodes.FunctionDef):
for node in self.get_descendants_from_best_ast(nodes.FunctionDef):
name_match = node.name == function
not_interface_declaration = not isinstance(
node.get_ancestor(), nodes.InterfaceDef
Expand All @@ -198,6 +185,8 @@ def find_function_declaration_node_for_name(self, function: str):
return None

def find_state_variable_declaration_node_for_name(self, variable: str):
# NOTE: The state variables should be fetched from self.ast_data, they are
# missing from self.ast_data_unfolded and self.ast_data_folded when constants
if self.ast_data is None:
return None

Expand All @@ -208,10 +197,7 @@ def find_state_variable_declaration_node_for_name(self, variable: str):
return None

def find_type_declaration_node_for_name(self, symbol: str):
if self.ast_data_unfolded is None:
return None

for node in self.ast_data_unfolded.get_descendants(self.custom_type_node_types):
for node in self.get_descendants_from_best_ast(self.custom_type_node_types):
if node.name == symbol:
return node
if isinstance(node, nodes.EnumDef):
Expand All @@ -222,74 +208,57 @@ def find_type_declaration_node_for_name(self, symbol: str):
return None

def find_nodes_referencing_enum(self, enum: str):
if self.ast_data_unfolded is None:
return []

return_nodes = []

for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.AnnAssign, {"annotation.id": enum}
):
return_nodes.append(node)
for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.Attribute, {"value.id": enum}
):
return_nodes.append(node)
for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.VariableDecl, {"annotation.id": enum}
):
return_nodes.append(node)

return return_nodes

def find_nodes_referencing_enum_variant(self, enum: str, variant: str):
if self.ast_data_unfolded is None:
return []

return self.ast_data_unfolded.get_descendants(
return self.get_descendants_from_best_ast(
nodes.Attribute, {"attr": variant, "value.id": enum}
)

def find_nodes_referencing_struct(self, struct: str):
if self.ast_data_unfolded is None:
return []

return_nodes = []

for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.AnnAssign, {"annotation.id": struct}
):
return_nodes.append(node)
for node in self.ast_data_unfolded.get_descendants(
nodes.Call, {"func.id": struct}
):
for node in self.get_descendants_from_best_ast(nodes.Call, {"func.id": struct}):
return_nodes.append(node)
for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.VariableDecl, {"annotation.id": struct}
):
return_nodes.append(node)
for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
nodes.FunctionDef, {"returns.id": struct}
):
return_nodes.append(node)

return return_nodes

def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]:
if self.ast_data_unfolded is None:
return None

for node in self.ast_data_unfolded.get_children():
for node in self.get_children_from_best_ast():
if node.lineno <= pos.line and node.end_lineno >= pos.line:
return node

def find_nodes_referencing_symbol(self, symbol: str):
if self.ast_data is None:
return []

return_nodes = []

for node in self.ast_data.get_descendants(nodes.Name, {"id": symbol}):
for node in self.get_descendants_from_best_ast(nodes.Name, {"id": symbol}):
parent = node.get_ancestor()
if isinstance(parent, nodes.Dict):
if symbol not in [key.id for key in parent.keys]:
Expand All @@ -305,10 +274,7 @@ def find_nodes_referencing_symbol(self, symbol: str):
return return_nodes

def find_node_declaring_symbol(self, symbol: str):
if self.ast_data_unfolded is None:
return None

for node in self.ast_data_unfolded.get_descendants(
for node in self.get_descendants_from_best_ast(
(nodes.AnnAssign, nodes.VariableDecl)
):
if node.target.id == symbol:
Expand Down
9 changes: 7 additions & 2 deletions vyper_lsp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,21 @@ def is_attribute_access(line):
return bool(re.match(reg, line.strip()))


def is_word_char(char):
# true for alnum and underscore
return char.isalnum() or char == "_"


def get_word_at_cursor(sentence: str, cursor_index: int) -> str:
start = cursor_index
end = cursor_index

# Find the start of the word
while start > 0 and sentence[start - 1].isalnum():
while start > 0 and is_word_char(sentence[start - 1]):
start -= 1

# Find the end of the word
while end < len(sentence) and sentence[end].isalnum():
while end < len(sentence) and is_word_char(sentence[end]):
end += 1

# Extract the word
Expand Down

0 comments on commit d0522ad

Please sign in to comment.