diff --git a/src/all_types.hpp b/src/all_types.hpp index a9304fdd0f..0d187a1c4e 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1509,6 +1509,7 @@ enum BuiltinFnId { BuiltinFnIdAtomicRmw, BuiltinFnIdAtomicLoad, BuiltinFnIdHasDecl, + BuiltinFnIdUnionInit, }; struct BuiltinFnEntry { @@ -2359,6 +2360,7 @@ enum IrInstructionId { IrInstructionIdAllocaGen, IrInstructionIdEndExpr, IrInstructionIdPtrOfArrayToSlice, + IrInstructionIdUnionInitNamedField, }; struct IrInstruction { @@ -3603,6 +3605,14 @@ struct IrInstructionAssertNonNull { IrInstruction *target; }; +struct IrInstructionUnionInitNamedField { + IrInstruction base; + + IrInstruction *union_type; + IrInstruction *field_name; + IrInstruction *value; +}; + struct IrInstructionHasDecl { IrInstruction base; diff --git a/src/codegen.cpp b/src/codegen.cpp index da9a42c651..2399152546 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -5635,6 +5635,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, case IrInstructionIdRef: case IrInstructionIdBitCastSrc: case IrInstructionIdTestErrSrc: + case IrInstructionIdUnionInitNamedField: zig_unreachable(); case IrInstructionIdDeclVarGen: @@ -7419,6 +7420,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdFromBytes, "bytesToSlice", 2); create_builtin_fn(g, BuiltinFnIdThis, "This", 0); create_builtin_fn(g, BuiltinFnIdHasDecl, "hasDecl", 2); + create_builtin_fn(g, BuiltinFnIdUnionInit, "unionInit", 3); } static const char *bool_to_str(bool b) { diff --git a/src/ir.cpp b/src/ir.cpp index b9afe405b9..594f61982a 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1069,6 +1069,7 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionAssertNonNull *) return IrInstructionIdAssertNonNull; } +<<<<<<< HEAD static constexpr IrInstructionId ir_instruction_id(IrInstructionHasDecl *) { return IrInstructionIdHasDecl; } @@ -1089,6 +1090,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEndExpr *) { return IrInstructionIdEndExpr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInitNamedField *) { + return IrInstructionIdUnionInitNamedField; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -3324,6 +3329,21 @@ static IrInstruction *ir_build_check_runtime_scope(IrBuilder *irb, Scope *scope, return &instruction->base; } +static IrInstruction *ir_build_union_init_2(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_type_value, IrInstruction *field_name_expr, IrInstruction *value) { + IrInstructionUnionInit2 *instruction = ir_build_instruction(irb, scope, source_node); + instruction->union_type_value = union_type_value; + instruction->field_name_expr = field_name_expr; + instruction->value = value; + + ir_ref_instruction(union_type_value, irb->current_basic_block); + ir_ref_instruction(field_name_expr, irb->current_basic_block); + ir_ref_instruction(value, irb->current_basic_block); + + return &instruction->base; +} + + static IrInstruction *ir_build_vector_to_array(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *result_type, IrInstruction *vector, IrInstruction *result_loc) { @@ -5651,6 +5671,29 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo IrInstruction *has_decl = ir_build_has_decl(irb, scope, node, arg0_value, arg1_value); return ir_lval_wrap(irb, scope, has_decl, lval, result_loc); } + case BuiltinFnIdUnionInit: + { + + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + AstNode *arg2_node = node->data.fn_call_expr.params.at(2); + IrInstruction *arg2_value = ir_gen_node(irb, arg2_node, scope); + if (arg2_value == irb->codegen->invalid_instruction) + return arg2_value; + + IrInstruction *result = ir_build_union_init_2(irb, scope, node, arg0_value, arg1_value, arg2_value); + + // TODO: Not sure if we need ir_lval_wrap or not. + return result; + } } zig_unreachable(); } @@ -25326,6 +25369,35 @@ static IrInstruction *ir_analyze_instruction_bit_cast_src(IrAnalyze *ira, IrInst return instruction->result_loc_bit_cast->parent->gen_instruction; } +static IrInstruction *ir_analyze_instruction_union_init_2(IrAnalyze *ira, IrInstructionUnionInit2 *union_init_instruction) +{ + Error err; + IrInstruction *union_type_value = union_init_instruction->union_type_value->child; + ZigType *union_type = ir_resolve_type(ira, union_type_value); + if (type_is_invalid(union_type)) { + return ira->codegen->invalid_instruction; + } + + if (union_type->id != ZigTypeIdUnion) + return ira->codegen->invalid_instruction; + + if ((err = ensure_complete_type(ira->codegen, union_type))) + return ira->codegen->invalid_instruction; + + IrInstruction *field_name_expr = union_init_instruction->field_name_expr->child; + Buf *field_name = ir_resolve_str(ira, field_name_expr); + if (!field_name) + return ira->codegen->invalid_instruction; + + IrInstructionContainerInitFieldsField *fields = allocate(1); + + fields[0].name = field_name; + fields[0].value = union_init_instruction->value; + fields[0].source_node = union_init_instruction->base.source_node; + + return ir_analyze_container_init_fields_union(ira, &union_init_instruction->base, union_type, 1, fields); +} + static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) { switch (instruction->id) { case IrInstructionIdInvalid: @@ -25641,6 +25713,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_end_expr(ira, (IrInstructionEndExpr *)instruction); case IrInstructionIdBitCastSrc: return ir_analyze_instruction_bit_cast_src(ira, (IrInstructionBitCastSrc *)instruction); + case IrInstructionIdUnionInitNamedField: + return ir_analyze_instruction_union_init_named_field(ira, (IrInstructionUnionInitNamedField *)instruction); } zig_unreachable(); } @@ -25794,6 +25868,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCast: case IrInstructionIdContainerInitList: case IrInstructionIdContainerInitFields: + case IrInstructionIdUnionInitNamedField: case IrInstructionIdFieldPtr: case IrInstructionIdElemPtr: case IrInstructionIdVarPtr: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 31b7e608b7..3ecb7fb683 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1626,6 +1626,16 @@ static void ir_print_undeclared_ident(IrPrint *irp, IrInstructionUndeclaredIdent fprintf(irp->f, "@undeclaredIdent(%s)", buf_ptr(instruction->name)); } +static void ir_print_union_init_named_field(IrPrint *irp, IrInstructionUnionInitNamedField *instruction) { + fprintf(irp->f, "@unionInit("); + ir_print_other_instruction(irp, instruction->union_type); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->field_name); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->value); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -2132,6 +2142,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdEndExpr: ir_print_end_expr(irp, (IrInstructionEndExpr *)instruction); break; + case IrInstructionIdUnionInitNamedField: + ir_print_union_init_named_field(irp, (IrInstructionUnionInitNamedField *)instruction); + break; } fprintf(irp->f, "\n"); } diff --git a/test/stage1/behavior/union.zig b/test/stage1/behavior/union.zig index 410b7e9615..7d6a8154ea 100644 --- a/test/stage1/behavior/union.zig +++ b/test/stage1/behavior/union.zig @@ -416,9 +416,44 @@ test "return union init with void payload" { two: u32, }; fn func() Outer { - return Outer{ .state = State{ .one = {} }}; + return Outer{ .state = State{ .one = {} } }; } }; S.entry(); comptime S.entry(); } + +test "@unionInit can modify a union type" { + const UnionInitEnum = union(enum) { + Boolean: bool, + Byte: u8, + }; + + var value: UnionInitEnum = undefined; + + value = @unionInit(UnionInitEnum, "Boolean", true); + expect(value.Boolean == true); + value.Boolean = false; + expect(value.Boolean == false); + + value = @unionInit(UnionInitEnum, "Byte", 2); + expect(value.Byte == 2); + value.Byte = 3; + expect(value.Byte == 3); +} + +test "@unionInit can modify a pointer value" { + const UnionInitEnum = union(enum) { + Boolean: bool, + Byte: u8, + }; + + var value: UnionInitEnum = undefined; + var value_ptr = &value; + + value_ptr.* = @unionInit(UnionInitEnum, "Boolean", true); + expect(value.Boolean == true); + + value_ptr.* = @unionInit(UnionInitEnum, "Byte", 2); + expect(value.Byte == 2); +}