diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 8a77899bc0..b11a822390 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -1662,13 +1662,11 @@ pub const DeclGen = struct { return try self.convertToDirect(result_ty, result_id); } - fn load(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef) !IdRef { - const mod = self.module; - const value_ty = ptr_ty.childType(mod); + fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, is_volatile: bool) !IdRef { const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect); const result_id = self.spv.allocId(); const access = spec.MemoryAccess.Extended{ - .Volatile = ptr_ty.isVolatilePtr(mod), + .Volatile = is_volatile, }; try self.func.body.emit(self.spv.gpa, .OpLoad, .{ .id_result_type = self.typeId(indirect_value_ty_ref), @@ -1679,12 +1677,10 @@ pub const DeclGen = struct { return try self.convertToDirect(value_ty, result_id); } - fn store(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, value_id: IdRef) !void { - const mod = self.module; - const value_ty = ptr_ty.childType(mod); + fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, is_volatile: bool) !void { const indirect_value_id = try self.convertToIndirect(value_ty, value_id); const access = spec.MemoryAccess.Extended{ - .Volatile = ptr_ty.isVolatilePtr(mod), + .Volatile = is_volatile, }; try self.func.body.emit(self.spv.gpa, .OpStore, .{ .pointer = ptr_id, @@ -1754,6 +1750,7 @@ pub const DeclGen = struct { .ptr_elem_ptr => try self.airPtrElemPtr(inst), .ptr_elem_val => try self.airPtrElemVal(inst), + .set_union_tag => return try self.airSetUnionTag(inst), .get_union_tag => try self.airGetUnionTag(inst), .struct_field_val => try self.airStructFieldVal(inst), @@ -2512,7 +2509,7 @@ pub const DeclGen = struct { const slice_ptr = try self.extractField(ptr_ty, slice_id, 0); const elem_ptr = try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index_id, &.{}); - return try self.load(slice_ty, elem_ptr); + return try self.load(slice_ty.childType(mod), elem_ptr, slice_ty.isVolatilePtr(mod)); } fn ptrElemPtr(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, index_id: IdRef) !IdRef { @@ -2548,25 +2545,41 @@ pub const DeclGen = struct { } fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const mod = self.module; const bin_op = self.air.instructions.items(.data)[inst].bin_op; const ptr_ty = self.typeOf(bin_op.lhs); + const elem_ty = self.typeOfIndex(inst); const ptr_id = try self.resolve(bin_op.lhs); const index_id = try self.resolve(bin_op.rhs); - const elem_ptr_id = try self.ptrElemPtr(ptr_ty, ptr_id, index_id); + return try self.load(elem_ty, elem_ptr_id, ptr_ty.isVolatilePtr(mod)); + } - // If we have a pointer-to-array, construct an element pointer to use with load() - // If we pass ptr_ty directly, it will attempt to load the entire array rather than - // just an element. - var elem_ptr_info = ptr_ty.ptrInfo(mod); - elem_ptr_info.flags.size = .One; - const elem_ptr_ty = try mod.intern_pool.get(mod.gpa, .{ .ptr_type = elem_ptr_info }); + fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void { + const mod = self.module; + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const un_ptr_ty = self.typeOf(bin_op.lhs); + const un_ty = un_ptr_ty.childType(mod); + const layout = self.unionLayout(un_ty, null); - return try self.load(elem_ptr_ty.toType(), elem_ptr_id); + if (layout.tag_size == 0) return; + + const tag_ty = un_ty.unionTagTypeSafety(mod).?; + const tag_ty_ref = try self.resolveType(tag_ty, .indirect); + const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod))); + + const union_ptr_id = try self.resolve(bin_op.lhs); + const new_tag_id = try self.resolve(bin_op.rhs); + + const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index}); + try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod)); } fn airGetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + const ty_op = self.air.instructions.items(.data)[inst].ty_op; const un_ty = self.typeOf(ty_op.operand); @@ -2588,25 +2601,25 @@ pub const DeclGen = struct { const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data; - const container_ty = self.typeOf(struct_field.struct_operand); + const object_ty = self.typeOf(struct_field.struct_operand); const object_id = try self.resolve(struct_field.struct_operand); const field_index = struct_field.field_index; - const field_ty = container_ty.structFieldType(field_index, mod); + const field_ty = object_ty.structFieldType(field_index, mod); if (!field_ty.hasRuntimeBitsIgnoreComptime(mod)) return null; - switch (container_ty.zigTypeTag(mod)) { - .Struct => switch (container_ty.containerLayout(mod)) { + switch (object_ty.zigTypeTag(mod)) { + .Struct => switch (object_ty.containerLayout(mod)) { .Packed => unreachable, // TODO else => return try self.extractField(field_ty, object_id, field_index), }, - .Union => switch (container_ty.containerLayout(mod)) { + .Union => switch (object_ty.containerLayout(mod)) { .Packed => unreachable, // TODO else => { // Store, pointer-cast, load - const un_general_ty_ref = try self.resolveType(container_ty, .indirect); + const un_general_ty_ref = try self.resolveType(object_ty, .indirect); const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function); - const un_active_ty_ref = try self.resolveUnionType(container_ty, field_index); + const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index); const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function); const field_ty_ref = try self.resolveType(field_ty, .indirect); const field_ptr_ty_ref = try self.spv.ptrType(field_ty_ref, .Function); @@ -2617,31 +2630,20 @@ pub const DeclGen = struct { .id_result = tmp_id, .storage_class = .Function, }); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = tmp_id, - .object = object_id, - }); + try self.store(object_ty, tmp_id, object_id, false); const casted_tmp_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ .id_result_type = self.typeId(un_active_ptr_ty_ref), .id_result = casted_tmp_id, .operand = tmp_id, }); - const layout = self.unionLayout(container_ty, field_index); + const layout = self.unionLayout(object_ty, field_index); const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index}); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLoad, .{ - .id_result_type = self.typeId(field_ty_ref), - .id_result = result_id, - .pointer = field_ptr_id, - }); - return try self.convertToDirect(field_ty, result_id); + return try self.load(field_ty, field_ptr_id, false); }, }, else => unreachable, } - - // return try self.extractField(field_ty, object_id, field_index); } fn structFieldPtr( @@ -2866,19 +2868,21 @@ pub const DeclGen = struct { const mod = self.module; const ty_op = self.air.instructions.items(.data)[inst].ty_op; const ptr_ty = self.typeOf(ty_op.operand); + const elem_ty = self.typeOfIndex(inst); const operand = try self.resolve(ty_op.operand); if (!ptr_ty.isVolatilePtr(mod) and self.liveness.isUnused(inst)) return null; - return try self.load(ptr_ty, operand); + return try self.load(elem_ty, operand, ptr_ty.isVolatilePtr(mod)); } fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void { const bin_op = self.air.instructions.items(.data)[inst].bin_op; const ptr_ty = self.typeOf(bin_op.lhs); + const elem_ty = ptr_ty.childType(self.module); const ptr = try self.resolve(bin_op.lhs); const value = try self.resolve(bin_op.rhs); - try self.store(ptr_ty, ptr, value); + try self.store(elem_ty, ptr, value, ptr_ty.isVolatilePtr(self.module)); } fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void { @@ -2922,7 +2926,7 @@ pub const DeclGen = struct { } const ptr = try self.resolve(un_op); - const value = try self.load(ptr_ty, ptr); + const value = try self.load(ret_ty, ptr, ptr_ty.isVolatilePtr(mod)); try self.func.body.emit(self.spv.gpa, .OpReturnValue, .{ .value = value, }); diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 003b6d3f8f..d87b411e99 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -29,7 +29,6 @@ test "init union with runtime value - floats" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var foo: FooWithFloats = undefined; @@ -59,7 +58,6 @@ test "init union with runtime value" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var foo: Foo = undefined; @@ -170,7 +168,6 @@ test "constant tagged union with payload" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var empty = TaggedUnionWithPayload{ .Empty = {} }; var full = TaggedUnionWithPayload{ .Full = 13 }; @@ -508,7 +505,6 @@ test "union initializer generates padding only if needed" { test "runtime tag name with single field" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union(enum) { A: i32, @@ -585,7 +581,6 @@ test "tagged union as return value" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; switch (returnAnInt(13)) { TaggedFoo.One => |value| try expect(value == 13), @@ -630,7 +625,6 @@ test "union(enum(u32)) with specified and unspecified tag values" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime expect(Tag(Tag(MultipleChoice2)) == u32); try testEnumWithSpecifiedAndUnspecifiedTagValues(MultipleChoice2{ .C = 123 }); @@ -668,7 +662,6 @@ test "switch on union with only 1 field" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var r: PartialInst = undefined; r = PartialInst.Compiled; @@ -697,7 +690,6 @@ const PartialInstWithPayload = union(enum) { test "union with only 1 field casted to its enum type which has enum value specified" { if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Literal = union(enum) { Number: f64, @@ -782,7 +774,6 @@ test "return union init with void payload" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn entry() !void { @@ -836,7 +827,6 @@ test "@unionInit can modify a union type" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const UnionInitEnum = union(enum) { Boolean: bool, @@ -860,7 +850,6 @@ test "@unionInit can modify a pointer value" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const UnionInitEnum = union(enum) { Boolean: bool, @@ -917,7 +906,6 @@ test "anonymous union literal syntax" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const Number = union { @@ -1041,7 +1029,6 @@ test "switching on non exhaustive union" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { const E = enum(u8) { @@ -1225,7 +1212,6 @@ test "union tag is set when initiated as a temporary value at runtime" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union(enum) { a, @@ -1263,7 +1249,6 @@ test "return an extern union from C calling convention" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const namespace = struct { const S = extern struct { @@ -1294,7 +1279,6 @@ test "noreturn field in union" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union(enum) { a: u32, @@ -1475,7 +1459,6 @@ test "no dependency loop when function pointer in union returns the union" { test "union reassignment can use previous value" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union { a: u32, @@ -1527,7 +1510,6 @@ test "reinterpreting enum value inside packed union" { test "access the tag of a global tagged union" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union(enum) { a, @@ -1539,7 +1521,6 @@ test "access the tag of a global tagged union" { test "coerce enum literal to union in result loc" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const U = union(enum) { a,