diff --git a/typecheck/src/actual.rs b/typecheck/src/actual.rs index 1fd1d986..9c533ba2 100644 --- a/typecheck/src/actual.rs +++ b/typecheck/src/actual.rs @@ -82,15 +82,9 @@ impl<'ctx> TypeLinkResolver<'ctx> { // we don't insert here so that we can get the typeref directly later on - does that make sense? // self.new.types.new_type(final_type.clone()); - let final_type = if self.old.primitives.is_primitive_union_type(r) { - Type::PrimitiveUnion(*r) - } else { - Type::record(*r) - }; - ControlFlow::Break(ChainEnd { intermediate_nodes, - final_type, + final_type: Type::record(*r), }) } TypeVariable::Union(u) => { diff --git a/typecheck/src/checker.rs b/typecheck/src/checker.rs index d5f4b11c..6db8b981 100644 --- a/typecheck/src/checker.rs +++ b/typecheck/src/checker.rs @@ -84,7 +84,7 @@ impl<'ctx> Checker<'ctx> { BuiltinType::Bool => Type::record(self.0.primitives.bool_type), }; - if valid_union_type.is_superset_of(&expected_ty) { + if expected_ty.can_widen_to(&valid_union_type) { Ok(vec![expected_ty; arity]) } else { Err(unexpected_arithmetic_type( @@ -209,14 +209,14 @@ impl<'ctx> Traversal, Error> for Checker<'ctx> { return Ok(()); } - let ret_ty = return_ty - .as_ref() - .map(|b| self.get_type(b)) - .unwrap_or(self.unit()); - let block_ty = block - .as_ref() - .map(|b| self.get_type(b)) - .unwrap_or(self.unit()); + let type_or_unit = |ty: &Option| { + ty.as_ref() + .map(|ty| self.get_type(ty)) + .unwrap_or(self.unit()) + }; + + let ret_ty = type_or_unit(return_ty); + let block_ty = type_or_unit(block); if !block_ty.can_widen_to(ret_ty) { let err = type_mismatch( diff --git a/typecheck/src/collectors.rs b/typecheck/src/collectors.rs new file mode 100644 index 00000000..672c206b --- /dev/null +++ b/typecheck/src/collectors.rs @@ -0,0 +1,4 @@ +//! A collection of collectors (heh) necessary for the well-being of the type system. + +pub mod constants; +pub mod primitives; diff --git a/typecheck/src/collectors/constants.rs b/typecheck/src/collectors/constants.rs new file mode 100644 index 00000000..9a9e72d5 --- /dev/null +++ b/typecheck/src/collectors/constants.rs @@ -0,0 +1,57 @@ +//! Collect all primitive union type constants in the program in order to build our primitive union types properly. There are three primitive union types: `char`, `int` and `string`, so this module collects all character, integer and string constants. + +use std::collections::HashSet; +use std::convert::Infallible; + +use fir::{Fallible, Fir, Node, OriginIdx, RefIdx, Traversal}; +use flatten::{AstInfo, FlattenData}; + +#[derive(Default)] +pub struct ConstantCollector { + pub(crate) integers: HashSet, + pub(crate) characters: HashSet, + pub(crate) strings: HashSet, +} + +impl ConstantCollector { + pub fn new() -> ConstantCollector { + ConstantCollector::default() + } + + fn add_integer(&mut self, idx: OriginIdx) { + self.integers.insert(RefIdx::Resolved(idx)); + } + + fn add_character(&mut self, idx: OriginIdx) { + self.characters.insert(RefIdx::Resolved(idx)); + } + + fn add_string(&mut self, idx: OriginIdx) { + self.strings.insert(RefIdx::Resolved(idx)); + } +} + +impl Traversal, Infallible> for ConstantCollector { + fn traverse_constant( + &mut self, + _: &Fir>, + node: &Node>, + _: &RefIdx, + ) -> Fallible { + match node.data.ast { + AstInfo::Node(ast::Ast { + node: ast::Node::Constant(value), + .. + }) => match value { + ast::Value::Integer(_) => self.add_integer(node.origin), + ast::Value::Char(_) => self.add_character(node.origin), + ast::Value::Str(_) => self.add_string(node.origin), + // do nothing - the other constants are not part of primitive union types + _ => {} + }, + _ => unreachable!("Fir constant with non-node AST info. this is an interpreter error"), + }; + + Ok(()) + } +} diff --git a/typecheck/src/primitives.rs b/typecheck/src/collectors/primitives.rs similarity index 94% rename from typecheck/src/primitives.rs rename to typecheck/src/collectors/primitives.rs index 1e7b1623..8bfa47be 100644 --- a/typecheck/src/primitives.rs +++ b/typecheck/src/collectors/primitives.rs @@ -32,17 +32,6 @@ pub struct PrimitiveTypes { pub(crate) string_type: OriginIdx, } -impl PrimitiveTypes { - /// Check if a given type declaration's [`OriginIdx`] corresponds to a primitive union type. There - /// are three primitive union types: `char`, `int` and `string`. - pub fn is_primitive_union_type(&self, ty: &OriginIdx) -> bool { - match *ty { - idx if idx == self.char_type || idx == self.int_type || idx == self.string_type => true, - _ => false, - } - } -} - fn validate_type( fir: &Fir>, sym: &Symbol, diff --git a/typecheck/src/lib.rs b/typecheck/src/lib.rs index 9c03e53b..abee43b4 100644 --- a/typecheck/src/lib.rs +++ b/typecheck/src/lib.rs @@ -1,20 +1,23 @@ mod actual; mod checker; -mod primitives; +mod collectors; mod typemap; mod typer; use std::collections::{HashMap, HashSet}; use error::{ErrKind, Error}; -use fir::{Fir, Incomplete, Mapper, OriginIdx, Pass, RefIdx, Traversal}; +use fir::{Fir, Incomplete, Kind, Mapper, OriginIdx, Pass, RefIdx, Traversal}; use flatten::FlattenData; use actual::Actual; use checker::Checker; use typer::Typer; -use primitives::PrimitiveTypes; +use collectors::{ + constants::ConstantCollector, + primitives::{self, PrimitiveTypes}, +}; #[derive(Clone, Debug, Eq, PartialEq)] // FIXME: Should that be a hashset RefIdx or OriginIdx? @@ -44,20 +47,15 @@ impl TypeSet { /// This is of course not a realistic definition to put in our standard library (and it gets worse for `string`) /// so these types have to be handled separately. #[derive(Clone, Debug, Eq, PartialEq)] -pub enum Type { - PrimitiveUnion(OriginIdx), - Set(OriginIdx, TypeSet), -} +pub struct Type(OriginIdx, TypeSet); impl Type { pub fn origin(&self) -> OriginIdx { - match self { - Type::PrimitiveUnion(idx) | Type::Set(idx, _) => *idx, - } + self.0 } pub fn builtin(set: HashSet) -> Type { - Type::Set(OriginIdx(u64::MAX), TypeSet(set)) + Type(OriginIdx(u64::MAX), TypeSet(set)) } pub fn record(origin: OriginIdx) -> Type { @@ -65,26 +63,23 @@ impl Type { // FIXME: Switch to keeping HashSet instead set.insert(RefIdx::Resolved(origin)); - Type::Set(origin, TypeSet(set)) + Type(origin, TypeSet(set)) } pub fn union(origin: OriginIdx, variants: impl Iterator) -> Type { - Type::Set(origin, TypeSet(variants.collect())) + Type(origin, TypeSet(variants.collect())) } pub fn set(&self) -> &TypeSet { - match self { - Type::Set(_, set) => set, - Type::PrimitiveUnion(_) => unreachable!("trying to access a non-representable set for a primitive union. this is an interpreter error"), - } + &self.1 } pub fn is_superset_of(&self, other: &Type) -> bool { - return self.set().contains(other.set()); + self.set().contains(other.set()) } - pub fn can_widen_to(&self, superset: (&Type, ) -> bool { - return superset.set().contains(self.set()); + pub fn can_widen_to(&self, superset: &Type) -> bool { + superset.set().contains(self.set()) } } @@ -110,9 +105,33 @@ pub trait TypeCheck: Sized { } impl<'ast> TypeCheck>> for Fir> { - fn type_check(self) -> Result>, Error> { + fn type_check(mut self) -> Result>, Error> { let primitives = primitives::find(&self)?; + let mut const_collector = ConstantCollector::new(); + const_collector.traverse(&self)?; + + // We can now build our primitive union types. Because the first TypeCtx deals + // with [`TypeVariable`]s, it's not possible to directly create a TypeSet - so + // we can do that later on during typechecking, right before the actual + // typechecking. An alternative is to modify the [`Fir`] directly by creating + // new nodes for these primitive unions, which is probably a little cleaner and + // less spaghetti. + let mk_constant_types = |set: HashSet| set.into_iter().collect(); + + self[primitives.int_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.integers), + }; + self[primitives.char_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.characters), + }; + self[primitives.string_type].kind = Kind::UnionType { + generics: vec![], + variants: mk_constant_types(const_collector.strings), + }; + TypeCtx { primitives, types: HashMap::new(), @@ -481,7 +500,7 @@ mod tests { #[test] fn typeset_makes_sense() { - let superset = Type::Set( + let superset = Type( OriginIdx(4), TypeSet( [ @@ -494,7 +513,7 @@ mod tests { .collect(), ), ); - let set = Type::Set( + let set = Type( OriginIdx(5), TypeSet( [ @@ -506,7 +525,7 @@ mod tests { ), ); let single = Type::record(OriginIdx(0)); - let empty = Type::Set(OriginIdx(7), TypeSet(HashSet::new())); + let empty = Type(OriginIdx(7), TypeSet(HashSet::new())); // FIXME: Decide on empty's behavior