From dacd70fbe41d959bb7b48b5bad8612e74231524b Mon Sep 17 00:00:00 2001 From: Ali Cheraghi Date: Wed, 7 May 2025 20:25:06 +0330 Subject: [PATCH] spirv: super basic composite int support --- src/Zcu.zig | 2 +- src/codegen/spirv.zig | 162 +++++++++++------- src/codegen/spirv/Module.zig | 7 +- src/target.zig | 3 +- .../compile_errors/@import_zon_bad_type.zig | 6 +- .../anytype_param_requires_comptime.zig | 2 +- .../bogus_method_call_on_slice.zig | 2 +- .../compile_errors/coerce_anon_struct.zig | 2 +- test/cases/compile_errors/redundant_try.zig | 4 +- test/tests.zig | 2 +- 10 files changed, 120 insertions(+), 72 deletions(-) diff --git a/src/Zcu.zig b/src/Zcu.zig index bee7fadb95..b8118b3f0b 100644 --- a/src/Zcu.zig +++ b/src/Zcu.zig @@ -3693,7 +3693,7 @@ pub fn errorSetBits(zcu: *const Zcu) u16 { const target = zcu.getTarget(); if (zcu.error_limit == 0) return 0; - if (target.cpu.arch == .spirv64) { + if (target.cpu.arch.isSpirV()) { if (!std.Target.spirv.featureSetHas(target.cpu.features, .storage_push_constant16)) { return 32; } diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index b5bba61016..2732a0a617 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -30,6 +30,7 @@ const SpvAssembler = @import("spirv/Assembler.zig"); const InstMap = std.AutoHashMapUnmanaged(Air.Inst.Index, IdRef); pub const zig_call_abi_ver = 3; +pub const big_int_bits = 32; const InternMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, NavGen.Repr }, IdResult); const PtrTypeMap = std.AutoHashMapUnmanaged( @@ -376,7 +377,7 @@ const NavGen = struct { /// The number of bits required to store the type. /// For `integer` and `float`, this is equal to `bits`. /// For `strange_integer` and `bool` this is the size of the backing integer. - /// For `composite_integer` this is 0 (TODO) + /// For `composite_integer` this is the elements count. backing_bits: u16, /// Null if this type is a scalar, or the length @@ -579,11 +580,13 @@ const NavGen = struct { /// The backing type will be chosen as the smallest supported integer larger or equal to it in number of bits. /// The result is valid to be used with OpTypeInt. /// TODO: Should the result of this function be cached? - fn backingIntBits(self: *NavGen, bits: u16) ?u16 { + fn backingIntBits(self: *NavGen, bits: u16) struct { u16, bool } { // The backend will never be asked to compiler a 0-bit integer, so we won't have to handle those in this function. assert(bits != 0); - if (self.spv.hasFeature(.arbitrary_precision_integers) and bits <= 32) return bits; + if (self.spv.hasFeature(.arbitrary_precision_integers) and bits <= 32) { + return .{ bits, false }; + } // We require Int8 and Int16 capabilities and benefit Int64 when available. // 32-bit integers are always supported (see spec, 2.16.1, Data rules). @@ -596,10 +599,11 @@ const NavGen = struct { for (ints) |int| { const has_feature = if (int.feature) |feature| self.spv.hasFeature(feature) else true; - if (bits <= int.bits and has_feature) return int.bits; + if (bits <= int.bits and has_feature) return .{ int.bits, false }; } - return null; + // Big int + return .{ std.mem.alignForward(u16, bits, big_int_bits), true }; } /// Return the amount of bits in the largest supported integer type. This is either 32 (always supported), or 64 (if @@ -623,7 +627,7 @@ const NavGen = struct { return switch (scalar_ty.zigTypeTag(zcu)) { .bool => .{ .bits = 1, // Doesn't matter for this class. - .backing_bits = self.backingIntBits(1).?, + .backing_bits = self.backingIntBits(1).@"0", .vector_len = vector_len, .signedness = .unsigned, // Technically, but doesn't matter for this class. .class = .bool, @@ -638,19 +642,16 @@ const NavGen = struct { .int => blk: { const int_info = scalar_ty.intInfo(zcu); // TODO: Maybe it's useful to also return this value. - const maybe_backing_bits = self.backingIntBits(int_info.bits); + const backing_bits, const big_int = self.backingIntBits(int_info.bits); break :blk .{ .bits = int_info.bits, - .backing_bits = maybe_backing_bits orelse 0, + .backing_bits = backing_bits, .vector_len = vector_len, .signedness = int_info.signedness, - .class = if (maybe_backing_bits) |backing_bits| - if (backing_bits == int_info.bits) - .integer - else - .strange_integer - else - .composite_integer, + .class = class: { + if (big_int) break :class .composite_integer; + break :class if (backing_bits == int_info.bits) .integer else .strange_integer; + }, }; }, .@"enum" => unreachable, @@ -659,6 +660,34 @@ const NavGen = struct { }; } + /// Checks whether the type can be directly translated to SPIR-V vectors + fn isSpvVector(self: *NavGen, ty: Type) bool { + const zcu = self.pt.zcu; + if (ty.zigTypeTag(zcu) != .vector) return false; + + // TODO: This check must be expanded for types that can be represented + // as integers (enums / packed structs?) and types that are represented + // by multiple SPIR-V values. + const scalar_ty = ty.scalarType(zcu); + switch (scalar_ty.zigTypeTag(zcu)) { + .bool, + .int, + .float, + => {}, + else => return false, + } + + const elem_ty = ty.childType(zcu); + const len = ty.vectorLen(zcu); + + if (elem_ty.isNumeric(zcu) or elem_ty.toIntern() == .bool_type) { + if (len > 1 and len <= 4) return true; + if (self.spv.hasFeature(.vector16)) return (len == 8 or len == 16); + } + + return false; + } + /// Emits a bool constant in a particular representation. fn constBool(self: *NavGen, value: bool, repr: Repr) !IdRef { return switch (repr) { @@ -675,14 +704,26 @@ const NavGen = struct { const scalar_ty = ty.scalarType(zcu); const int_info = scalar_ty.intInfo(zcu); // Use backing bits so that negatives are sign extended - const backing_bits = self.backingIntBits(int_info.bits).?; // Assertion failure means big int + const backing_bits, const big_int = self.backingIntBits(int_info.bits); assert(backing_bits != 0); // u0 is comptime + const result_ty_id = try self.resolveType(scalar_ty, .indirect); const signedness: Signedness = switch (@typeInfo(@TypeOf(value))) { .int => |int| int.signedness, .comptime_int => if (value < 0) .signed else .unsigned, else => unreachable, }; + if (@sizeOf(@TypeOf(value)) >= 4 and big_int) { + const value64: u64 = switch (signedness) { + .signed => @bitCast(@as(i64, @intCast(value))), + .unsigned => @as(u64, @intCast(value)), + }; + assert(backing_bits == 64); + return self.constructComposite(result_ty_id, &.{ + try self.constInt(.u32, @as(u32, @truncate(value64))), + try self.constInt(.u32, @as(u32, @truncate(value64 << 32))), + }); + } const final_value: spec.LiteralContextDependentNumber = blk: { if (self.spv.hasFeature(.kernel)) { @@ -700,18 +741,17 @@ const NavGen = struct { break :blk switch (backing_bits) { 1...32 => .{ .uint32 = @truncate(truncated_value) }, 33...64 => .{ .uint64 = truncated_value }, - else => unreachable, // TODO: Large integer constants + else => unreachable, }; } break :blk switch (backing_bits) { 1...32 => if (signedness == .signed) .{ .int32 = @intCast(value) } else .{ .uint32 = @intCast(value) }, 33...64 => if (signedness == .signed) .{ .int64 = value } else .{ .uint64 = value }, - else => unreachable, // TODO: Large integer constants + else => unreachable, }; }; - const result_ty_id = try self.resolveType(scalar_ty, .indirect); const result_id = try self.spv.constant(result_ty_id, final_value); if (!ty.isVector(zcu)) return result_id; @@ -949,7 +989,7 @@ const NavGen = struct { // TODO: composite int // TODO: endianness const bits: u16 = @intCast(ty.bitSize(zcu)); - const bytes = std.mem.alignForward(u16, self.backingIntBits(bits).?, 8) / 8; + const bytes = std.mem.alignForward(u16, self.backingIntBits(bits).@"0", 8) / 8; var limbs: [8]u8 = undefined; @memset(&limbs, 0); val.writeToPackedMemory(ty, pt, limbs[0..bytes], 0) catch unreachable; @@ -1068,19 +1108,11 @@ const NavGen = struct { const parent_ptr_id = try self.derivePtr(oac.parent.*); const parent_ptr_ty = try oac.parent.ptrType(pt); const result_ty_id = try self.resolveType(oac.new_ptr_ty, .direct); + const child_size = oac.new_ptr_ty.childType(zcu).abiSize(zcu); - if (oac.byte_offset != 0) { - const child_size = oac.new_ptr_ty.childType(zcu).abiSize(zcu); - if (oac.byte_offset % child_size != 0) { - return self.fail("cannot perform pointer cast: '{}' to '{}'", .{ - parent_ptr_ty.fmt(pt), - oac.new_ptr_ty.fmt(pt), - }); - } - + if (parent_ptr_ty.childType(zcu).isVector(zcu) and oac.byte_offset % child_size == 0) { // Vector element ptr accesses are derived as offset_and_cast. // We can just use OpAccessChain. - assert(parent_ptr_ty.childType(zcu).zigTypeTag(zcu) == .vector); return self.accessChain( result_ty_id, parent_ptr_id, @@ -1088,15 +1120,22 @@ const NavGen = struct { ); } - // Allow changing the pointer type child only to restructure arrays. - // e.g. [3][2]T to T is fine, as is [2]T -> [2][1]T. - const result_ptr_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ - .id_result_type = result_ty_id, - .id_result = result_ptr_id, - .operand = parent_ptr_id, + if (oac.byte_offset == 0) { + // Allow changing the pointer type child only to restructure arrays. + // e.g. [3][2]T to T is fine, as is [2]T -> [2][1]T. + const result_ptr_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpBitcast, .{ + .id_result_type = result_ty_id, + .id_result = result_ptr_id, + .operand = parent_ptr_id, + }); + return result_ptr_id; + } + + return self.fail("cannot perform pointer cast: '{}' to '{}'", .{ + parent_ptr_ty.fmt(pt), + oac.new_ptr_ty.fmt(pt), }); - return result_ptr_id; }, } } @@ -1217,11 +1256,14 @@ const NavGen = struct { /// actual operations (as well as store) a Zig type of a particular number of bits. To create /// a type with an exact size, use SpvModule.intType. fn intType(self: *NavGen, signedness: std.builtin.Signedness, bits: u16) !IdRef { - const backing_bits = self.backingIntBits(bits) orelse { - // TODO: Integers too big for any native type are represented as "composite integers": - // An array of largestSupportedIntBits. - return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits }); - }; + const backing_bits, const big_int = self.backingIntBits(bits); + if (big_int) { + if (backing_bits > 64) { + return self.fail("composite integers larger than 64bit aren't supported", .{}); + } + const int_ty = try self.resolveType(.u32, .direct); + return self.arrayType(backing_bits / big_int_bits, int_ty); + } // Kernel only supports unsigned ints. if (self.spv.hasFeature(.kernel)) { @@ -1509,6 +1551,17 @@ const NavGen = struct { return result_id; } }, + .vector => { + const elem_ty = ty.childType(zcu); + const elem_ty_id = try self.resolveType(elem_ty, repr); + const len = ty.vectorLen(zcu); + + if (self.isSpvVector(ty)) { + return try self.spv.vectorType(len, elem_ty_id); + } else { + return try self.arrayType(len, elem_ty_id); + } + }, .@"fn" => switch (repr) { .direct => { const fn_info = zcu.typeToFunc(ty).?; @@ -1577,12 +1630,6 @@ const NavGen = struct { ); return result_id; }, - .vector => { - const elem_ty = ty.childType(zcu); - const elem_ty_id = try self.resolveType(elem_ty, repr); - const len = ty.vectorLen(zcu); - return self.arrayType(len, elem_ty_id); - }, .@"struct" => { const struct_type = switch (ip.indexToKey(ty.toIntern())) { .tuple_type => |tuple| { @@ -3378,8 +3425,7 @@ const NavGen = struct { const zcu = self.pt.zcu; const ty = value.ty; switch (info.class) { - .integer, .bool, .float => return value, - .composite_integer => unreachable, // TODO + .composite_integer, .integer, .bool, .float => return value, .strange_integer => switch (info.signedness) { .unsigned => { const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1; @@ -5039,7 +5085,7 @@ const NavGen = struct { const mask_id = try self.constInt(object_ty, (@as(u64, 1) << @as(u6, @intCast(field_bit_size))) - 1); const masked = try self.buildBinary(.bit_and, shift, .{ .ty = object_ty, .value = .{ .singleton = mask_id } }); const result_id = blk: { - if (self.backingIntBits(field_bit_size).? == self.backingIntBits(@intCast(object_ty.bitSize(zcu))).?) + if (self.backingIntBits(field_bit_size).@"0" == self.backingIntBits(@intCast(object_ty.bitSize(zcu))).@"0") break :blk try self.bitCast(field_int_ty, object_ty, try masked.materialize(self)); const trunc = try self.buildConvert(field_int_ty, masked); break :blk try trunc.materialize(self); @@ -5063,7 +5109,7 @@ const NavGen = struct { .{ .ty = backing_int_ty, .value = .{ .singleton = mask_id } }, ); const result_id = blk: { - if (self.backingIntBits(field_bit_size).? == self.backingIntBits(@intCast(backing_int_ty.bitSize(zcu))).?) + if (self.backingIntBits(field_bit_size).@"0" == self.backingIntBits(@intCast(backing_int_ty.bitSize(zcu))).@"0") break :blk try self.bitCast(int_ty, backing_int_ty, try masked.materialize(self)); const trunc = try self.buildConvert(int_ty, masked); break :blk try trunc.materialize(self); @@ -6100,17 +6146,15 @@ const NavGen = struct { .bool, .error_set => 1, .int => blk: { const bits = cond_ty.intInfo(zcu).bits; - const backing_bits = self.backingIntBits(bits) orelse { - return self.todo("implement composite int switch", .{}); - }; + const backing_bits, const big_int = self.backingIntBits(bits); + if (big_int) return self.todo("implement composite int switch", .{}); break :blk if (backing_bits <= 32) 1 else 2; }, .@"enum" => blk: { const int_ty = cond_ty.intTagType(zcu); const int_info = int_ty.intInfo(zcu); - const backing_bits = self.backingIntBits(int_info.bits) orelse { - return self.todo("implement composite int switch", .{}); - }; + const backing_bits, const big_int = self.backingIntBits(int_info.bits); + if (big_int) return self.todo("implement composite int switch", .{}); break :blk if (backing_bits <= 32) 1 else 2; }, .pointer => blk: { diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 16c32c26d5..920215bee1 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -369,8 +369,11 @@ pub fn finalize(self: *Module, a: Allocator) ![]Word { // Emit memory model const addressing_model: spec.AddressingModel = blk: { if (self.hasFeature(.shader)) { - assert(self.target.cpu.arch == .spirv64); - if (self.hasFeature(.physical_storage_buffer)) break :blk .PhysicalStorageBuffer64; + if (self.hasFeature(.physical_storage_buffer)) { + assert(self.target.cpu.arch == .spirv64); + break :blk .PhysicalStorageBuffer64; + } + assert(self.target.cpu.arch == .spirv); break :blk .Logical; } diff --git a/src/target.zig b/src/target.zig index c5b2d97efb..6119b002a4 100644 --- a/src/target.zig +++ b/src/target.zig @@ -807,7 +807,8 @@ pub fn zigBackend(target: std.Target, use_llvm: bool) std.builtin.CompilerBacken .powerpc, .powerpcle, .powerpc64, .powerpc64le => .stage2_powerpc, .riscv64 => .stage2_riscv64, .sparc64 => .stage2_sparc64, - .spirv64 => .stage2_spirv64, + .spirv32 => if (target.os.tag == .opencl) .stage2_spirv64 else .other, + .spirv, .spirv64 => .stage2_spirv64, .wasm32, .wasm64 => .stage2_wasm, .x86 => .stage2_x86, .x86_64 => .stage2_x86_64, diff --git a/test/cases/compile_errors/@import_zon_bad_type.zig b/test/cases/compile_errors/@import_zon_bad_type.zig index 3265c6d92c..a2e13c4a6d 100644 --- a/test/cases/compile_errors/@import_zon_bad_type.zig +++ b/test/cases/compile_errors/@import_zon_bad_type.zig @@ -117,9 +117,9 @@ export fn testMutablePointer() void { // tmp.zig:37:38: note: imported here // neg_inf.zon:1:1: error: expected type '?u8' // tmp.zig:57:28: note: imported here -// neg_inf.zon:1:1: error: expected type 'tmp.testNonExhaustiveEnum__enum_499' +// neg_inf.zon:1:1: error: expected type 'tmp.testNonExhaustiveEnum__enum_501' // tmp.zig:62:39: note: imported here -// neg_inf.zon:1:1: error: expected type 'tmp.testUntaggedUnion__union_501' +// neg_inf.zon:1:1: error: expected type 'tmp.testUntaggedUnion__union_503' // tmp.zig:67:44: note: imported here -// neg_inf.zon:1:1: error: expected type 'tmp.testTaggedUnionVoid__union_504' +// neg_inf.zon:1:1: error: expected type 'tmp.testTaggedUnionVoid__union_506' // tmp.zig:72:50: note: imported here diff --git a/test/cases/compile_errors/anytype_param_requires_comptime.zig b/test/cases/compile_errors/anytype_param_requires_comptime.zig index 3546955e23..541a49a460 100644 --- a/test/cases/compile_errors/anytype_param_requires_comptime.zig +++ b/test/cases/compile_errors/anytype_param_requires_comptime.zig @@ -15,6 +15,6 @@ pub export fn entry() void { // error // // :7:25: error: unable to resolve comptime value -// :7:25: note: initializer of comptime-only struct 'tmp.S.foo__anon_473.C' must be comptime-known +// :7:25: note: initializer of comptime-only struct 'tmp.S.foo__anon_475.C' must be comptime-known // :4:16: note: struct requires comptime because of this field // :4:16: note: types are not available at runtime diff --git a/test/cases/compile_errors/bogus_method_call_on_slice.zig b/test/cases/compile_errors/bogus_method_call_on_slice.zig index fe30379476..598c04d2c5 100644 --- a/test/cases/compile_errors/bogus_method_call_on_slice.zig +++ b/test/cases/compile_errors/bogus_method_call_on_slice.zig @@ -16,5 +16,5 @@ pub export fn entry2() void { // // :3:6: error: no field or member function named 'copy' in '[]const u8' // :9:8: error: no field or member function named 'bar' in '@TypeOf(.{})' -// :12:18: error: no field or member function named 'bar' in 'tmp.entry2__struct_477' +// :12:18: error: no field or member function named 'bar' in 'tmp.entry2__struct_479' // :12:6: note: struct declared here diff --git a/test/cases/compile_errors/coerce_anon_struct.zig b/test/cases/compile_errors/coerce_anon_struct.zig index 75e27ddbed..ec5cf966d9 100644 --- a/test/cases/compile_errors/coerce_anon_struct.zig +++ b/test/cases/compile_errors/coerce_anon_struct.zig @@ -6,6 +6,6 @@ export fn foo() void { // error // -// :4:16: error: expected type 'tmp.T', found 'tmp.foo__struct_466' +// :4:16: error: expected type 'tmp.T', found 'tmp.foo__struct_468' // :3:16: note: struct declared here // :1:11: note: struct declared here diff --git a/test/cases/compile_errors/redundant_try.zig b/test/cases/compile_errors/redundant_try.zig index 2a3488c413..1f44cc05dc 100644 --- a/test/cases/compile_errors/redundant_try.zig +++ b/test/cases/compile_errors/redundant_try.zig @@ -44,9 +44,9 @@ comptime { // // :5:23: error: expected error union type, found 'comptime_int' // :10:23: error: expected error union type, found '@TypeOf(.{})' -// :15:23: error: expected error union type, found 'tmp.test2__struct_503' +// :15:23: error: expected error union type, found 'tmp.test2__struct_505' // :15:23: note: struct declared here -// :20:27: error: expected error union type, found 'tmp.test3__struct_505' +// :20:27: error: expected error union type, found 'tmp.test3__struct_507' // :20:27: note: struct declared here // :25:23: error: expected error union type, found 'struct { comptime *const [5:0]u8 = "hello" }' // :31:13: error: expected error union type, found 'u32' diff --git a/test/tests.zig b/test/tests.zig index 04c89444df..cc1da4cf9f 100644 --- a/test/tests.zig +++ b/test/tests.zig @@ -145,7 +145,7 @@ const test_targets = blk: { .{ .target = std.Target.Query.parse(.{ .arch_os_abi = "spirv64-vulkan", - .cpu_features = "vulkan_v1_2+int64+float16+float64", + .cpu_features = "vulkan_v1_2+physical_storage_buffer+int64+float16+float64", }) catch unreachable, .use_llvm = false, .use_lld = false,