diff --git a/src/all_types.hpp b/src/all_types.hpp index 38964d0091..48323e58ad 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -312,6 +312,7 @@ enum LazyValueId { LazyValueIdOptType, LazyValueIdSliceType, LazyValueIdFnType, + LazyValueIdErrUnionType, }; struct LazyValue { @@ -372,6 +373,14 @@ struct LazyValueFnType { bool is_generic; }; +struct LazyValueErrUnionType { + LazyValue base; + + IrAnalyze *ira; + IrInstruction *err_set_type; + IrInstruction *payload_type; +}; + struct ConstExprValue { ZigType *type; ConstValSpecial special; diff --git a/src/analyze.cpp b/src/analyze.cpp index 965bd57e02..9ae7e99547 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1015,6 +1015,7 @@ static Error type_val_resolve_zero_bits(CodeGen *g, ConstExprValue *type_val, Zi } case LazyValueIdOptType: case LazyValueIdSliceType: + case LazyValueIdErrUnionType: *is_zero_bits = false; return ErrorNone; case LazyValueIdFnType: { @@ -1040,6 +1041,7 @@ Error type_val_resolve_is_opaque_type(CodeGen *g, ConstExprValue *type_val, bool case LazyValueIdPtrType: case LazyValueIdFnType: case LazyValueIdOptType: + case LazyValueIdErrUnionType: *is_opaque_type = false; return ErrorNone; } @@ -1094,6 +1096,11 @@ static ReqCompTime type_val_resolve_requires_comptime(CodeGen *g, ConstExprValue } return ReqCompTimeNo; } + case LazyValueIdErrUnionType: { + LazyValueErrUnionType *lazy_err_union_type = + reinterpret_cast(type_val->data.x_lazy); + return type_val_resolve_requires_comptime(g, &lazy_err_union_type->payload_type->value); + } } zig_unreachable(); } @@ -1102,10 +1109,8 @@ static Error type_val_resolve_abi_size(CodeGen *g, AstNode *source_node, ConstEx size_t *abi_size, size_t *size_in_bits) { Error err; - if (type_val->data.x_lazy->id == LazyValueIdOptType) { - if ((err = ir_resolve_lazy(g, source_node, type_val))) - return err; - } + +start_over: if (type_val->special != ConstValSpecialLazy) { assert(type_val->special == ConstValSpecialStatic); ZigType *ty = type_val->data.x_type; @@ -1129,7 +1134,10 @@ static Error type_val_resolve_abi_size(CodeGen *g, AstNode *source_node, ConstEx *size_in_bits = g->builtin_types.entry_usize->size_in_bits; return ErrorNone; case LazyValueIdOptType: - zig_unreachable(); + case LazyValueIdErrUnionType: + if ((err = ir_resolve_lazy(g, source_node, type_val))) + return err; + goto start_over; } zig_unreachable(); } @@ -1161,6 +1169,19 @@ Error type_val_resolve_abi_align(CodeGen *g, ConstExprValue *type_val, uint32_t LazyValueOptType *lazy_opt_type = reinterpret_cast(type_val->data.x_lazy); return type_val_resolve_abi_align(g, &lazy_opt_type->payload_type->value, abi_align); } + case LazyValueIdErrUnionType: { + LazyValueErrUnionType *lazy_err_union_type = + reinterpret_cast(type_val->data.x_lazy); + uint32_t payload_abi_align; + if ((err = type_val_resolve_abi_align(g, &lazy_err_union_type->payload_type->value, + &payload_abi_align))) + { + return err; + } + *abi_align = (payload_abi_align > g->err_tag_type->abi_align) ? + payload_abi_align : g->err_tag_type->abi_align; + return ErrorNone; + } } zig_unreachable(); } @@ -1189,6 +1210,18 @@ static OnePossibleValue type_val_resolve_has_one_possible_value(CodeGen *g, Cons return OnePossibleValueNo; } } + case LazyValueIdErrUnionType: { + LazyValueErrUnionType *lazy_err_union_type = + reinterpret_cast(type_val->data.x_lazy); + switch (type_val_resolve_has_one_possible_value(g, &lazy_err_union_type->err_set_type->value)) { + case OnePossibleValueInvalid: + return OnePossibleValueInvalid; + case OnePossibleValueNo: + return OnePossibleValueNo; + case OnePossibleValueYes: + return type_val_resolve_has_one_possible_value(g, &lazy_err_union_type->payload_type->value); + } + } } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 15fa4ccbe1..52cf69de82 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -14626,28 +14626,23 @@ static IrInstruction *ir_analyze_instruction_error_return_trace(IrAnalyze *ira, static IrInstruction *ir_analyze_instruction_error_union(IrAnalyze *ira, IrInstructionErrorUnion *instruction) { - Error err; + IrInstruction *result = ir_const(ira, &instruction->base, ira->codegen->builtin_types.entry_type); + result->value.special = ConstValSpecialLazy; - ZigType *err_set_type = ir_resolve_type(ira, instruction->err_set->child); - if (type_is_invalid(err_set_type)) + LazyValueErrUnionType *lazy_err_union_type = allocate(1); + lazy_err_union_type->ira = ira; + result->value.data.x_lazy = &lazy_err_union_type->base; + lazy_err_union_type->base.id = LazyValueIdErrUnionType; + + lazy_err_union_type->err_set_type = instruction->err_set->child; + if (ir_resolve_type_lazy(ira, lazy_err_union_type->err_set_type) == nullptr) return ira->codegen->invalid_instruction; - ZigType *payload_type = ir_resolve_type(ira, instruction->payload->child); - if (type_is_invalid(payload_type)) + lazy_err_union_type->payload_type = instruction->payload->child; + if (ir_resolve_type_lazy(ira, lazy_err_union_type->payload_type) == nullptr) return ira->codegen->invalid_instruction; - if (err_set_type->id != ZigTypeIdErrorSet) { - ir_add_error(ira, instruction->err_set->child, - buf_sprintf("expected error set type, found type '%s'", - buf_ptr(&err_set_type->name))); - return ira->codegen->invalid_instruction; - } - - if ((err = type_resolve(ira->codegen, payload_type, ResolveStatusSizeKnown))) - return ira->codegen->invalid_instruction; - ZigType *result_type = get_error_union_type(ira->codegen, err_set_type, payload_type); - - return ir_const_type(ira, &instruction->base, result_type); + return result; } static IrInstruction *ir_analyze_alloca(IrAnalyze *ira, IrInstruction *source_inst, ZigType *var_type, @@ -25698,6 +25693,34 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) { val->data.x_type = fn_type; return ErrorNone; } + case LazyValueIdErrUnionType: { + LazyValueErrUnionType *lazy_err_union_type = + reinterpret_cast(val->data.x_lazy); + IrAnalyze *ira = lazy_err_union_type->ira; + + ZigType *err_set_type = ir_resolve_type(ira, lazy_err_union_type->err_set_type); + if (type_is_invalid(err_set_type)) + return ErrorSemanticAnalyzeFail; + + ZigType *payload_type = ir_resolve_type(ira, lazy_err_union_type->payload_type); + if (type_is_invalid(payload_type)) + return ErrorSemanticAnalyzeFail; + + if (err_set_type->id != ZigTypeIdErrorSet) { + ir_add_error(ira, lazy_err_union_type->err_set_type, + buf_sprintf("expected error set type, found type '%s'", + buf_ptr(&err_set_type->name))); + return ErrorSemanticAnalyzeFail; + } + + if ((err = type_resolve(ira->codegen, payload_type, ResolveStatusSizeKnown))) + return ErrorSemanticAnalyzeFail; + + assert(val->type->id == ZigTypeIdMetaType); + val->data.x_type = get_error_union_type(ira->codegen, err_set_type, payload_type); + val->special = ConstValSpecialStatic; + return ErrorNone; + } } zig_unreachable(); } diff --git a/test/stage1/behavior/error.zig b/test/stage1/behavior/error.zig index 264f140c9d..fefd95a850 100644 --- a/test/stage1/behavior/error.zig +++ b/test/stage1/behavior/error.zig @@ -375,3 +375,23 @@ test "implicit cast to optional to error union to return result loc" { S.entry(); //comptime S.entry(); TODO } + +test "function pointer with return type that is error union with payload which is pointer of parent struct" { + const S = struct { + const Foo = struct { + fun: fn (a: i32) (anyerror!*Foo), + }; + + const Err = error{UnspecifiedErr}; + + fn bar(a: i32) anyerror!*Foo { + return Err.UnspecifiedErr; + } + + fn doTheTest() void { + var x = Foo{ .fun = bar }; + expectError(error.UnspecifiedErr, x.fun(1)); + } + }; + S.doTheTest(); +}