Skip to content

Commit

Permalink
checker: fix generic lambda type binding resolution (fix vlang#22109) (
Browse files Browse the repository at this point in the history
  • Loading branch information
felipensp authored Aug 26, 2024
1 parent caa0c46 commit 426205e
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 29 deletions.
9 changes: 9 additions & 0 deletions vlib/v/checker/check_types.v
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,15 @@ fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr) {
if param_sym.info.func.return_type.nr_muls() > 0 && typ.nr_muls() > 0 {
typ = typ.set_nr_muls(0)
}
// resolve lambda with generic return type
if arg.expr is ast.LambdaExpr && typ.has_flag(.generic) {
typ = c.comptime.resolve_generic_expr(arg.expr.expr, typ)
if typ.has_flag(.generic) {
lambda_ret_gt_name := c.table.type_to_str(typ)
idx := func.generic_names.index(lambda_ret_gt_name)
typ = node.concrete_types[idx]
}
}
}
}
} else if arg_sym.kind in [.struct_, .interface_, .sum_type] {
Expand Down
2 changes: 1 addition & 1 deletion vlib/v/checker/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ fn (mut c Checker) fn_decl(mut node ast.FnDecl) {
}
node.params = [node.params[0], ctx_param]
node.params << params[1..]
println('new params ${node.name}')
// println('new params ${node.name}')
// println(node.params)
}
// sym := c.table.sym(typ_veb_context)
Expand Down
5 changes: 5 additions & 0 deletions vlib/v/checker/return.v
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ fn (mut c Checker) return_stmt(mut node ast.Return) {
} else {
got_type_sym.name
}
// ignore generic casting expr on lambda in this phase
if c.inside_lambda && exp_type.has_flag(.generic)
&& node.exprs[expr_idxs[i]] is ast.CastExpr {
continue
}
c.error('cannot use `${got_type_name}` as ${c.error_type_name(exp_type)} in return argument',
pos)
}
Expand Down
30 changes: 30 additions & 0 deletions vlib/v/comptime/comptimeinfo.v
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,36 @@ fn (mut ct ComptimeInfo) comptime_get_kind_var(var ast.Ident) ?ast.ComptimeForKi
}
}

pub fn (mut ct ComptimeInfo) resolve_generic_expr(expr ast.Expr, default_typ ast.Type) ast.Type {
match expr {
ast.ParExpr {
return ct.resolve_generic_expr(expr.expr, default_typ)
}
ast.CastExpr {
return expr.typ
}
ast.InfixExpr {
if ct.is_comptime_var(expr.left) {
return ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr.left))
}
if ct.is_comptime_var(expr.right) {
return ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr.right))
}
return default_typ
}
ast.Ident {
return if ct.is_comptime_var(expr) {
ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr))
} else {
default_typ
}
}
else {
return default_typ
}
}
}

pub struct DummyResolver {
mut:
file &ast.File = unsafe { nil }
Expand Down
29 changes: 1 addition & 28 deletions vlib/v/gen/c/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1469,39 +1469,12 @@ fn (mut g Gen) resolve_receiver_name(node ast.CallExpr, unwrapped_rec_type ast.T
return receiver_type_name
}

fn (mut g Gen) resolve_generic_expr(expr ast.Expr, default_typ ast.Type) ast.Type {
match expr {
ast.ParExpr {
return g.resolve_generic_expr(expr.expr, default_typ)
}
ast.InfixExpr {
if g.comptime.is_comptime_var(expr.left) {
return g.unwrap_generic(g.comptime.get_comptime_var_type(expr.left))
}
if g.comptime.is_comptime_var(expr.right) {
return g.unwrap_generic(g.comptime.get_comptime_var_type(expr.right))
}
return default_typ
}
ast.Ident {
return if g.comptime.is_comptime_var(expr) {
g.unwrap_generic(g.comptime.get_comptime_var_type(expr))
} else {
default_typ
}
}
else {
return default_typ
}
}
}

fn (mut g Gen) resolve_receiver_type(node ast.CallExpr) (ast.Type, &ast.TypeSymbol) {
left_type := g.unwrap_generic(node.left_type)
mut unwrapped_rec_type := node.receiver_type
if g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0 { // in generic fn
unwrapped_rec_type = g.unwrap_generic(node.receiver_type)
unwrapped_rec_type = g.resolve_generic_expr(node.left, unwrapped_rec_type)
unwrapped_rec_type = g.comptime.resolve_generic_expr(node.left, unwrapped_rec_type)
} else { // in non-generic fn
sym := g.table.sym(node.receiver_type)
match sym.info {
Expand Down
15 changes: 15 additions & 0 deletions vlib/v/tests/generic_lambda_expr_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn mymap[T, R](input []T, f fn (T) R) []R {
mut results := []R{cap: input.len}
for x in input {
results << f(x)
}
return results
}

fn test_main() {
assert dump(mymap([1, 2, 3, 4, 5], fn (i int) int {
return i * i
})) == [1, 4, 9, 16, 25]
assert dump(mymap([1, 2, 3, 4, 5], |x| x * x)) == [1, 4, 9, 16, 25]
assert dump(mymap([1, 2, 3, 4, 5], |x| u16(x * x))) == [u16(1), 4, 9, 16, 25]
}

0 comments on commit 426205e

Please sign in to comment.