From fba6b7e4c20c3d1e1701ad737ad47a2440f44ea7 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 13 Apr 2024 14:37:46 +0200 Subject: [PATCH 1/7] spirv: fix error code encoding --- src/link/SpirV.zig | 1 + 1 file changed, 1 insertion(+) diff --git a/src/link/SpirV.zig b/src/link/SpirV.zig index 0cc238f140..6577b0df51 100644 --- a/src/link/SpirV.zig +++ b/src/link/SpirV.zig @@ -232,6 +232,7 @@ pub fn flushModule(self: *SpirV, arena: Allocator, prog_node: std.Progress.Node) // name if it contains no strange characters is nice for debugging. URI encoding fits the bill. // We're using : as separator, which is a reserved character. + try error_info.append(':'); try std.Uri.Component.percentEncode( error_info.writer(), name.toSlice(&mod.intern_pool), From 44443b833b9c86cb5a3c50b157f13ab4097226d8 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 13 Apr 2024 15:30:03 +0200 Subject: [PATCH 2/7] build: inherit setExecCmd from test compile steps when creating run steps This should fix #17756 --- lib/std/Build.zig | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/lib/std/Build.zig b/lib/std/Build.zig index 698aa3af34..086c0c1786 100644 --- a/lib/std/Build.zig +++ b/lib/std/Build.zig @@ -972,10 +972,24 @@ pub fn addRunArtifact(b: *Build, exe: *Step.Compile) *Step.Run { // Consider that this is declarative; the run step may not be run unless a user // option is supplied. const run_step = Step.Run.create(b, b.fmt("run {s}", .{exe.name})); - run_step.addArtifactArg(exe); + if (exe.kind == .@"test") { + if (exe.exec_cmd_args) |exec_cmd_args| { + for (exec_cmd_args) |cmd_arg| { + if (cmd_arg) |arg| { + run_step.addArg(arg); + } else { + run_step.addArtifactArg(exe); + } + } + } else { + run_step.addArtifactArg(exe); + } - if (exe.kind == .@"test" and exe.test_server_mode) { - run_step.enableTestRunnerMode(); + if (exe.test_server_mode) { + run_step.enableTestRunnerMode(); + } + } else { + run_step.addArtifactArg(exe); } return run_step; From b9d738a5cff0ab896c25f1c8abe15757bcd6a0ba Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Fri, 31 May 2024 00:09:28 +0200 Subject: [PATCH 3/7] spirv: disable tests that fail on pocl Besides the Intel OpenCL CPU runtime, we can now run the behavior tests using the Portable Computing Language. This implementation is open-source, so it will be easier for us to patch in updated versions of spirv-llvm-translator that have bug fixes etc. --- test/behavior/abs.zig | 1 + test/behavior/array.zig | 1 + test/behavior/byval_arg_var.zig | 1 + test/behavior/cast.zig | 1 + test/behavior/enum.zig | 1 + test/behavior/error.zig | 3 +++ test/behavior/hasdecl.zig | 2 ++ test/behavior/math.zig | 1 + test/behavior/optional.zig | 2 ++ test/behavior/packed-struct.zig | 2 ++ test/behavior/packed-union.zig | 2 ++ test/behavior/slice.zig | 3 +++ test/behavior/string_literals.zig | 2 ++ test/behavior/typename.zig | 5 +++++ test/behavior/union.zig | 1 + test/behavior/vector.zig | 1 + 16 files changed, 29 insertions(+) diff --git a/test/behavior/abs.zig b/test/behavior/abs.zig index 21f02b2a3d..8ca160faff 100644 --- a/test/behavior/abs.zig +++ b/test/behavior/abs.zig @@ -152,6 +152,7 @@ test "@abs int vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsIntVectors(1); try testAbsIntVectors(1); diff --git a/test/behavior/array.zig b/test/behavior/array.zig index d524023c9b..49a03c05e2 100644 --- a/test/behavior/array.zig +++ b/test/behavior/array.zig @@ -768,6 +768,7 @@ test "slicing array of zero-sized values" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var arr: [32]u0 = undefined; for (arr[0..]) |*zero| diff --git a/test/behavior/byval_arg_var.zig b/test/behavior/byval_arg_var.zig index ed0fde991f..6b48769500 100644 --- a/test/behavior/byval_arg_var.zig +++ b/test/behavior/byval_arg_var.zig @@ -6,6 +6,7 @@ var result: []const u8 = "wrong"; test "pass string literal byvalue to a generic var param" { if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; start(); blowUpStack(10); diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig index 46cf272e57..a3ffb7cb3a 100644 --- a/test/behavior/cast.zig +++ b/test/behavior/cast.zig @@ -1378,6 +1378,7 @@ test "assignment to optional pointer result loc" { test "cast between *[N]void and []void" { if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var a: [4]void = undefined; const b: []void = &a; diff --git a/test/behavior/enum.zig b/test/behavior/enum.zig index dd2d83a289..0742c2d91c 100644 --- a/test/behavior/enum.zig +++ b/test/behavior/enum.zig @@ -1286,6 +1286,7 @@ test "matching captures causes enum equivalence" { test "large enum field values" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; { const E = enum(u64) { min = std.math.minInt(u64), max = std.math.maxInt(u64) }; diff --git a/test/behavior/error.zig b/test/behavior/error.zig index b579f1478e..8db9703f51 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -997,6 +997,7 @@ test "try used in recursive function with inferred error set" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Value = union(enum) { values: []const @This(), @@ -1103,6 +1104,7 @@ test "result location initialization of error union with OPV payload" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { x: u0, @@ -1125,6 +1127,7 @@ test "result location initialization of error union with OPV payload" { test "return error union with i65" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expect(try add(1000, 234) == 1234); } diff --git a/test/behavior/hasdecl.zig b/test/behavior/hasdecl.zig index 71f9200b27..8f371defdb 100644 --- a/test/behavior/hasdecl.zig +++ b/test/behavior/hasdecl.zig @@ -13,6 +13,7 @@ const Bar = struct { test "@hasDecl" { if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expect(@hasDecl(Foo, "public_thing")); try expect(!@hasDecl(Foo, "private_thing")); @@ -25,6 +26,7 @@ test "@hasDecl" { test "@hasDecl using a sliced string literal" { if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expect(@hasDecl(@This(), "std") == true); try expect(@hasDecl(@This(), "std"[0..0]) == false); diff --git a/test/behavior/math.zig b/test/behavior/math.zig index eaef26b804..59937515ab 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -622,6 +622,7 @@ test "negation wrapping" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expectEqual(@as(u1, 1), negateWrap(u1, 1)); } diff --git a/test/behavior/optional.zig b/test/behavior/optional.zig index 02c329a7d5..7884fec6cd 100644 --- a/test/behavior/optional.zig +++ b/test/behavior/optional.zig @@ -61,6 +61,7 @@ test "optional with zero-bit type" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest(comptime ZeroBit: type, comptime zero_bit: ZeroBit) !void { @@ -641,6 +642,7 @@ test "result location initialization of optional with OPV payload" { if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { x: u0, diff --git a/test/behavior/packed-struct.zig b/test/behavior/packed-struct.zig index 51a302c945..89289d6063 100644 --- a/test/behavior/packed-struct.zig +++ b/test/behavior/packed-struct.zig @@ -1306,6 +1306,8 @@ test "2-byte packed struct argument in C calling convention" { } test "packed struct contains optional pointer" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + const foo: packed struct { a: ?*@This() = null, } = .{}; diff --git a/test/behavior/packed-union.zig b/test/behavior/packed-union.zig index d76f28ae59..b0b0bd7f39 100644 --- a/test/behavior/packed-union.zig +++ b/test/behavior/packed-union.zig @@ -177,6 +177,8 @@ test "assigning to non-active field at comptime" { } test "comptime packed union of pointers" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + const U = packed union { a: *const u32, b: *const [1]u32, diff --git a/test/behavior/slice.zig b/test/behavior/slice.zig index 437d248127..a1f38b1dfe 100644 --- a/test/behavior/slice.zig +++ b/test/behavior/slice.zig @@ -408,6 +408,7 @@ test "slice syntax resulting in pointer-to-array" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { @@ -863,6 +864,7 @@ test "global slice field access" { test "slice of void" { if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var n: usize = 10; _ = &n; @@ -988,6 +990,7 @@ test "get address of element of zero-sized slice" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn destroy(_: *void) void {} diff --git a/test/behavior/string_literals.zig b/test/behavior/string_literals.zig index b1bb508503..898de2167c 100644 --- a/test/behavior/string_literals.zig +++ b/test/behavior/string_literals.zig @@ -35,6 +35,7 @@ test "@typeName() returns a string literal" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try std.testing.expect(*const [type_name.len:0]u8 == @TypeOf(type_name)); try std.testing.expect(std.mem.eql(u8, "behavior.string_literals.TestType", type_name)); @@ -49,6 +50,7 @@ test "@embedFile() returns a string literal" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try std.testing.expect(*const [expected_contents.len:0]u8 == @TypeOf(actual_contents)); try std.testing.expect(std.mem.eql(u8, expected_contents, actual_contents)); diff --git a/test/behavior/typename.zig b/test/behavior/typename.zig index e5ebbb6f47..b08de5484e 100644 --- a/test/behavior/typename.zig +++ b/test/behavior/typename.zig @@ -43,6 +43,7 @@ test "anon field init" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Foo = .{ .T1 = struct {}, @@ -91,6 +92,7 @@ test "top level decl" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expectEqualStrings( "behavior.typename.A_Struct", @@ -141,6 +143,7 @@ test "fn param" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/675 try expectEqualStrings( @@ -221,6 +224,7 @@ test "local variable" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Foo = struct { a: u32 }; const Bar = union { a: u32 }; @@ -250,6 +254,7 @@ test "anon name strategy used in sub expression" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn getTheName() []const u8 { diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 62997d097a..004774bd17 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -920,6 +920,7 @@ test "union no tag with struct member" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const Struct = struct {}; const Union = union { diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index 688b36a911..8987e0c091 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -268,6 +268,7 @@ test "tuple to vector" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .aarch64) { // Regressed with LLVM 14: From 4bd9d9b7e0d769dd8d7701b73e54ae249ac7f1da Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 2 Jun 2024 15:57:18 +0200 Subject: [PATCH 4/7] spirv: change direct vector child repr to direct Previously the child type of a vector was always in indirect representation. Concretely, this meant that vectors of bools are represented by vectors of u8. This was undesirable because it introduced a difference between vectorizable operations with a scalar bool and a vector of bool. This commit changes the representation to be the same for vectors and scalars everywhere. Some issues arised with constructing vectors: it seems the previous temporary- and-pointer approach does not work properly with vectors of bool. To work around this, simply use OpCompositeConstruct. This is the proper instruction for this, but it was previously not used because of a now-solved limitation in the SPIRV-LLVM-Translator. It was not yet applied to Zig because the Intel OpenCL CPU runtime does not have a recent enough version of the translator yet, but to solve that we just switch to testing with POCL instead. --- src/codegen/spirv.zig | 272 ++++++++++++++++++++++++++++++++---------- 1 file changed, 208 insertions(+), 64 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index ed04ee475b..0574b7ee9e 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -22,6 +22,7 @@ const IdResultType = spec.IdResultType; const StorageClass = spec.StorageClass; const SpvModule = @import("spirv/Module.zig"); +const IdRange = SpvModule.IdRange; const SpvSection = @import("spirv/Section.zig"); const SpvAssembler = @import("spirv/Assembler.zig"); @@ -32,7 +33,7 @@ pub const zig_call_abi_ver = 3; const InternMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, DeclGen.Repr }, IdResult); const PtrTypeMap = std.AutoHashMapUnmanaged( - struct { InternPool.Index, StorageClass }, + struct { InternPool.Index, StorageClass, DeclGen.Repr }, struct { ty_id: IdRef, fwd_emitted: bool }, ); @@ -626,7 +627,7 @@ const DeclGen = struct { } /// Checks whether the type can be directly translated to SPIR-V vectors - fn isVector(self: *DeclGen, ty: Type) bool { + fn isSpvVector(self: *DeclGen, ty: Type) bool { const mod = self.module; const target = self.getTarget(); if (ty.zigTypeTag(mod) != .Vector) return false; @@ -798,26 +799,39 @@ const DeclGen = struct { /// Construct a vector at runtime. /// ty must be an vector type. - /// Constituents should be in `indirect` representation (as the elements of an vector should be). - /// Result is in `direct` representation. fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef { - // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which' - // operands are not constant. - // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 - // For now, just initialize the struct by setting the fields manually... - // TODO: Make this OpCompositeConstruct when we can const mod = self.module; - const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function }); - const ptr_elem_ty_id = try self.ptrType(ty.elemType2(mod), .Function); - for (constituents, 0..) |constitent_id, index| { - const ptr_id = try self.accessChain(ptr_elem_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))}); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = constitent_id, - }); - } + assert(ty.vectorLen(mod) == constituents.len); - return try self.load(ty, ptr_composite_id, .{}); + // Note: older versions of the Khronos SPRIV-LLVM translator crash on this instruction + // because it cannot construct structs which' operands are not constant. + // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 + // Currently this is the case for Intel OpenCL CPU runtime (2023-WW46), but the + // alternatives dont work properly: + // - using temporaries/pointers doesn't work properly with vectors of bool, causes + // backends that use llvm to crash + // - using OpVectorInsertDynamic doesn't work for non-spirv-vectors of bool. + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .constituents = constituents, + }); + return result_id; + } + + /// Construct a vector at runtime with all lanes set to the same value. + /// ty must be an vector type. + fn constructVectorSplat(self: *DeclGen, ty: Type, constituent: IdRef) !IdRef { + const mod = self.module; + const n = ty.vectorLen(mod); + + const constituents = try self.gpa.alloc(IdRef, n); + defer self.gpa.free(constituents); + @memset(constituents, constituent); + + return try self.constructVector(ty, constituents); } /// Construct an array at runtime. @@ -1031,21 +1045,27 @@ const DeclGen = struct { const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod))); defer self.gpa.free(constituents); + const child_repr: Repr = switch (tag) { + .array_type => .indirect, + .vector_type => .direct, + else => unreachable, + }; + switch (aggregate.storage) { .bytes => |bytes| { // TODO: This is really space inefficient, perhaps there is a better // way to do it? for (constituents, bytes.toSlice(constituents.len, ip)) |*constituent, byte| { - constituent.* = try self.constInt(elem_ty, byte, .indirect); + constituent.* = try self.constInt(elem_ty, byte, child_repr); } }, .elems => |elems| { for (constituents, elems) |*constituent, elem| { - constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), .indirect); + constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), child_repr); } }, .repeated_elem => |elem| { - @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), .indirect)); + @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), child_repr)); }, } @@ -1334,7 +1354,11 @@ const DeclGen = struct { } fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !IdRef { - const key = .{ child_ty.toIntern(), storage_class }; + return try self.ptrType2(child_ty, storage_class, .indirect); + } + + fn ptrType2(self: *DeclGen, child_ty: Type, storage_class: StorageClass, child_repr: Repr) !IdRef { + const key = .{ child_ty.toIntern(), storage_class, child_repr }; const entry = try self.ptr_types.getOrPut(self.gpa, key); if (entry.found_existing) { const fwd_id = entry.value_ptr.ty_id; @@ -1354,7 +1378,7 @@ const DeclGen = struct { .fwd_emitted = false, }; - const child_ty_id = try self.resolveType(child_ty, .indirect); + const child_ty_id = try self.resolveType(child_ty, child_repr); try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{ .id_result = result_id, @@ -1645,11 +1669,10 @@ const DeclGen = struct { }, .Vector => { const elem_ty = ty.childType(mod); - // TODO: Make `.direct`. - const elem_ty_id = try self.resolveType(elem_ty, .indirect); + const elem_ty_id = try self.resolveType(elem_ty, repr); const len = ty.vectorLen(mod); - if (self.isVector(ty)) { + if (self.isSpvVector(ty)) { return try self.spv.vectorType(len, elem_ty_id); } else { return try self.arrayType(len, elem_ty_id); @@ -1948,7 +1971,7 @@ const DeclGen = struct { const mod = wip.dg.module; if (wip.is_array) { assert(ty.isVector(mod)); - return try wip.dg.extractField(ty.childType(mod), value, @intCast(index)); + return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index)); } else { assert(index == 0); return value; @@ -1961,11 +1984,7 @@ const DeclGen = struct { /// Results is in `direct` representation. fn finalize(wip: *WipElementWise) !IdRef { if (wip.is_array) { - // Convert all the constituents to indirect, as required for the array. - for (wip.results) |*result| { - result.* = try wip.dg.convertToIndirect(wip.ty, result.*); - } - return try wip.dg.constructArray(wip.result_ty, wip.results); + return try wip.dg.constructVector(wip.result_ty, wip.results); } else { return wip.results[0]; } @@ -1982,7 +2001,7 @@ const DeclGen = struct { /// Create a new element-wise operation. fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise { const mod = self.module; - const is_array = result_ty.isVector(mod) and (!self.isVector(result_ty) or force_element_wise); + const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise); const num_results = if (is_array) result_ty.vectorLen(mod) else 1; const results = try self.gpa.alloc(IdRef, num_results); @memset(results, undefined); @@ -2253,29 +2272,102 @@ const DeclGen = struct { /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct). fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - return switch (ty.zigTypeTag(mod)) { - .Bool => blk: { - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = try self.resolveType(Type.bool, .direct), - .id_result = result_id, - .operand_1 = operand_id, - .operand_2 = try self.constBool(false, .indirect), - }); - break :blk result_id; + const scalar_ty = ty.scalarType(mod); + const is_spv_vector = self.isSpvVector(ty); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool => { + // TODO: We may want to use something like elementWise in this function. + // First we need to audit whether this would recursively call into itself. + if (!ty.isVector(mod) or is_spv_vector) { + const result_id = self.spv.allocId(); + const scalar_false_id = try self.constBool(false, .indirect); + const false_id = if (is_spv_vector) blk: { + const index = try mod.intern_pool.get(mod.gpa, .{ + .vector_type = .{ + .len = ty.vectorLen(mod), + .child = Type.u1.toIntern(), + }, + }); + const vec_ty = Type.fromInterned(index); + break :blk try self.constructVectorSplat(vec_ty, scalar_false_id); + } else scalar_false_id; + + try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .operand_1 = operand_id, + .operand_2 = false_id, + }); + return result_id; + } + + const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); + for (constituents, 0..) |*id, i| { + const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); + id.* = try self.convertToDirect(scalar_ty, element); + } + return try self.constructVector(ty, constituents); }, - else => operand_id, - }; + else => return operand_id, + } } /// Convert representation from direct (in 'register) to direct (in memory) /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect). fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - return switch (ty.zigTypeTag(mod)) { - .Bool => try self.intFromBool(Type.u1, operand_id), - else => operand_id, - }; + const scalar_ty = ty.scalarType(mod); + const is_spv_vector = self.isSpvVector(ty); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool => { + const result_ty = if (is_spv_vector) blk: { + const index = try mod.intern_pool.get(mod.gpa, .{ + .vector_type = .{ + .len = ty.vectorLen(mod), + .child = Type.u1.toIntern(), + }, + }); + break :blk Type.fromInterned(index); + } else Type.u1; + + if (!ty.isVector(mod) or is_spv_vector) { + // TODO: We may want to use something like elementWise in this function. + // First we need to audit whether this would recursively call into itself. + // Also unify it with intFromBool + + const scalar_zero_id = try self.constInt(Type.u1, 0, .direct); + const scalar_one_id = try self.constInt(Type.u1, 1, .direct); + + const zero_id = if (is_spv_vector) + try self.constructVectorSplat(result_ty, scalar_zero_id) + else + scalar_zero_id; + + const one_id = if (is_spv_vector) + try self.constructVectorSplat(result_ty, scalar_one_id) + else + scalar_one_id; + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = try self.resolveType(result_ty, .direct), + .id_result = result_id, + .condition = operand_id, + .object_1 = one_id, + .object_2 = zero_id, + }); + return result_id; + } + + const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); + for (constituents, 0..) |*id, i| { + const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); + id.* = try self.convertToIndirect(scalar_ty, element); + } + return try self.constructVector(result_ty, constituents); + }, + else => return operand_id, + } } fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef { @@ -2292,6 +2384,21 @@ const DeclGen = struct { return try self.convertToDirect(result_ty, result_id); } + fn extractVectorComponent(self: *DeclGen, result_ty: Type, vector_id: IdRef, field: u32) !IdRef { + // Whether this is an OpTypeVector or OpTypeArray, we need to emit the same instruction regardless. + const result_ty_id = try self.resolveType(result_ty, .direct); + const result_id = self.spv.allocId(); + const indexes = [_]u32{field}; + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = result_ty_id, + .id_result = result_id, + .composite = vector_id, + .indexes = &indexes, + }); + // Vector components are already stored in direct representation. + return result_id; + } + const MemoryOptions = struct { is_volatile: bool = false, }; @@ -2926,7 +3033,7 @@ const DeclGen = struct { const ov_ty = result_ty.structFieldType(1, self.module); const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isVector(operand_ty)) + const cmp_ty_id = if (self.isSpvVector(operand_ty)) // TODO: Resolving a vector type with .direct should return a SPIR-V vector try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) else @@ -3100,7 +3207,7 @@ const DeclGen = struct { const ov_ty = result_ty.structFieldType(1, self.module); const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isVector(operand_ty)) + const cmp_ty_id = if (self.isSpvVector(operand_ty)) // TODO: Resolving a vector type with .direct should return a SPIR-V vector try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) else @@ -3312,7 +3419,7 @@ const DeclGen = struct { const info = self.arithmeticTypeInfo(operand_ty); - var result_id = try self.extractField(scalar_ty, operand, 0); + var result_id = try self.extractVectorComponent(scalar_ty, operand, 0); const len = operand_ty.vectorLen(mod); switch (reduce.operation) { @@ -3320,7 +3427,7 @@ const DeclGen = struct { const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; for (1..len) |i| { const lhs = result_id; - const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); } @@ -3354,7 +3461,7 @@ const DeclGen = struct { for (1..len) |i| { const lhs = result_id; - const rhs = try self.extractField(scalar_ty, operand, @intCast(i)); + const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); result_id = self.spv.allocId(); try self.func.body.emitRaw(self.spv.gpa, opcode, 4); @@ -3388,9 +3495,9 @@ const DeclGen = struct { const index = elem.toSignedInt(mod); if (index >= 0) { - result_id.* = try self.extractField(wip.ty, a, @intCast(index)); + result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index)); } else { - result_id.* = try self.extractField(wip.ty, b, @intCast(~index)); + result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index)); } } return try wip.finalize(); @@ -4086,8 +4193,7 @@ const DeclGen = struct { defer self.gpa.free(elem_ids); for (elements, 0..) |element, i| { - const id = try self.resolve(element); - elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id); + elem_ids[i] = try self.resolve(element); } return try self.constructVector(result_ty, elem_ids); @@ -4234,16 +4340,54 @@ const DeclGen = struct { const array_id = try self.resolve(bin_op.lhs); const index_id = try self.resolve(bin_op.rhs); + if (self.isSpvVector(array_ty)) { + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpVectorExtractDynamic, .{ + .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result = result_id, + .vector = array_id, + .index = index_id, + }); + return result_id; + } + // SPIR-V doesn't have an array indexing function for some damn reason. // For now, just generate a temporary and use that. // TODO: This backend probably also should use isByRef from llvm... - const elem_ptr_ty_id = try self.ptrType(elem_ty, .Function); + const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct); + const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct); - const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function }); - try self.store(array_ty, tmp_id, array_id, .{}); - const elem_ptr_id = try self.accessChainId(elem_ptr_ty_id, tmp_id, &.{index_id}); - return try self.load(elem_ty, elem_ptr_id, .{}); + const tmp_id = self.spv.allocId(); + try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ + .id_result_type = ptr_array_ty_id, + .id_result = tmp_id, + .storage_class = .Function, + }); + + try self.func.body.emit(self.spv.gpa, .OpStore, .{ + .pointer = tmp_id, + .object = array_id, + }); + + const elem_ptr_id = try self.accessChainId(ptr_elem_ty_id, tmp_id, &.{index_id}); + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpLoad, .{ + .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result = result_id, + .pointer = elem_ptr_id, + }); + + if (array_ty.isVector(mod)) { + // Result is already in direct representation + return result_id; + } + + // This is an array type; the elements are stored in indirect representation. + // We have to convert the type to direct. + + return try self.convertToDirect(elem_ty, result_id); } fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { From 4e7159ae1d08ce74548e0adc3b3936aacc23a06e Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 2 Jun 2024 16:09:20 +0200 Subject: [PATCH 5/7] spirv: remove OpCompositeConstruct workarounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now that we use POCL to test, we no longer need this ✨ --- src/codegen/spirv.zig | 47 ++++++++++++++----------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 0574b7ee9e..1cf67357f0 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -780,21 +780,14 @@ const DeclGen = struct { /// Result is in `direct` representation. fn constructStruct(self: *DeclGen, ty: Type, types: []const Type, constituents: []const IdRef) !IdRef { assert(types.len == constituents.len); - // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which' - // operands are not constant. - // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 - // For now, just initialize the struct by setting the fields manually... - // TODO: Make this OpCompositeConstruct when we can - const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function }); - for (constituents, types, 0..) |constitent_id, member_ty, index| { - const ptr_member_ty_id = try self.ptrType(member_ty, .Function); - const ptr_id = try self.accessChain(ptr_member_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))}); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = constitent_id, - }); - } - return try self.load(ty, ptr_composite_id, .{}); + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .constituents = constituents, + }); + return result_id; } /// Construct a vector at runtime. @@ -839,23 +832,13 @@ const DeclGen = struct { /// Constituents should be in `indirect` representation (as the elements of an array should be). /// Result is in `direct` representation. fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef { - // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which' - // operands are not constant. - // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349 - // For now, just initialize the struct by setting the fields manually... - // TODO: Make this OpCompositeConstruct when we can - const mod = self.module; - const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function }); - const ptr_elem_ty_id = try self.ptrType(ty.elemType2(mod), .Function); - for (constituents, 0..) |constitent_id, index| { - const ptr_id = try self.accessChain(ptr_elem_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))}); - try self.func.body.emit(self.spv.gpa, .OpStore, .{ - .pointer = ptr_id, - .object = constitent_id, - }); - } - - return try self.load(ty, ptr_composite_id, .{}); + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{ + .id_result_type = try self.resolveType(ty, .direct), + .id_result = result_id, + .constituents = constituents, + }); + return result_id; } /// This function generates a load for a constant in direct (ie, non-memory) representation. From a3b1ba82f57d5d8981a471850cbbb0db29c3a479 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 3 Jun 2024 00:44:08 +0200 Subject: [PATCH 6/7] spirv: new vectorization helper The old vectorization helper (WipElementWise) was clunky and a bit annoying to use, and it wasn't really flexible enough. This introduces a new vectorization helper, which uses Temporary and Operation types to deduce a Vectorization to perform the operation in a reasonably efficient manner. It removes the outer loop required by WipElementWise so that implementations of AIR instructions are cleaner. This helps with sanity when we start to introduce support for composite integers. airShift, convertToDirect, convertToIndirect, and normalize are initially implemented using this new method. --- src/codegen/spirv.zig | 2591 +++++++++++++++++++++------------- src/codegen/spirv/Module.zig | 22 +- test/behavior/abs.zig | 1 - test/behavior/floatop.zig | 34 - test/behavior/math.zig | 59 +- test/behavior/select.zig | 2 - test/behavior/vector.zig | 1 - 7 files changed, 1645 insertions(+), 1065 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 1cf67357f0..215a9421f1 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -3,6 +3,7 @@ const Allocator = std.mem.Allocator; const Target = std.Target; const log = std.log.scoped(.codegen); const assert = std.debug.assert; +const Signedness = std.builtin.Signedness; const Module = @import("../Module.zig"); const Decl = Module.Decl; @@ -423,6 +424,17 @@ const DeclGen = struct { return self.fail("TODO (SPIR-V): " ++ format, args); } + /// This imports the "default" extended instruction set for the target + /// For OpenCL, OpenCL.std.100. For Vulkan, GLSL.std.450. + fn importExtendedSet(self: *DeclGen) !IdResult { + const target = self.getTarget(); + return switch (target.os.tag) { + .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), + .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), + else => unreachable, + }; + } + /// Fetch the result-id for a previously generated instruction or constant. fn resolve(self: *DeclGen, inst: Air.Inst.Ref) !IdRef { const mod = self.module; @@ -631,6 +643,19 @@ const DeclGen = struct { const mod = self.module; const target = self.getTarget(); if (ty.zigTypeTag(mod) != .Vector) return false; + + // TODO: This check must be expanded for types that can be represented + // as integers (enums / packed structs?) and types that are represented + // by multiple SPIR-V values. + const scalar_ty = ty.scalarType(mod); + switch (scalar_ty.zigTypeTag(mod)) { + .Bool, + .Int, + .Float, + => {}, + else => return false, + } + const elem_ty = ty.childType(mod); const len = ty.vectorLen(mod); @@ -723,9 +748,13 @@ const DeclGen = struct { // Use backing bits so that negatives are sign extended const backing_bits = self.backingIntBits(int_info.bits).?; // Assertion failure means big int - const bits: u64 = switch (int_info.signedness) { - // Intcast needed to silence compile errors for when the wrong path is compiled. - // Lazy fix. + const signedness: Signedness = switch (@typeInfo(@TypeOf(value))) { + .Int => |int| int.signedness, + .ComptimeInt => if (value < 0) .signed else .unsigned, + else => unreachable, + }; + + const bits: u64 = switch (signedness) { .signed => @bitCast(@as(i64, @intCast(value))), .unsigned => @as(u64, @intCast(value)), }; @@ -1392,6 +1421,19 @@ const DeclGen = struct { return ty_id; } + fn zigScalarOrVectorTypeLike(self: *DeclGen, new_ty: Type, base_ty: Type) !Type { + const mod = self.module; + const new_scalar_ty = new_ty.scalarType(mod); + if (!base_ty.isVector(mod)) { + return new_scalar_ty; + } + + return try mod.vectorType(.{ + .len = base_ty.vectorLen(mod), + .child = new_scalar_ty.toIntern(), + }); + } + /// Generate a union type. Union types are always generated with the /// most aligned field active. If the tag alignment is greater /// than that of the payload, a regular union (non-packed, with both tag and @@ -1928,77 +1970,897 @@ const DeclGen = struct { return union_layout; } - /// This structure is used as helper for element-wise operations. It is intended - /// to be used with vectors, fake vectors (arrays) and single elements. - const WipElementWise = struct { - dg: *DeclGen, - result_ty: Type, + /// This structure represents a "temporary" value: Something we are currently + /// operating on. It typically lives no longer than the function that + /// implements a particular AIR operation. These are used to easier + /// implement vectorizable operations (see Vectorization and the build* + /// functions), and typically are only used for vectors of primitive types. + const Temporary = struct { + /// The type of the temporary. This is here mainly + /// for easier bookkeeping. Because we will never really + /// store Temporaries, they only cause extra stack space, + /// therefore no real storage is wasted. ty: Type, - /// Always in direct representation. - ty_id: IdRef, - /// True if the input is an array type. - is_array: bool, - /// The element-wise operation should fill these results before calling finalize(). - /// These should all be in **direct** representation! `finalize()` will convert - /// them to indirect if required. - results: []IdRef, + /// The value that this temporary holds. This is not necessarily + /// a value that is actually usable, or a single value: It is virtual + /// until materialize() is called, at which point is turned into + /// the usual SPIR-V representation of `self.ty`. + value: Temporary.Value, - fn deinit(wip: *WipElementWise) void { - wip.dg.gpa.free(wip.results); + const Value = union(enum) { + singleton: IdResult, + exploded_vector: IdRange, + }; + + fn init(ty: Type, singleton: IdResult) Temporary { + return .{ .ty = ty, .value = .{ .singleton = singleton } }; } - /// Utility function to extract the element at a particular index in an - /// input array. This type is expected to be a fake vector (array) if `wip.is_array`, and - /// a vector or scalar otherwise. - fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef { - const mod = wip.dg.module; - if (wip.is_array) { - assert(ty.isVector(mod)); - return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index)); - } else { - assert(index == 0); - return value; + fn materialize(self: Temporary, dg: *DeclGen) !IdResult { + const mod = dg.module; + switch (self.value) { + .singleton => |id| return id, + .exploded_vector => |range| { + assert(self.ty.isVector(mod)); + assert(self.ty.vectorLen(mod) == range.len); + const consituents = try dg.gpa.alloc(IdRef, range.len); + defer dg.gpa.free(consituents); + for (consituents, 0..range.len) |*id, i| { + id.* = range.at(i); + } + return dg.constructVector(self.ty, consituents); + }, } } - /// Turns the results of this WipElementWise into a result. This can be - /// vectors, fake vectors (arrays) and single elements, depending on `result_ty`. - /// After calling this function, this WIP is no longer usable. - /// Results is in `direct` representation. - fn finalize(wip: *WipElementWise) !IdRef { - if (wip.is_array) { - return try wip.dg.constructVector(wip.result_ty, wip.results); - } else { - return wip.results[0]; - } + fn vectorization(self: Temporary, dg: *DeclGen) Vectorization { + return Vectorization.fromType(self.ty, dg); } - /// Allocate a result id at a particular index, and return it. - fn allocId(wip: *WipElementWise, index: usize) IdRef { - assert(wip.is_array or index == 0); - wip.results[index] = wip.dg.spv.allocId(); - return wip.results[index]; + fn pun(self: Temporary, new_ty: Type) Temporary { + return .{ + .ty = new_ty, + .value = self.value, + }; + } + + /// 'Explode' a temporary into separate elements. This turns a vector + /// into a bag of elements. + fn explode(self: Temporary, dg: *DeclGen) !IdRange { + const mod = dg.module; + + // If the value is a scalar, then this is a no-op. + if (!self.ty.isVector(mod)) { + return switch (self.value) { + .singleton => |id| IdRange{ .base = @intFromEnum(id), .len = 1 }, + .exploded_vector => |range| range, + }; + } + + const ty_id = try dg.resolveType(self.ty.scalarType(mod), .direct); + const n = self.ty.vectorLen(mod); + const results = dg.spv.allocIds(n); + + const id = switch (self.value) { + .singleton => |id| id, + .exploded_vector => |range| return range, + }; + + for (0..n) |i| { + const indexes = [_]u32{@intCast(i)}; + try dg.func.body.emit(dg.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = ty_id, + .id_result = results.at(i), + .composite = id, + .indexes = &indexes, + }); + } + + return results; } }; - /// Create a new element-wise operation. - fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise { - const mod = self.module; - const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise); - const num_results = if (is_array) result_ty.vectorLen(mod) else 1; - const results = try self.gpa.alloc(IdRef, num_results); - @memset(results, undefined); - - const ty = if (is_array) result_ty.scalarType(mod) else result_ty; - const ty_id = try self.resolveType(ty, .direct); - + /// Initialize a `Temporary` from an AIR value. + fn temporary(self: *DeclGen, inst: Air.Inst.Ref) !Temporary { return .{ - .dg = self, - .result_ty = result_ty, - .ty = ty, - .ty_id = ty_id, - .is_array = is_array, - .results = results, + .ty = self.typeOf(inst), + .value = .{ .singleton = try self.resolve(inst) }, + }; + } + + /// This union describes how a particular operation should be vectorized. + /// That depends on the operation and number of components of the inputs. + const Vectorization = union(enum) { + /// This is an operation between scalars. + scalar, + /// This is an operation between SPIR-V vectors. + /// Value is number of components. + spv_vectorized: u32, + /// This operation is unrolled into separate operations. + /// Inputs may still be SPIR-V vectors, for example, + /// when the operation can't be vectorized in SPIR-V. + /// Value is number of components. + unrolled: u32, + + /// Derive a vectorization from a particular type. This usually + /// only checks the size, but the source-of-truth is implemented + /// by `isSpvVector()`. + fn fromType(ty: Type, dg: *DeclGen) Vectorization { + const mod = dg.module; + if (!ty.isVector(mod)) { + return .scalar; + } else if (dg.isSpvVector(ty)) { + return .{ .spv_vectorized = ty.vectorLen(mod) }; + } else { + return .{ .unrolled = ty.vectorLen(mod) }; + } + } + + /// Given two vectorization methods, compute a "unification": a fallback + /// that works for both, according to the following rules: + /// - Scalars may broadcast + /// - SPIR-V vectorized operations may unroll + /// - Prefer scalar > SPIR-V vectorized > unrolled + fn unify(a: Vectorization, b: Vectorization) Vectorization { + if (a == .scalar and b == .scalar) { + return .scalar; + } else if (a == .spv_vectorized and b == .spv_vectorized) { + assert(a.components() == b.components()); + return .{ .spv_vectorized = a.components() }; + } else if (a == .unrolled or b == .unrolled) { + if (a == .unrolled and b == .unrolled) { + assert(a.components() == b.components()); + return .{ .unrolled = a.components() }; + } else if (a == .unrolled) { + return .{ .unrolled = a.components() }; + } else if (b == .unrolled) { + return .{ .unrolled = b.components() }; + } else { + unreachable; + } + } else { + if (a == .spv_vectorized) { + return .{ .spv_vectorized = a.components() }; + } else if (b == .spv_vectorized) { + return .{ .spv_vectorized = b.components() }; + } else { + unreachable; + } + } + } + + /// Force this vectorization to be unrolled, if its + /// an operation involving vectors. + fn unroll(self: Vectorization) Vectorization { + return switch (self) { + .scalar, .unrolled => self, + .spv_vectorized => |n| .{ .unrolled = n }, + }; + } + + /// Query the number of components that inputs of this operation have. + /// Note: for broadcasting scalars, this returns the number of elements + /// that the broadcasted vector would have. + fn components(self: Vectorization) u32 { + return switch (self) { + .scalar => 1, + .spv_vectorized => |n| n, + .unrolled => |n| n, + }; + } + + /// Query the number of operations involving this vectorization. + /// This is basically the number of components, except that SPIR-V vectorized + /// operations only need a single SPIR-V instruction. + fn operations(self: Vectorization) u32 { + return switch (self) { + .scalar, .spv_vectorized => 1, + .unrolled => |n| n, + }; + } + + /// Turns `ty` into the result-type of an individual vector operation. + /// `ty` may be a scalar or vector, it doesn't matter. + fn operationType(self: Vectorization, dg: *DeclGen, ty: Type) !Type { + const mod = dg.module; + const scalar_ty = ty.scalarType(mod); + return switch (self) { + .scalar, .unrolled => scalar_ty, + .spv_vectorized => |n| try mod.vectorType(.{ + .len = n, + .child = scalar_ty.toIntern(), + }), + }; + } + + /// Turns `ty` into the result-type of the entire operation. + /// `ty` may be a scalar or vector, it doesn't matter. + fn resultType(self: Vectorization, dg: *DeclGen, ty: Type) !Type { + const mod = dg.module; + const scalar_ty = ty.scalarType(mod); + return switch (self) { + .scalar => scalar_ty, + .unrolled, .spv_vectorized => |n| try mod.vectorType(.{ + .len = n, + .child = scalar_ty.toIntern(), + }), + }; + } + + /// Before a temporary can be used, some setup may need to be one. This function implements + /// this setup, and returns a new type that holds the relevant information on how to access + /// elements of the input. + fn prepare(self: Vectorization, dg: *DeclGen, tmp: Temporary) !PreparedOperand { + const mod = dg.module; + const is_vector = tmp.ty.isVector(mod); + const is_spv_vector = dg.isSpvVector(tmp.ty); + const value: PreparedOperand.Value = switch (tmp.value) { + .singleton => |id| switch (self) { + .scalar => blk: { + assert(!is_vector); + break :blk .{ .scalar = id }; + }, + .spv_vectorized => blk: { + if (is_vector) { + assert(is_spv_vector); + break :blk .{ .spv_vectorwise = id }; + } + + // Broadcast scalar into vector. + const vector_ty = try mod.vectorType(.{ + .len = self.components(), + .child = tmp.ty.toIntern(), + }); + + const vector = try dg.constructVectorSplat(vector_ty, id); + return .{ + .ty = vector_ty, + .value = .{ .spv_vectorwise = vector }, + }; + }, + .unrolled => blk: { + if (is_vector) { + break :blk .{ .vector_exploded = try tmp.explode(dg) }; + } else { + break :blk .{ .scalar_broadcast = id }; + } + }, + }, + .exploded_vector => |range| switch (self) { + .scalar => unreachable, + .spv_vectorized => |n| blk: { + // We can vectorize this operation, but we have an exploded vector. This can happen + // when a vectorizable operation succeeds a non-vectorizable operation. In this case, + // pack up the IDs into a SPIR-V vector. This path should not be able to be hit with + // a type that cannot do that. + assert(is_spv_vector); + assert(range.len == n); + const vec = try tmp.materialize(dg); + break :blk .{ .spv_vectorwise = vec }; + }, + .unrolled => |n| blk: { + assert(range.len == n); + break :blk .{ .vector_exploded = range }; + }, + }, + }; + + return .{ + .ty = tmp.ty, + .value = value, + }; + } + + /// Finalize the results of an operation back into a temporary. `results` is + /// a list of result-ids of the operation. + fn finalize(self: Vectorization, ty: Type, results: IdRange) Temporary { + assert(self.operations() == results.len); + const value: Temporary.Value = switch (self) { + .scalar, .spv_vectorized => blk: { + break :blk .{ .singleton = results.at(0) }; + }, + .unrolled => blk: { + break :blk .{ .exploded_vector = results }; + }, + }; + + return .{ .ty = ty, .value = value }; + } + + /// This struct represents an operand that has gone through some setup, and is + /// ready to be used as part of an operation. + const PreparedOperand = struct { + ty: Type, + value: PreparedOperand.Value, + + /// The types of value that a prepared operand can hold internally. Depends + /// on the operation and input value. + const Value = union(enum) { + /// A single scalar value that is used by a scalar operation. + scalar: IdResult, + /// A single scalar that is broadcasted in an unrolled operation. + scalar_broadcast: IdResult, + /// A SPIR-V vector that is used in SPIR-V vectorize operation. + spv_vectorwise: IdResult, + /// A vector represented by a consecutive list of IDs that is used in an unrolled operation. + vector_exploded: IdRange, + }; + + /// Query the value at a particular index of the operation. Note that + /// the index is *not* the component/lane, but the index of the *operation*. When + /// this operation is vectorized, the return value of this function is a SPIR-V vector. + /// See also `Vectorization.operations()`. + fn at(self: PreparedOperand, i: usize) IdResult { + switch (self.value) { + .scalar => |id| { + assert(i == 0); + return id; + }, + .scalar_broadcast => |id| { + return id; + }, + .spv_vectorwise => |id| { + assert(i == 0); + return id; + }, + .vector_exploded => |range| { + return range.at(i); + }, + } + } + }; + }; + + /// A utility function to compute the vectorization style of + /// a list of values. These values may be any of the following: + /// - A `Vectorization` instance + /// - A Type, in which case the vectorization is computed via `Vectorization.fromType`. + /// - A Temporary, in which case the vectorization is computed via `Temporary.vectorization`. + fn vectorization(self: *DeclGen, args: anytype) Vectorization { + var v: Vectorization = undefined; + assert(args.len >= 1); + inline for (args, 0..) |arg, i| { + const iv: Vectorization = switch (@TypeOf(arg)) { + Vectorization => arg, + Type => Vectorization.fromType(arg, self), + Temporary => arg.vectorization(self), + else => @compileError("invalid type"), + }; + if (i == 0) { + v = iv; + } else { + v = v.unify(iv); + } + } + return v; + } + + /// This function builds an OpSConvert of OpUConvert depending on the + /// signedness of the types. + fn buildIntConvert(self: *DeclGen, dst_ty: Type, src: Temporary) !Temporary { + const mod = self.module; + + const dst_ty_id = try self.resolveType(dst_ty.scalarType(mod), .direct); + const src_ty_id = try self.resolveType(src.ty.scalarType(mod), .direct); + + const v = self.vectorization(.{ dst_ty, src }); + const result_ty = try v.resultType(self, dst_ty); + + // We can directly compare integers, because those type-IDs are cached. + if (dst_ty_id == src_ty_id) { + // Nothing to do, type-pun to the right value. + // Note, Caller guarantees that the types fit (or caller will normalize after), + // so we don't have to normalize here. + // Note, dst_ty may be a scalar type even if we expect a vector, so we have to + // convert to the right type here. + return src.pun(result_ty); + } + + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, dst_ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + + const opcode: Opcode = if (dst_ty.isSignedInt(mod)) .OpSConvert else .OpUConvert; + + const op_src = try v.prepare(self, src); + + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 3); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_src.at(i)); + } + + return v.finalize(result_ty, results); + } + + fn buildFma(self: *DeclGen, a: Temporary, b: Temporary, c: Temporary) !Temporary { + const target = self.getTarget(); + + const v = self.vectorization(.{ a, b, c }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, a.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, a.ty); + + const op_a = try v.prepare(self, a); + const op_b = try v.prepare(self, b); + const op_c = try v.prepare(self, c); + + const set = try self.importExtendedSet(); + + // TODO: Put these numbers in some definition + const instruction: u32 = switch (target.os.tag) { + .opencl => 26, // fma + // NOTE: Vulkan's FMA instruction does *NOT* produce the right values! + // its precision guarantees do NOT match zigs and it does NOT match OpenCLs! + // it needs to be emulated! + .vulkan => unreachable, // TODO: See above + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = instruction }, + .id_ref_4 = &.{ op_a.at(i), op_b.at(i), op_c.at(i) }, + }); + } + + return v.finalize(result_ty, results); + } + + fn buildSelect(self: *DeclGen, condition: Temporary, lhs: Temporary, rhs: Temporary) !Temporary { + const mod = self.module; + + const v = self.vectorization(.{ condition, lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, lhs.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, lhs.ty); + + assert(condition.ty.scalarType(mod).zigTypeTag(mod) == .Bool); + + const cond = try v.prepare(self, condition); + const object_1 = try v.prepare(self, lhs); + const object_2 = try v.prepare(self, rhs); + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .condition = cond.at(i), + .object_1 = object_1.at(i), + .object_2 = object_2.at(i), + }); + } + + return v.finalize(result_ty, results); + } + + const CmpPredicate = enum { + l_eq, + l_ne, + i_ne, + i_eq, + s_lt, + s_gt, + s_le, + s_ge, + u_lt, + u_gt, + u_le, + u_ge, + f_oeq, + f_une, + f_olt, + f_ole, + f_ogt, + f_oge, + }; + + fn buildCmp(self: *DeclGen, pred: CmpPredicate, lhs: Temporary, rhs: Temporary) !Temporary { + const v = self.vectorization(.{ lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, Type.bool); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, Type.bool); + + const op_lhs = try v.prepare(self, lhs); + const op_rhs = try v.prepare(self, rhs); + + const opcode: Opcode = switch (pred) { + .l_eq => .OpLogicalEqual, + .l_ne => .OpLogicalNotEqual, + .i_eq => .OpIEqual, + .i_ne => .OpINotEqual, + .s_lt => .OpSLessThan, + .s_gt => .OpSGreaterThan, + .s_le => .OpSLessThanEqual, + .s_ge => .OpSGreaterThanEqual, + .u_lt => .OpULessThan, + .u_gt => .OpUGreaterThan, + .u_le => .OpULessThanEqual, + .u_ge => .OpUGreaterThanEqual, + .f_oeq => .OpFOrdEqual, + .f_une => .OpFUnordNotEqual, + .f_olt => .OpFOrdLessThan, + .f_ole => .OpFOrdLessThanEqual, + .f_ogt => .OpFOrdGreaterThan, + .f_oge => .OpFOrdGreaterThanEqual, + }; + + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_lhs.at(i)); + self.func.body.writeOperand(IdResult, op_rhs.at(i)); + } + + return v.finalize(result_ty, results); + } + + const UnaryOp = enum { + l_not, + bit_not, + i_neg, + f_neg, + i_abs, + f_abs, + clz, + ctz, + floor, + ceil, + trunc, + round, + sqrt, + sin, + cos, + tan, + exp, + exp2, + log, + log2, + log10, + }; + + fn buildUnary(self: *DeclGen, op: UnaryOp, operand: Temporary) !Temporary { + const target = self.getTarget(); + const v = blk: { + const v = self.vectorization(.{operand}); + break :blk switch (op) { + // TODO: These instructions don't seem to be working + // properly for LLVM-based backends on OpenCL for 8- and + // 16-component vectors. + .i_abs => if (target.os.tag == .opencl and v.components() >= 8) v.unroll() else v, + else => v, + }; + }; + + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, operand.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, operand.ty); + + const op_operand = try v.prepare(self, operand); + + if (switch (op) { + .l_not => .OpLogicalNot, + .bit_not => .OpNot, + .i_neg => .OpSNegate, + .f_neg => .OpFNegate, + else => @as(?Opcode, null), + }) |opcode| { + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 3); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_operand.at(i)); + } + } else { + const set = try self.importExtendedSet(); + const extinst: u32 = switch (target.os.tag) { + .opencl => switch (op) { + .i_abs => 141, // s_abs + .f_abs => 23, // fabs + .clz => 151, // clz + .ctz => 152, // ctz + .floor => 25, // floor + .ceil => 12, // ceil + .trunc => 66, // trunc + .round => 55, // round + .sqrt => 61, // sqrt + .sin => 57, // sin + .cos => 14, // cos + .tan => 62, // tan + .exp => 19, // exp + .exp2 => 20, // exp2 + .log => 37, // log + .log2 => 38, // log2 + .log10 => 39, // log10 + else => unreachable, + }, + // Note: We'll need to check these for floating point accuracy + // Vulkan does not put tight requirements on these, for correction + // we might want to emulate them at some point. + .vulkan => switch (op) { + .i_abs => 5, // SAbs + .f_abs => 4, // FAbs + .clz => unreachable, // TODO + .ctz => unreachable, // TODO + .floor => 8, // Floor + .ceil => 9, // Ceil + .trunc => 3, // Trunc + .round => 1, // Round + .sqrt, + .sin, + .cos, + .tan, + .exp, + .exp2, + .log, + .log2, + .log10, + => unreachable, // TODO + else => unreachable, + }, + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = extinst }, + .id_ref_4 = &.{op_operand.at(i)}, + }); + } + } + + return v.finalize(result_ty, results); + } + + const BinaryOp = enum { + i_add, + f_add, + i_sub, + f_sub, + i_mul, + f_mul, + s_div, + u_div, + f_div, + s_rem, + f_rem, + s_mod, + u_mod, + f_mod, + srl, + sra, + sll, + bit_and, + bit_or, + bit_xor, + f_max, + s_max, + u_max, + f_min, + s_min, + u_min, + l_and, + l_or, + }; + + fn buildBinary(self: *DeclGen, op: BinaryOp, lhs: Temporary, rhs: Temporary) !Temporary { + const target = self.getTarget(); + + const v = self.vectorization(.{ lhs, rhs }); + const ops = v.operations(); + const results = self.spv.allocIds(ops); + + const op_result_ty = try v.operationType(self, lhs.ty); + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + const result_ty = try v.resultType(self, lhs.ty); + + const op_lhs = try v.prepare(self, lhs); + const op_rhs = try v.prepare(self, rhs); + + if (switch (op) { + .i_add => .OpIAdd, + .f_add => .OpFAdd, + .i_sub => .OpISub, + .f_sub => .OpFSub, + .i_mul => .OpIMul, + .f_mul => .OpFMul, + .s_div => .OpSDiv, + .u_div => .OpUDiv, + .f_div => .OpFDiv, + .s_rem => .OpSRem, + .f_rem => .OpFRem, + .s_mod => .OpSMod, + .u_mod => .OpUMod, + .f_mod => .OpFMod, + .srl => .OpShiftRightLogical, + .sra => .OpShiftRightArithmetic, + .sll => .OpShiftLeftLogical, + .bit_and => .OpBitwiseAnd, + .bit_or => .OpBitwiseOr, + .bit_xor => .OpBitwiseXor, + .l_and => .OpLogicalAnd, + .l_or => .OpLogicalOr, + else => @as(?Opcode, null), + }) |opcode| { + for (0..ops) |i| { + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, results.at(i)); + self.func.body.writeOperand(IdResult, op_lhs.at(i)); + self.func.body.writeOperand(IdResult, op_rhs.at(i)); + } + } else { + const set = try self.importExtendedSet(); + + // TODO: Put these numbers in some definition + const extinst: u32 = switch (target.os.tag) { + .opencl => switch (op) { + .f_max => 27, // fmax + .s_max => 156, // s_max + .u_max => 157, // u_max + .f_min => 28, // fmin + .s_min => 158, // s_min + .u_min => 159, // u_min + else => unreachable, + }, + .vulkan => switch (op) { + .f_max => 40, // FMax + .s_max => 42, // SMax + .u_max => 41, // UMax + .f_min => 37, // FMin + .s_min => 39, // SMin + .u_min => 38, // UMin + else => unreachable, + }, + else => unreachable, + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = op_result_ty_id, + .id_result = results.at(i), + .set = set, + .instruction = .{ .inst = extinst }, + .id_ref_4 = &.{ op_lhs.at(i), op_rhs.at(i) }, + }); + } + } + + return v.finalize(result_ty, results); + } + + /// This function builds an extended multiplication, either OpSMulExtended or OpUMulExtended on Vulkan, + /// or OpIMul and s_mul_hi or u_mul_hi on OpenCL. + fn buildWideMul( + self: *DeclGen, + op: enum { + s_mul_extended, + u_mul_extended, + }, + lhs: Temporary, + rhs: Temporary, + ) !struct { Temporary, Temporary } { + const mod = self.module; + const target = self.getTarget(); + const ip = &mod.intern_pool; + + const v = lhs.vectorization(self).unify(rhs.vectorization(self)); + const ops = v.operations(); + + const arith_op_ty = try v.operationType(self, lhs.ty); + const arith_op_ty_id = try self.resolveType(arith_op_ty, .direct); + + const lhs_op = try v.prepare(self, lhs); + const rhs_op = try v.prepare(self, rhs); + + const value_results = self.spv.allocIds(ops); + const overflow_results = self.spv.allocIds(ops); + + switch (target.os.tag) { + .opencl => { + // Currently, SPIRV-LLVM-Translator based backends cannot deal with OpSMulExtended and + // OpUMulExtended. For these we will use the OpenCL s_mul_hi to compute the high-order bits + // instead. + const set = try self.importExtendedSet(); + const overflow_inst: u32 = switch (op) { + .s_mul_extended => 160, // s_mul_hi + .u_mul_extended => 203, // u_mul_hi + }; + + for (0..ops) |i| { + try self.func.body.emit(self.spv.gpa, .OpIMul, .{ + .id_result_type = arith_op_ty_id, + .id_result = value_results.at(i), + .operand_1 = lhs_op.at(i), + .operand_2 = rhs_op.at(i), + }); + + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = arith_op_ty_id, + .id_result = overflow_results.at(i), + .set = set, + .instruction = .{ .inst = overflow_inst }, + .id_ref_4 = &.{ lhs_op.at(i), rhs_op.at(i) }, + }); + } + }, + .vulkan => { + const op_result_ty = blk: { + // Operations return a struct{T, T} + // where T is maybe vectorized. + const types = [2]InternPool.Index{ arith_op_ty.toIntern(), arith_op_ty.toIntern() }; + const values = [2]InternPool.Index{ .none, .none }; + const index = try ip.getAnonStructType(mod.gpa, .{ + .types = &types, + .values = &values, + .names = &.{}, + }); + break :blk Type.fromInterned(index); + }; + const op_result_ty_id = try self.resolveType(op_result_ty, .direct); + + const opcode: Opcode = switch (op) { + .s_mul_extended => .OpSMulExtended, + .u_mul_extended => .OpUMulExtended, + }; + + for (0..ops) |i| { + const op_result = self.spv.allocId(); + + try self.func.body.emitRaw(self.spv.gpa, opcode, 4); + self.func.body.writeOperand(spec.IdResultType, op_result_ty_id); + self.func.body.writeOperand(IdResult, op_result); + self.func.body.writeOperand(IdResult, lhs_op.at(i)); + self.func.body.writeOperand(IdResult, rhs_op.at(i)); + + // The above operation returns a struct. We might want to expand + // Temporary to deal with the fact that these are structs eventually, + // but for now, take the struct apart and return two separate vectors. + + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = arith_op_ty_id, + .id_result = value_results.at(i), + .composite = op_result, + .indexes = &.{0}, + }); + + try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ + .id_result_type = arith_op_ty_id, + .id_result = overflow_results.at(i), + .composite = op_result, + .indexes = &.{1}, + }); + } + }, + else => unreachable, + } + + const result_ty = try v.resultType(self, lhs.ty); + return .{ + v.finalize(result_ty, value_results), + v.finalize(result_ty, overflow_results), }; } @@ -2237,59 +3099,42 @@ const DeclGen = struct { } } - fn intFromBool(self: *DeclGen, ty: Type, condition_id: IdRef) !IdRef { - const zero_id = try self.constInt(ty, 0, .direct); - const one_id = try self.constInt(ty, 1, .direct); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .condition = condition_id, - .object_1 = one_id, - .object_2 = zero_id, - }); - return result_id; + fn intFromBool(self: *DeclGen, value: Temporary) !Temporary { + return try self.intFromBool2(value, Type.u1); + } + + fn intFromBool2(self: *DeclGen, value: Temporary, result_ty: Type) !Temporary { + const zero_id = try self.constInt(result_ty, 0, .direct); + const one_id = try self.constInt(result_ty, 1, .direct); + + return try self.buildSelect( + value, + Temporary.init(result_ty, one_id), + Temporary.init(result_ty, zero_id), + ); } /// Convert representation from indirect (in memory) to direct (in 'register') /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct). fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - const scalar_ty = ty.scalarType(mod); - const is_spv_vector = self.isSpvVector(ty); - switch (scalar_ty.zigTypeTag(mod)) { + switch (ty.scalarType(mod).zigTypeTag(mod)) { .Bool => { - // TODO: We may want to use something like elementWise in this function. - // First we need to audit whether this would recursively call into itself. - if (!ty.isVector(mod) or is_spv_vector) { - const result_id = self.spv.allocId(); - const scalar_false_id = try self.constBool(false, .indirect); - const false_id = if (is_spv_vector) blk: { - const index = try mod.intern_pool.get(mod.gpa, .{ - .vector_type = .{ - .len = ty.vectorLen(mod), - .child = Type.u1.toIntern(), - }, - }); - const vec_ty = Type.fromInterned(index); - break :blk try self.constructVectorSplat(vec_ty, scalar_false_id); - } else scalar_false_id; + const false_id = try self.constBool(false, .indirect); + // The operation below requires inputs in direct representation, but the operand + // is actually in indirect representation. + // Cheekily swap out the type to the direct equivalent of the indirect type here, they have the + // same representation when converted to SPIR-V. + const operand_ty = try self.zigScalarOrVectorTypeLike(Type.u1, ty); + // Note: We can guarantee that these are the same ID due to the SPIR-V Module's `vector_types` cache! + assert(try self.resolveType(operand_ty, .direct) == try self.resolveType(ty, .indirect)); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .operand_1 = operand_id, - .operand_2 = false_id, - }); - return result_id; - } - - const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); - for (constituents, 0..) |*id, i| { - const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); - id.* = try self.convertToDirect(scalar_ty, element); - } - return try self.constructVector(ty, constituents); + const result = try self.buildCmp( + .i_ne, + Temporary.init(operand_ty, operand_id), + Temporary.init(Type.u1, false_id), + ); + return try result.materialize(self); }, else => return operand_id, } @@ -2299,55 +3144,10 @@ const DeclGen = struct { /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect). fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { const mod = self.module; - const scalar_ty = ty.scalarType(mod); - const is_spv_vector = self.isSpvVector(ty); - switch (scalar_ty.zigTypeTag(mod)) { + switch (ty.scalarType(mod).zigTypeTag(mod)) { .Bool => { - const result_ty = if (is_spv_vector) blk: { - const index = try mod.intern_pool.get(mod.gpa, .{ - .vector_type = .{ - .len = ty.vectorLen(mod), - .child = Type.u1.toIntern(), - }, - }); - break :blk Type.fromInterned(index); - } else Type.u1; - - if (!ty.isVector(mod) or is_spv_vector) { - // TODO: We may want to use something like elementWise in this function. - // First we need to audit whether this would recursively call into itself. - // Also unify it with intFromBool - - const scalar_zero_id = try self.constInt(Type.u1, 0, .direct); - const scalar_one_id = try self.constInt(Type.u1, 1, .direct); - - const zero_id = if (is_spv_vector) - try self.constructVectorSplat(result_ty, scalar_zero_id) - else - scalar_zero_id; - - const one_id = if (is_spv_vector) - try self.constructVectorSplat(result_ty, scalar_one_id) - else - scalar_one_id; - - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = try self.resolveType(result_ty, .direct), - .id_result = result_id, - .condition = operand_id, - .object_1 = one_id, - .object_2 = zero_id, - }); - return result_id; - } - - const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod)); - for (constituents, 0..) |*id, i| { - const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i)); - id.* = try self.convertToIndirect(scalar_ty, element); - } - return try self.constructVector(result_ty, constituents); + const result = try self.intFromBool(Temporary.init(ty, operand_id)); + return try result.materialize(self); }, else => return operand_id, } @@ -2428,26 +3228,35 @@ const DeclGen = struct { const air_tags = self.air.instructions.items(.tag); const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) { // zig fmt: off - .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd), - .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub), - .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul), - + .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .f_add, .i_add, .i_add), + .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .f_sub, .i_sub, .i_sub), + .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .f_mul, .i_mul, .i_mul), + .sqrt => try self.airUnOpSimple(inst, .sqrt), + .sin => try self.airUnOpSimple(inst, .sin), + .cos => try self.airUnOpSimple(inst, .cos), + .tan => try self.airUnOpSimple(inst, .tan), + .exp => try self.airUnOpSimple(inst, .exp), + .exp2 => try self.airUnOpSimple(inst, .exp2), + .log => try self.airUnOpSimple(inst, .log), + .log2 => try self.airUnOpSimple(inst, .log2), + .log10 => try self.airUnOpSimple(inst, .log10), .abs => try self.airAbs(inst), - .floor => try self.airFloor(inst), + .floor => try self.airUnOpSimple(inst, .floor), + .ceil => try self.airUnOpSimple(inst, .ceil), + .round => try self.airUnOpSimple(inst, .round), + .trunc_float => try self.airUnOpSimple(inst, .trunc), + .neg, .neg_optimized => try self.airUnOpSimple(inst, .f_neg), - .div_floor => try self.airDivFloor(inst), + .div_float, .div_float_optimized => try self.airArithOp(inst, .f_div, .s_div, .u_div), + .div_floor, .div_floor_optimized => try self.airDivFloor(inst), + .div_trunc, .div_trunc_optimized => try self.airDivTrunc(inst), - .div_float, - .div_float_optimized, - .div_trunc, - .div_trunc_optimized => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv), - .rem, .rem_optimized => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem), - .mod, .mod_optimized => try self.airArithOp(inst, .OpFMod, .OpSMod, .OpSMod), + .rem, .rem_optimized => try self.airArithOp(inst, .f_rem, .s_rem, .u_mod), + .mod, .mod_optimized => try self.airArithOp(inst, .f_mod, .s_mod, .u_mod), - - .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan), - .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan), + .add_with_overflow => try self.airAddSubOverflow(inst, .i_add, .u_lt, .s_lt), + .sub_with_overflow => try self.airAddSubOverflow(inst, .i_sub, .u_gt, .s_gt), .mul_with_overflow => try self.airMulOverflow(inst), .shl_with_overflow => try self.airShlOverflow(inst), @@ -2456,6 +3265,8 @@ const DeclGen = struct { .ctz => try self.airClzCtz(inst, .ctz), .clz => try self.airClzCtz(inst, .clz), + .select => try self.airSelect(inst), + .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), @@ -2463,17 +3274,17 @@ const DeclGen = struct { .ptr_add => try self.airPtrAdd(inst), .ptr_sub => try self.airPtrSub(inst), - .bit_and => try self.airBinOpSimple(inst, .OpBitwiseAnd), - .bit_or => try self.airBinOpSimple(inst, .OpBitwiseOr), - .xor => try self.airBinOpSimple(inst, .OpBitwiseXor), - .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd), - .bool_or => try self.airBinOpSimple(inst, .OpLogicalOr), + .bit_and => try self.airBinOpSimple(inst, .bit_and), + .bit_or => try self.airBinOpSimple(inst, .bit_or), + .xor => try self.airBinOpSimple(inst, .bit_xor), + .bool_and => try self.airBinOpSimple(inst, .l_and), + .bool_or => try self.airBinOpSimple(inst, .l_or), - .shl, .shl_exact => try self.airShift(inst, .OpShiftLeftLogical, .OpShiftLeftLogical), - .shr, .shr_exact => try self.airShift(inst, .OpShiftRightLogical, .OpShiftRightArithmetic), + .shl, .shl_exact => try self.airShift(inst, .sll, .sll), + .shr, .shr_exact => try self.airShift(inst, .srl, .sra), - .min => try self.airMinMax(inst, .lt), - .max => try self.airMinMax(inst, .gt), + .min => try self.airMinMax(inst, .min), + .max => try self.airMinMax(inst, .max), .bitcast => try self.airBitCast(inst), .intcast, .trunc => try self.airIntCast(inst), @@ -2574,39 +3385,23 @@ const DeclGen = struct { try self.inst_results.putNoClobber(self.gpa, inst, result_id); } - fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef { - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (0..wip.results.len) |i| { - try self.func.body.emit(self.spv.gpa, opcode, .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand_1 = try wip.elementAt(ty, lhs_id, i), - .operand_2 = try wip.elementAt(ty, rhs_id, i), - }); - } - return try wip.finalize(); - } - - fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef { + fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, op: BinaryOp) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOf(bin_op.lhs); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.binOpSimple(ty, lhs_id, rhs_id, opcode); + const result = try self.buildBinary(op, lhs, rhs); + return try result.materialize(self); } - fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime unsigned: Opcode, comptime signed: Opcode) !?IdRef { + fn airShift(self: *DeclGen, inst: Air.Inst.Index, unsigned: BinaryOp, signed: BinaryOp) !?IdRef { const mod = self.module; const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); + + const base = try self.temporary(bin_op.lhs); + const shift = try self.temporary(bin_op.rhs); const result_ty = self.typeOfIndex(inst); - const shift_ty = self.typeOf(bin_op.rhs); - const scalar_result_ty_id = try self.resolveType(result_ty.scalarType(mod), .direct); - const scalar_shift_ty_id = try self.resolveType(shift_ty.scalarType(mod), .direct); const info = self.arithmeticTypeInfo(result_ty); switch (info.class) { @@ -2615,121 +3410,58 @@ const DeclGen = struct { .float, .bool => unreachable, } - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i); + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. - // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, - // so just manually upcast it if required. - const shift_id = if (scalar_shift_ty_id != scalar_result_ty_id) blk: { - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = shift_id, - .unsigned_value = rhs_elem_id, - }); - break :blk shift_id; - } else rhs_elem_id; + // Note: The sign may differ here between the shift and the base type, in case + // of an arithmetic right shift. SPIR-V still expects the same type, + // so in that case we have to cast convert to signed. + const casted_shift = try self.buildIntConvert(base.ty.scalarType(mod), shift); - const value_id = self.spv.allocId(); - const args = .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .base = lhs_elem_id, - .shift = shift_id, - }; + const shifted = switch (info.signedness) { + .unsigned => try self.buildBinary(unsigned, base, casted_shift), + .signed => try self.buildBinary(signed, base, casted_shift), + }; - if (result_ty.isSignedInt(mod)) { - try self.func.body.emit(self.spv.gpa, signed, args); - } else { - try self.func.body.emit(self.spv.gpa, unsigned, args); - } - - result_id.* = try self.normalize(wip.ty, value_id, info); - } - return try wip.finalize(); + const result = try self.normalize(shifted, info); + return try result.materialize(self); } - fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef { + const MinMax = enum { min, max }; + + fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: MinMax) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const result_ty = self.typeOfIndex(inst); - return try self.minMax(result_ty, op, lhs_id, rhs_id); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const result = try self.minMax(lhs, rhs, op); + return try result.materialize(self); } - fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { - const info = self.arithmeticTypeInfo(result_ty); - const target = self.getTarget(); + fn minMax(self: *DeclGen, lhs: Temporary, rhs: Temporary, op: MinMax) !Temporary { + const info = self.arithmeticTypeInfo(lhs.ty); - const use_backup_codegen = target.os.tag == .opencl and info.class != .float; - var wip = try self.elementWise(result_ty, use_backup_codegen); - defer wip.deinit(); + const binop: BinaryOp = switch (info.class) { + .float => switch (op) { + .min => .f_min, + .max => .f_max, + }, + .integer, .strange_integer => switch (info.signedness) { + .signed => switch (op) { + .min => .s_min, + .max => .s_max, + }, + .unsigned => switch (op) { + .min => .u_min, + .max => .u_max, + }, + }, + .composite_integer => unreachable, // TODO + .bool => unreachable, + }; - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); - - if (use_backup_codegen) { - const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id); - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .condition = cmp_id, - .object_1 = lhs_elem_id, - .object_2 = rhs_elem_id, - }); - } else { - const ext_inst: Word = switch (target.os.tag) { - .opencl => switch (op) { - .lt => 28, // fmin - .gt => 27, // fmax - else => unreachable, - }, - .vulkan => switch (info.class) { - .float => switch (op) { - .lt => 37, // FMin - .gt => 40, // FMax - else => unreachable, - }, - .integer, .strange_integer => switch (info.signedness) { - .signed => switch (op) { - .lt => 39, // SMin - .gt => 42, // SMax - else => unreachable, - }, - .unsigned => switch (op) { - .lt => 38, // UMin - .gt => 41, // UMax - else => unreachable, - }, - }, - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{ lhs_elem_id, rhs_elem_id }, - }); - } - } - return wip.finalize(); + return try self.buildBinary(binop, lhs, rhs); } /// This function normalizes values to a canonical representation @@ -2740,41 +3472,24 @@ const DeclGen = struct { /// - Signed integers are also sign extended if they are negative. /// All other values are returned unmodified (this makes strange integer /// wrapping easier to use in generic operations). - fn normalize(self: *DeclGen, ty: Type, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef { + fn normalize(self: *DeclGen, value: Temporary, info: ArithmeticTypeInfo) !Temporary { + const mod = self.module; + const ty = value.ty; switch (info.class) { - .integer, .bool, .float => return value_id, + .integer, .bool, .float => return value, .composite_integer => unreachable, // TODO .strange_integer => switch (info.signedness) { .unsigned => { const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1; - const result_id = self.spv.allocId(); - const mask_id = try self.constInt(ty, mask_value, .direct); - try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = result_id, - .operand_1 = value_id, - .operand_2 = mask_id, - }); - return result_id; + const mask_id = try self.constInt(ty.scalarType(mod), mask_value, .direct); + return try self.buildBinary(.bit_and, value, Temporary.init(ty.scalarType(mod), mask_id)); }, .signed => { // Shift left and right so that we can copy the sight bit that way. - const shift_amt_id = try self.constInt(ty, info.backing_bits - info.bits, .direct); - const left_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = left_id, - .base = value_id, - .shift = shift_amt_id, - }); - const right_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ - .id_result_type = try self.resolveType(ty, .direct), - .id_result = right_id, - .base = left_id, - .shift = shift_amt_id, - }); - return right_id; + const shift_amt_id = try self.constInt(ty.scalarType(mod), info.backing_bits - info.bits, .direct); + const shift_amt = Temporary.init(ty.scalarType(mod), shift_amt_id); + const left = try self.buildBinary(.sll, value, shift_amt); + return try self.buildBinary(.sra, left, shift_amt); }, }, } @@ -2782,491 +3497,438 @@ const DeclGen = struct { fn airDivFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOfIndex(inst); - const ty_id = try self.resolveType(ty, .direct); - const info = self.arithmeticTypeInfo(ty); + + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { .composite_integer => unreachable, // TODO .integer, .strange_integer => { - const zero_id = try self.constInt(ty, 0, .direct); - const one_id = try self.constInt(ty, 1, .direct); + switch (info.signedness) { + .unsigned => { + const result = try self.buildBinary(.u_div, lhs, rhs); + return try result.materialize(self); + }, + .signed => {}, + } - // (a ^ b) > 0 - const bin_bitwise_id = try self.binOpSimple(ty, lhs_id, rhs_id, .OpBitwiseXor); - const is_positive_id = try self.cmp(.gt, Type.bool, ty, bin_bitwise_id, zero_id); + // For signed integers: + // (a / b) - (a % b != 0 && a < 0 != b < 0); + // There shouldn't be any overflow issues. - // a / b - const positive_div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv); + const div = try self.buildBinary(.s_div, lhs, rhs); + const rem = try self.buildBinary(.s_rem, lhs, rhs); - // - (abs(a) + abs(b) - 1) / abs(b) - const lhs_abs = try self.abs(ty, ty, lhs_id); - const rhs_abs = try self.abs(ty, ty, rhs_id); - const negative_div_lhs = try self.arithOp( - ty, - try self.arithOp(ty, lhs_abs, rhs_abs, .OpFAdd, .OpIAdd, .OpIAdd), - one_id, - .OpFSub, - .OpISub, - .OpISub, + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + + const rem_is_not_zero = try self.buildCmp(.i_ne, rem, zero); + + const result_negative = try self.buildCmp( + .l_ne, + try self.buildCmp(.s_lt, lhs, zero), + try self.buildCmp(.s_lt, rhs, zero), + ); + const rem_is_not_zero_and_result_is_negative = try self.buildBinary( + .l_and, + rem_is_not_zero, + result_negative, ); - const negative_div_id = try self.arithOp(ty, negative_div_lhs, rhs_abs, .OpFDiv, .OpSDiv, .OpUDiv); - const negated_negative_div_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSNegate, .{ - .id_result_type = ty_id, - .id_result = negated_negative_div_id, - .operand = negative_div_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = ty_id, - .id_result = result_id, - .condition = is_positive_id, - .object_1 = positive_div_id, - .object_2 = negated_negative_div_id, - }); - return result_id; + const result = try self.buildBinary( + .i_sub, + div, + try self.intFromBool2(rem_is_not_zero_and_result_is_negative, div.ty), + ); + + return try result.materialize(self); }, .float => { - const div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv); - return try self.floor(ty, div_id); + const div = try self.buildBinary(.f_div, lhs, rhs); + const result = try self.buildUnary(.floor, div); + return try result.materialize(self); }, .bool => unreachable, } } - fn airFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { - const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; - const operand_id = try self.resolve(un_op); - const result_ty = self.typeOfIndex(inst); - return try self.floor(result_ty, operand_id); + fn airDivTrunc(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; + + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); + + const info = self.arithmeticTypeInfo(lhs.ty); + switch (info.class) { + .composite_integer => unreachable, // TODO + .integer, .strange_integer => switch (info.signedness) { + .unsigned => { + const result = try self.buildBinary(.u_div, lhs, rhs); + return try result.materialize(self); + }, + .signed => { + const result = try self.buildBinary(.s_div, lhs, rhs); + return try result.materialize(self); + }, + }, + .float => { + const div = try self.buildBinary(.f_div, lhs, rhs); + const result = try self.buildUnary(.trunc, div); + return try result.materialize(self); + }, + .bool => unreachable, + } } - fn floor(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef { - const target = self.getTarget(); - const ty_id = try self.resolveType(ty, .direct); - const ext_inst: Word = switch (target.os.tag) { - .opencl => 25, - .vulkan => 8, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = ty_id, - .id_result = result_id, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{operand_id}, - }); - return result_id; + fn airUnOpSimple(self: *DeclGen, inst: Air.Inst.Index, op: UnaryOp) !?IdRef { + const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; + const operand = try self.temporary(un_op); + const result = try self.buildUnary(op, operand); + return try result.materialize(self); } fn airArithOp( self: *DeclGen, inst: Air.Inst.Index, - comptime fop: Opcode, - comptime sop: Opcode, - comptime uop: Opcode, + comptime fop: BinaryOp, + comptime sop: BinaryOp, + comptime uop: BinaryOp, ) !?IdRef { - // LHS and RHS are guaranteed to have the same type, and AIR guarantees - // the result to be the same as the LHS and RHS, which matches SPIR-V. - const ty = self.typeOfIndex(inst); const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - assert(self.typeOf(bin_op.lhs).eql(ty, self.module)); - assert(self.typeOf(bin_op.rhs).eql(ty, self.module)); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.arithOp(ty, lhs_id, rhs_id, fop, sop, uop); - } + const info = self.arithmeticTypeInfo(lhs.ty); - fn arithOp( - self: *DeclGen, - ty: Type, - lhs_id: IdRef, - rhs_id: IdRef, - comptime fop: Opcode, - comptime sop: Opcode, - comptime uop: Opcode, - ) !IdRef { - // Binary operations are generally applicable to both scalar and vector operations - // in SPIR-V, but int and float versions of operations require different opcodes. - const info = self.arithmeticTypeInfo(ty); - - const opcode_index: usize = switch (info.class) { - .composite_integer => { - return self.todo("binary operations for composite integers", .{}); - }, + const result = switch (info.class) { + .composite_integer => unreachable, // TODO .integer, .strange_integer => switch (info.signedness) { - .signed => 1, - .unsigned => 2, + .signed => try self.buildBinary(sop, lhs, rhs), + .unsigned => try self.buildBinary(uop, lhs, rhs), }, - .float => 0, + .float => try self.buildBinary(fop, lhs, rhs), .bool => unreachable, }; - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); - - const value_id = self.spv.allocId(); - const operands = .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .operand_1 = lhs_elem_id, - .operand_2 = rhs_elem_id, - }; - - switch (opcode_index) { - 0 => try self.func.body.emit(self.spv.gpa, fop, operands), - 1 => try self.func.body.emit(self.spv.gpa, sop, operands), - 2 => try self.func.body.emit(self.spv.gpa, uop, operands), - else => unreachable, - } - - // TODO: Trap on overflow? Probably going to be annoying. - // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. - result_id.* = try self.normalize(wip.ty, value_id, info); - } - - return try wip.finalize(); + return try result.materialize(self); } fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); // Note: operand_ty may be signed, while ty is always unsigned! - const operand_ty = self.typeOf(ty_op.operand); const result_ty = self.typeOfIndex(inst); - return try self.abs(result_ty, operand_ty, operand_id); + const result = try self.abs(result_ty, operand); + return try result.materialize(self); } - fn abs(self: *DeclGen, result_ty: Type, operand_ty: Type, operand_id: IdRef) !IdRef { + fn abs(self: *DeclGen, result_ty: Type, value: Temporary) !Temporary { const target = self.getTarget(); - const operand_info = self.arithmeticTypeInfo(operand_ty); + const operand_info = self.arithmeticTypeInfo(value.ty); - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); + switch (operand_info.class) { + .float => return try self.buildUnary(.f_abs, value), + .integer, .strange_integer => { + const abs_value = try self.buildUnary(.i_abs, value); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(operand_ty, operand_id, i); + // TODO: We may need to bitcast the result to a uint + // depending on the result type. Do that when + // bitCast is implemented for vectors. + // This is only relevant for Vulkan + assert(target.os.tag != .vulkan); // TODO - const ext_inst: Word = switch (target.os.tag) { - .opencl => switch (operand_info.class) { - .float => 23, // fabs - .integer, .strange_integer => switch (operand_info.signedness) { - .signed => 141, // s_abs - .unsigned => 201, // u_abs - }, - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - .vulkan => switch (operand_info.class) { - .float => 4, // FAbs - .integer, .strange_integer => 5, // SAbs - .composite_integer => unreachable, // TODO - .bool => unreachable, - }, - else => unreachable, - }; - const set_id = switch (target.os.tag) { - .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"), - .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"), - else => unreachable, - }; - - result_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .set = set_id, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{elem_id}, - }); + return try self.normalize(abs_value, self.arithmeticTypeInfo(result_ty)); + }, + .composite_integer => unreachable, // TODO + .bool => unreachable, } - return try wip.finalize(); } fn airAddSubOverflow( self: *DeclGen, inst: Air.Inst.Index, - comptime add: Opcode, - comptime ucmp: Opcode, - comptime scmp: Opcode, + comptime add: BinaryOp, + comptime ucmp: CmpPredicate, + comptime scmp: CmpPredicate, ) !?IdRef { - const mod = self.module; + // Note: OpIAddCarry and OpISubBorrow are not really useful here: For unsigned numbers, + // there is in both cases only one extra operation required. For signed operations, + // the overflow bit is set then going from 0x80.. to 0x00.., but this doesn't actually + // normally set a carry bit. So the SPIR-V overflow operations are not particularly + // useful here. + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const lhs = try self.temporary(extra.lhs); + const rhs = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const ov_ty = result_ty.structFieldType(1, self.module); - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isSpvVector(operand_ty)) - // TODO: Resolving a vector type with .direct should return a SPIR-V vector - try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) - else - bool_ty_id; - - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { - .composite_integer => return self.todo("overflow ops for composite integers", .{}), + .composite_integer => unreachable, // TODO .strange_integer, .integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, false); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, false); - defer wip_ov.deinit(); - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i); + const sum = try self.buildBinary(add, lhs, rhs); + const result = try self.normalize(sum, info); - // Normalize both so that we can properly check for overflow - const value_id = self.spv.allocId(); + const overflowed = switch (info.signedness) { + // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. + // For subtraction the conditions need to be swapped. + .unsigned => try self.buildCmp(ucmp, result, lhs), + // For addition, overflow happened if: + // - rhs is negative and value > lhs + // - rhs is positive and value < lhs + // This can be shortened to: + // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) + // = (rhs < 0) == (value > lhs) + // = (rhs < 0) == (lhs < value) + // Note that signed overflow is also wrapping in spir-v. + // For subtraction, overflow happened if: + // - rhs is negative and value < lhs + // - rhs is positive and value > lhs + // This can be shortened to: + // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) + // = (rhs < 0) == (value < lhs) + // = (rhs < 0) == (lhs > value) + .signed => blk: { + const zero = Temporary.init(rhs.ty, try self.constInt(rhs.ty, 0, .direct)); + const rhs_lt_zero = try self.buildCmp(.s_lt, rhs, zero); + const result_gt_lhs = try self.buildCmp(scmp, lhs, result); + break :blk try self.buildCmp(.l_eq, rhs_lt_zero, result_gt_lhs); + }, + }; - try self.func.body.emit(self.spv.gpa, add, .{ - .id_result_type = wip_result.ty_id, - .id_result = value_id, - .operand_1 = lhs_elem_id, - .operand_2 = rhs_elem_id, - }); - - // Normalize the result so that the comparisons go well - result_id.* = try self.normalize(wip_result.ty, value_id, info); - - const overflowed_id = switch (info.signedness) { - .unsigned => blk: { - // Overflow happened if the result is smaller than either of the operands. It doesn't matter which. - // For subtraction the conditions need to be swapped. - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, ucmp, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = result_id.*, - .operand_2 = lhs_elem_id, - }); - break :blk overflowed_id; - }, - .signed => blk: { - // lhs - rhs - // For addition, overflow happened if: - // - rhs is negative and value > lhs - // - rhs is positive and value < lhs - // This can be shortened to: - // (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs) - // = (rhs < 0) == (value > lhs) - // = (rhs < 0) == (lhs < value) - // Note that signed overflow is also wrapping in spir-v. - // For subtraction, overflow happened if: - // - rhs is negative and value < lhs - // - rhs is positive and value > lhs - // This can be shortened to: - // (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs) - // = (rhs < 0) == (value < lhs) - // = (rhs < 0) == (lhs > value) - - const rhs_lt_zero_id = self.spv.allocId(); - const zero_id = try self.constInt(wip_result.ty, 0, .direct); - try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{ - .id_result_type = cmp_ty_id, - .id_result = rhs_lt_zero_id, - .operand_1 = rhs_elem_id, - .operand_2 = zero_id, - }); - - const value_gt_lhs_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, scmp, .{ - .id_result_type = cmp_ty_id, - .id_result = value_gt_lhs_id, - .operand_1 = lhs_elem_id, - .operand_2 = result_id.*, - }); - - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = rhs_lt_zero_id, - .operand_2 = value_gt_lhs_id, - }); - break :blk overflowed_id; - }, - }; - - ov_id.* = try self.intFromBool(wip_ov.ty, overflowed_id); - } + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } fn airMulOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const target = self.getTarget(); + const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const lhs = try self.temporary(extra.lhs); + const rhs = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const ov_ty = result_ty.structFieldType(1, self.module); - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(lhs.ty); switch (info.class) { - .composite_integer => return self.todo("overflow ops for composite integers", .{}), + .composite_integer => unreachable, // TODO .strange_integer, .integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, true); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, true); - defer wip_ov.deinit(); + // There are 3 cases which we have to deal with: + // - If info.bits < 32 / 2, we will upcast to 32 and check the higher bits + // - If info.bits > 32 / 2, we have to use extended multiplication + // - Additionally, if info.bits != 32, we'll have to check the high bits + // of the result too. - const zero_id = try self.constInt(wip_result.ty, 0, .direct); - const zero_ov_id = try self.constInt(wip_ov.ty, 0, .direct); - const one_ov_id = try self.constInt(wip_ov.ty, 1, .direct); + const largest_int_bits: u16 = if (Target.spirv.featureSetHas(target.cpu.features, .Int64)) 64 else 32; + // If non-null, the number of bits that the multiplication should be performed in. If + // null, we have to use wide multiplication. + const maybe_op_ty_bits: ?u16 = switch (info.bits) { + 0 => unreachable, + 1...16 => 32, + 17...32 => if (largest_int_bits > 32) 64 else null, // Upcast if we can. + 33...64 => null, // Always use wide multiplication. + else => unreachable, // TODO: Composite integers + }; - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i); + const result, const overflowed = switch (info.signedness) { + .unsigned => blk: { + if (maybe_op_ty_bits) |op_ty_bits| { + const op_ty = try mod.intType(.unsigned, op_ty_bits); + const casted_lhs = try self.buildIntConvert(op_ty, lhs); + const casted_rhs = try self.buildIntConvert(op_ty, rhs); - result_id.* = try self.arithOp(wip_result.ty, lhs_elem_id, rhs_elem_id, .OpFMul, .OpIMul, .OpIMul); + const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs); - // (a != 0) and (x / a != b) - const not_zero_id = try self.cmp(.neq, Type.bool, wip_result.ty, lhs_elem_id, zero_id); - const res_rhs_id = try self.arithOp(wip_result.ty, result_id.*, lhs_elem_id, .OpFDiv, .OpSDiv, .OpUDiv); - const res_rhs_not_rhs_id = try self.cmp(.neq, Type.bool, wip_result.ty, res_rhs_id, rhs_elem_id); - const cond_id = try self.binOpSimple(Type.bool, not_zero_id, res_rhs_not_rhs_id, .OpLogicalAnd); + const low_bits = try self.buildIntConvert(lhs.ty, full_result); + const result = try self.normalize(low_bits, info); - ov_id.* = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpSelect, .{ - .id_result_type = wip_ov.ty_id, - .id_result = ov_id.*, - .condition = cond_id, - .object_1 = one_ov_id, - .object_2 = zero_ov_id, - }); - } + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(full_result.ty, try self.constInt(full_result.ty, info.bits, .direct)); + const overflow = try self.buildBinary(.srl, full_result, shift); + + // Directly check if its zero in the op_ty without converting first. + const zero = Temporary.init(full_result.ty, try self.constInt(full_result.ty, 0, .direct)); + const overflowed = try self.buildCmp(.i_ne, zero, overflow); + + break :blk .{ result, overflowed }; + } + + const low_bits, const high_bits = try self.buildWideMul(.u_mul_extended, lhs, rhs); + + // Truncate the result, if required. + const result = try self.normalize(low_bits, info); + + // Overflow happened if the high-bits of the result are non-zero OR if the + // high bits of the low word of the result (those outside the range of the + // int) are nonzero. + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + const high_overflowed = try self.buildCmp(.i_ne, zero, high_bits); + + // If no overflow bits in low_bits, no extra work needs to be done. + if (info.backing_bits == info.bits) { + break :blk .{ result, high_overflowed }; + } + + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits, .direct)); + const low_overflow = try self.buildBinary(.srl, low_bits, shift); + const low_overflowed = try self.buildCmp(.i_ne, zero, low_overflow); + + const overflowed = try self.buildBinary(.l_or, low_overflowed, high_overflowed); + + break :blk .{ result, overflowed }; + }, + .signed => blk: { + // - lhs >= 0, rhxs >= 0: expect positive; overflow should be 0 + // - lhs == 0 : expect positive; overflow should be 0 + // - rhs == 0: expect positive; overflow should be 0 + // - lhs > 0, rhs < 0: expect negative; overflow should be -1 + // - lhs < 0, rhs > 0: expect negative; overflow should be -1 + // - lhs <= 0, rhs <= 0: expect positive; overflow should be 0 + // ------ + // overflow should be -1 when + // (lhs > 0 && rhs < 0) || (lhs < 0 && rhs > 0) + + const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0, .direct)); + const lhs_negative = try self.buildCmp(.s_lt, lhs, zero); + const rhs_negative = try self.buildCmp(.s_lt, rhs, zero); + const lhs_positive = try self.buildCmp(.s_gt, lhs, zero); + const rhs_positive = try self.buildCmp(.s_gt, rhs, zero); + + // Set to `true` if we expect -1. + const expected_overflow_bit = try self.buildBinary( + .l_or, + try self.buildBinary(.l_and, lhs_positive, rhs_negative), + try self.buildBinary(.l_and, lhs_negative, rhs_positive), + ); + + if (maybe_op_ty_bits) |op_ty_bits| { + const op_ty = try mod.intType(.signed, op_ty_bits); + // Assume normalized; sign bit is set. We want a sign extend. + const casted_lhs = try self.buildIntConvert(op_ty, lhs); + const casted_rhs = try self.buildIntConvert(op_ty, rhs); + + const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs); + + // Truncate to the result type. + const low_bits = try self.buildIntConvert(lhs.ty, full_result); + const result = try self.normalize(low_bits, info); + + // Now, we need to check the overflow bits AND the sign + // bit for the expceted overflow bits. + // To do that, shift out everything bit the sign bit and + // then check what remains. + const shift = Temporary.init(full_result.ty, try self.constInt(full_result.ty, info.bits - 1, .direct)); + // Use SRA so that any sign bits are duplicated. Now we can just check if ALL bits are set + // for negative cases. + const overflow = try self.buildBinary(.sra, full_result, shift); + + const long_all_set = Temporary.init(full_result.ty, try self.constInt(full_result.ty, -1, .direct)); + const long_zero = Temporary.init(full_result.ty, try self.constInt(full_result.ty, 0, .direct)); + const mask = try self.buildSelect(expected_overflow_bit, long_all_set, long_zero); + + const overflowed = try self.buildCmp(.i_ne, mask, overflow); + + break :blk .{ result, overflowed }; + } + + const low_bits, const high_bits = try self.buildWideMul(.s_mul_extended, lhs, rhs); + + // Truncate result if required. + const result = try self.normalize(low_bits, info); + + const all_set = Temporary.init(lhs.ty, try self.constInt(lhs.ty, -1, .direct)); + const mask = try self.buildSelect(expected_overflow_bit, all_set, zero); + + // Like with unsigned, overflow happened if high_bits are not the ones we expect, + // and we also need to check some ones from the low bits. + + const high_overflowed = try self.buildCmp(.i_ne, mask, high_bits); + + // If no overflow bits in low_bits, no extra work needs to be done. + // Careful, we still have to check the sign bit, so this branch + // only goes for i33 and such. + if (info.backing_bits == info.bits + 1) { + break :blk .{ result, high_overflowed }; + } + + // Shift the result bits away to get the overflow bits. + const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits - 1, .direct)); + // Use SRA so that any sign bits are duplicated. Now we can just check if ALL bits are set + // for negative cases. + const low_overflow = try self.buildBinary(.sra, low_bits, shift); + const low_overflowed = try self.buildCmp(.i_ne, mask, low_overflow); + + const overflowed = try self.buildBinary(.l_or, low_overflowed, high_overflowed); + + break :blk .{ result, overflowed }; + }, + }; + + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; + const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Bin, ty_pl.payload).data; - const lhs = try self.resolve(extra.lhs); - const rhs = try self.resolve(extra.rhs); + + const base = try self.temporary(extra.lhs); + const shift = try self.temporary(extra.rhs); const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(extra.lhs); - const shift_ty = self.typeOf(extra.rhs); - const scalar_shift_ty_id = try self.resolveType(shift_ty.scalarType(mod), .direct); - const scalar_operand_ty_id = try self.resolveType(operand_ty.scalarType(mod), .direct); - const ov_ty = result_ty.structFieldType(1, self.module); - - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const cmp_ty_id = if (self.isSpvVector(operand_ty)) - // TODO: Resolving a vector type with .direct should return a SPIR-V vector - try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct)) - else - bool_ty_id; - - const info = self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(base.ty); switch (info.class) { - .composite_integer => return self.todo("overflow shift for composite integers", .{}), + .composite_integer => unreachable, // TODO .integer, .strange_integer => {}, .float, .bool => unreachable, } - var wip_result = try self.elementWise(operand_ty, false); - defer wip_result.deinit(); - var wip_ov = try self.elementWise(ov_ty, false); - defer wip_ov.deinit(); - for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| { - const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i); - const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i); + // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, + // so just manually upcast it if required. + const casted_shift = try self.buildIntConvert(base.ty.scalarType(mod), shift); - // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, - // so just manually upcast it if required. - const shift_id = if (scalar_shift_ty_id != scalar_operand_ty_id) blk: { - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip_result.ty_id, - .id_result = shift_id, - .unsigned_value = rhs_elem_id, - }); - break :blk shift_id; - } else rhs_elem_id; + const left = try self.buildBinary(.sll, base, casted_shift); + const result = try self.normalize(left, info); - const value_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{ - .id_result_type = wip_result.ty_id, - .id_result = value_id, - .base = lhs_elem_id, - .shift = shift_id, - }); - result_id.* = try self.normalize(wip_result.ty, value_id, info); + const right = switch (info.signedness) { + .unsigned => try self.buildBinary(.srl, result, casted_shift), + .signed => try self.buildBinary(.sra, result, casted_shift), + }; - const right_shift_id = self.spv.allocId(); - switch (info.signedness) { - .signed => { - try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{ - .id_result_type = wip_result.ty_id, - .id_result = right_shift_id, - .base = result_id.*, - .shift = shift_id, - }); - }, - .unsigned => { - try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{ - .id_result_type = wip_result.ty_id, - .id_result = right_shift_id, - .base = result_id.*, - .shift = shift_id, - }); - }, - } - - const overflowed_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ - .id_result_type = cmp_ty_id, - .id_result = overflowed_id, - .operand_1 = lhs_elem_id, - .operand_2 = right_shift_id, - }); - - ov_id.* = try self.intFromBool(wip_ov.ty, overflowed_id); - } + const overflowed = try self.buildCmp(.i_ne, base, right); + const ov = try self.intFromBool(overflowed); return try self.constructStruct( result_ty, - &.{ operand_ty, ov_ty }, - &.{ try wip_result.finalize(), try wip_ov.finalize() }, + &.{ result.ty, ov.ty }, + &.{ try result.materialize(self), try ov.materialize(self) }, ); } @@ -3274,122 +3936,67 @@ const DeclGen = struct { const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const extra = self.air.extraData(Air.Bin, pl_op.payload).data; - const mulend1 = try self.resolve(extra.lhs); - const mulend2 = try self.resolve(extra.rhs); - const addend = try self.resolve(pl_op.operand); + const a = try self.temporary(extra.lhs); + const b = try self.temporary(extra.rhs); + const c = try self.temporary(pl_op.operand); - const ty = self.typeOfIndex(inst); - - const info = self.arithmeticTypeInfo(ty); + const result_ty = self.typeOfIndex(inst); + const info = self.arithmeticTypeInfo(result_ty); assert(info.class == .float); // .mul_add is only emitted for floats - var wip = try self.elementWise(ty, false); - defer wip.deinit(); - for (0..wip.results.len) |i| { - const mul_result = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpFMul, .{ - .id_result_type = wip.ty_id, - .id_result = mul_result, - .operand_1 = try wip.elementAt(ty, mulend1, i), - .operand_2 = try wip.elementAt(ty, mulend2, i), - }); - - try self.func.body.emit(self.spv.gpa, .OpFAdd, .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand_1 = mul_result, - .operand_2 = try wip.elementAt(ty, addend, i), - }); - } - return try wip.finalize(); + const result = try self.buildFma(a, b, c); + return try result.materialize(self); } - fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef { + fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: UnaryOp) !?IdRef { if (self.liveness.isUnused(inst)) return null; const mod = self.module; const target = self.getTarget(); const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const result_ty = self.typeOfIndex(inst); - const operand_ty = self.typeOf(ty_op.operand); - const operand = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); - const info = self.arithmeticTypeInfo(operand_ty); + const scalar_result_ty = self.typeOfIndex(inst).scalarType(mod); + + const info = self.arithmeticTypeInfo(operand.ty); switch (info.class) { .composite_integer => unreachable, // TODO .integer, .strange_integer => {}, .float, .bool => unreachable, } - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - - const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty; - const elem_ty_id = try self.resolveType(elem_ty, .direct); - - for (wip.results, 0..) |*result_id, i| { - const elem = try wip.elementAt(operand_ty, operand, i); - - switch (target.os.tag) { - .opencl => { - const set = try self.spv.importInstructionSet(.@"OpenCL.std"); - const ext_inst: u32 = switch (op) { - .clz => 151, // clz - .ctz => 152, // ctz - }; - - // Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty. - // result_ty is always large enough to hold the result, so we might have to down - // cast it. - const tmp = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ - .id_result_type = elem_ty_id, - .id_result = tmp, - .set = set, - .instruction = .{ .inst = ext_inst }, - .id_ref_4 = &.{elem}, - }); - - // TODO: Comparison should be removed.. - // Its valid because SpvModule caches numeric types - if (wip.ty_id == elem_ty_id) { - result_id.* = tmp; - continue; - } - - result_id.* = self.spv.allocId(); - if (result_ty.scalarType(mod).isSignedInt(mod)) { - assert(elem_ty.scalarType(mod).isSignedInt(mod)); - try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .signed_value = tmp, - }); - } else { - assert(elem_ty.scalarType(mod).isUnsignedInt(mod)); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = result_id.*, - .unsigned_value = tmp, - }); - } - }, - .vulkan => unreachable, // TODO - else => unreachable, - } + switch (target.os.tag) { + .vulkan => unreachable, // TODO + else => {}, } - return try wip.finalize(); + const count = try self.buildUnary(op, operand); + + // Result of OpenCL ctz/clz returns operand.ty, and we want result_ty. + // result_ty is always large enough to hold the result, so we might have to down + // cast it. + const result = try self.buildIntConvert(scalar_result_ty, count); + return try result.materialize(self); + } + + fn airSelect(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; + const pred = try self.temporary(pl_op.operand); + const a = try self.temporary(extra.lhs); + const b = try self.temporary(extra.rhs); + + const result = try self.buildSelect(pred, a, b); + return try result.materialize(self); } fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); const result_ty = self.typeOfIndex(inst); - var wip = try self.elementWise(result_ty, true); - defer wip.deinit(); - @memset(wip.results, operand_id); - return try wip.finalize(); + + return try self.constructVectorSplat(result_ty, operand_id); } fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -3402,23 +4009,33 @@ const DeclGen = struct { const info = self.arithmeticTypeInfo(operand_ty); - var result_id = try self.extractVectorComponent(scalar_ty, operand, 0); const len = operand_ty.vectorLen(mod); + const first = try self.extractVectorComponent(scalar_ty, operand, 0); + switch (reduce.operation) { .Min, .Max => |op| { - const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt; + var result = Temporary.init(scalar_ty, first); + const cmp_op: MinMax = switch (op) { + .Max => .max, + .Min => .min, + else => unreachable, + }; for (1..len) |i| { - const lhs = result_id; - const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); - result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs); + const lhs = result; + const rhs_id = try self.extractVectorComponent(scalar_ty, operand, @intCast(i)); + const rhs = Temporary.init(scalar_ty, rhs_id); + + result = try self.minMax(lhs, rhs, cmp_op); } - return result_id; + return try result.materialize(self); }, else => {}, } + var result_id = first; + const opcode: Opcode = switch (info.class) { .bool => switch (reduce.operation) { .And => .OpLogicalAnd, @@ -3602,50 +4219,66 @@ const DeclGen = struct { fn cmp( self: *DeclGen, op: std.math.CompareOperator, - result_ty: Type, - ty: Type, - lhs_id: IdRef, - rhs_id: IdRef, - ) !IdRef { + lhs: Temporary, + rhs: Temporary, + ) !Temporary { const mod = self.module; - var cmp_lhs_id = lhs_id; - var cmp_rhs_id = rhs_id; - const bool_ty_id = try self.resolveType(Type.bool, .direct); - const op_ty = switch (ty.zigTypeTag(mod)) { - .Int, .Bool, .Float => ty, - .Enum => ty.intTagType(mod), - .ErrorSet => Type.u16, - .Pointer => blk: { + const scalar_ty = lhs.ty.scalarType(mod); + const is_vector = lhs.ty.isVector(mod); + + switch (scalar_ty.zigTypeTag(mod)) { + .Int, .Bool, .Float => {}, + .Enum => { + assert(!is_vector); + const ty = lhs.ty.intTagType(mod); + return try self.cmp(op, lhs.pun(ty), rhs.pun(ty)); + }, + .ErrorSet => { + assert(!is_vector); + return try self.cmp(op, lhs.pun(Type.u16), rhs.pun(Type.u16)); + }, + .Pointer => { + assert(!is_vector); // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using // OpConvertPtrToU... - cmp_lhs_id = self.spv.allocId(); - cmp_rhs_id = self.spv.allocId(); const usize_ty_id = try self.resolveType(Type.usize, .direct); + const lhs_int_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ .id_result_type = usize_ty_id, - .id_result = cmp_lhs_id, - .pointer = lhs_id, + .id_result = lhs_int_id, + .pointer = try lhs.materialize(self), }); + const rhs_int_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{ .id_result_type = usize_ty_id, - .id_result = cmp_rhs_id, - .pointer = rhs_id, + .id_result = rhs_int_id, + .pointer = try rhs.materialize(self), }); - break :blk Type.usize; + const lhs_int = Temporary.init(Type.usize, lhs_int_id); + const rhs_int = Temporary.init(Type.usize, rhs_int_id); + return try self.cmp(op, lhs_int, rhs_int); }, .Optional => { + assert(!is_vector); + + const ty = lhs.ty; + const payload_ty = ty.optionalChild(mod); if (ty.optionalReprIsPayload(mod)) { assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod)); assert(!payload_ty.isSlice(mod)); - return self.cmp(op, Type.bool, payload_ty, lhs_id, rhs_id); + + return try self.cmp(op, lhs.pun(payload_ty), rhs.pun(payload_ty)); } + const lhs_id = try lhs.materialize(self); + const rhs_id = try rhs.materialize(self); + const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) try self.extractField(Type.bool, lhs_id, 1) else @@ -3656,8 +4289,11 @@ const DeclGen = struct { else try self.convertToDirect(Type.bool, rhs_id); + const lhs_valid = Temporary.init(Type.bool, lhs_valid_id); + const rhs_valid = Temporary.init(Type.bool, rhs_valid_id); + if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { - return try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); + return try self.cmp(op, lhs_valid, rhs_valid); } // a = lhs_valid @@ -3678,118 +4314,71 @@ const DeclGen = struct { const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0); const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0); - switch (op) { - .eq => { - const valid_eq_id = try self.cmp(.eq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); - const pl_eq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); - const lhs_not_valid_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{ - .id_result_type = bool_ty_id, - .id_result = lhs_not_valid_id, - .operand = lhs_valid_id, - }); - const impl_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = bool_ty_id, - .id_result = impl_id, - .operand_1 = lhs_not_valid_id, - .operand_2 = pl_eq_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{ - .id_result_type = bool_ty_id, - .id_result = result_id, - .operand_1 = valid_eq_id, - .operand_2 = impl_id, - }); - return result_id; - }, - .neq => { - const valid_neq_id = try self.cmp(.neq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id); - const pl_neq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); + const lhs_pl = Temporary.init(payload_ty, lhs_pl_id); + const rhs_pl = Temporary.init(payload_ty, rhs_pl_id); - const impl_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{ - .id_result_type = bool_ty_id, - .id_result = impl_id, - .operand_1 = lhs_valid_id, - .operand_2 = pl_neq_id, - }); - const result_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{ - .id_result_type = bool_ty_id, - .id_result = result_id, - .operand_1 = valid_neq_id, - .operand_2 = impl_id, - }); - return result_id; - }, + return switch (op) { + .eq => try self.buildBinary( + .l_and, + try self.cmp(.eq, lhs_valid, rhs_valid), + try self.buildBinary( + .l_or, + try self.buildUnary(.l_not, lhs_valid), + try self.cmp(.eq, lhs_pl, rhs_pl), + ), + ), + .neq => try self.buildBinary( + .l_or, + try self.cmp(.neq, lhs_valid, rhs_valid), + try self.buildBinary( + .l_and, + lhs_valid, + try self.cmp(.neq, lhs_pl, rhs_pl), + ), + ), else => unreachable, - } - }, - .Vector => { - var wip = try self.elementWise(result_ty, true); - defer wip.deinit(); - const scalar_ty = ty.scalarType(mod); - for (wip.results, 0..) |*result_id, i| { - const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); - const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); - result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id); - } - return wip.finalize(); + }; }, else => unreachable, - }; + } - const opcode: Opcode = opcode: { - const info = self.arithmeticTypeInfo(op_ty); - const signedness = switch (info.class) { - .composite_integer => { - return self.todo("binary operations for composite integers", .{}); - }, - .float => break :opcode switch (op) { - .eq => .OpFOrdEqual, - .neq => .OpFUnordNotEqual, - .lt => .OpFOrdLessThan, - .lte => .OpFOrdLessThanEqual, - .gt => .OpFOrdGreaterThan, - .gte => .OpFOrdGreaterThanEqual, - }, - .bool => break :opcode switch (op) { - .eq => .OpLogicalEqual, - .neq => .OpLogicalNotEqual, - else => unreachable, - }, - .integer, .strange_integer => info.signedness, - }; - - break :opcode switch (signedness) { - .unsigned => switch (op) { - .eq => .OpIEqual, - .neq => .OpINotEqual, - .lt => .OpULessThan, - .lte => .OpULessThanEqual, - .gt => .OpUGreaterThan, - .gte => .OpUGreaterThanEqual, - }, + const info = self.arithmeticTypeInfo(scalar_ty); + const pred: CmpPredicate = switch (info.class) { + .composite_integer => unreachable, // TODO + .float => switch (op) { + .eq => .f_oeq, + .neq => .f_une, + .lt => .f_olt, + .lte => .f_ole, + .gt => .f_ogt, + .gte => .f_oge, + }, + .bool => switch (op) { + .eq => .l_eq, + .neq => .l_ne, + else => unreachable, + }, + .integer, .strange_integer => switch (info.signedness) { .signed => switch (op) { - .eq => .OpIEqual, - .neq => .OpINotEqual, - .lt => .OpSLessThan, - .lte => .OpSLessThanEqual, - .gt => .OpSGreaterThan, - .gte => .OpSGreaterThanEqual, + .eq => .i_eq, + .neq => .i_ne, + .lt => .s_lt, + .lte => .s_le, + .gt => .s_gt, + .gte => .s_ge, }, - }; + .unsigned => switch (op) { + .eq => .i_eq, + .neq => .i_ne, + .lt => .u_lt, + .lte => .u_le, + .gt => .u_gt, + .gte => .u_ge, + }, + }, }; - const result_id = self.spv.allocId(); - try self.func.body.emitRaw(self.spv.gpa, opcode, 4); - self.func.body.writeOperand(spec.IdResultType, bool_ty_id); - self.func.body.writeOperand(spec.IdResult, result_id); - self.func.body.writeOperand(spec.IdResultType, cmp_lhs_id); - self.func.body.writeOperand(spec.IdResultType, cmp_rhs_id); - return result_id; + return try self.buildCmp(pred, lhs, rhs); } fn airCmp( @@ -3798,24 +4387,22 @@ const DeclGen = struct { comptime op: std.math.CompareOperator, ) !?IdRef { const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; - const lhs_id = try self.resolve(bin_op.lhs); - const rhs_id = try self.resolve(bin_op.rhs); - const ty = self.typeOf(bin_op.lhs); - const result_ty = self.typeOfIndex(inst); + const lhs = try self.temporary(bin_op.lhs); + const rhs = try self.temporary(bin_op.rhs); - return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); + const result = try self.cmp(op, lhs, rhs); + return try result.materialize(self); } fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const vec_cmp = self.air.extraData(Air.VectorCmp, ty_pl.payload).data; - const lhs_id = try self.resolve(vec_cmp.lhs); - const rhs_id = try self.resolve(vec_cmp.rhs); + const lhs = try self.temporary(vec_cmp.lhs); + const rhs = try self.temporary(vec_cmp.rhs); const op = vec_cmp.compareOperator(); - const ty = self.typeOf(vec_cmp.lhs); - const result_ty = self.typeOfIndex(inst); - return try self.cmp(op, result_ty, ty, lhs_id, rhs_id); + const result = try self.cmp(op, lhs, rhs); + return try result.materialize(self); } /// Bitcast one type to another. Note: both types, input, output are expected in **direct** representation. @@ -3881,7 +4468,8 @@ const DeclGen = struct { // should we change the representation of strange integers? if (dst_ty.zigTypeTag(mod) == .Int) { const info = self.arithmeticTypeInfo(dst_ty); - return try self.normalize(dst_ty, result_id, info); + const result = try self.normalize(Temporary.init(dst_ty, result_id), info); + return try result.materialize(self); } return result_id; @@ -3897,46 +4485,28 @@ const DeclGen = struct { fn airIntCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); - const src_ty = self.typeOf(ty_op.operand); + const src = try self.temporary(ty_op.operand); const dst_ty = self.typeOfIndex(inst); - const src_info = self.arithmeticTypeInfo(src_ty); + const src_info = self.arithmeticTypeInfo(src.ty); const dst_info = self.arithmeticTypeInfo(dst_ty); if (src_info.backing_bits == dst_info.backing_bits) { - return operand_id; + return try src.materialize(self); } - var wip = try self.elementWise(dst_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(src_ty, operand_id, i); - const value_id = self.spv.allocId(); - switch (dst_info.signedness) { - .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .signed_value = elem_id, - }), - .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = wip.ty_id, - .id_result = value_id, - .unsigned_value = elem_id, - }), - } + const converted = try self.buildIntConvert(dst_ty, src); - // Make sure to normalize the result if shrinking. - // Because strange ints are sign extended in their backing - // type, we don't need to normalize when growing the type. The - // representation is already the same. - if (dst_info.bits < src_info.bits) { - result_id.* = try self.normalize(wip.ty, value_id, dst_info); - } else { - result_id.* = value_id; - } - } - return try wip.finalize(); + // Make sure to normalize the result if shrinking. + // Because strange ints are sign extended in their backing + // type, we don't need to normalize when growing the type. The + // representation is already the same. + const result = if (dst_info.bits < src_info.bits) + try self.normalize(converted, dst_info) + else + converted; + + return try result.materialize(self); } fn intFromPtr(self: *DeclGen, operand_id: IdRef) !IdRef { @@ -4011,16 +4581,9 @@ const DeclGen = struct { fn airIntFromBool(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; - const operand_id = try self.resolve(un_op); - const result_ty = self.typeOfIndex(inst); - - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { - const elem_id = try wip.elementAt(Type.bool, operand_id, i); - result_id.* = try self.intFromBool(wip.ty, elem_id); - } - return try wip.finalize(); + const operand = try self.temporary(un_op); + const result = try self.intFromBool(operand); + return try result.materialize(self); } fn airFloatCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -4040,33 +4603,21 @@ const DeclGen = struct { fn airNot(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; - const operand_id = try self.resolve(ty_op.operand); + const operand = try self.temporary(ty_op.operand); const result_ty = self.typeOfIndex(inst); const info = self.arithmeticTypeInfo(result_ty); - var wip = try self.elementWise(result_ty, false); - defer wip.deinit(); + const result = switch (info.class) { + .bool => try self.buildUnary(.l_not, operand), + .float => unreachable, + .composite_integer => unreachable, // TODO + .strange_integer, .integer => blk: { + const complement = try self.buildUnary(.bit_not, operand); + break :blk try self.normalize(complement, info); + }, + }; - for (0..wip.results.len) |i| { - const args = .{ - .id_result_type = wip.ty_id, - .id_result = wip.allocId(i), - .operand = try wip.elementAt(result_ty, operand_id, i), - }; - switch (info.class) { - .bool => { - try self.func.body.emit(self.spv.gpa, .OpLogicalNot, args); - }, - .float => unreachable, - .composite_integer => unreachable, // TODO - .strange_integer, .integer => { - // Note: strange integer bits will be masked before operations that do not hold under modulo. - try self.func.body.emit(self.spv.gpa, .OpNot, args); - }, - } - } - - return try wip.finalize(); + return try result.materialize(self); } fn airArrayToSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { @@ -4338,8 +4889,11 @@ const DeclGen = struct { // For now, just generate a temporary and use that. // TODO: This backend probably also should use isByRef from llvm... + const is_vector = array_ty.isVector(mod); + + const elem_repr: Repr = if (is_vector) .direct else .indirect; const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct); - const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct); + const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, elem_repr); const tmp_id = self.spv.allocId(); try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{ @@ -4357,12 +4911,12 @@ const DeclGen = struct { const result_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpLoad, .{ - .id_result_type = try self.resolveType(elem_ty, .direct), + .id_result_type = try self.resolveType(elem_ty, elem_repr), .id_result = result_id, .pointer = elem_ptr_id, }); - if (array_ty.isVector(mod)) { + if (is_vector) { // Result is already in direct representation return result_id; } @@ -4585,7 +5139,10 @@ const DeclGen = struct { if (field_offset == 0) break :base_ptr_int field_ptr_int; const field_offset_id = try self.constInt(Type.usize, field_offset, .direct); - break :base_ptr_int try self.binOpSimple(Type.usize, field_ptr_int, field_offset_id, .OpISub); + const field_ptr_tmp = Temporary.init(Type.usize, field_ptr_int); + const field_offset_tmp = Temporary.init(Type.usize, field_offset_id); + const result = try self.buildBinary(.i_sub, field_ptr_tmp, field_offset_tmp); + break :base_ptr_int try result.materialize(self); }; const base_ptr = self.spv.allocId(); @@ -5400,13 +5957,17 @@ const DeclGen = struct { else loaded_id; - const payload_ty_id = try self.resolveType(ptr_ty, .direct); - const null_id = try self.spv.constNull(payload_ty_id); + const ptr_ty_id = try self.resolveType(ptr_ty, .direct); + const null_id = try self.spv.constNull(ptr_ty_id); + const null_tmp = Temporary.init(ptr_ty, null_id); + const ptr = Temporary.init(ptr_ty, ptr_id); + const op: std.math.CompareOperator = switch (pred) { .is_null => .eq, .is_non_null => .neq, }; - return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id); + const result = try self.cmp(op, ptr, null_tmp); + return try result.materialize(self); } const is_non_null_id = blk: { diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 88fe677345..d1b2171786 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -155,6 +155,9 @@ cache: struct { void_type: ?IdRef = null, int_types: std.AutoHashMapUnmanaged(std.builtin.Type.Int, IdRef) = .{}, float_types: std.AutoHashMapUnmanaged(std.builtin.Type.Float, IdRef) = .{}, + // This cache is required so that @Vector(X, u1) in direct representation has the + // same ID as @Vector(X, bool) in indirect representation. + vector_types: std.AutoHashMapUnmanaged(struct { IdRef, u32 }, IdRef) = .{}, } = .{}, /// Set of Decls, referred to by Decl.Index. @@ -194,6 +197,7 @@ pub fn deinit(self: *Module) void { self.cache.int_types.deinit(self.gpa); self.cache.float_types.deinit(self.gpa); + self.cache.vector_types.deinit(self.gpa); self.decls.deinit(self.gpa); self.decl_deps.deinit(self.gpa); @@ -474,13 +478,17 @@ pub fn floatType(self: *Module, bits: u16) !IdRef { } pub fn vectorType(self: *Module, len: u32, child_id: IdRef) !IdRef { - const result_id = self.allocId(); - try self.sections.types_globals_constants.emit(self.gpa, .OpTypeVector, .{ - .id_result = result_id, - .component_type = child_id, - .component_count = len, - }); - return result_id; + const entry = try self.cache.vector_types.getOrPut(self.gpa, .{ child_id, len }); + if (!entry.found_existing) { + const result_id = self.allocId(); + entry.value_ptr.* = result_id; + try self.sections.types_globals_constants.emit(self.gpa, .OpTypeVector, .{ + .id_result = result_id, + .component_type = child_id, + .component_count = len, + }); + } + return entry.value_ptr.*; } pub fn constUndef(self: *Module, ty_id: IdRef) !IdRef { diff --git a/test/behavior/abs.zig b/test/behavior/abs.zig index 8ca160faff..21f02b2a3d 100644 --- a/test/behavior/abs.zig +++ b/test/behavior/abs.zig @@ -152,7 +152,6 @@ test "@abs int vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsIntVectors(1); try testAbsIntVectors(1); diff --git a/test/behavior/floatop.zig b/test/behavior/floatop.zig index 2e18b58d3c..d32319c644 100644 --- a/test/behavior/floatop.zig +++ b/test/behavior/floatop.zig @@ -275,7 +275,6 @@ test "@sqrt f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -287,7 +286,6 @@ test "@sqrt f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -389,7 +387,6 @@ test "@sqrt with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testSqrtWithVectors(); @@ -410,7 +407,6 @@ test "@sin f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -422,7 +418,6 @@ test "@sin f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -464,7 +459,6 @@ test "@sin with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -486,7 +480,6 @@ test "@cos f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -498,7 +491,6 @@ test "@cos f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -540,7 +532,6 @@ test "@cos with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -574,7 +565,6 @@ test "@tan f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -616,7 +606,6 @@ test "@tan with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -638,7 +627,6 @@ test "@exp f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -650,7 +638,6 @@ test "@exp f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -696,7 +683,6 @@ test "@exp with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -718,7 +704,6 @@ test "@exp2 f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -730,7 +715,6 @@ test "@exp2 f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -771,7 +755,6 @@ test "@exp2 with @vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -793,7 +776,6 @@ test "@log f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -805,7 +787,6 @@ test "@log f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -847,7 +828,6 @@ test "@log with @vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -866,7 +846,6 @@ test "@log2 f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -878,7 +857,6 @@ test "@log2 f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -919,7 +897,6 @@ test "@log2 with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/13681 if (builtin.zig_backend == .stage2_llvm and @@ -945,7 +922,6 @@ test "@log10 f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -957,7 +933,6 @@ test "@log10 f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -998,7 +973,6 @@ test "@log10 with vectors" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1243,7 +1217,6 @@ test "@ceil f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testCeil(f16); @@ -1255,7 +1228,6 @@ test "@ceil f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try testCeil(f32); @@ -1320,7 +1292,6 @@ test "@ceil with vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; @@ -1344,7 +1315,6 @@ test "@trunc f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch.isMIPS()) { @@ -1361,7 +1331,6 @@ test "@trunc f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch.isMIPS()) { @@ -1430,7 +1399,6 @@ fn testTrunc(comptime T: type) !void { test "@trunc with vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and @@ -1454,7 +1422,6 @@ test "neg f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1472,7 +1439,6 @@ test "neg f32/f64" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 59937515ab..66f86ede89 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -440,7 +440,6 @@ test "division" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -530,7 +529,6 @@ test "division half-precision floats" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -622,7 +620,6 @@ test "negation wrapping" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try expectEqual(@as(u1, 1), negateWrap(u1, 1)); } @@ -1031,6 +1028,60 @@ test "@mulWithOverflow bitsize > 32" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + { + var a: u40 = 3; + var b: u40 = 0x55_5555_5555; + var ov = @mulWithOverflow(a, b); + + try expect(ov[0] == 0xff_ffff_ffff); + try expect(ov[1] == 0); + + // Check that overflow bits in the low-word of wide-multiplications are checked too. + // Intermediate result is less than 2**64 + b = 0x55_5555_5556; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 2); + try expect(ov[1] == 1); + + // Check that overflow bits in the high-word of wide-multiplications are checked too. + // Intermediate result is more than 2**64 and bits 40..64 are not set. + a = 0x10_0000_0000; + b = 0x10_0000_0000; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 0); + try expect(ov[1] == 1); + } + + { + var a: i40 = 3; + var b: i40 = -0x2a_aaaa_aaaa; + var ov = @mulWithOverflow(a, b); + + try expect(ov[0] == -0x7f_ffff_fffe); + try expect(ov[1] == 0); + + // Check that the sign bit is properly checked + b = -0x2a_aaaa_aaab; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == 0x7f_ffff_ffff); + try expect(ov[1] == 1); + + // Check that the low-order bits above the sign are checked. + a = 6; + ov = @mulWithOverflow(a, b); + try expect(ov[0] == -2); + try expect(ov[1] == 1); + + // Check that overflow bits in the high-word of wide-multiplications are checked too. + // high parts and sign of low-order bits are all 1. + a = 0x08_0000_0000; + b = -0x08_0000_0001; + ov = @mulWithOverflow(a, b); + + try expect(ov[0] == -0x8_0000_0000); + try expect(ov[1] == 1); + } + { var a: u62 = 3; _ = &a; @@ -1580,7 +1631,6 @@ test "@round f16" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; @@ -1592,7 +1642,6 @@ test "@round f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; diff --git a/test/behavior/select.zig b/test/behavior/select.zig index 90166dcfe5..f2a6cf8a63 100644 --- a/test/behavior/select.zig +++ b/test/behavior/select.zig @@ -8,7 +8,6 @@ test "@select vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; try comptime selectVectors(); @@ -39,7 +38,6 @@ test "@select arrays" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) return error.SkipZigTest; diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index 8987e0c091..2e860e1001 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -548,7 +548,6 @@ test "vector division operators" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_llvm and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest; if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; const S = struct { From a567f3871ec06f3e6a8c0e6424aba556f1069ccc Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Tue, 4 Jun 2024 22:09:15 +0200 Subject: [PATCH 7/7] spirv: improve shuffle codegen --- src/codegen/spirv.zig | 63 +++++++++++++++++++++++++---- test/behavior/shuffle.zig | 83 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 8 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 215a9421f1..09185211ef 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -4082,25 +4082,72 @@ const DeclGen = struct { const b = try self.resolve(extra.b); const mask = Value.fromInterned(extra.mask); - const ty = self.typeOfIndex(inst); + // Note: number of components in the result, a, and b may differ. + const result_ty = self.typeOfIndex(inst); + const a_ty = self.typeOf(extra.a); + const b_ty = self.typeOf(extra.b); - var wip = try self.elementWise(ty, true); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { + const scalar_ty = result_ty.scalarType(mod); + const scalar_ty_id = try self.resolveType(scalar_ty, .direct); + + // If all of the types are SPIR-V vectors, we can use OpVectorShuffle. + if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) { + // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are + // numbered consecutively instead of using negatives. + + const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + const a_len = a_ty.vectorLen(mod); + + for (components, 0..) |*component, i| { + const elem = try mask.elemValue(mod, i); + if (elem.isUndef(mod)) { + // This is explicitly valid for OpVectorShuffle, it indicates undefined. + component.* = 0xFFFF_FFFF; + continue; + } + + const index = elem.toSignedInt(mod); + if (index >= 0) { + component.* = @intCast(index); + } else { + component.* = @intCast(~index + a_len); + } + } + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{ + .id_result_type = try self.resolveType(result_ty, .direct), + .id_result = result_id, + .vector_1 = a, + .vector_2 = b, + .components = components, + }); + return result_id; + } + + // Fall back to manually extracting and inserting components. + + const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + for (components, 0..) |*id, i| { const elem = try mask.elemValue(mod, i); if (elem.isUndef(mod)) { - result_id.* = try self.spv.constUndef(wip.ty_id); + id.* = try self.spv.constUndef(scalar_ty_id); continue; } const index = elem.toSignedInt(mod); if (index >= 0) { - result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index)); + id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index)); } else { - result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index)); + id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index)); } } - return try wip.finalize(); + + return try self.constructVector(result_ty, components); } fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef { diff --git a/test/behavior/shuffle.zig b/test/behavior/shuffle.zig index fb16f3fbb3..2bcdbd1581 100644 --- a/test/behavior/shuffle.zig +++ b/test/behavior/shuffle.zig @@ -2,6 +2,7 @@ const std = @import("std"); const builtin = @import("builtin"); const mem = std.mem; const expect = std.testing.expect; +const expectEqual = std.testing.expectEqual; test "@shuffle int" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO @@ -49,6 +50,88 @@ test "@shuffle int" { try comptime S.doTheTest(); } +test "@shuffle int strange sizes" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + + try comptime testShuffle(2, 2, 2); + try testShuffle(2, 2, 2); + try comptime testShuffle(4, 4, 4); + try testShuffle(4, 4, 4); + try comptime testShuffle(7, 4, 4); + try testShuffle(7, 4, 4); + try comptime testShuffle(8, 6, 4); + try testShuffle(8, 6, 4); + try comptime testShuffle(2, 7, 5); + try testShuffle(2, 7, 5); + try comptime testShuffle(13, 16, 12); + try testShuffle(13, 16, 12); + try comptime testShuffle(19, 3, 17); + try testShuffle(19, 3, 17); + try comptime testShuffle(1, 10, 1); + try testShuffle(1, 10, 1); +} + +fn testShuffle( + comptime x_len: comptime_int, + comptime a_len: comptime_int, + comptime b_len: comptime_int, +) !void { + const T = i32; + const XT = @Vector(x_len, T); + const AT = @Vector(a_len, T); + const BT = @Vector(b_len, T); + + const a_elems = comptime blk: { + var elems: [a_len]T = undefined; + for (&elems, 0..) |*elem, i| elem.* = @intCast(100 + i); + break :blk elems; + }; + var a: AT = a_elems; + _ = &a; + + const b_elems = comptime blk: { + var elems: [b_len]T = undefined; + for (&elems, 0..) |*elem, i| elem.* = @intCast(1000 + i); + break :blk elems; + }; + var b: BT = b_elems; + _ = &b; + + const mask_seed: []const i32 = &.{ -14, -31, 23, 1, 21, 13, 17, -21, -10, -27, -16, -5, 15, 14, -2, 26, 2, -31, -24, -16 }; + + const mask = comptime blk: { + var elems: [x_len]i32 = undefined; + for (&elems, 0..) |*elem, i| { + const mask_val = mask_seed[i]; + if (mask_val >= 0) { + elem.* = @mod(mask_val, a_len); + } else { + elem.* = @mod(mask_val, -b_len); + } + } + + break :blk elems; + }; + + const x: XT = @shuffle(T, a, b, mask); + + const x_elems: [x_len]T = x; + for (mask, x_elems) |m, x_elem| { + if (m >= 0) { + // Element from A + try expectEqual(x_elem, a_elems[@intCast(m)]); + } else { + // Element from B + try expectEqual(x_elem, b_elems[@intCast(~m)]); + } + } +} + test "@shuffle bool 1" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO