Skip to content

Commit

Permalink
typecheck: typecheck function calls and arguments properly
Browse files Browse the repository at this point in the history
  • Loading branch information
CohenArthur committed Oct 29, 2023
1 parent 5db3313 commit ff875b9
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 23 deletions.
127 changes: 108 additions & 19 deletions typecheck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ use error::{ErrKind, Error};
use fir::{Fallible, Fir, Node, RefIdx, Traversal};
use flatten::FlattenData;
use location::SpanTuple;
use symbol::Symbol;

use colored::Colorize;

use crate::{Type, TypeCtx};

Expand All @@ -24,33 +21,66 @@ impl<'ctx> Checker<'ctx> {
}
}

fn type_mismatch(
loc: &SpanTuple,
fir: &Fir<FlattenData>,
expected: Option<Type>,
got: Option<Type>,
) -> Error {
mod format {
use colored::Colorize;
use symbol::Symbol;

pub fn number(value: usize) -> String {
match value {
0 => format!("{}", "no".purple()),
rest => format!("{}", rest.to_string().purple()),
}
}

pub fn plural(to_pluralize: &str, value: usize) -> String {
match value {
1 => to_pluralize.to_string(),
_ => format!("{to_pluralize}s"),
}
}

pub fn ty(ty: Option<&Symbol>) -> String {
match ty {
Some(ty) => format!("`{}`", ty.access().purple()),
None => format!("{}", "no type".green()),
}
}
}

struct Expected(Option<Type>);
struct Got(Option<Type>);

fn type_mismatch(loc: &SpanTuple, fir: &Fir<FlattenData>, expected: Expected, got: Got) -> Error {
let get_symbol = |ty| {
let Type::One(idx) = ty;
fir.nodes[&idx.expect_resolved()].data.ast.symbol().unwrap()
};
let name_fmt = |ty: Option<&Symbol>| match ty {
Some(ty) => format!("`{}`", ty.access().purple()),
None => format!("{}", "no type".green()),
};

let expected_ty = expected.map(get_symbol);
let got_ty = got.map(get_symbol);
let expected_ty = expected.0.map(get_symbol);
let got_ty = got.0.map(get_symbol);

Error::new(ErrKind::TypeChecker)
.with_msg(format!(
"type mismatch found: expected {}, got {}",
name_fmt(expected_ty),
name_fmt(got_ty)
format::ty(expected_ty),
format::ty(got_ty)
))
.with_loc(loc.clone()) // FIXME: Missing hint
}

fn argument_count_mismatch(loc: &SpanTuple, expected: &[RefIdx], got: &[RefIdx]) -> Error {
Error::new(ErrKind::TypeChecker)
.with_msg(format!(
"argument count mismatch: expected {} {}, got {} {}",
format::number(expected.len()),
format::plural("argument", expected.len()),
format::number(got.len()),
format::plural("argument", got.len()),
))
.with_loc(loc.clone())
// FIXME: missing hint
}

impl<'ctx> Traversal<FlattenData<'_>, Error> for Checker<'ctx> {
fn traverse_function(
&mut self,
Expand All @@ -70,7 +100,12 @@ impl<'ctx> Traversal<FlattenData<'_>, Error> for Checker<'ctx> {
let block_ty = block.as_ref().and_then(|b| self.get_type(b));

if ret_ty != block_ty {
let err = type_mismatch(node.data.ast.location(), fir, ret_ty, block_ty);
let err = type_mismatch(
node.data.ast.location(),
fir,
Expected(ret_ty),
Got(block_ty),
);
let err = match (ret_ty, block_ty) {
(None, Some(_)) => err
.with_hint(Error::hint().with_msg(String::from(
Expand All @@ -95,6 +130,55 @@ impl<'ctx> Traversal<FlattenData<'_>, Error> for Checker<'ctx> {
}
}

fn traverse_call(
&mut self,
fir: &Fir<FlattenData<'_>>,
node: &Node<FlattenData<'_>>,
to: &RefIdx,
_generics: &[RefIdx],
args: &[RefIdx],
) -> Fallible<Error> {
let function = &fir.nodes[&to.expect_resolved()];
let def_args = match &function.kind {
fir::Kind::Function { args, .. } => args,
_ => unreachable!("resolved call to a non-function. this is an interpreter error."),
};

if def_args.len() != args.len() {
return Err(argument_count_mismatch(
node.data.ast.location(),
def_args,
args,
));
}

// now we can safely zip both argument slices
let errs = def_args
.iter()
.zip(args)
.fold(Vec::new(), |mut errs, (def_arg, arg)| {
let expected = self.get_type(def_arg);
let got = self.get_type(arg);

if expected != got {
errs.push(type_mismatch(
node.data.ast.location(),
fir,
Expected(expected),
Got(got),
))
}

errs
});

if errs.is_empty() {
Ok(())
} else {
Err(Error::new(ErrKind::Multiple(errs)))
}
}

fn traverse_assignment(
&mut self,
fir: &Fir<FlattenData>,
Expand All @@ -106,7 +190,12 @@ impl<'ctx> Traversal<FlattenData<'_>, Error> for Checker<'ctx> {
let from = self.get_type(from);

if to != from {
Err(type_mismatch(node.data.ast.location(), fir, to, from))
Err(type_mismatch(
node.data.ast.location(),
fir,
Expected(to),
Got(from),
))
} else {
Ok(())
}
Expand Down
75 changes: 75 additions & 0 deletions typecheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,79 @@ mod tests {

assert!(fir.is_ok())
}

#[test]
fn typeck_function_call_argument_count_mismatch() {
let ast = ast! {
func foo(one: int, two: int) -> int { one }

foo(15)
};

let fir = fir!(ast).type_check();

assert!(fir.is_err());
}

#[test]
fn typeck_function_call_argument_count_match() {
let ast = ast! {
func foo(one: int, two: int) -> int { one }

foo(15, 14)
};

let fir = fir!(ast).type_check();

assert!(fir.is_ok());
}

#[test]
fn typeck_method_call() {
let ast = ast! {
func foo(one: string, two: int) -> int { two }

"hoo".foo(15)
};

let fir = fir!(ast).type_check();

assert!(fir.is_ok());
}

#[test]
fn typeck_method_call2() {
let ast = ast! {
func foo(one: string, two: int, three: char) -> int { two }

"hoo".foo(15, 14)
};

let fir = fir!(ast).type_check();

assert!(fir.is_err());
}

#[test]
fn typeck_call_complex_arg() {
let ast = ast! {
type Marker;

func take_marker(m: Marker) {}
func get_marker() -> Marker { Marker }

take_marker(Marker);

where m = Marker;
take_marker(m);

take_marker(get_marker());

get_marker().take_marker();
};

let fir = fir!(ast).type_check();

assert!(fir.is_ok())
}
}
8 changes: 4 additions & 4 deletions typecheck/src/typer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ impl<'ast> Mapper<FlattenData<'ast>, FlattenData<'ast>, Error> for Typer<'_> {
match node.kind {
fir::Kind::Constant(c) => self.map_constant(node.data, node.origin, c),
// Declarations and assignments are void
fir::Kind::Type { .. }
| fir::Kind::Function { .. }
| fir::Kind::Binding { .. }
| fir::Kind::Assignment { .. } => self.ty(node, None),
fir::Kind::Type { .. } | fir::Kind::Function { .. } | fir::Kind::Assignment { .. } => {
self.ty(node, None)
}
// // FIXME: This might be the wrong way to go about this
// // special case where we want to change the `ty` of a `TypedValue`
// fir::Kind::TypedValue {
Expand All @@ -105,6 +104,7 @@ impl<'ast> Mapper<FlattenData<'ast>, FlattenData<'ast>, Error> for Typer<'_> {
ty: RefIdx::Unresolved,
value: ty,
}
| fir::Kind::Binding { to: ty }
| fir::Kind::TypedValue { ty, .. }
| fir::Kind::Instantiation { to: ty, .. }
| fir::Kind::Call { to: ty, .. }
Expand Down

0 comments on commit ff875b9

Please sign in to comment.