Skip to content

Commit

Permalink
fix tests, imports wip
Browse files Browse the repository at this point in the history
  • Loading branch information
z80dev committed Dec 13, 2024
1 parent 057f546 commit f4dcfdd
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 85 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ loguru = "^0.6.0"
tree-sitter = "^0.20.1"
pydantic = "^1.10"
lark = "^1.1.7"
lsprotocol = "^2023.0.0b1"
lsprotocol = "^2023.0.1"
vyper = "^0.4.0"
vvm = "^0.2.0"
packaging = "^23.1"
pygls = "^1.1.2"
pygls = "^1.3.1"

[tool.poetry.group.dev.dependencies]
flake8 = "^5.0.4"
Expand Down
1 change: 0 additions & 1 deletion tests/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def foo():

def test_ast_no_data_returns_empty_and_none(ast: AST):
ast.ast_data = None
ast.ast_data_folded = None
ast.ast_data_annotated = None

assert ast.get_constants() == []
Expand Down
110 changes: 85 additions & 25 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional
from packaging.version import Version
from lsprotocol.types import (
CompletionItemLabelDetails,
Diagnostic,
DiagnosticSeverity,
ParameterInformation,
Expand All @@ -16,12 +17,11 @@
from vyper_lsp.analyzer.BaseAnalyzer import Analyzer
from vyper_lsp.ast import AST
from vyper_lsp.utils import (
format_fn,
get_expression_at_cursor,
get_word_at_cursor,
get_installed_vyper_version,
get_internal_fn_name_at_cursor,
is_internal_fn,
is_state_var,
)
from lsprotocol.types import (
CompletionItem,
Expand All @@ -34,7 +34,7 @@
pattern_text = r"(.+) is deprecated\. Please use `(.+)` instead\."
deprecation_pattern = re.compile(pattern_text)

min_vyper_version = Version("0.3.7")
min_vyper_version = Version("0.4.0")

# Available base types
UNSIGNED_INTEGER_TYPES = {f"uint{8*(i)}" for i in range(32, 0, -1)}
Expand Down Expand Up @@ -62,19 +62,36 @@ def __init__(self, ast: AST) -> None:

def signature_help(
self, doc: Document, params: SignatureHelpParams
) -> SignatureHelp:
) -> Optional[SignatureHelp]:
logger.info("signature help triggered")
current_line = doc.lines[params.position.line]
expression = get_expression_at_cursor(
current_line, params.position.character - 1
)
logger.info(f"expression: {expression}")
# regex for matching 'module.function'
fncall_pattern = "(.*)\\.(.*)"

if matches := re.match(fncall_pattern, expression):
module, fn = matches.groups()
logger.info(f"looking up function {fn} in module {module}")
if module in self.ast.imports:
logger.info(f"found module")
if fn := self.ast.imports[module].functions[fn]:
logger.info(f"args: {fn.arguments}")


# this returns for all external functions
# TODO: Implement checking interfaces
if not expression.startswith("self."):
return None

# TODO: Implement checking external functions, module functions, and interfaces
fn_name = get_internal_fn_name_at_cursor(
current_line, params.position.character - 1
)

# this returns for all external functions
# TODO: Implement checking interfaces
if not expression.startswith("self."):
if not fn_name:
return None

node = self.ast.find_function_declaration_node_for_name(fn_name)
Expand Down Expand Up @@ -114,23 +131,54 @@ def signature_help(
active_signature=0,
)

def dot_completions_for_element(self, element: str, top_level_node = None) -> List[CompletionItem]:
def _dot_completions_for_element(self, element: str, top_level_node = None, line: str="") -> List[CompletionItem]:
completions = []
logger.info(f"getting dot completions for element: {element}")
#logger.info(f"import keys: {self.ast.imports.keys()}")
self.ast.imports.keys()
if element == "self":
for fn in self.ast.get_internal_functions():
completions.append(CompletionItem(label=fn))
# TODO: This should exclude constants and immutables
for var in self.ast.get_state_variables():
completions.append(CompletionItem(label=var))
elif element in self.ast.imported_fns_for_alias:
if isinstance(top_level_node, nodes.FunctionDef):
for name, fn in self.ast.imported_fns_for_alias[element].items():
if fn.is_internal or fn.is_deploy:
completions.append(CompletionItem(label=name))
elif isinstance(top_level_node, nodes.ExportsDecl):
for name, fn in self.ast.imported_fns_for_alias[element].items():
if fn.is_external:
completions.append(CompletionItem(label=name))
elif self.ast.imports and element in self.ast.imports.keys():
for name, fn in self.ast.imports[element].functions.items():
doc_string = ""
if getattr(fn.ast_def, "doc_string", False):
doc_string = fn.ast_def.doc_string.value

#out = self._format_fn_signature(fn.decl_node)
out = format_fn(fn)

# NOTE: this just gets ignored by most editors
# so we put the signature in the documentation string also
completion_item_label_details = CompletionItemLabelDetails(detail=out)

doc_string = f"{out}\n{doc_string}"

show_external: bool = isinstance(top_level_node, nodes.ExportsDecl) or line.startswith("exports:")
show_internal_and_deploy: bool = isinstance(top_level_node, nodes.FunctionDef)

if show_internal_and_deploy and (fn.is_internal or fn.is_deploy):
completions.append(CompletionItem(label=name, documentation=doc_string, label_details=completion_item_label_details))
elif show_external and fn.is_external:
completions.append(CompletionItem(label=name, documentation=doc_string, label_details=completion_item_label_details))
elif element in self.ast.flags:
members = self.ast.flags[element]._flag_members
for member in members.keys():
completions.append(CompletionItem(label=member))

if isinstance(top_level_node, nodes.FunctionDef):
var_declarations = top_level_node.get_descendants(nodes.AnnAssign, filters={"target.id": element})
assert len(var_declarations) <= 1
for vardecl in var_declarations:
type_name = vardecl.annotation.id
structt = self.ast.structs.get(type_name, None)
if structt:
for member in structt.members:
completions.append(CompletionItem(label=member))

return completions

def get_completions_in_doc(
Expand All @@ -147,20 +195,18 @@ def get_completions_in_doc(

if params.context.trigger_character == ".":
# get element before the dot
# TODO: this could lead to bugs if we're not at EOL
element = current_line.split(" ")[-1].split(".")[0]
logger.info(f"Element: {element}")

pos = params.position
surrounding_node = self.ast.find_top_level_node_at_pos(pos)
logger.info(f"Surrounding node: {surrounding_node}")

# internal functions and state variables
dot_completions = self.dot_completions_for_element(element, top_level_node=surrounding_node)
# internal + imported fns, state vars, and flags
dot_completions = self._dot_completions_for_element(element, top_level_node=surrounding_node, line=current_line)
if len(dot_completions) > 0:
return CompletionList(is_incomplete=False, items=dot_completions)
else:
# TODO: This is currently only correct for enums
# For structs, we'll need to get the type of the variable
logger.info(f"no dot completions for {element}")
for attr in self.ast.get_attributes_for_symbol(element):
items.append(CompletionItem(label=attr))
completions = CompletionList(is_incomplete=False, items=items)
Expand Down Expand Up @@ -230,6 +276,19 @@ def _format_fn_signature(self, node: nodes.FunctionDef) -> str:
function_def = match.group()
return f"(Internal Function) {function_def}"

def is_internal_fn(self, expression: str):
if not expression.startswith("self."):
return False
fn_name = expression.split("self.")[-1]
return fn_name in self.ast.functions and self.ast.functions[fn_name].is_internal

def is_state_var(self, expression: str):
if not expression.startswith("self."):
return False
var_name = expression.split("self.")[-1]
return var_name in self.ast.variables


def hover_info(self, document: Document, pos: Position) -> Optional[str]:
if len(document.lines) < pos.line:
return None
Expand All @@ -238,11 +297,12 @@ def hover_info(self, document: Document, pos: Position) -> Optional[str]:
word = get_word_at_cursor(og_line, pos.character)
full_word = get_expression_at_cursor(og_line, pos.character)

if is_internal_fn(full_word):
if self.is_internal_fn(full_word):
logger.info("looking for internal fn")
node = self.ast.find_function_declaration_node_for_name(word)
return node and self._format_fn_signature(node)

if is_state_var(full_word):
if self.is_state_var(full_word):
node = self.ast.find_state_variable_declaration_node_for_name(word)
if not node:
return None
Expand Down
84 changes: 49 additions & 35 deletions vyper_lsp/ast.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import copy
from functools import cached_property
import logging
from pathlib import Path
from typing import Optional, List
from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position
from pygls.workspace import Document
from vyper.ast import Module, VyperNode, nodes
from vyper.compiler import CompilerData, FileInput
from vyper.compiler import CompilerData
from vyper.compiler.input_bundle import FilesystemInputBundle
from vyper.compiler.phases import ModuleT
from vyper.compiler.phases import DEFAULT_CONTRACT_PATH, ModuleT
from vyper.semantics.types import StructT
from vyper.semantics.types.user import FlagT
from vyper.utils import VyperException
from vyper.cli.vyper_compile import get_search_paths
import warnings
Expand All @@ -23,13 +26,18 @@

class AST:
ast_data = None
ast_data_folded = None
ast_data_annotated = None

custom_type_node_types = (nodes.StructDef, nodes.FlagDef)

# Data parsed from AST for easy access
imported_fns_for_alias = {}
# Module Data
functions = {}
variables = {}
flags = {}
structs = {}

# Import Data
imports = {}

@classmethod
def from_node(cls, node: VyperNode):
Expand All @@ -38,39 +46,41 @@ def from_node(cls, node: VyperNode):
ast.ast_data_annotated = node
return ast

def _load_functions(self, ast: Module):
import_from_nodes = ast.get_descendants((nodes.ImportFrom, nodes.Import))
def _load_import_data(self):
ast = self.ast_data_annotated
if ast is None:
return
import_nodes = ast.get_descendants((nodes.ImportFrom, nodes.Import))
node: nodes.ImportFrom | nodes.Import
for node in import_from_nodes:
imports = {}
for node in import_nodes:
import_info = node._metadata["import_info"]
module_t: ModuleT = import_info.typ.module_t
alias = node._metadata["import_info"].alias
if alias not in self.imported_fns_for_alias:
self.imported_fns_for_alias[alias] = {}
for name, fn in module_t.functions.items():
self.imported_fns_for_alias[alias][name] = fn
return
imports[alias] = module_t

self.imports = imports

def _load_module_data(self):
ast = self.ast_data_annotated
if ast is None:
return
self.functions = ast._metadata["type"].functions
self.variables = ast._metadata["type"].variables

def get_imported_functions_for_alias(self, alias: str):
functions = {}
if alias not in self.imported_fns_for_alias:
return {}
for name, fn in self.imported_fns_for_alias[alias].items():
functions[name] = fn
return functions
flagt_list = [FlagT.from_FlagDef(node) for node in ast._metadata["type"].flag_defs]
self.flags = {flagt.name: flagt for flagt in flagt_list}

structt_list = [StructT.from_StructDef(node) for node in ast._metadata["type"].struct_defs]
self.structs = {structt.name: structt for structt in structt_list}

def update_ast(self, doc: Document) -> List[Diagnostic]:
diagnostics = self.build_ast(doc)
has_errors = any(d.severity == DiagnosticSeverity.Error for d in diagnostics)
logger.info(f"AST updated with {len(diagnostics)} diagnostics")
logger.info(f"AST updated with errors: {has_errors}")
if self.ast_data_annotated is not None:
logger.info("Loading functions from updated AST")
self._load_functions(self.ast_data_annotated)
return diagnostics

def build_ast(self, doc: Document) -> List[Diagnostic]:
def build_ast(self, doc: Document | str) -> List[Diagnostic]:
if isinstance(doc, str):
doc = Document(uri=str(DEFAULT_CONTRACT_PATH), source=doc)
uri_parent_path = working_directory_for_document(doc)
search_paths = get_search_paths([str(uri_parent_path)])
fileinput = document_to_fileinput(doc)
Expand All @@ -84,6 +94,10 @@ def build_ast(self, doc: Document) -> List[Diagnostic]:
# out from under us when folding stuff happens
self.ast_data = copy.deepcopy(compiler_data.vyper_module)
self.ast_data_annotated = compiler_data.annotated_vyper_module

self._load_module_data()
self._load_import_data()

except VyperException as e:
# make message string include class name
message = f"{e.__class__.__name__}: {e}"
Expand Down Expand Up @@ -141,10 +155,12 @@ def get_top_level_nodes(self, *args, **kwargs):
return self.best_ast.get_children(*args, **kwargs)

def get_enums(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.FlagDef)]
#return [node.name for node in self.get_descendants(nodes.FlagDef)]
return list(self.flags.keys())

def get_structs(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.StructDef)]
#return [node.name for node in self.get_descendants(nodes.StructDef)]
return list(self.structs.keys())

def get_events(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.EventDef)]
Expand All @@ -155,13 +171,13 @@ def get_user_defined_types(self):
def get_constants(self):
# NOTE: Constants should be fetched from self.ast_data, they are missing
# from self.ast_data_unfolded and self.ast_data_folded
# NOTE: This may no longer be the case with the new AST format
if self.ast_data is None:
return []

return [
node.target.id
for node in self.ast_data.get_children(nodes.VariableDecl)
if node.is_constant
for node in self.ast_data.get_children(nodes.VariableDecl, {"is_constant": True})
]

def get_enum_variants(self, enum: str):
Expand All @@ -188,9 +204,6 @@ def get_state_variables(self):
node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl)
]

def get_import_nodes(self) -> List[nodes.Import | nodes.ImportFrom]:
return self.get_descendants((nodes.Import, nodes.ImportFrom))

def get_internal_function_nodes(self):
function_nodes = self.get_descendants(nodes.FunctionDef)
internal_nodes = []
Expand All @@ -203,7 +216,8 @@ def get_internal_function_nodes(self):
return internal_nodes

def get_internal_functions(self):
return [node.name for node in self.get_internal_function_nodes()]
internal_fn_names = [k for k, v in self.functions.items() if v.is_internal]
return internal_fn_names

def find_nodes_referencing_internal_function(self, function: str):
return self.get_descendants(
Expand Down
Loading

0 comments on commit f4dcfdd

Please sign in to comment.