diff --git a/fir/src/lib.rs b/fir/src/lib.rs index 48619300..7b727b80 100644 --- a/fir/src/lib.rs +++ b/fir/src/lib.rs @@ -95,9 +95,9 @@ // Does that make sense? Does that indicate that for all types we must first keep a Option which is set to None? // Is this going to cause problems? -use std::collections::BTreeMap; use std::fmt::Debug; use std::hash::Hash; +use std::{collections::BTreeMap, ops::Index}; mod checks; pub mod iter; @@ -122,7 +122,7 @@ pub enum RefIdx { impl RefIdx { #[track_caller] - pub fn unwrap(&self) -> OriginIdx { + pub fn expect_resolved(&self) -> OriginIdx { match self { RefIdx::Resolved(idx) => *idx, RefIdx::Unresolved => unreachable!("unexpected `RefIdx::Unresolved` in `Fir`"), @@ -223,6 +223,22 @@ pub enum Kind { Return(Option), // to any kind } +impl Index<&RefIdx> for Fir { + type Output = Node; + + fn index(&self, index: &RefIdx) -> &Node { + &self.nodes[&index.expect_resolved()] + } +} + +impl Index<&OriginIdx> for Fir { + type Output = Node; + + fn index(&self, index: &OriginIdx) -> &Node { + &self.nodes[index] + } +} + #[derive(Debug, Clone)] pub struct Node { pub data: T, diff --git a/flatten/src/lib.rs b/flatten/src/lib.rs index b83a22b8..1130939b 100644 --- a/flatten/src/lib.rs +++ b/flatten/src/lib.rs @@ -666,7 +666,6 @@ impl<'ast> Ctx<'ast> { fn visit_type( self, - ast: AstInfo<'ast>, generics: &[GenericArgument], fields: &[TypedValue], @@ -690,7 +689,7 @@ impl<'ast> Ctx<'ast> { fn visit_var_declaration( self, ast: AstInfo<'ast>, - _mutable: &bool, + _mutable: bool, _to_declare: &Symbol, value: &'ast Ast, ) -> (Ctx<'ast>, RefIdx) { @@ -746,6 +745,28 @@ impl<'ast> Ctx<'ast> { self.append(data, kind) } + fn handle_field_instantiation(self, instantiation: &'ast Ast) -> (Ctx<'ast>, RefIdx) { + let AstNode::VarAssign { value, .. } = &instantiation.node else { + // FIXME: Ugly? + unreachable!( + "invalid AST: non var-assign in field instantiation, in type instantiation" + ) + }; + + let (ctx, value) = self.visit(value); + + let data = FlattenData { + scope: ctx.scope, + ast: AstInfo::Node(instantiation), + }; + let kind = Kind::Assignment { + to: RefIdx::Unresolved, + from: value, + }; + + ctx.append(data, kind) + } + fn visit_type_instantiation( self, ast: AstInfo<'ast>, @@ -756,7 +777,7 @@ impl<'ast> Ctx<'ast> { }: &'ast Call, ) -> (Ctx<'ast>, RefIdx) { let (ctx, generics) = self.visit_fold(generics.iter(), Ctx::handle_ty_node); - let (ctx, fields) = ctx.visit_fold(fields.iter(), Ctx::visit); + let (ctx, fields) = ctx.visit_fold(fields.iter(), Ctx::handle_field_instantiation); let data = FlattenData { scope: ctx.scope, @@ -802,7 +823,7 @@ impl<'ast> Ctx<'ast> { mutable, to_declare, value, - } => self.visit_var_declaration(node, mutable, to_declare, value), + } => self.visit_var_declaration(node, *mutable, to_declare, value), AstNode::VarAssign { to_assign, value } => { self.visit_var_assign(node, to_assign, value) } @@ -899,7 +920,7 @@ mod tests { Kind::Statements(stmts) => stmts, _ => unreachable!(), }; - let ret_idx = stmts[0].unwrap(); + let ret_idx = stmts[0].expect_resolved(); assert!(matches!( fir.nodes.get(&ret_idx).unwrap().kind, diff --git a/name_resolve/src/declarator.rs b/name_resolve/src/declarator.rs index 99b5c160..695f3324 100644 --- a/name_resolve/src/declarator.rs +++ b/name_resolve/src/declarator.rs @@ -1,60 +1,78 @@ -use fir::{Fallible, Fir, Node, RefIdx, Traversal}; +use fir::{Fallible, Fir, Node, OriginIdx, RefIdx, Traversal}; use flatten::FlattenData; -use crate::{NameResolutionError, NameResolveCtx}; +use crate::{NameResolutionError, NameResolveCtx, UniqueError}; -pub(crate) struct Declarator<'ctx>(pub(crate) &'ctx mut NameResolveCtx); +enum DefinitionKind { + Function, + Type, + Binding, +} + +pub(crate) struct Declarator<'ctx, 'enclosing>(pub(crate) &'ctx mut NameResolveCtx<'enclosing>); + +impl<'ctx, 'enclosing> Declarator<'ctx, 'enclosing> { + fn define( + &mut self, + kind: DefinitionKind, + node: &Node, + ) -> Fallible { + let (map, kind) = match kind { + DefinitionKind::Function => (&mut self.0.mappings.functions, "function"), + DefinitionKind::Type => (&mut self.0.mappings.types, "type"), + DefinitionKind::Binding => (&mut self.0.mappings.bindings, "binding"), + }; + + map.insert( + node.data.ast.symbol().unwrap().clone(), + node.origin, + self.0.enclosing_scope[node.origin], + ) + .map_err(|existing| Declarator::unique_error(node, existing, kind)) + } + + fn unique_error( + node: &Node, + existing: OriginIdx, + kind: &'static str, + ) -> NameResolutionError { + NameResolutionError::non_unique(node.data.ast.location(), UniqueError(existing, kind)) + } +} + +impl<'ast, 'ctx, 'enclosing> Traversal, NameResolutionError> + for Declarator<'ctx, 'enclosing> +{ + // TODO: Can we factor these three functions? -impl<'ast, 'ctx> Traversal, NameResolutionError> for Declarator<'ctx> { fn traverse_function( &mut self, - _fir: &Fir, + _: &Fir, node: &Node, - _generics: &[RefIdx], - _args: &[RefIdx], - _return_ty: &Option, - _block: &Option, + _: &[RefIdx], + _: &[RefIdx], + _: &Option, + _: &Option, ) -> Fallible { - self.0 - .mappings - .add_function( - node.data.ast.symbol().unwrap().clone(), - node.data.scope, - node.origin, - ) - .map_err(|ue| NameResolutionError::non_unique(node.data.ast.location(), ue)) + self.define(DefinitionKind::Function, node) } fn traverse_type( &mut self, - _fir: &Fir, + _: &Fir, node: &Node, _: &[RefIdx], _: &[RefIdx], ) -> Fallible { - self.0 - .mappings - .add_type( - node.data.ast.symbol().unwrap().clone(), - node.data.scope, - node.origin, - ) - .map_err(|ue| NameResolutionError::non_unique(node.data.ast.location(), ue)) + self.define(DefinitionKind::Type, node) } fn traverse_binding( &mut self, - _fir: &Fir, + _: &Fir, node: &Node, - _to: &RefIdx, + _: &RefIdx, ) -> Fallible { - self.0 - .mappings - .add_variable( - node.data.ast.symbol().unwrap().clone(), - node.data.scope, - node.origin, - ) - .map_err(|ue| NameResolutionError::non_unique(node.data.ast.location(), ue)) + self.define(DefinitionKind::Binding, node) } } diff --git a/name_resolve/src/lib.rs b/name_resolve/src/lib.rs index 37ff8835..d80e6d57 100644 --- a/name_resolve/src/lib.rs +++ b/name_resolve/src/lib.rs @@ -1,145 +1,192 @@ -use std::collections::HashMap; +//! The name-resolve module takes care of resolving each "name" within a `jinko` program to +//! its definition site. You can think of it as a function, which will take an unresolved +//! [`Fir`] as input, and output a new [`Fir`] where each node points to the definition it +//! uses. To do this, multiple passes are applied on the input [`Fir`]. +//! +//! 1. We start by "scoping" *all* of the definitions and usages in the program. +//! This is done via the [`Scoper`] type, which will assign an enclosing scope to +//! each node of our [`Fir`]. This is important, since name resolution in `jinko` +//! can go "backwards" or "upwards". You can access a definition in an *outer* +//! scope from an *inner* scope. This means that the outermost scope is able to +//! use definitions from the outermost scope. But if this "enclosing" or "parent" +//! relationship does not exist, then the usage is invalid: +//! +//! ```text +//! { +//! func foo() {} +//! } +//! { +//! { foo() } +//! // this scope does not have `foo`'s scope as a parent, so it will need to error out later on +//! } +//! ``` +//! +//! 2. We collect all of the definitions within the program using the [`Declarator`] struct. +//! A definition can be a function definition, a new type, as well as a new binding. This +//! "definition collection" pass will only error out if a definition is present twice, e.g.: +//! +//! ```text +//! func foo() {} +//! func foo(different: int, arguments: int) -> string { "oh no" } +//! ``` +//! +//! 2. Finally, we resolve all *usages* to their *definitions* using the [`Resolver`] type. +//! If a usage cannot be resolved to a definition, then it is an error. Similarly, if a usage +//! can be resolved to more than one definitions, we error out. The resolver does not take care +//! of resolving complex usages, such as methods, generic function calls or specialization. + +use std::{collections::HashMap, mem, ops::Index}; use error::{ErrKind, Error}; -use fir::{Fallible, Fir, Incomplete, Mapper, OriginIdx, Pass, Traversal}; +use fir::{Fallible, Fir, Incomplete, Kind, Mapper, OriginIdx, Pass, Traversal}; use flatten::FlattenData; use location::SpanTuple; use symbol::Symbol; mod declarator; mod resolver; +mod scoper; use declarator::Declarator; use resolver::{ResolveKind, Resolver}; +use scoper::Scoper; /// Error reported when an item (variable, function, type) was already declared /// in the current scope. struct UniqueError(OriginIdx, &'static str); -/// A scope contains a set of available variables, functions and types. -#[derive(Clone, Default, Debug)] -struct Scope { - variables: HashMap, - functions: HashMap, - types: HashMap, -} +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub(crate) struct Scope(pub(crate) OriginIdx); -/// A scope map keeps track of the currently available scopes and the current depth -/// level. -#[derive(Clone, Default, Debug)] -struct ScopeMap { - scopes: Vec, -} +impl Scope { + pub fn replace(&mut self, new: OriginIdx) -> OriginIdx { + mem::replace(&mut self.0, new) + } -impl ScopeMap { - fn get( - &self, - key: &Symbol, - scope: usize, - map_extractor: impl Fn(&Scope) -> &HashMap, - ) -> Option<&OriginIdx> { - self.scopes - .get(scope) - // FIXME: This is buggy if there aren't any scopes, so if we get without having inserted first. e.g with the following code - // ```jinko - // name = "jinko"; // no declaration, just a binding, so we "get" without having created the scope first - // ``` - .map_or( - // This is a workaround for now but Wow! it's super fucking ugly - match scope { - 1 => None, - _ => Some(self.scopes.len() - 1), - }, - |_| Some(scope), - ) - .map(|last| &self.scopes[0..=last]) - .and_then(|scopes| { - scopes - .iter() - .map(|scope| map_extractor(scope).get(key)) - .find(|value| value.is_some())? - }) + pub fn origin(&self) -> OriginIdx { + self.0 } +} - fn insert_unique( - &mut self, - key: Symbol, - value: OriginIdx, - scope: usize, - map_extractor: impl Fn(&mut Scope) -> &mut HashMap, - ) -> Result<(), OriginIdx> { - let scope = match self.scopes.get_mut(scope) { - Some(scope) => scope, - None => { - (self.scopes.len()..=scope).for_each(|_| self.scopes.push(Scope::default())); - - &mut self.scopes[scope] - } - }; +/// This data structure maps each node from the [`Fir`] to the scope which contains it. This +/// makes finding the definition associated with a name very easy, as we can simply look at the +/// name's enclosing scope, and look for definitions. If no suitable definition, we look at the +/// parent scope of this scope, and repeat the process, until we find a definition or exhaust +/// valid parents. This struct keeps a reference on a map, making it cheap to copy and pass around. +#[derive(Clone, Copy, Debug)] +struct EnclosingScope<'enclosing>(&'enclosing HashMap); - let map = map_extractor(scope); +impl Index for EnclosingScope<'_> { + type Output = Scope; - match map.get(&key) { - Some(existing) => Err(*existing), - None => { - map.insert(key, value); - Ok(()) - } - } + fn index(&self, index: OriginIdx) -> &Self::Output { + &self.0[&index] } +} - /// Maybe get a variable in any available scopes - fn get_variable(&self, name: &Symbol, scope: usize) -> Option<&OriginIdx> { - self.get(name, scope, |scope| &scope.variables) - } +type Bindings = HashMap; - /// Maybe get a function in any available scopes - fn get_function(&self, name: &Symbol, scope: usize) -> Option<&OriginIdx> { - self.get(name, scope, |scope| &scope.functions) - } +/// Each scope in the [`scopes`] map contains the bindings associated with a given scope, +/// meaning that each scope contains a list of definitions. A definition can be thought of as the +/// mapping of a name ([`Symbol`]) to the node's index in the [`Fir`]. +#[derive(Clone, Debug)] +struct FlatScope<'enclosing> { + scopes: HashMap, + enclosing_scope: EnclosingScope<'enclosing>, +} + +/// Allow iterating on a [`FlatScope`] by going through the chain of enclosing scopes +struct FlatIterator<'scope, 'enclosing>(Option, &'scope FlatScope<'enclosing>); + +trait LookupIterator<'scope, 'enclosing> { + fn lookup_iterator(&'scope self, starting_scope: Scope) -> FlatIterator<'scope, 'enclosing>; +} - /// Maybe get a type in any available scopes - fn get_type(&self, name: &Symbol, scope: usize) -> Option<&OriginIdx> { - self.get(name, scope, |scope| &scope.types) +impl<'scope, 'enclosing> LookupIterator<'scope, 'enclosing> for FlatScope<'enclosing> { + fn lookup_iterator(&'scope self, starting_scope: Scope) -> FlatIterator<'scope, 'enclosing> { + FlatIterator(Some(starting_scope), self) } +} - /// Add a variable to the current scope if it hasn't been added before - fn add_variable( - &mut self, - name: Symbol, - scope: usize, - var: OriginIdx, - ) -> Result<(), UniqueError> { - self.insert_unique(name, var, scope, |scope| &mut scope.variables) - .map_err(|existing| UniqueError(existing, "binding")) +impl<'scope, 'enclosing> Iterator for FlatIterator<'scope, 'enclosing> { + type Item = &'scope Bindings; + + fn next(&mut self) -> Option { + let cursor = self.0; + let bindings = cursor.and_then(|scope| self.1.scopes.get(&scope)); + + self.0 = cursor + // TODO: Factor this in a method? + .and_then(|current| self.1.enclosing_scope.0.get(¤t.origin())) + .copied(); + + bindings } +} - /// Add a function to the current scope if it hasn't been added before - fn add_function( - &mut self, - name: Symbol, - scope: usize, - func: OriginIdx, - ) -> Result<(), UniqueError> { - self.insert_unique(name, func, scope, |scope| &mut scope.functions) - .map_err(|existing| UniqueError(existing, "function")) +impl<'enclosing> FlatScope<'enclosing> { + fn lookup(&self, name: &Symbol, starting_scope: Scope) -> Option<&OriginIdx> { + self.lookup_iterator(starting_scope) + .find_map(|bindings| bindings.get(name)) } - /// Add a type to the current scope if it hasn't been added before - fn add_type( - &mut self, - name: Symbol, - scope: usize, - custom_type: OriginIdx, - ) -> Result<(), UniqueError> { - self.insert_unique(name, custom_type, scope, |scope| &mut scope.types) - .map_err(|existing| UniqueError(existing, "type")) + fn insert(&mut self, name: Symbol, idx: OriginIdx, scope: Scope) -> Result<(), OriginIdx> { + // we need to use the innermost scope here, not `lookup` + if let Some(existing) = self + .scopes + .get(&scope) + .expect("interpreter error: there should always be at least one outer scope") + .get(&name) + { + return Err(*existing); + } + + self.scopes.entry(scope).or_default().insert(name, idx); + + Ok(()) } } -#[derive(Default)] -struct NameResolveCtx { - mappings: ScopeMap, +/// A scope map keeps track of the currently available scopes and the current depth +/// level. +#[derive(Clone, Debug)] +struct ScopeMap<'enclosing> { + pub bindings: FlatScope<'enclosing>, + pub functions: FlatScope<'enclosing>, + pub types: FlatScope<'enclosing>, +} + +struct NameResolveCtx<'enclosing> { + enclosing_scope: EnclosingScope<'enclosing>, + mappings: ScopeMap<'enclosing>, +} + +impl<'enclosing> NameResolveCtx<'enclosing> { + fn new(enclosing_scope: EnclosingScope<'enclosing>) -> NameResolveCtx { + let empty_scope_map: HashMap = enclosing_scope + .0 + .values() + .map(|scope_idx| (*scope_idx, Bindings::new())) + .collect(); + + NameResolveCtx { + enclosing_scope, + mappings: ScopeMap { + bindings: FlatScope { + scopes: empty_scope_map.clone(), + enclosing_scope, + }, + functions: FlatScope { + scopes: empty_scope_map.clone(), + enclosing_scope, + }, + types: FlatScope { + scopes: empty_scope_map, + enclosing_scope, + }, + }, + } + } } /// Extension type of [`Error`] to be able to implement [`IterError`]. @@ -224,12 +271,14 @@ impl NameResolutionError { // TODO: Go through mappings again to find a relevant type or var which could work Error::new(ErrKind::NameResolution) - .with_msg(format!("unresolved binding to {sym}")) + .with_msg(format!("unresolved binding to `{sym}`")) .with_loc(location) .with_hint( - Error::hint().with_msg(format!("searched for empty type named {sym}")), + Error::hint().with_msg(format!("searched for empty type named `{sym}`")), + ) + .with_hint( + Error::hint().with_msg(format!("searched for binding named `{sym}`")), ) - .with_hint(Error::hint().with_msg(format!("searched for binding named {sym}"))) } } } @@ -263,7 +312,26 @@ impl NameResolutionError { } } -impl NameResolveCtx { +impl<'enclosing> NameResolveCtx<'enclosing> { + fn scope(fir: &Fir) -> HashMap { + let root = fir.nodes.last_key_value().unwrap(); + + let mut scoper = Scoper { + current_scope: Scope(*root.0), + enclosing_scope: HashMap::new(), + }; + + let Kind::Statements(stmts) = &root.1.kind else { + unreachable!() + }; + + stmts + .iter() + .for_each(|stmt| scoper.traverse_node(fir, &fir[stmt]).unwrap()); + + scoper.enclosing_scope + } + fn insert_definitions(&mut self, fir: &Fir) -> Fallible> { Declarator(self).traverse(fir) } @@ -276,7 +344,9 @@ impl NameResolveCtx { } } -impl<'ast> Pass, FlattenData<'ast>, Error> for NameResolveCtx { +impl<'ast, 'enclosing> Pass, FlattenData<'ast>, Error> + for NameResolveCtx<'enclosing> +{ fn pre_condition(_fir: &Fir) {} fn post_condition(_fir: &Fir) {} @@ -311,7 +381,9 @@ pub trait NameResolve<'ast> { impl<'ast> NameResolve<'ast> for Fir> { fn name_resolve(self) -> Result>, Error> { - let mut ctx = NameResolveCtx::default(); + // TODO: Ugly asf + let enclosing_scope = NameResolveCtx::scope(&self); + let mut ctx = NameResolveCtx::new(EnclosingScope(&enclosing_scope)); ctx.pass(self) } @@ -337,9 +409,11 @@ mod tests { let a_reference = &fir.nodes[&OriginIdx(3)]; assert!(matches!(a_reference.kind, Kind::TypedValue { .. })); - let a_reference = match a_reference.kind { - Kind::TypedValue { value, .. } => value, - _ => unreachable!(), + let Kind::TypedValue { + value: a_reference, .. + } = a_reference.kind + else { + unreachable!() }; assert_eq!(a_reference, RefIdx::Resolved(a.origin)); @@ -409,7 +483,7 @@ mod tests { let fir = ast.flatten().name_resolve(); - assert!(fir.is_err()) + assert!(fir.is_err()); } #[test] @@ -478,9 +552,9 @@ mod tests { let fir = ast.flatten().name_resolve().unwrap(); - let x_value = &fir.nodes[&OriginIdx(3)]; - let marker_1 = &fir.nodes[&OriginIdx(2)]; - let marker_2 = &fir.nodes[&OriginIdx(1)]; + let marker_1 = dbg!(&fir.nodes[&OriginIdx(1)]); + let marker_2 = dbg!(&fir.nodes[&OriginIdx(2)]); + let x_value = dbg!(&fir.nodes[&OriginIdx(3)]); assert!(matches!(x_value.kind, Kind::TypedValue { .. })); assert!(matches!(marker_1.kind, Kind::Type { .. })); @@ -495,4 +569,19 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn fail_resolution_to_fn_arg() { + let ast = ast! { + type A; + + func f(a: A) {} + + where x = a; // invalid + }; + + let fir = ast.flatten().name_resolve(); + + assert!(fir.is_err()); + } } diff --git a/name_resolve/src/resolver.rs b/name_resolve/src/resolver.rs index 80b6b458..df30e14a 100644 --- a/name_resolve/src/resolver.rs +++ b/name_resolve/src/resolver.rs @@ -24,24 +24,27 @@ impl Display for ResolveKind { } } -pub(crate) struct Resolver<'ctx>(pub(crate) &'ctx mut NameResolveCtx); +pub(crate) struct Resolver<'ctx, 'enclosing>(pub(crate) &'ctx mut NameResolveCtx<'enclosing>); -impl<'ctx> Resolver<'ctx> { +impl<'ctx, 'enclosing> Resolver<'ctx, 'enclosing> { fn get_definition( &self, kind: ResolveKind, sym: Option<&Symbol>, location: &SpanTuple, - scope: usize, + node: OriginIdx, ) -> Result { let symbol = sym.expect("attempting to get definition for non existent symbol - interpreter bug"); let mappings = &self.0.mappings; + + let scope = self.0.enclosing_scope[dbg!(node)]; + let origin = match kind { - ResolveKind::Call => mappings.get_function(symbol, scope), - ResolveKind::Type => mappings.get_type(symbol, scope), - ResolveKind::Var => mappings.get_variable(symbol, scope), + ResolveKind::Call => mappings.functions.lookup(symbol, scope), + ResolveKind::Type => mappings.types.lookup(symbol, scope), + ResolveKind::Var => mappings.bindings.lookup(symbol, scope), }; origin.map_or_else( @@ -51,8 +54,8 @@ impl<'ctx> Resolver<'ctx> { } } -impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionError> - for Resolver<'ctx> +impl<'ast, 'ctx, 'enclosing> Mapper, FlattenData<'ast>, NameResolutionError> + for Resolver<'ctx, 'enclosing> { fn map_call( &mut self, @@ -75,7 +78,7 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionErro ResolveKind::Call, data.ast.symbol(), data.ast.location(), - data.scope, + origin, )?; Ok(Node { @@ -102,13 +105,13 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionErro ResolveKind::Var, data.ast.symbol(), data.ast.location(), - data.scope, + origin, ); let ty_def = self.get_definition( ResolveKind::Type, data.ast.symbol(), data.ast.location(), - data.scope, + origin, ); // If we're dealing with a type definition, we can "early typecheck" @@ -119,7 +122,6 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionErro }; let definition = match (var_def, ty_def) { - (Ok(def), Err(_)) | (Err(_), Ok(def)) => Ok(def), (Ok(var_def), Ok(ty_def)) => Err(NameResolutionError::ambiguous_binding( var_def, ty_def, @@ -129,6 +131,7 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionErro data.ast.symbol(), data.ast.location(), )), + (Ok(def), _) | (_, Ok(def)) => Ok(def), }?; Ok(Node { @@ -151,7 +154,7 @@ impl<'ast, 'ctx> Mapper, FlattenData<'ast>, NameResolutionErro ResolveKind::Type, data.ast.symbol(), data.ast.location(), - data.scope, + origin, )?; Ok(Node { diff --git a/name_resolve/src/scoper.rs b/name_resolve/src/scoper.rs new file mode 100644 index 00000000..0a2b5e4c --- /dev/null +++ b/name_resolve/src/scoper.rs @@ -0,0 +1,194 @@ +use std::collections::HashMap; + +use crate::Scope; + +use fir::{Fallible, Fir, Kind, Node, OriginIdx, RefIdx, Traversal}; +use flatten::FlattenData; + +pub(crate) struct Scoper { + /// The current scope we are visiting, which we will use when we assign a scope to each + /// node in [`Scoper::scope`] + pub(crate) current_scope: Scope, + /// Map of each node to the scope it is contained in. This will be built progressively as + /// we visit each node + pub(crate) enclosing_scope: HashMap, +} + +impl Scoper { + /// Set the enclosing scope of `to_scope` to the current scope + fn scope(&mut self, to_scope: &Node) { + self.enclosing_scope + .insert(to_scope.origin, self.current_scope); + } + + /// Enter a new scope, replacing the context's current scope. This returns the old scope, + /// which you will need to reuse when you exit the scoped node you are visiting + fn enter_scope(&mut self, new_scope: OriginIdx) -> OriginIdx { + self.current_scope.replace(new_scope) + } + + // TODO: Move this function in `Traversal`? + fn maybe_visit_child(&mut self, fir: &Fir>, ref_idx: &RefIdx) -> Fallible<()> { + match ref_idx { + RefIdx::Resolved(origin) => self.traverse_node(fir, &fir[origin]), + // we skip unresolved nodes here + RefIdx::Unresolved => Ok(()), + } + } +} + +impl<'ast> Traversal, () /* FIXME: Ok to have void as an error type? */> + for Scoper +{ + fn traverse_assignment( + &mut self, + fir: &Fir>, + _node: &Node>, + to: &RefIdx, + from: &RefIdx, + ) -> Fallible<()> { + self.maybe_visit_child(fir, to)?; + self.maybe_visit_child(fir, from) + } + + fn traverse_function( + &mut self, + fir: &Fir>, + node: &Node>, + generics: &[RefIdx], + args: &[RefIdx], + return_ty: &Option, + block: &Option, + ) -> Fallible<()> { + let old = self.enter_scope(node.origin); + + generics + .iter() + .for_each(|generic| self.maybe_visit_child(fir, generic).unwrap()); + + args.iter() + .for_each(|arg| self.maybe_visit_child(fir, arg).unwrap()); + + block.map(|definition| self.maybe_visit_child(fir, &definition)); + return_ty.map(|ty| self.maybe_visit_child(fir, &ty)); + + self.enter_scope(old); + + Ok(()) + } + + fn traverse_statements( + &mut self, + fir: &Fir>, + node: &Node>, + stmts: &[RefIdx], + ) -> Fallible<()> { + // TODO: Ugly but can we do anything better? Can we have types which force you to exit a scope if you enter one? + let old = self.enter_scope(node.origin); + + stmts + .iter() + .for_each(|stmt| self.maybe_visit_child(fir, stmt).unwrap()); + + self.enter_scope(old); + + Ok(()) + } + + fn traverse_node( + &mut self, + fir: &Fir>, + node: &Node>, + ) -> Fallible<()> { + self.scope(node); + + match &node.kind { + Kind::TypeReference(sub_node) + | Kind::TypeOffset { + instance: sub_node, .. + } + | Kind::Binding { to: sub_node } => self.maybe_visit_child(fir, sub_node), + Kind::TypedValue { value, ty } => { + self.maybe_visit_child(fir, value)?; + self.maybe_visit_child(fir, ty) + } + Kind::Type { fields, .. } => { + let old = self.enter_scope(node.origin); + + fields.iter().for_each(|field| { + // FIXME: Is unwrap okay here? + self.maybe_visit_child(fir, field).unwrap(); + }); + + self.enter_scope(old); + + Ok(()) + } + Kind::Generic { default } => default + .map(|def| self.maybe_visit_child(fir, &def)) + .ok_or(())?, + Kind::Assignment { to, from } => self.traverse_assignment(fir, node, to, from), + Kind::Instantiation { + to, + generics, + fields, + } => { + self.maybe_visit_child(fir, to)?; + generics + .iter() + .for_each(|generic| self.maybe_visit_child(fir, generic).unwrap()); + + fields + .iter() + .for_each(|field| { self.maybe_visit_child(fir, field) }.unwrap()); + + Ok(()) + } + Kind::Call { to, generics, args } => { + self.maybe_visit_child(fir, to)?; + generics + .iter() + .for_each(|generic| self.maybe_visit_child(fir, generic).unwrap()); + args.iter() + .for_each(|arg| self.maybe_visit_child(fir, arg).unwrap()); + + Ok(()) + } + Kind::Function { + generics, + args, + return_type, + block, + } => self.traverse_function(fir, node, generics, args, return_type, block), + Kind::Statements(stmts) => self.traverse_statements(fir, node, stmts), + Kind::Conditional { + condition, + true_block, + false_block, + } => { + self.maybe_visit_child(fir, condition)?; + self.maybe_visit_child(fir, true_block)?; + false_block + .map(|else_block| self.maybe_visit_child(fir, &else_block)) + .ok_or(())? + } + Kind::Return(sub_node) => sub_node + .map(|node| self.maybe_visit_child(fir, &node)) + .ok_or(())?, + Kind::Loop { condition, block } => { + self.maybe_visit_child(fir, condition)?; + self.maybe_visit_child(fir, block)?; + + Ok(()) + } + // nothing to do for constants, other than scoping them + Kind::Constant(_) => Ok(()), + } + } + + /// Nothing to do here, right? This function should never be called + /// FIXME: Add documentation + fn traverse(&mut self, _fir: &Fir>) -> Fallible> { + unreachable!() + } +} diff --git a/typecheck/src/actual.rs b/typecheck/src/actual.rs index 712d7241..7521b4f2 100644 --- a/typecheck/src/actual.rs +++ b/typecheck/src/actual.rs @@ -10,7 +10,7 @@ use crate::{Type, TypeCtx}; pub(crate) struct Actual<'ctx>(pub(crate) &'ctx mut TypeCtx); fn innermost_type(fir: &Fir, linked_node: RefIdx) -> Option { - let linked_node = &fir.nodes[&linked_node.unwrap()]; + let linked_node = &fir.nodes[&linked_node.expect_resolved()]; let inner_opt = |fir, opt| match opt { Some(opt) => innermost_type(fir, opt), diff --git a/typecheck/src/checker.rs b/typecheck/src/checker.rs index 62f3e7ca..f9b26491 100644 --- a/typecheck/src/checker.rs +++ b/typecheck/src/checker.rs @@ -20,7 +20,7 @@ impl<'ctx> Checker<'ctx> { fn get_type(&self, of: &RefIdx) -> Option { // if at this point, the reference is unresolved, or if we haven't seen that node yet, it's // an interpreter error - *self.0.types.get(&of.unwrap()).unwrap() + *self.0.types.get(&of.expect_resolved()).unwrap() } } @@ -32,7 +32,7 @@ fn type_mismatch( ) -> Error { let get_symbol = |ty| { let Type::One(idx) = ty; - fir.nodes[&idx.unwrap()].data.ast.symbol().unwrap() + fir.nodes[&idx.expect_resolved()].data.ast.symbol().unwrap() }; let name_fmt = |ty: Option<&Symbol>| match ty { Some(ty) => format!("`{}`", ty.access().purple()), diff --git a/typecheck/src/primitives.rs b/typecheck/src/primitives.rs index 39588376..b583c8f0 100644 --- a/typecheck/src/primitives.rs +++ b/typecheck/src/primitives.rs @@ -39,7 +39,7 @@ fn validate_type( let collect_locs = |many_refs: &[RefIdx]| { many_refs .iter() - .map(|generic| &fir.nodes[&generic.unwrap()]) + .map(|generic| &fir.nodes[&generic.expect_resolved()]) .map(|node| node.data.ast.location().clone()) .collect::>() };