diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 191731fb95..9ac4883446 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -874,12 +874,12 @@ const DeclGen = struct { }, .un => |un| { const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?; - const layout = self.unionLayout(ty, active_field); - const payload = if (layout.active_field_size != 0) - try self.constant(layout.active_field_ty, un.val.toValue(), .direct) + const union_obj = mod.typeToUnion(ty).?; + const field_ty = union_obj.field_types.get(ip)[active_field].toType(); + const payload = if (field_ty.hasRuntimeBitsIgnoreComptime(mod)) + try self.constant(field_ty, un.val.toValue(), .direct) else null; - return try self.unionInit(ty, active_field, payload); }, .memoized_call => unreachable, @@ -1105,29 +1105,25 @@ const DeclGen = struct { return try self.spv.ptrType(child_ty_ref, storage_class); } - /// Generate a union type, optionally with a known field. If the tag alignment is greater - /// than that of the payload, a regular union (non-packed, with both tag and payload), will - /// be generated as follows: - /// If the active field is known: + /// Generate a union type. Union types are always generated with the + /// most aligned field active. If the tag alignment is greater + /// than that of the payload, a regular union (non-packed, with both tag and + /// payload), will be generated as follows: /// struct { /// tag: TagType, - /// payload: ActivePayloadType, - /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// payload: MostAlignedFieldType, + /// payload_padding: [payload_size - @sizeOf(MostAlignedFieldType)]u8, /// padding: [padding_size]u8, /// } /// If the payload alignment is greater than that of the tag: /// struct { - /// payload: ActivePayloadType, - /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// payload: MostAlignedFieldType, + /// payload_padding: [payload_size - @sizeOf(MostAlignedFieldType)]u8, /// tag: TagType, /// padding: [padding_size]u8, /// } - /// If the active payload is unknown, it will default back to the most aligned field. This is - /// to make sure that the overal struct has the correct alignment in spir-v. /// If any of the fields' size is 0, it will be omitted. - /// NOTE: When the active field is set to something other than the most aligned field, the - /// resulting struct will be *underaligned*. - fn resolveUnionType(self: *DeclGen, ty: Type, maybe_active_field: ?usize) !CacheRef { + fn resolveUnionType(self: *DeclGen, ty: Type) !CacheRef { const mod = self.module; const ip = &mod.intern_pool; const union_obj = mod.typeToUnion(ty).?; @@ -1136,17 +1132,13 @@ const DeclGen = struct { return self.todo("packed union types", .{}); } - const layout = self.unionLayout(ty, maybe_active_field); - - if (layout.payload_size == 0) { + const layout = self.unionLayout(ty); + if (!layout.has_payload) { // No payload, so represent this as just the tag type. return try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect); } - // TODO: We need to add the active field to the key, somehow. - if (maybe_active_field == null) { - if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref; - } + if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref; var member_types: [4]CacheRef = undefined; var member_names: [4]CacheString = undefined; @@ -1159,10 +1151,10 @@ const DeclGen = struct { member_names[layout.tag_index] = try self.spv.resolveString("(tag)"); } - if (layout.active_field_size != 0) { - const active_payload_ty_ref = try self.resolveType(layout.active_field_ty, .indirect); - member_types[layout.active_field_index] = active_payload_ty_ref; - member_names[layout.active_field_index] = try self.spv.resolveString("(payload)"); + if (layout.payload_size != 0) { + const payload_ty_ref = try self.resolveType(layout.payload_ty, .indirect); + member_types[layout.payload_index] = payload_ty_ref; + member_names[layout.payload_index] = try self.spv.resolveString("(payload)"); } if (layout.payload_padding_size != 0) { @@ -1183,9 +1175,7 @@ const DeclGen = struct { .member_names = member_names[0..layout.total_fields], } }); - if (maybe_active_field == null) { - try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref }); - } + try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref }); return ty_ref; } @@ -1453,7 +1443,7 @@ const DeclGen = struct { try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref }); return ty_ref; }, - .Union => return try self.resolveUnionType(ty, null), + .Union => return try self.resolveUnionType(ty), .ErrorSet => return try self.intType(.unsigned, 16), .ErrorUnion => { const payload_ty = ty.errorUnionPayload(mod); @@ -1567,14 +1557,16 @@ const DeclGen = struct { } const UnionLayout = struct { - active_field: u32, - active_field_ty: Type, - payload_size: u32, - + /// If false, this union is represented + /// by only an integer of the tag type. + has_payload: bool, tag_size: u32, tag_index: u32, - active_field_size: u32, - active_field_index: u32, + /// Note: This is the size of the payload type itself, NOT the size of the ENTIRE payload. + /// Use `has_payload` instead!! + payload_ty: Type, + payload_size: u32, + payload_index: u32, payload_padding_size: u32, payload_padding_index: u32, padding_size: u32, @@ -1582,23 +1574,19 @@ const DeclGen = struct { total_fields: u32, }; - fn unionLayout(self: *DeclGen, ty: Type, maybe_active_field: ?usize) UnionLayout { + fn unionLayout(self: *DeclGen, ty: Type) UnionLayout { const mod = self.module; const ip = &mod.intern_pool; const layout = ty.unionGetLayout(self.module); const union_obj = mod.typeToUnion(ty).?; - const active_field = maybe_active_field orelse layout.most_aligned_field; - const active_field_ty = union_obj.field_types.get(ip)[active_field].toType(); - var union_layout = UnionLayout{ - .active_field = @intCast(active_field), - .active_field_ty = active_field_ty, - .payload_size = @intCast(layout.payload_size), + .has_payload = layout.payload_size != 0, .tag_size = @intCast(layout.tag_size), .tag_index = undefined, - .active_field_size = undefined, - .active_field_index = undefined, + .payload_ty = undefined, + .payload_size = undefined, + .payload_index = undefined, .payload_padding_size = undefined, .payload_padding_index = undefined, .padding_size = @intCast(layout.padding), @@ -1606,11 +1594,16 @@ const DeclGen = struct { .total_fields = undefined, }; - union_layout.active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime(mod)) - @intCast(active_field_ty.abiSize(mod)) - else - 0; - union_layout.payload_padding_size = @intCast(layout.payload_size - union_layout.active_field_size); + if (union_layout.has_payload) { + const most_aligned_field = layout.most_aligned_field; + const most_aligned_field_ty = union_obj.field_types.get(ip)[most_aligned_field].toType(); + union_layout.payload_ty = most_aligned_field_ty; + union_layout.payload_size = @intCast(most_aligned_field_ty.abiSize(mod)); + } else { + union_layout.payload_size = 0; + } + + union_layout.payload_padding_size = @intCast(layout.payload_size - union_layout.payload_size); const tag_first = layout.tag_align.compare(.gte, layout.payload_align); var field_index: u32 = 0; @@ -1620,8 +1613,8 @@ const DeclGen = struct { field_index += 1; } - if (union_layout.active_field_size != 0) { - union_layout.active_field_index = field_index; + if (union_layout.payload_size != 0) { + union_layout.payload_index = field_index; field_index += 1; } @@ -3300,7 +3293,7 @@ const DeclGen = struct { const bin_op = self.air.instructions.items(.data)[inst].bin_op; const un_ptr_ty = self.typeOf(bin_op.lhs); const un_ty = un_ptr_ty.childType(mod); - const layout = self.unionLayout(un_ty, null); + const layout = self.unionLayout(un_ty); if (layout.tag_size == 0) return; @@ -3310,7 +3303,7 @@ const DeclGen = struct { const union_ptr_id = try self.resolve(bin_op.lhs); const new_tag_id = try self.resolve(bin_op.rhs); - if (layout.payload_size == 0) { + if (!layout.has_payload) { try self.store(tag_ty, union_ptr_id, new_tag_id, .{ .is_volatile = un_ptr_ty.isVolatilePtr(mod) }); } else { const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index}); @@ -3325,11 +3318,11 @@ const DeclGen = struct { const un_ty = self.typeOf(ty_op.operand); const mod = self.module; - const layout = self.unionLayout(un_ty, null); + const layout = self.unionLayout(un_ty); if (layout.tag_size == 0) return null; const union_handle = try self.resolve(ty_op.operand); - if (layout.payload_size == 0) return union_handle; + if (!layout.has_payload) return union_handle; const tag_ty = un_ty.unionTagTypeSafety(mod).?; return try self.extractField(tag_ty, union_handle, layout.tag_index); @@ -3342,8 +3335,8 @@ const DeclGen = struct { payload: ?IdRef, ) !IdRef { // To initialize a union, generate a temporary variable with the - // type that has the right field active, then pointer-cast and store - // the active field, and finally load and return the entire union. + // union type, then get the field pointer and pointer-cast it to the + // right type to store it. Finally load the entire union. const mod = self.module; const ip = &mod.intern_pool; @@ -3354,7 +3347,7 @@ const DeclGen = struct { } const maybe_tag_ty = ty.unionTagTypeSafety(mod); - const layout = self.unionLayout(ty, active_field); + const layout = self.unionLayout(ty); const tag_int = if (layout.tag_size != 0) blk: { const tag_ty = maybe_tag_ty.?; @@ -3365,23 +3358,12 @@ const DeclGen = struct { break :blk tag_int_val.toUnsignedInt(mod); } else 0; - if (layout.payload_size == 0) { + if (!layout.has_payload) { const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct); return try self.constInt(tag_ty_ref, tag_int); } - // TODO: Make this use self.ptrType - const un_active_ty_ref = try self.resolveUnionType(ty, active_field); - const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function); - const un_general_ty_ref = try self.resolveType(ty, .direct); - const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function); - - const tmp_id = self.spv.allocId(); - try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(un_active_ptr_ty_ref), - .id_result = tmp_id, - .storage_class = .Function, - }); + const tmp_id = try self.alloc(ty, .{ .storage_class = .Function }); if (layout.tag_size != 0) { const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct); @@ -3391,10 +3373,19 @@ const DeclGen = struct { try self.store(maybe_tag_ty.?, ptr_id, tag_id, .{}); } - if (layout.active_field_size != 0) { - const active_field_ptr_ty_ref = try self.ptrType(layout.active_field_ty, .Function); - const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))}); - try self.store(layout.active_field_ty, ptr_id, payload.?, .{}); + const payload_ty = union_ty.field_types.get(ip)[active_field].toType(); + if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { + const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, .Function); + const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, tmp_id, &.{layout.payload_index}); + const active_pl_ptr_ty_ref = try self.ptrType(payload_ty, .Function); + const active_pl_ptr_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ + .id_result_type = self.typeId(active_pl_ptr_ty_ref), + .id_result = active_pl_ptr_id, + .operand = pl_ptr_id, + }); + + try self.store(payload_ty, active_pl_ptr_id, payload.?, .{}); } else { assert(payload == null); } @@ -3402,34 +3393,21 @@ const DeclGen = struct { // Just leave the padding fields uninitialized... // TODO: Or should we initialize them with undef explicitly? - // Now cast the pointer and load it as the 'generic' union type. - - const casted_var_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = self.typeId(un_general_ptr_ty_ref), - .id_result = casted_var_id, - .operand = tmp_id, - }); - - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLoad, .{ - .id_result_type = self.typeId(un_general_ty_ref), - .id_result = result_id, - .pointer = casted_var_id, - }); - - return result_id; + return try self.load(ty, tmp_id, .{}); } fn airUnionInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { if (self.liveness.isUnused(inst)) return null; + const mod = self.module; + const ip = &mod.intern_pool; const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; const extra = self.air.extraData(Air.UnionInit, ty_pl.payload).data; const ty = self.typeOfIndex(inst); - const layout = self.unionLayout(ty, extra.field_index); - const payload = if (layout.active_field_size != 0) + const union_obj = mod.typeToUnion(ty).?; + const field_ty = union_obj.field_types.get(ip)[extra.field_index].toType(); + const payload = if (field_ty.hasRuntimeBitsIgnoreComptime(mod)) try self.resolve(extra.init) else null; @@ -3458,30 +3436,24 @@ const DeclGen = struct { .Union => switch (object_ty.containerLayout(mod)) { .Packed => unreachable, // TODO else => { - // Store, pointer-cast, load - const un_general_ty_ref = try self.resolveType(object_ty, .indirect); - const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function); - const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index); - const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function); - const field_ty_ref = try self.resolveType(field_ty, .indirect); - const field_ptr_ty_ref = try self.spv.ptrType(field_ty_ref, .Function); + // Store, ptr-elem-ptr, pointer-cast, load + const layout = self.unionLayout(object_ty); + assert(layout.has_payload); - const tmp_id = self.spv.allocId(); - try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ - .id_result_type = self.typeId(un_general_ptr_ty_ref), - .id_result = tmp_id, - .storage_class = .Function, - }); + const tmp_id = try self.alloc(object_ty, .{ .storage_class = .Function }); try self.store(object_ty, tmp_id, object_id, .{}); - const casted_tmp_id = self.spv.allocId(); + + const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, .Function); + const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, tmp_id, &.{layout.payload_index}); + + const active_pl_ptr_ty_ref = try self.ptrType(field_ty, .Function); + const active_pl_ptr_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = self.typeId(un_active_ptr_ty_ref), - .id_result = casted_tmp_id, - .operand = tmp_id, + .id_result_type = self.typeId(active_pl_ptr_ty_ref), + .id_result = active_pl_ptr_id, + .operand = pl_ptr_id, }); - const layout = self.unionLayout(object_ty, field_index); - const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index}); - return try self.load(field_ty, field_ptr_id, .{}); + return try self.load(field_ty, active_pl_ptr_id, .{}); }, }, else => unreachable, @@ -3540,18 +3512,24 @@ const DeclGen = struct { .Union => switch (object_ty.containerLayout(mod)) { .Packed => unreachable, // TODO else => { - const storage_class = spvStorageClass(object_ptr_ty.ptrAddressSpace(mod)); - const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index); - const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, storage_class); + const layout = self.unionLayout(object_ty); + if (!layout.has_payload) { + // Asked to get a pointer to a zero-sized field. Just lower this + // to undefined, there is no reason to make it be a valid pointer. + return try self.spv.constUndef(result_ty_ref); + } - const casted_id = self.spv.allocId(); + const storage_class = spvStorageClass(object_ptr_ty.ptrAddressSpace(mod)); + const pl_ptr_ty_ref = try self.ptrType(layout.payload_ty, storage_class); + const pl_ptr_id = try self.accessChain(pl_ptr_ty_ref, object_ptr, &.{layout.payload_index}); + + const active_pl_ptr_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = self.typeId(un_active_ptr_ty_ref), - .id_result = casted_id, - .operand = object_ptr, + .id_result_type = self.typeId(result_ty_ref), + .id_result = active_pl_ptr_id, + .operand = pl_ptr_id, }); - const layout = self.unionLayout(object_ty, field_index); - return try self.accessChain(result_ty_ref, casted_id, &.{layout.active_field_index}); + return active_pl_ptr_id; }, }, else => unreachable, diff --git a/test/behavior/struct.zig b/test/behavior/struct.zig index 8e7aa59844..2edd7fae02 100644 --- a/test/behavior/struct.zig +++ b/test/behavior/struct.zig @@ -736,7 +736,6 @@ test "packed struct with u0 field access" { test "access to global struct fields" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; g_foo.bar.value = 42; try expect(g_foo.bar.value == 42); diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 6b87ab96ce..ebfde8899e 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -399,7 +399,6 @@ test "tagged union with no payloads" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const a = UnionEnumNoPayloads{ .B = {} }; switch (a) { @@ -474,7 +473,6 @@ test "update the tag value for zero-sized unions" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = union(enum) { U0: void, @@ -515,7 +513,6 @@ test "method call on an empty union" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const MyUnion = union(MyUnionTag) { @@ -593,7 +590,6 @@ test "tagged union with all void fields but a meaningful tag" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const B = union(enum) { @@ -795,7 +791,6 @@ test "@unionInit stored to a const" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const U = union(enum) { @@ -867,7 +862,6 @@ test "union no tag with struct member" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Struct = struct {}; const Union = union { @@ -1079,7 +1073,6 @@ test "@unionInit on union with tag but no fields" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const Type = enum(u8) { no_op = 105 }; @@ -1128,7 +1121,6 @@ test "global variable struct contains union initialized to non-most-aligned fiel if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const T = struct { const U = union(enum) { @@ -1348,7 +1340,6 @@ test "union field ptr - zero sized payload" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union { foo: void, @@ -1363,7 +1354,6 @@ test "union field ptr - zero sized field" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union { foo: void,