From a1b0ce73dfeced36cba7deec628de6a8b97cfa3e Mon Sep 17 00:00:00 2001 From: z80 Date: Sat, 14 Dec 2024 17:04:58 -0500 Subject: [PATCH] parse fn and module names w regex and small renaming --- tests/test_completions.py | 6 ++-- tests/test_utils.py | 17 +++-------- vyper_lsp/analyzer/AstAnalyzer.py | 44 +++++++++++---------------- vyper_lsp/utils.py | 50 ++++++++----------------------- 4 files changed, 38 insertions(+), 79 deletions(-) diff --git a/tests/test_completions.py b/tests/test_completions.py index 928b405..1530800 100644 --- a/tests/test_completions.py +++ b/tests/test_completions.py @@ -36,7 +36,7 @@ def baz(): ) analyzer = AstAnalyzer(ast) - completions = analyzer.get_completions_in_doc(doc, params) + completions = analyzer._get_completions_in_doc(doc, params) assert len(completions.items) == 1 assert "foo" in [c.label for c in completions.items] @@ -71,7 +71,7 @@ def baz(): ) analyzer = AstAnalyzer(ast) - completions = analyzer.get_completions_in_doc(doc, params) + completions = analyzer._get_completions_in_doc(doc, params) assert len(completions.items) == 2 assert "BAR" in [c.label for c in completions.items] assert "BAZ" in [c.label for c in completions.items] @@ -103,7 +103,7 @@ def bar(): ) analyzer = AstAnalyzer(ast) - completions = analyzer.get_completions_in_doc(doc, params) + completions = analyzer._get_completions_in_doc(doc, params) assert len(completions.items) == 7 labels = [c.label for c in completions.items] assert "internal" in labels diff --git a/tests/test_utils.py b/tests/test_utils.py index b50c03c..3a72a2f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,18 +26,9 @@ def test_get_expression_at_cursor(): assert utils.get_expression_at_cursor(text, 21) == "self.baz (1,2,3)" -def test_get_internal_fn_name_at_cursor(): - text = "self.foo = 123" - assert utils.get_internal_fn_name_at_cursor(text, 0) is None - assert utils.get_internal_fn_name_at_cursor(text, 1) is None - assert utils.get_internal_fn_name_at_cursor(text, 5) is None - assert utils.get_internal_fn_name_at_cursor(text, 12) is None - - text = "foo_bar = self.baz (1,2,3)" - assert utils.get_internal_fn_name_at_cursor(text, 0) is None - assert utils.get_internal_fn_name_at_cursor(text, 4) is None - assert utils.get_internal_fn_name_at_cursor(text, 21) == "baz" +def test_parse_fncall_expression(): + text = "self.foo()" + assert utils.parse_fncall_expression(text) == ("self", "foo") text = "self.foo(self.bar())" - assert utils.get_internal_fn_name_at_cursor(text, 7) == "foo" - assert utils.get_internal_fn_name_at_cursor(text, 15) == "bar" + assert utils.parse_fncall_expression(text) == ("self", "bar") diff --git a/vyper_lsp/analyzer/AstAnalyzer.py b/vyper_lsp/analyzer/AstAnalyzer.py index 3ba7b25..a8045f9 100644 --- a/vyper_lsp/analyzer/AstAnalyzer.py +++ b/vyper_lsp/analyzer/AstAnalyzer.py @@ -11,6 +11,7 @@ ) from pygls.workspace import Document from vyper.ast import nodes +from vyper_lsp import utils from vyper_lsp.analyzer.BaseAnalyzer import Analyzer from vyper_lsp.ast import AST from vyper_lsp.utils import ( @@ -18,7 +19,6 @@ get_expression_at_cursor, get_word_at_cursor, get_installed_vyper_version, - get_internal_fn_name_at_cursor, ) from lsprotocol.types import ( CompletionItem, @@ -60,33 +60,26 @@ def __init__(self, ast: AST) -> None: def signature_help( self, doc: Document, params: SignatureHelpParams ) -> Optional[SignatureHelp]: - logger.info("signature help triggered") + # TODO: Implement checking external functions, module functions, and interfaces current_line = doc.lines[params.position.line] expression = get_expression_at_cursor( current_line, params.position.character - 1 ) - logger.info(f"expression: {expression}") - # regex for matching 'module.function' - fncall_pattern = "(.*)\\.(.*)" - - if matches := re.match(fncall_pattern, expression): - module, fn = matches.groups() - logger.info(f"looking up function {fn} in module {module}") - if module in self.ast.imports: - logger.info("found module") - if fn := self.ast.imports[module].functions[fn]: - logger.info(f"args: {fn.arguments}") + parsed = utils.parse_fncall_expression(expression) + if parsed is None: + return None + module, fn_name = parsed + + logger.info(f"looking up function {fn_name} in module {module}") + if module in self.ast.imports: + logger.info("found module") + if fn := self.ast.imports[module].functions[fn_name]: + logger.info(f"args: {fn.arguments}") # this returns for all external functions - # TODO: Implement checking interfaces if not expression.startswith("self."): return None - # TODO: Implement checking external functions, module functions, and interfaces - fn_name = get_internal_fn_name_at_cursor( - current_line, params.position.character - 1 - ) - if not fn_name: return None @@ -196,7 +189,7 @@ def _dot_completions_for_element( return completions - def get_completions_in_doc( + def _get_completions_in_doc( self, document: Document, params: CompletionParams ) -> CompletionList: items = [] @@ -270,7 +263,7 @@ def get_completions( self, ls: LanguageServer, params: CompletionParams ) -> CompletionList: document = ls.workspace.get_text_document(params.text_document.uri) - return self.get_completions_in_doc(document, params) + return self._get_completions_in_doc(document, params) def _format_arg(self, arg: nodes.arg) -> str: if arg.annotation is None: @@ -302,13 +295,13 @@ def _format_fn_signature(self, node: nodes.FunctionDef) -> str: function_def = match.group() return f"(Internal Function) {function_def}" - def is_internal_fn(self, expression: str): + def _is_internal_fn(self, expression: str): if not expression.startswith("self."): return False fn_name = expression.split("self.")[-1] return fn_name in self.ast.functions and self.ast.functions[fn_name].is_internal - def is_state_var(self, expression: str): + def _is_state_var(self, expression: str): if not expression.startswith("self."): return False var_name = expression.split("self.")[-1] @@ -322,12 +315,11 @@ def hover_info(self, doc: Document, pos: Position) -> Optional[str]: word = get_word_at_cursor(og_line, pos.character) full_word = get_expression_at_cursor(og_line, pos.character) - if self.is_internal_fn(full_word): - logger.info("looking for internal fn") + if self._is_internal_fn(full_word): node = self.ast.find_function_declaration_node_for_name(word) return node and self._format_fn_signature(node) - if self.is_state_var(full_word): + if self._is_state_var(full_word): node = self.ast.find_state_variable_declaration_node_for_name(word) if not node: return None diff --git a/vyper_lsp/utils.py b/vyper_lsp/utils.py index c97be2b..fa16f4f 100644 --- a/vyper_lsp/utils.py +++ b/vyper_lsp/utils.py @@ -3,7 +3,7 @@ import re from pathlib import Path from importlib.metadata import version -from typing import Optional +from typing import Optional, Tuple from lsprotocol.types import Diagnostic, DiagnosticSeverity, Position, Range from packaging.version import Version from pygls.workspace import Document @@ -125,42 +125,6 @@ def get_expression_at_cursor(sentence: str, cursor_index: int) -> str: return word -def get_internal_fn_name_at_cursor(sentence: str, cursor_index: int) -> Optional[str]: - # TODO: Improve this function to handle more cases - # should be simpler, and handle when the cursor is on "self." before a fn name - # Split the sentence into segments at each 'self.' - segments = sentence.split("self.") - - # Accumulated length to keep track of the cursor's position relative to the original sentence - accumulated_length = 0 - - for segment in segments: - if not segment: - accumulated_length += len("self.") - continue - - # Update the accumulated length for each segment - segment_start = accumulated_length - segment_end = accumulated_length + len(segment) - accumulated_length = segment_end + 5 # Update for next segment - - # Check if the cursor is within the current segment - if segment_start <= cursor_index <= segment_end: - # Extract the function name from the segment - function_name = re.findall(r"\b\w+\s*\(", segment) - if function_name: - # Take the function name closest to the cursor - closest_fn = min( - function_name, - key=lambda fn: abs( - cursor_index - (segment_start + segment.find(fn)) - ), - ) - return closest_fn.split("(")[0].strip() - - return None - - def extract_enum_name(line: str): m = re.match(r"enum\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*:", line) if m: @@ -238,3 +202,15 @@ def format_fn(func) -> str: f"def __{escape_underscores(func.name)}__({args}){return_value}: _{mutability}_" ) return out + + +def parse_fncall_expression(expression: str) -> Optional[Tuple[str, str]]: + # regex for matching 'module.function' or 'module.function(args)', not capturing args + fncall_pattern = "(.*)\\.([^\\(]+)(?:\\(.*\\))?" + + if matches := re.match(fncall_pattern, expression): + groups = matches.groups() + module, fn = groups + if "(" in module: + module = module.split("(")[-1] + return module, fn