Skip to content

Commit

Permalink
Move token utils import parser module
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Aug 11, 2024
1 parent e99b2ff commit 3cc1557
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 217 deletions.
2 changes: 2 additions & 0 deletions src/spinasm_lsp/documentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Documentation utilities for the SPINAsm LSP."""

from __future__ import annotations

from collections import UserDict
Expand Down
137 changes: 135 additions & 2 deletions src/spinasm_lsp/parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,142 @@
"""The SPINAsm language parser."""

from __future__ import annotations

import bisect
import copy
from typing import Literal, TypedDict, cast

import lsprotocol.types as lsp
from asfv1 import fv1parse
from lsprotocol import types as lsp

from spinasm_lsp.utils import Symbol, Token, TokenRegistry

class Symbol(TypedDict):
"""The token specification used by asfv1."""

type: Literal[
"ASSEMBLER",
"EOF",
"INTEGER",
"LABEL",
"TARGET",
"MNEMONIC",
"OPERATOR",
"FLOAT",
"ARGSEP",
]
txt: str | None
stxt: str | None
val: int | float | None


class Token:
"""A token and its position in a source file."""

def __init__(self, symbol: Symbol, start: lsp.Position) -> None:
width = max(len(symbol["stxt"] or "") - 1, 0)
end = lsp.Position(line=start.line, character=start.character + width)

self.symbol: Symbol = symbol
self.range: lsp.Range = lsp.Range(start=start, end=end)
self.next_token: Token | None = None
self.prev_token: Token | None = None

def __str__(self):
return self.symbol["stxt"] or "Empty token"

def _clone(self) -> Token:
"""Return a clone of the token to avoid mutating the original."""
return copy.deepcopy(self)

def without_address_modifier(self) -> Token:
"""
Create a clone of the token with the address modifier removed.
"""
if not str(self).endswith("#") and not str(self).endswith("^"):
return self

token = self._clone()
token.symbol["stxt"] = cast(str, token.symbol["stxt"])[:-1]
token.range.end.character -= 1

return token


class TokenRegistry:
"""A registry of tokens and their positions in a source file."""

def __init__(self, tokens: list[Token] | None = None) -> None:
self._prev_token: Token | None = None

"""A dictionary mapping program lines to all Tokens on that line."""
self._tokens_by_line: dict[int, list[Token]] = {}

"""A dictionary mapping token names to all matching Tokens in the program."""
self._tokens_by_name: dict[str, list[Token]] = {}

for token in tokens or []:
self.register_token(token)

def register_token(self, token: Token) -> None:
"""Add a token to the registry."""
# TODO: Maybe handle multi-word CHO instructions here, by merging with the next
# token? The tricky part is that the next token would still end up getting
# registered unless we prevent it... If we end up with overlapping tokens, that
# will break `get_token_at_position`. I could check if prev token was CHO when
# I register RDAL, SOF, or RDA, and if so register them as one and unregister
# the previous?
if token.range.start.line not in self._tokens_by_line:
self._tokens_by_line[token.range.start.line] = []

# Record the previous and next token for each token to allow traversing
if self._prev_token:
token.prev_token = self._prev_token
self._prev_token.next_token = token

# Store the token on its line
self._tokens_by_line[token.range.start.line].append(token)
self._prev_token = token

# Store user-defined tokens together by name. Other token types could be stored,
# but currently there's no use case for retrieving their positions.
if token.symbol["type"] in ("LABEL", "TARGET"):
# Tokens are stored by name without address modifiers, so that e.g. Delay#
# and Delay can be retrieved with the same query. This allows for renaming
# all instances of a memory token.
token = token.without_address_modifier()

if str(token) not in self._tokens_by_name:
self._tokens_by_name[str(token)] = []

self._tokens_by_name[str(token)].append(token)

def get_matching_tokens(self, token_name: str) -> list[Token]:
"""Retrieve all tokens with a given name in the program."""
return self._tokens_by_name.get(token_name.upper(), [])

def get_token_at_position(self, position: lsp.Position) -> Token | None:
"""Retrieve the token at the given position."""
if position.line not in self._tokens_by_line:
return None

line_tokens = self._tokens_by_line[position.line]
token_starts = [t.range.start.character for t in line_tokens]
token_ends = [t.range.end.character for t in line_tokens]

idx = bisect.bisect_left(token_starts, position.character)

# The index returned by bisect_left points to the start value >= character. This
# will either be the first character of the token or the start of the next
# token. First check if we're out of bounds, then shift left unless we're at the
# first character of the token.
if idx == len(line_tokens) or token_starts[idx] != position.character:
idx -= 1

# If the col falls after the end of the token, we're not inside a token.
if position.character > token_ends[idx]:
return None

return line_tokens[idx]


class SPINAsmParser(fv1parse):
Expand Down
2 changes: 2 additions & 0 deletions src/spinasm_lsp/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The SPINAsm Language Server Protocol implementation."""

