diff --git a/language_server/__main__.py b/language_server/__main__.py index 2931ef2..39a0ce6 100644 --- a/language_server/__main__.py +++ b/language_server/__main__.py @@ -2,6 +2,10 @@ import logging from lsprotocol import types as lsp +from .server.features.rename import rename_variable + +from .server.features.references import get_references + from .server.features.hover import get_hover from .server.features.definition import get_definition @@ -55,12 +59,24 @@ def semantic_tokens_full(ls: MechaLanguageServer, params: lsp.SemanticTokensPara def definition(ls: MechaLanguageServer, params: lsp.DefinitionParams): return get_definition(ls, params) +@mecha_server.feature( + lsp.TEXT_DOCUMENT_REFERENCES +) +def references(ls: MechaLanguageServer, params: lsp.ReferenceParams): + return get_references(ls, params) + @mecha_server.feature( lsp.TEXT_DOCUMENT_HOVER ) def hover(ls: MechaLanguageServer, params: lsp.HoverParams): return get_hover(ls, params) +@mecha_server.feature( + lsp.TEXT_DOCUMENT_RENAME +) +def rename(ls: MechaLanguageServer, params: lsp.RenameParams): + return rename_variable(ls, params) + def add_arguments(parser: argparse.ArgumentParser): parser.description = "simple json server example" diff --git a/language_server/server/features/definition.py b/language_server/server/features/definition.py index 280ff94..dc35850 100644 --- a/language_server/server/features/definition.py +++ b/language_server/server/features/definition.py @@ -1,5 +1,5 @@ import logging -from bolt import AstIdentifier, AstTargetIdentifier, LexicalScope +from bolt import AstIdentifier, AstTargetIdentifier, Binding, LexicalScope from lsprotocol import types as lsp from .validate import get_compilation_data @@ -9,35 +9,19 @@ get_node_at_position, node_location_to_range, node_start_to_range, + search_scope_for_binding, ) from .. import MechaLanguageServer -def search_scope( - var_name: str, node: AstIdentifier | AstTargetIdentifier, scope: LexicalScope -) -> lsp.Range | None: - variables = scope.variables - - if var_name in variables: - var_data = variables[var_name] - - for binding in var_data.bindings: - if node in binding.references or node == binding.origin: - return node_location_to_range(binding.origin) - - for child in scope.children: - if range := search_scope(var_name, node, child): - return range - - return None def get_definition(ls: MechaLanguageServer, params: lsp.DefinitionParams): compiled_doc = fetch_compilation_data(ls, params) if compiled_doc is None or compiled_doc.compiled_module is None: - return [] + return ast = compiled_doc.compiled_module.ast scope = compiled_doc.compiled_module.lexical_scope @@ -46,9 +30,11 @@ def get_definition(ls: MechaLanguageServer, params: lsp.DefinitionParams): if isinstance(node, AstIdentifier) or isinstance(node, AstTargetIdentifier): var_name = node.value - range = search_scope(var_name, node, scope) + binding, scope = search_scope_for_binding(var_name, node, scope) - if not range: + if not binding: return + range = node_location_to_range(binding.origin) + return lsp.Location(params.text_document.uri, range) diff --git a/language_server/server/features/helpers.py b/language_server/server/features/helpers.py index 3ba8338..00727aa 100644 --- a/language_server/server/features/helpers.py +++ b/language_server/server/features/helpers.py @@ -1,5 +1,6 @@ import logging from typing import Any +from bolt import AstIdentifier, AstTargetIdentifier, Binding, LexicalScope from tokenstream import SourceLocation from mecha import AstNode from lsprotocol import types as lsp @@ -69,4 +70,22 @@ def offset_location(location: SourceLocation, offset): location.pos + offset, location.lineno, location.colno + offset - ) \ No newline at end of file + ) + +def search_scope_for_binding( + var_name: str, node: AstIdentifier | AstTargetIdentifier, scope: LexicalScope +) -> tuple[Binding, LexicalScope] | None: + variables = scope.variables + + if var_name in variables: + var_data = variables[var_name] + + for binding in var_data.bindings: + if node in binding.references or node == binding.origin: + return (binding, scope) + + for child in scope.children: + if binding := search_scope_for_binding(var_name, node, child): + return binding + + return None diff --git a/language_server/server/features/references.py b/language_server/server/features/references.py new file mode 100644 index 0000000..da23ec9 --- /dev/null +++ b/language_server/server/features/references.py @@ -0,0 +1,33 @@ +from bolt import AstIdentifier, AstTargetIdentifier +from lsprotocol import types as lsp + +from .helpers import fetch_compilation_data, get_node_at_position, node_location_to_range, search_scope_for_binding + +from .. import MechaLanguageServer + + +def get_references(ls: MechaLanguageServer, params: lsp.ReferenceParams): + compiled_doc = fetch_compilation_data(ls, params) + + if compiled_doc is None or compiled_doc.compiled_module is None: + return + + ast = compiled_doc.compiled_module.ast + scope = compiled_doc.compiled_module.lexical_scope + + node = get_node_at_position(ast, params.position) + if isinstance(node, AstIdentifier) or isinstance(node, AstTargetIdentifier): + var_name = node.value + + binding = search_scope_for_binding(var_name, node, scope) + if not (result := search_scope_for_binding(var_name, node, scope)): + return + + binding, _ = result + + locations = [] + for reference in binding.references: + range = node_location_to_range(reference) + locations.append(lsp.Location(params.text_document.uri, range)) + + return locations \ No newline at end of file diff --git a/language_server/server/features/rename.py b/language_server/server/features/rename.py new file mode 100644 index 0000000..db8aa3e --- /dev/null +++ b/language_server/server/features/rename.py @@ -0,0 +1,43 @@ +from bolt import AstIdentifier, AstTargetIdentifier +from lsprotocol import types as lsp + +from .helpers import ( + fetch_compilation_data, + get_node_at_position, + node_location_to_range, + search_scope_for_binding, +) + +from .. import MechaLanguageServer + + +def rename_variable(ls: MechaLanguageServer, params: lsp.RenameParams): + compiled_doc = fetch_compilation_data(ls, params) + + if compiled_doc is None or compiled_doc.compiled_module is None: + return + + ast = compiled_doc.compiled_module.ast + scope = compiled_doc.compiled_module.lexical_scope + + node = get_node_at_position(ast, params.position) + if isinstance(node, AstIdentifier) or isinstance(node, AstTargetIdentifier): + var_name = node.value + + + + if not (result := search_scope_for_binding(var_name, node, scope)): + return + binding, _ = result + + edits = [] + edits.append( + lsp.TextEdit(node_location_to_range(binding.origin), params.new_name) + ) + + for reference in binding.references: + edits.append( + lsp.TextEdit(node_location_to_range(reference), params.new_name) + ) + + return lsp.WorkspaceEdit(changes={params.text_document.uri: edits}) diff --git a/language_server/server/features/semantics.py b/language_server/server/features/semantics.py index 4557eb0..1ff483d 100644 --- a/language_server/server/features/semantics.py +++ b/language_server/server/features/semantics.py @@ -95,15 +95,15 @@ class SemanticTokenCollector(Visitor): def command(self, node: AstCommand): match node.identifier: case "import:module": - modules: list[AstResourceLocation] = node.arguments + modules: list[AstResourceLocation] = node.arguments #type: ignore for m in modules: self.nodes.append( (m, TOKEN_TYPES["class" if m.namespace == None else "function"], 0) ) case "import:module:as:alias": - module: AstResourceLocation = node.arguments[0] - item: AstImportedItem = node.arguments[1] + module: AstResourceLocation = node.arguments[0] #type: ignore + item: AstImportedItem = node.arguments[1] #type: ignore type = TOKEN_TYPES["class" if module.namespace == None else "function"] @@ -118,8 +118,8 @@ def from_import(self, from_import: AstFromImport): logging.debug(from_import) - location: AstResourceLocation = from_import.arguments[0] - imports: tuple[AstImportedItem] = from_import.arguments[1:] + location: AstResourceLocation = from_import.arguments[0] #type: ignore + imports: tuple[AstImportedItem] = from_import.arguments[1:] #type: ignore self.nodes.append( ( @@ -195,10 +195,10 @@ def function_signature( def assignment(self, assignment: AstAssignment): operator = assignment.operator - nodes.append((assignment.target, TOKEN_TYPES["variable"], 0)) + self.nodes.append((assignment.target, TOKEN_TYPES["variable"], 0)) if assignment.type_annotation != None: - nodes.append((assignment.type_annotation, TOKEN_TYPES["class"], 0)) + self.nodes.append((assignment.type_annotation, TOKEN_TYPES["class"], 0)) def walk(self, root: AstNode):