diff --git a/tests/unit/ast/test_tokenizer.py b/tests/unit/ast/test_tokenizer.py new file mode 100644 index 0000000000..f6000e0425 --- /dev/null +++ b/tests/unit/ast/test_tokenizer.py @@ -0,0 +1,94 @@ +""" +Tests that the tokenizer / parser are passing correct source location +info to the AST +""" +import pytest + +from vyper.ast.parse import parse_to_ast +from vyper.compiler import compile_code +from vyper.exceptions import UndeclaredDefinition + + +def test_log_token_aligned(): + # GH issue 3430 + code = """ +event A: + b: uint256 + +@external +def f(): + log A(b=d) + """ + with pytest.raises(UndeclaredDefinition) as e: + compile_code(code) + + expected = """ + 'd' has not been declared. + + function "f", line 7:12 + 6 def f(): + ---> 7 log A(b=d) + -------------------^ + 8 + """ # noqa: W291 + assert expected.strip() == str(e.value).strip() + + +def test_log_token_aligned2(): + # GH issue 3059 + code = """ +interface Contract: + def foo(): nonpayable + +event MyEvent: + a: address + +@external +def foo(c: Contract): + log MyEvent(a=c.address) + """ + compile_code(code) + + +def test_log_token_aligned3(): + # https://github.com/vyperlang/vyper/pull/3808#pullrequestreview-1900570163 + code = """ +import ITest + +implements: ITest + +event Foo: + a: address + +@external +def foo(u: uint256): + log Foo(empty(address)) + log i.Foo(empty(address)) + """ + # not semantically valid code, check we can at least parse it + assert parse_to_ast(code) is not None + + +def test_log_token_aligned4(): + # GH issue 4139 + code = """ +b: public(uint256) + +event Transfer: + random: indexed(uint256) + shi: uint256 + +@external +def transfer(): + log Transfer(T(self).b(), 10) + return + """ + # not semantically valid code, check we can at least parse it + assert parse_to_ast(code) is not None + + +def test_long_string_non_coding_token(): + # GH issue 2258 + code = '\r[[]]\ndef _(e:[],l:[]):\n """"""""""""""""""""""""""""""""""""""""""""""""""""""\n f.n()' # noqa: E501 + # not valid code, but should at least parse + assert parse_to_ast(code) is not None diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py index f65a361338..f5487e8a91 100644 --- a/vyper/ast/natspec.py +++ b/vyper/ast/natspec.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Optional, Tuple +# NOTE: this is our only use of asttokens -- consider vendoring in the implementation. from asttokens import LineNumbers from vyper.ast import nodes as vy_ast diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 8df295c9eb..a7cd0464ed 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -1,10 +1,10 @@ import ast as python_ast +import pickle import tokenize from decimal import Decimal +from functools import cached_property from typing import Any, Dict, List, Optional, Union -import asttokens - from vyper.ast import nodes as vy_ast from vyper.ast.pre_parser import PreParser from vyper.compiler.settings import Settings @@ -80,12 +80,16 @@ def _parse_to_ast_with_settings( try: py_ast = python_ast.parse(pre_parser.reformatted_code) except SyntaxError as e: - # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors offset = e.offset if offset is not None: # SyntaxError offset is 1-based, not 0-based (see: # https://docs.python.org/3/library/exceptions.html#SyntaxError.offset) offset -= 1 + + # adjust the column of the error if it was modified by the pre-parser + if e.lineno is not None: # help mypy + offset += pre_parser.adjustments.get((e.lineno, offset), 0) + new_e = SyntaxException(str(e), vyper_source, e.lineno, offset) likely_errors = ("staticall", "staticcal") @@ -97,6 +101,11 @@ def _parse_to_ast_with_settings( raise new_e from None + # some python AST node instances are singletons and are reused between + # parse() invocations. copy the python AST so that we are using fresh + # objects. + py_ast = _deepcopy_ast(py_ast) + # Add dummy function node to ensure local variables are treated as `AnnAssign` # instead of state variables (`VariableDecl`) if add_fn_node: @@ -129,6 +138,9 @@ def _parse_to_ast_with_settings( return pre_parser.settings, module +LINE_INFO_FIELDS = ("lineno", "col_offset", "end_lineno", "end_col_offset") + + def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]: """ Converts a Vyper AST node, or list of nodes, into a dictionary suitable for @@ -155,7 +167,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: def annotate_python_ast( - parsed_ast: python_ast.AST, + parsed_ast: python_ast.Module, vyper_source: str, pre_parser: PreParser, source_id: int = 0, @@ -178,22 +190,19 @@ def annotate_python_ast( ------- The annotated and optimized AST. """ - tokens = asttokens.ASTTokens(vyper_source) - assert isinstance(parsed_ast, python_ast.Module) # help mypy - tokens.mark_tokens(parsed_ast) visitor = AnnotatingVisitor( - vyper_source, - pre_parser, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, + vyper_source, pre_parser, source_id, module_path=module_path, resolved_path=resolved_path ) - visitor.visit(parsed_ast) + visitor.start(parsed_ast) return parsed_ast +def _deepcopy_ast(ast_node: python_ast.AST): + # pickle roundtrip is faster than copy.deepcopy() here. + return pickle.loads(pickle.dumps(ast_node)) + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _pre_parser: PreParser @@ -202,12 +211,10 @@ def __init__( self, source_code: str, pre_parser: PreParser, - tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, resolved_path: Optional[str] = None, ): - self._tokens = tokens self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path @@ -216,6 +223,58 @@ def __init__( self.counter: int = 0 + @cached_property + def source_lines(self): + return self._source_code.splitlines(keepends=True) + + @cached_property + def line_offsets(self): + ofst = 0 + # ensure line_offsets has at least 1 entry for 0-line source + ret = {1: ofst} + for lineno, line in enumerate(self.source_lines): + ret[lineno + 1] = ofst + ofst += len(line) + return ret + + def start(self, node: python_ast.Module): + self._fix_missing_locations(node) + self.visit(node) + + def _fix_missing_locations(self, ast_node: python_ast.Module): + """ + adapted from cpython Lib/ast.py. adds line/col info to ast, + but unlike Lib/ast.py, adjusts *all* ast nodes, not just the + one that python defines to have line/col info. + https://github.com/python/cpython/blob/62729d79206014886f5d/Lib/ast.py#L228 + """ + assert isinstance(ast_node, python_ast.Module) + ast_node.lineno = 1 + ast_node.col_offset = 0 + ast_node.end_lineno = max(1, len(self.source_lines)) + + if len(self.source_lines) > 0: + ast_node.end_col_offset = len(self.source_lines[-1]) + else: + ast_node.end_col_offset = 0 + + def _fix(node, parent=None): + for field in LINE_INFO_FIELDS: + if parent is not None: + val = getattr(node, field, None) + # special case for USub - heisenbug when coverage is + # enabled in the test suite. + if val is None or isinstance(node, python_ast.USub): + val = getattr(parent, field) + setattr(node, field, val) + else: + assert hasattr(node, field), node + + for child in python_ast.iter_child_nodes(node): + _fix(child, node) + + _fix(ast_node) + def generic_visit(self, node): """ Annotate a node with information that simplifies Vyper node generation. @@ -223,38 +282,28 @@ def generic_visit(self, node): # Decorate every node with the original source code to allow pretty-printing errors node.full_source_code = self._source_code node.node_id = self.counter - node.ast_type = node.__class__.__name__ self.counter += 1 + node.ast_type = node.__class__.__name__ - # Decorate every node with source end offsets - start = (None, None) - if hasattr(node, "first_token"): - start = node.first_token.start - end = (None, None) - if hasattr(node, "last_token"): - end = node.last_token.end - if node.last_token.type == 4: - # token type 4 is a `\n`, some nodes include a trailing newline - # here we ignore it when building the node offsets - end = (end[0], end[1] - 1) - - node.lineno = start[0] - node.col_offset = start[1] - node.end_lineno = end[0] - node.end_col_offset = end[1] - - # TODO: adjust end_lineno and end_col_offset when this node is in - # modification_offsets - - if hasattr(node, "last_token"): - start_pos = node.first_token.startpos - end_pos = node.last_token.endpos - - if node.last_token.type == 4: - # ignore trailing newline once more - end_pos -= 1 - node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}" - node.node_source_code = self._source_code[start_pos:end_pos] + adjustments = self._pre_parser.adjustments + + # Load and Store behave differently inside of fix_missing_locations; + # we don't use them in the vyper AST so just skip adjusting the line + # info. + if isinstance(node, (python_ast.Load, python_ast.Store)): + return super().generic_visit(node) + + adj = adjustments.get((node.lineno, node.col_offset), 0) + node.col_offset += adj + + adj = adjustments.get((node.end_lineno, node.end_col_offset), 0) + node.end_col_offset += adj + + start_pos = self.line_offsets[node.lineno] + node.col_offset + end_pos = self.line_offsets[node.end_lineno] + node.end_col_offset + + node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}" + node.node_source_code = self._source_code[start_pos:end_pos] return super().generic_visit(node) @@ -288,12 +337,6 @@ def visit_Module(self, node): return self._visit_docstring(node) def visit_FunctionDef(self, node): - if node.decorator_list: - # start the source highlight at `def` to improve annotation readability - decorator_token = node.decorator_list[-1].last_token - def_token = self._tokens.find_token(decorator_token, tokenize.NAME, tok_str="def") - node.first_token = def_token - return self._visit_docstring(node) def visit_ClassDef(self, node): @@ -306,7 +349,7 @@ def visit_ClassDef(self, node): """ self.generic_visit(node) - node.ast_type = self._pre_parser.modification_offsets[(node.lineno, node.col_offset)] + node.ast_type = self._pre_parser.keyword_translations[(node.lineno, node.col_offset)] return node def visit_For(self, node): @@ -349,16 +392,13 @@ def visit_For(self, node): try: fake_node = python_ast.parse(annotation_str).body[0] + # do we need to fix location info here? + fake_node = _deepcopy_ast(fake_node) except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - # fill in with asttokens info. note we can use `self._tokens` because - # it is indented to exactly the same position where it appeared - # in the original source! - self._tokens.mark_tokens(fake_node) - # replace the dummy target name with the real target name. fake_node.target = node.target # replace the For node target with the new ann_assign @@ -383,14 +423,14 @@ def visit_Expr(self, node): # CMC 2024-03-03 consider unremoving this from the enclosing Expr node = node.value key = (node.lineno, node.col_offset) - node.ast_type = self._pre_parser.modification_offsets[key] + node.ast_type = self._pre_parser.keyword_translations[key] return node def visit_Await(self, node): - start_pos = node.lineno, node.col_offset # grab these before generic_visit modifies them + start_pos = node.lineno, node.col_offset self.generic_visit(node) - node.ast_type = self._pre_parser.modification_offsets[start_pos] + node.ast_type = self._pre_parser.keyword_translations[start_pos] return node def visit_Call(self, node): @@ -410,6 +450,9 @@ def visit_Call(self, node): assert len(dict_.keys) == len(dict_.values) for key, value in zip(dict_.keys, dict_.values): replacement_kw_node = python_ast.keyword(key.id, value) + # set locations + for attr in LINE_INFO_FIELDS: + setattr(replacement_kw_node, attr, getattr(key, attr)) kw_list.append(replacement_kw_node) node.args = [] diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 5cbddffed8..8e221fb7e6 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -164,8 +164,14 @@ def consume(self, token, result): class PreParser: # Compilation settings based on the directives in the source code settings: Settings - # A mapping of class names to their original class types. - modification_offsets: dict[tuple[int, int], str] + + # A mapping of offsets to new class names + keyword_translations: dict[tuple[int, int], str] + + # Map from offsets in the original vyper source code to offsets + # in the new ("reformatted", i.e. python-compatible) source code + adjustments: dict[tuple[int, int], int] + # A mapping of line/column offsets of `For` nodes to the annotation of the for loop target for_loop_annotations: dict[tuple[int, int], list[TokenInfo]] # A list of line/column offsets of hex string literals @@ -199,8 +205,9 @@ def parse(self, code: str): raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e def _parse(self, code: str): + adjustments: dict = {} result: list[TokenInfo] = [] - modification_offsets: dict[tuple[int, int], str] = {} + keyword_translations: dict[tuple[int, int], str] = {} settings = Settings() for_parser = ForParser(code) hex_string_parser = HexStringParser() @@ -219,6 +226,12 @@ def _parse(self, code: str): end = token.end line = token.line + # handle adjustments + lineno, col = token.start + adj = _col_adjustments[lineno] + newstart = lineno, col - adj + adjustments[lineno, col - adj] = adj + if typ == COMMENT: contents = string[1:].strip() if contents.startswith("@version"): @@ -275,37 +288,32 @@ def _parse(self, code: str): ) if typ == NAME: + # see if it's a keyword we need to replace + new_keyword = None if string in VYPER_CLASS_TYPES and start[1] == 0: - toks = [TokenInfo(NAME, "class", start, end, line)] - modification_offsets[start] = VYPER_CLASS_TYPES[string] + new_keyword = "class" + vyper_type = VYPER_CLASS_TYPES[string] elif string in CUSTOM_STATEMENT_TYPES: new_keyword = "yield" - adjustment = len(new_keyword) - len(string) - # adjustments for following staticcall/extcall modification_offsets - _col_adjustments[start[0]] += adjustment - toks = [TokenInfo(NAME, new_keyword, start, end, line)] - modification_offsets[start] = CUSTOM_STATEMENT_TYPES[string] + vyper_type = CUSTOM_STATEMENT_TYPES[string] elif string in CUSTOM_EXPRESSION_TYPES: - # a bit cursed technique to get untokenize to put - # the new tokens in the right place so that modification_offsets - # will work correctly. - # (recommend comparing the result of parse with the - # source code side by side to visualize the whitespace) new_keyword = "await" vyper_type = CUSTOM_EXPRESSION_TYPES[string] - lineno, col_offset = start - - # fixup for when `extcall/staticcall` follows `log` - adjustment = _col_adjustments[lineno] - new_start = (lineno, col_offset + adjustment) - modification_offsets[new_start] = vyper_type + if new_keyword is not None: + keyword_translations[newstart] = vyper_type - # tells untokenize to add whitespace, preserving locations - diff = len(new_keyword) - len(string) - new_end = end[0], end[1] + diff + adjustment = len(string) - len(new_keyword) + # adjustments for following tokens + lineno, col = start + _col_adjustments[lineno] += adjustment - toks = [TokenInfo(NAME, new_keyword, start, new_end, line)] + # a bit cursed technique to get untokenize to put + # the new tokens in the right place so that + # `keyword_translations` will work correctly. + # (recommend comparing the result of parse with the + # source code side by side to visualize the whitespace) + toks = [TokenInfo(NAME, new_keyword, start, end, line)] if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) @@ -317,8 +325,9 @@ def _parse(self, code: str): for k, v in for_parser.annotations.items(): for_loop_annotations[k] = v.copy() + self.adjustments = adjustments self.settings = settings - self.modification_offsets = modification_offsets + self.keyword_translations = keyword_translations self.for_loop_annotations = for_loop_annotations self.hex_string_locations = hex_string_parser.locations self.reformatted_code = untokenize(result).decode("utf-8") diff --git a/vyper/utils.py b/vyper/utils.py index 999e211acb..5bebca7776 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -518,7 +518,7 @@ def timeit(msg): # pragma: nocover yield end_time = time.perf_counter() total_time = end_time - start_time - print(f"{msg}: Took {total_time:.4f} seconds", file=sys.stderr) + print(f"{msg}: Took {total_time:.6f} seconds", file=sys.stderr) _CUMTIMES = None @@ -527,7 +527,7 @@ def timeit(msg): # pragma: nocover def _dump_cumtime(): # pragma: nocover global _CUMTIMES for msg, total_time in _CUMTIMES.items(): - print(f"{msg}: Cumulative time {total_time:.4f} seconds", file=sys.stderr) + print(f"{msg}: Cumulative time {total_time:.3f} seconds", file=sys.stderr) @contextlib.contextmanager