From 6ce84409875e92d69da1fd88409e2f00d4adee9f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 17 Jun 2021 22:40:07 -0700 Subject: [PATCH] AstGen: properly generate errdefer expressions when returning `return` statements use a new function `nodeMayEvalToError` which does some basic checks on the AST node to return never, always, or maybe. Depending on this result, AstGen skips the errdefers, always includes the errdefers, or emits a conditional branch to check whether the return value is an error that Sema will have to evaluate. Closes #8821 Unblocks #9047 --- src/AstGen.zig | 272 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 259 insertions(+), 13 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 54500675df..bba82f24a6 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -6005,22 +6005,55 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref if (gz.in_defer) return astgen.failNode(node, "cannot return from defer expression", .{}); const operand_node = node_datas[node].lhs; - if (operand_node != 0) { - const rl: ResultLoc = if (nodeMayNeedMemoryLocation(tree, operand_node)) .{ - .ptr = try gz.addNodeExtended(.ret_ptr, node), - } else .{ - .ty = try gz.addNodeExtended(.ret_type, node), - }; - const operand = try expr(gz, scope, rl, operand_node); - // TODO check operand to see if we need to generate errdefers + if (operand_node == 0) { + // Returning a void value; skip error defers. try genDefers(gz, &astgen.fn_block.?.base, scope, .none); - _ = try gz.addUnNode(.ret_node, operand, node); + _ = try gz.addUnNode(.ret_node, .void_value, node); return Zir.Inst.Ref.unreachable_value; } - // Returning a void value; skip error defers. - try genDefers(gz, &astgen.fn_block.?.base, scope, .none); - _ = try gz.addUnNode(.ret_node, .void_value, node); - return Zir.Inst.Ref.unreachable_value; + + const rl: ResultLoc = if (nodeMayNeedMemoryLocation(tree, operand_node)) .{ + .ptr = try gz.addNodeExtended(.ret_ptr, node), + } else .{ + .ty = try gz.addNodeExtended(.ret_type, node), + }; + const operand = try expr(gz, scope, rl, operand_node); + + switch (nodeMayEvalToError(tree, operand_node)) { + .never => { + // Returning a value that cannot be an error; skip error defers. + try genDefers(gz, &astgen.fn_block.?.base, scope, .none); + _ = try gz.addUnNode(.ret_node, operand, node); + return Zir.Inst.Ref.unreachable_value; + }, + .always => { + // Value is always an error. Emit both error defers and regular defers. + const err_code = try gz.addUnNode(.err_union_code, operand, node); + try genDefers(gz, &astgen.fn_block.?.base, scope, err_code); + _ = try gz.addUnNode(.ret_node, operand, node); + return Zir.Inst.Ref.unreachable_value; + }, + .maybe => { + // Emit conditional branch for generating errdefers. + const is_err = try gz.addUnNode(.is_err, operand, node); + const condbr = try gz.addCondBr(.condbr, node); + + var then_scope = gz.makeSubBlock(scope); + defer then_scope.instructions.deinit(astgen.gpa); + const err_code = try then_scope.addUnNode(.err_union_code, operand, node); + try genDefers(&then_scope, &astgen.fn_block.?.base, scope, err_code); + _ = try then_scope.addUnNode(.ret_node, operand, node); + + var else_scope = gz.makeSubBlock(scope); + defer else_scope.instructions.deinit(astgen.gpa); + try genDefers(&else_scope, &astgen.fn_block.?.base, scope, .none); + _ = try else_scope.addUnNode(.ret_node, operand, node); + + try setCondBrPayload(condbr, is_err, &then_scope, &else_scope); + + return Zir.Inst.Ref.unreachable_value; + }, + } } fn identifier( @@ -7608,6 +7641,219 @@ fn nodeMayNeedMemoryLocation(tree: *const ast.Tree, start_node: ast.Node.Index) } } +fn nodeMayEvalToError(tree: *const ast.Tree, start_node: ast.Node.Index) enum { never, always, maybe } { + const node_tags = tree.nodes.items(.tag); + const node_datas = tree.nodes.items(.data); + const main_tokens = tree.nodes.items(.main_token); + const token_tags = tree.tokens.items(.tag); + + var node = start_node; + while (true) { + switch (node_tags[node]) { + .root, + .@"usingnamespace", + .test_decl, + .switch_case, + .switch_case_one, + .container_field_init, + .container_field_align, + .container_field, + .asm_output, + .asm_input, + => unreachable, + + .error_value => return .always, + + .@"asm", + .asm_simple, + .identifier, + .field_access, + .deref, + .array_access, + .while_simple, + .while_cont, + .for_simple, + .if_simple, + .@"while", + .@"if", + .@"for", + .@"switch", + .switch_comma, + .call_one, + .call_one_comma, + .async_call_one, + .async_call_one_comma, + .call, + .call_comma, + .async_call, + .async_call_comma, + => return .maybe, + + .@"return", + .@"break", + .@"continue", + .bit_not, + .bool_not, + .global_var_decl, + .local_var_decl, + .simple_var_decl, + .aligned_var_decl, + .@"defer", + .@"errdefer", + .address_of, + .optional_type, + .negation, + .negation_wrap, + .@"resume", + .array_type, + .array_type_sentinel, + .ptr_type_aligned, + .ptr_type_sentinel, + .ptr_type, + .ptr_type_bit_range, + .@"suspend", + .@"anytype", + .fn_proto_simple, + .fn_proto_multi, + .fn_proto_one, + .fn_proto, + .fn_decl, + .anyframe_type, + .anyframe_literal, + .integer_literal, + .float_literal, + .enum_literal, + .string_literal, + .multiline_string_literal, + .char_literal, + .true_literal, + .false_literal, + .null_literal, + .undefined_literal, + .unreachable_literal, + .error_set_decl, + .container_decl, + .container_decl_trailing, + .container_decl_two, + .container_decl_two_trailing, + .container_decl_arg, + .container_decl_arg_trailing, + .tagged_union, + .tagged_union_trailing, + .tagged_union_two, + .tagged_union_two_trailing, + .tagged_union_enum_tag, + .tagged_union_enum_tag_trailing, + .add, + .add_wrap, + .array_cat, + .array_mult, + .assign, + .assign_bit_and, + .assign_bit_or, + .assign_bit_shift_left, + .assign_bit_shift_right, + .assign_bit_xor, + .assign_div, + .assign_sub, + .assign_sub_wrap, + .assign_mod, + .assign_add, + .assign_add_wrap, + .assign_mul, + .assign_mul_wrap, + .bang_equal, + .bit_and, + .bit_or, + .bit_shift_left, + .bit_shift_right, + .bit_xor, + .bool_and, + .bool_or, + .div, + .equal_equal, + .error_union, + .greater_or_equal, + .greater_than, + .less_or_equal, + .less_than, + .merge_error_sets, + .mod, + .mul, + .mul_wrap, + .switch_range, + .sub, + .sub_wrap, + .slice, + .slice_open, + .slice_sentinel, + .array_init_one, + .array_init_one_comma, + .array_init_dot_two, + .array_init_dot_two_comma, + .array_init_dot, + .array_init_dot_comma, + .array_init, + .array_init_comma, + .struct_init_one, + .struct_init_one_comma, + .struct_init_dot_two, + .struct_init_dot_two_comma, + .struct_init_dot, + .struct_init_dot_comma, + .struct_init, + .struct_init_comma, + => return .never, + + // Forward the question to the LHS sub-expression. + .grouped_expression, + .@"try", + .@"await", + .@"comptime", + .@"nosuspend", + .unwrap_optional, + => node = node_datas[node].lhs, + + // Forward the question to the RHS sub-expression. + .@"catch", + .@"orelse", + => node = node_datas[node].rhs, + + .block_two, + .block_two_semicolon, + .block, + .block_semicolon, + => { + const lbrace = main_tokens[node]; + if (token_tags[lbrace - 1] == .colon) { + // Labeled blocks may need a memory location to forward + // to their break statements. + return .maybe; + } else { + return .never; + } + }, + + .builtin_call, + .builtin_call_comma, + .builtin_call_two, + .builtin_call_two_comma, + => { + const builtin_token = main_tokens[node]; + const builtin_name = tree.tokenSlice(builtin_token); + // If the builtin is an invalid name, we don't cause an error here; instead + // let it pass, and the error will be "invalid builtin function" later. + const builtin_info = BuiltinFn.list.get(builtin_name) orelse return .maybe; + if (builtin_info.tag == .err_set_cast) { + return .always; + } else { + return .never; + } + }, + } + } +} + /// Applies `rl` semantics to `inst`. Expressions which do not do their own handling of /// result locations must call this function on their result. /// As an example, if the `ResultLoc` is `ptr`, it will write the result to the pointer.