from __future__ import annotations

from functools import lru_cache
Expand Down
136 changes: 0 additions & 136 deletions src/spinasm_lsp/utils.py

This file was deleted.

73 changes: 71 additions & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,84 @@

from __future__ import annotations

import lsprotocol.types as lsp
import pytest

from spinasm_lsp.parser import SPINAsmParser
from spinasm_lsp.parser import SPINAsmParser, Token, TokenRegistry

from .conftest import TEST_PATCHES
from .conftest import PATCH_DIR, TEST_PATCHES


@pytest.mark.parametrize("patch", TEST_PATCHES, ids=lambda x: x.stem)
def test_example_patches(patch):
"""Test that the example patches from SPINAsm are parsable."""
with open(patch, encoding="utf-8") as f:
assert SPINAsmParser(f.read())


@pytest.fixture()
def sentence_token_registry() -> tuple[str, TokenRegistry]:
"""A sentence with a token registry for each word."""
sentence = "This is a line with words."

# Build a list of word tokens, ignoring whitespace. We'll build the tokens
# consistently with asfv1 parsed tokens.
words = list(filter(lambda x: x, sentence.split(" ")))
token_vals = [{"type": "LABEL", "txt": w, "stxt": w, "val": None} for w in words]
tokens = []
col = 0

for t in token_vals:
start = sentence.index(t["txt"], col)
token = Token(t, start=lsp.Position(line=0, character=start))
col = token.range.end.character + 1

tokens.append(token)

return sentence, TokenRegistry(tokens)


def test_get_token_from_registry(sentence_token_registry):
"""Test that tokens are correctly retrieved by position from a registry."""
sentence, reg = sentence_token_registry

# Manually build a mapping of column indexes to expected token words
token_positions = {i: None for i in range(len(sentence))}
for i in range(0, 4):
token_positions[i] = "This"
for i in range(7, 9):
token_positions[i] = "is"
for i in range(10, 11):
token_positions[i] = "a"
for i in range(12, 16):
token_positions[i] = "line"
for i in range(20, 24):
token_positions[i] = "with"
for i in range(25, 31):
token_positions[i] = "words."

for i, word in token_positions.items():
found_tok = reg.get_token_at_position(lsp.Position(line=0, character=i))
found_val = found_tok.symbol["txt"] if found_tok is not None else found_tok
msg = f"Expected token `{word}` at col {i}, found `{found_val}`"
assert found_val == word, msg


def test_get_token_at_invalid_position_returns_none(sentence_token_registry):
"""Test that retrieving tokens from out of bounds always returns None."""
_, reg = sentence_token_registry

assert reg.get_token_at_position(lsp.Position(line=99, character=99)) is None


def test_get_token_positions():
"""Test getting all positions of a token from a registry."""
patch = PATCH_DIR / "Basic.spn"
with open(patch) as fp:
source = fp.read()

parser = SPINAsmParser(source).parse()

all_matches = parser.token_registry.get_matching_tokens("apout")
assert len(all_matches) == 4
assert [t.range.start.line for t in all_matches] == [23, 57, 60, 70]
Loading

0 comments on commit 3cc1557

Please sign in to comment.