From d48af541c7aa235948621cdbc250d983af303977 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Sun, 21 Aug 2022 12:25:19 +0300 Subject: [PATCH] Sema: handle union and enum field order being different Closes #12543 --- src/Sema.zig | 33 ++++++++++++++++++--------------- src/codegen.zig | 2 +- src/codegen/c.zig | 3 +-- src/codegen/llvm.zig | 2 +- src/type.zig | 9 ++++++++- test/behavior/union.zig | 24 ++++++++++++++++++++++++ 6 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 5c66c8d6a1..28ff4cc1c2 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3615,8 +3615,6 @@ fn validateUnionInit( union_ptr: Air.Inst.Ref, is_comptime: bool, ) CompileError!void { - const union_obj = union_ty.cast(Type.Payload.Union).?.data; - if (instrs.len != 1) { const msg = msg: { const msg = try sema.errMsg( @@ -3650,7 +3648,8 @@ fn validateUnionInit( const field_src: LazySrcLoc = .{ .node_offset_initializer = field_ptr_data.src_node }; const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data; const field_name = sema.code.nullTerminatedString(field_ptr_extra.field_name_start); - const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_src); + // Validate the field access but ignore the index since we want the tag enum field index. + _ = try sema.unionFieldIndex(block, union_ty, field_name, field_src); const air_tags = sema.air_instructions.items(.tag); const air_datas = sema.air_instructions.items(.data); const field_ptr_air_ref = sema.inst_map.get(field_ptr).?; @@ -3709,7 +3708,9 @@ fn validateUnionInit( break; } - const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index); + const tag_ty = union_ty.unionTagTypeHypothetical(); + const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?); + const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index); if (init_val) |val| { // Our task is to delete all the `field_ptr` and `store` instructions, and insert @@ -3726,7 +3727,7 @@ fn validateUnionInit( } try sema.requireFunctionBlock(block, init_src); - const new_tag = try sema.addConstant(union_obj.tag_ty, tag_val); + const new_tag = try sema.addConstant(tag_ty, tag_val); _ = try block.addBinOp(.set_union_tag, union_ptr, new_tag); } @@ -8838,13 +8839,11 @@ fn zirSwitchCapture( switch (operand_ty.zigTypeTag()) { .Union => { const union_obj = operand_ty.cast(Type.Payload.Union).?.data; - const enum_ty = union_obj.tag_ty; - const first_item = try sema.resolveInst(items[0]); // Previous switch validation ensured this will succeed const first_item_val = sema.resolveConstValue(block, .unneeded, first_item, undefined) catch unreachable; - const first_field_index = @intCast(u32, enum_ty.enumTagFieldIndex(first_item_val, sema.mod).?); + const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, sema.mod).?); const first_field = union_obj.fields.values()[first_field_index]; for (items[1..]) |item, i| { @@ -8852,7 +8851,7 @@ fn zirSwitchCapture( // Previous switch validation ensured this will succeed const item_val = sema.resolveConstValue(block, .unneeded, item_ref, undefined) catch unreachable; - const field_index = enum_ty.enumTagFieldIndex(item_val, sema.mod).?; + const field_index = operand_ty.unionTagFieldIndex(item_val, sema.mod).?; const field = union_obj.fields.values()[field_index]; if (!field.ty.eql(first_field.ty, sema.mod)) { const msg = msg: { @@ -15585,7 +15584,9 @@ fn unionInit( const init = try sema.coerce(block, field.ty, uncasted_init, init_src); if (try sema.resolveMaybeUndefVal(block, init_src, init)) |init_val| { - const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index); + const tag_ty = union_ty.unionTagTypeHypothetical(); + const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?); + const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index); return sema.addConstant(union_ty, try Value.Tag.@"union".create(sema.arena, .{ .tag = tag_val, .val = init_val, @@ -15683,7 +15684,9 @@ fn zirStructInit( const field_type_extra = sema.code.extraData(Zir.Inst.FieldType, field_type_data.payload_index).data; const field_name = sema.code.nullTerminatedString(field_type_extra.name_start); const field_index = try sema.unionFieldIndex(block, resolved_ty, field_name, field_src); - const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index); + const tag_ty = resolved_ty.unionTagTypeHypothetical(); + const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?); + const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index); const init_inst = try sema.resolveInst(item.data.init); if (try sema.resolveMaybeUndefVal(block, field_src, init_inst)) |val| { @@ -16448,9 +16451,8 @@ fn zirReify(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData, in const type_info = try sema.coerce(block, type_info_ty, uncasted_operand, operand_src); const val = try sema.resolveConstValue(block, operand_src, type_info, "operand to @Type must be comptime known"); const union_val = val.cast(Value.Payload.Union).?.data; - const tag_ty = type_info_ty.unionTagType().?; const target = mod.getTarget(); - const tag_index = tag_ty.enumTagFieldIndex(union_val.tag, mod).?; + const tag_index = type_info_ty.unionTagFieldIndex(union_val.tag, mod).?; if (union_val.val.anyUndef()) return sema.failWithUseOfUndef(block, src); switch (@intToEnum(std.builtin.TypeId, tag_index)) { .Type => return Air.Inst.Ref.type_type, @@ -25155,8 +25157,7 @@ fn coerceEnumToUnion( const enum_tag = try sema.coerce(block, tag_ty, inst, inst_src); if (try sema.resolveDefinedValue(block, inst_src, enum_tag)) |val| { - const union_obj = union_ty.cast(Type.Payload.Union).?.data; - const field_index = union_obj.tag_ty.enumTagFieldIndex(val, sema.mod) orelse { + const field_index = union_ty.unionTagFieldIndex(val, sema.mod) orelse { const msg = msg: { const msg = try sema.errMsg(block, inst_src, "union '{}' has no tag with value '{}'", .{ union_ty.fmt(sema.mod), val.fmtValue(tag_ty, sema.mod), @@ -25167,6 +25168,8 @@ fn coerceEnumToUnion( }; return sema.failWithOwnedErrorMsg(msg); }; + + const union_obj = union_ty.cast(Type.Payload.Union).?.data; const field = union_obj.fields.values()[field_index]; const field_ty = try sema.resolveTypeFields(block, inst_src, field.ty); if (field_ty.zigTypeTag() == .NoReturn) { diff --git a/src/codegen.zig b/src/codegen.zig index 025decdb4b..f5340458a5 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -607,7 +607,7 @@ pub fn generateSymbol( const union_ty = typed_value.ty.cast(Type.Payload.Union).?.data; const mod = bin_file.options.module.?; - const field_index = union_ty.tag_ty.enumTagFieldIndex(union_obj.tag, mod).?; + const field_index = typed_value.ty.unionTagFieldIndex(union_obj.tag, mod).?; assert(union_ty.haveFieldTypes()); const field_ty = union_ty.fields.values()[field_index].ty; if (!field_ty.hasRuntimeBits()) { diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 81a892183f..4a09c09cc9 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -835,7 +835,6 @@ pub const DeclGen = struct { }, .Union => { const union_obj = val.castTag(.@"union").?.data; - const union_ty = ty.cast(Type.Payload.Union).?.data; const layout = ty.unionGetLayout(target); try writer.writeAll("("); @@ -851,7 +850,7 @@ pub const DeclGen = struct { try writer.writeAll(".payload = {"); } - const index = union_ty.tag_ty.enumTagFieldIndex(union_obj.tag, dg.module).?; + const index = ty.unionTagFieldIndex(union_obj.tag, dg.module).?; const field_ty = ty.unionFields().values()[index].ty; const field_name = ty.unionFields().keys()[index]; if (field_ty.hasRuntimeBits()) { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index d50b463606..5c537cd5bc 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -3502,7 +3502,7 @@ pub const DeclGen = struct { }); } const union_obj = tv.ty.cast(Type.Payload.Union).?.data; - const field_index = union_obj.tag_ty.enumTagFieldIndex(tag_and_val.tag, dg.module).?; + const field_index = tv.ty.unionTagFieldIndex(tag_and_val.tag, dg.module).?; assert(union_obj.haveFieldTypes()); // Sometimes we must make an unnamed struct because LLVM does diff --git a/src/type.zig b/src/type.zig index f6afa33df1..d2885f537f 100644 --- a/src/type.zig +++ b/src/type.zig @@ -4285,11 +4285,18 @@ pub const Type = extern union { pub fn unionFieldType(ty: Type, enum_tag: Value, mod: *Module) Type { const union_obj = ty.cast(Payload.Union).?.data; - const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag, mod).?; + const index = ty.unionTagFieldIndex(enum_tag, mod).?; assert(union_obj.haveFieldTypes()); return union_obj.fields.values()[index].ty; } + pub fn unionTagFieldIndex(ty: Type, enum_tag: Value, mod: *Module) ?usize { + const union_obj = ty.cast(Payload.Union).?.data; + const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag, mod) orelse return null; + const name = union_obj.tag_ty.enumFieldName(index); + return union_obj.fields.getIndex(name); + } + pub fn unionHasAllZeroBitFieldTypes(ty: Type) bool { return ty.cast(Payload.Union).?.data.hasAllZeroBitFieldTypes(); } diff --git a/test/behavior/union.zig b/test/behavior/union.zig index efca75af30..92f277b946 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -1301,3 +1301,27 @@ test "noreturn field in union" { } try expect(count == 5); } + +test "union and enum field order doesn't match" { + if (builtin.zig_backend == .stage1) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + + const MyTag = enum(u32) { + b = 1337, + a = 1666, + }; + const MyUnion = union(MyTag) { + a: f32, + b: void, + }; + var x: MyUnion = .{ .a = 666 }; + switch (x) { + .a => |my_f32| { + try expect(@TypeOf(my_f32) == f32); + }, + .b => unreachable, + } + x = .b; + try expect(x == .b); +}