diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index a49b1cc391cf84..3732043ec09277 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1647,7 +1647,7 @@ fn (mut c Checker) selector_expr(mut node ast.SelectorExpr) ast.Type { node.typ = field.typ if node.or_block.kind == .block { c.expected_or_type = node.typ.clear_option_and_result() - c.stmts_ending_with_expression(mut node.or_block.stmts) + c.stmts_ending_with_expression(mut node.or_block.stmts, c.expected_or_type) c.check_or_expr(node.or_block, node.typ, c.expected_or_type, node) c.expected_or_type = ast.void_type } @@ -2593,7 +2593,7 @@ fn (mut c Checker) import_stmt(node ast.Import) { fn (mut c Checker) stmts(mut stmts []ast.Stmt) { old_stmt_level := c.stmt_level c.stmt_level = 0 - c.stmts_ending_with_expression(mut stmts) + c.stmts_ending_with_expression(mut stmts, c.expected_or_type) c.stmt_level = old_stmt_level } @@ -2602,7 +2602,7 @@ fn (mut c Checker) stmts(mut stmts []ast.Stmt) { // `x := opt() or { stmt1 stmt2 ExprStmt }`, // `x := if cond { stmt1 stmt2 ExprStmt } else { stmt2 stmt3 ExprStmt }`, // `x := match expr { Type1 { stmt1 stmt2 ExprStmt } else { stmt2 stmt3 ExprStmt }`. -fn (mut c Checker) stmts_ending_with_expression(mut stmts []ast.Stmt) { +fn (mut c Checker) stmts_ending_with_expression(mut stmts []ast.Stmt, expected_or_type ast.Type) { if stmts.len == 0 { c.scope_returns = false return @@ -2623,7 +2623,10 @@ fn (mut c Checker) stmts_ending_with_expression(mut stmts []ast.Stmt) { unreachable = stmt.pos } } + prev_expected_or_type := c.expected_or_type + c.expected_or_type = expected_or_type c.stmt(mut stmt) + c.expected_or_type = prev_expected_or_type if !c.inside_anon_fn && c.in_for_count > 0 && stmt is ast.BranchStmt && stmt.kind in [.key_continue, .key_break] { c.scope_returns = true @@ -3649,7 +3652,7 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { } unwrapped_typ := typ.clear_option_and_result() c.expected_or_type = unwrapped_typ - c.stmts_ending_with_expression(mut node.or_expr.stmts) + c.stmts_ending_with_expression(mut node.or_expr.stmts, c.expected_or_type) c.check_or_expr(node.or_expr, typ, c.expected_or_type, node) return unwrapped_typ } @@ -3759,7 +3762,7 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { } unwrapped_typ := typ.clear_option_and_result() c.expected_or_type = unwrapped_typ - c.stmts_ending_with_expression(mut node.or_expr.stmts) + c.stmts_ending_with_expression(mut node.or_expr.stmts, c.expected_or_type) c.check_or_expr(node.or_expr, typ, c.expected_or_type, node) return unwrapped_typ } @@ -3821,7 +3824,7 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { if node.or_expr.kind != .absent { unwrapped_typ := typ.clear_option_and_result() c.expected_or_type = unwrapped_typ - c.stmts_ending_with_expression(mut node.or_expr.stmts) + c.stmts_ending_with_expression(mut node.or_expr.stmts, c.expected_or_type) c.check_or_expr(node.or_expr, typ, c.expected_or_type, node) } return typ @@ -4448,7 +4451,7 @@ fn (mut c Checker) prefix_expr(mut node ast.PrefixExpr) ast.Type { if node.op == .arrow { raw_right_sym := c.table.final_sym(right_type) if raw_right_sym.kind == .chan { - c.stmts_ending_with_expression(mut node.or_block.stmts) + c.stmts_ending_with_expression(mut node.or_block.stmts, c.expected_or_type) return raw_right_sym.chan_info().elem_type } c.type_error_for_operator('<-', '`chan`', raw_right_sym.name, node.pos) @@ -4645,7 +4648,7 @@ fn (mut c Checker) index_expr(mut node ast.IndexExpr) ast.Type { if node.or_expr.stmts.len > 0 && node.or_expr.stmts.last() is ast.ExprStmt { c.expected_or_type = typ } - c.stmts_ending_with_expression(mut node.or_expr.stmts) + c.stmts_ending_with_expression(mut node.or_expr.stmts, c.expected_or_type) c.check_expr_option_or_result_call(node, typ) return typ } diff --git a/vlib/v/checker/comptime.v b/vlib/v/checker/comptime.v index 2c7b3e8aee7844..c4de20371b57a5 100644 --- a/vlib/v/checker/comptime.v +++ b/vlib/v/checker/comptime.v @@ -123,7 +123,7 @@ fn (mut c Checker) comptime_call(mut node ast.ComptimeCall) ast.Type { // check each arg expression node.args[i].typ = c.expr(mut arg.expr) } - c.stmts_ending_with_expression(mut node.or_block.stmts) + c.stmts_ending_with_expression(mut node.or_block.stmts, c.expected_or_type) return c.comptime.get_comptime_var_type(node) } if node.method_name == 'res' { diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 20eff3b01d7bf0..aa59cb5619c10c 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -600,7 +600,7 @@ fn (mut c Checker) call_expr(mut node ast.CallExpr) ast.Type { } } c.expected_or_type = node.return_type.clear_flag(.result) - c.stmts_ending_with_expression(mut node.or_block.stmts) + c.stmts_ending_with_expression(mut node.or_block.stmts, c.expected_or_type) c.expected_or_type = ast.void_type if !c.inside_const && c.table.cur_fn != unsafe { nil } && !c.table.cur_fn.is_main diff --git a/vlib/v/checker/if.v b/vlib/v/checker/if.v index bf716f0bfac61d..88fd2427e12ee6 100644 --- a/vlib/v/checker/if.v +++ b/vlib/v/checker/if.v @@ -324,7 +324,7 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { } if !c.skip_flags { if node_is_expr { - c.stmts_ending_with_expression(mut branch.stmts) + c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type) } else { c.stmts(mut branch.stmts) c.check_non_expr_branch_last_stmt(branch.stmts) @@ -341,7 +341,7 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { node.branches[i].stmts = [] } if node_is_expr { - c.stmts_ending_with_expression(mut branch.stmts) + c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type) } else { c.stmts(mut branch.stmts) c.check_non_expr_branch_last_stmt(branch.stmts) @@ -364,7 +364,7 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type { // smartcast sumtypes and interfaces when using `is` c.smartcast_if_conds(mut branch.cond, mut branch.scope) if node_is_expr { - c.stmts_ending_with_expression(mut branch.stmts) + c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type) } else { c.stmts(mut branch.stmts) c.check_non_expr_branch_last_stmt(branch.stmts) diff --git a/vlib/v/checker/infix.v b/vlib/v/checker/infix.v index 36bbab2c580433..cebf2591833473 100644 --- a/vlib/v/checker/infix.v +++ b/vlib/v/checker/infix.v @@ -757,7 +757,7 @@ fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { c.error('cannot push `${c.table.type_to_str(right_type)}` on `${left_sym.name}`', right_pos) } - c.stmts_ending_with_expression(mut node.or_block.stmts) + c.stmts_ending_with_expression(mut node.or_block.stmts, c.expected_or_type) } else { c.error('cannot push on non-channel `${left_sym.name}`', left_pos) } diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index 9e33f7e95eb310..d65fbf633a656d 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -43,7 +43,7 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { mut nbranches_without_return := 0 for mut branch in node.branches { if node.is_expr { - c.stmts_ending_with_expression(mut branch.stmts) + c.stmts_ending_with_expression(mut branch.stmts, c.expected_or_type) } else { c.stmts(mut branch.stmts) } diff --git a/vlib/v/checker/orm.v b/vlib/v/checker/orm.v index c67465e60e794e..33151bb027b4a3 100644 --- a/vlib/v/checker/orm.v +++ b/vlib/v/checker/orm.v @@ -626,7 +626,7 @@ fn (mut c Checker) check_orm_or_expr(mut expr ORMExpr) { if expr.or_expr.kind == .block { c.expected_or_type = return_type.clear_flag(.result) - c.stmts_ending_with_expression(mut expr.or_expr.stmts) + c.stmts_ending_with_expression(mut expr.or_expr.stmts, c.expected_or_type) c.expected_or_type = ast.void_type } } diff --git a/vlib/v/tests/result_call_or_block_with_stmts_test.v b/vlib/v/tests/result_call_or_block_with_stmts_test.v new file mode 100644 index 00000000000000..8a46bb6b07bdd4 --- /dev/null +++ b/vlib/v/tests/result_call_or_block_with_stmts_test.v @@ -0,0 +1,28 @@ +fn str_ret_fn(str string) string { + return str +} + +fn foo(str string) !string { + if str.contains('foo') { + return str + } + return error('error') +} + +fn test_result_call_or_block_with_stmts() { + var := foo('bar') or { + foo_var := str_ret_fn('foo') + if foo_var == 'foo' { + foo(foo_var) or { + eprintln(err) + exit(1) + } + } else { + eprintln(err) + exit(1) + } + } + + println(var) + assert var == 'foo' +}