From 631d1b63a8027c49073995e28aab489534f01efa Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 21 Jan 2024 20:38:56 +0100 Subject: [PATCH] spirv: fix shuffle properly --- src/codegen/spirv.zig | 32 +++++++++++++------------------- test/behavior/abs.zig | 1 - test/behavior/cast.zig | 2 -- test/behavior/shuffle.zig | 3 --- test/behavior/vector.zig | 2 -- 5 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index be3c9957e1..28f2c1677c 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2876,37 +2876,31 @@ const DeclGen = struct { fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const mod = self.module; if (self.liveness.isUnused(inst)) return null; - const ty = self.typeOfIndex(inst); const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data; const a = try self.resolve(extra.a); const b = try self.resolve(extra.b); const mask = Value.fromInterned(extra.mask); - const mask_len = extra.mask_len; - const a_len = self.typeOf(extra.a).vectorLen(mod); - const result_id = self.spv.allocId(); - const result_type_id = try self.resolveTypeId(ty); - // Similar to LLVM, SPIR-V uses indices larger than the length of the first vector - // to index into the second vector. - try self.func.body.emitRaw(self.spv.gpa, .OpVectorShuffle, 4 + mask_len); - self.func.body.writeOperand(spec.IdResultType, result_type_id); - self.func.body.writeOperand(spec.IdResult, result_id); - self.func.body.writeOperand(spec.IdRef, a); - self.func.body.writeOperand(spec.IdRef, b); + const ty = self.typeOfIndex(inst); - var i: usize = 0; - while (i < mask_len) : (i += 1) { + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (wip.results, 0..) |*result_id, i| { const elem = try mask.elemValue(mod, i); if (elem.isUndef(mod)) { - self.func.body.writeOperand(spec.LiteralInteger, 0xFFFF_FFFF); + result_id.* = try self.spv.constUndef(wip.scalar_ty_ref); + continue; + } + + const index = elem.toSignedInt(mod); + if (index >= 0) { + result_id.* = try self.extractField(wip.scalar_ty, a, @intCast(index)); } else { - const int = elem.toSignedInt(mod); - const unsigned = if (int >= 0) @as(u32, @intCast(int)) else @as(u32, @intCast(~int + a_len)); - self.func.body.writeOperand(spec.LiteralInteger, unsigned); + result_id.* = try self.extractField(wip.scalar_ty, b, @intCast(~index)); } } - return result_id; + return try wip.finalize(); } fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef { diff --git a/test/behavior/abs.zig b/test/behavior/abs.zig index fad29a1a58..d8666405a0 100644 --- a/test/behavior/abs.zig +++ b/test/behavior/abs.zig @@ -224,7 +224,6 @@ test "@abs unsigned int vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsUnsignedIntVectors(1); try testAbsUnsignedIntVectors(1); diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig index be25bde693..48feb86ef1 100644 --- a/test/behavior/cast.zig +++ b/test/behavior/cast.zig @@ -605,7 +605,6 @@ test "@intCast on vector" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { @@ -2508,7 +2507,6 @@ test "@intCast vector of signed integer" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO diff --git a/test/behavior/shuffle.zig b/test/behavior/shuffle.zig index e9d7706ff4..95913be3af 100644 --- a/test/behavior/shuffle.zig +++ b/test/behavior/shuffle.zig @@ -8,7 +8,6 @@ test "@shuffle int" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { @@ -54,7 +53,6 @@ test "@shuffle bool 1" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const S = struct { fn doTheTest() !void { @@ -77,7 +75,6 @@ test "@shuffle bool 2" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm) { // https://github.com/ziglang/zig/issues/3246 diff --git a/test/behavior/vector.zig b/test/behavior/vector.zig index b23eac924d..26d60c337a 100644 --- a/test/behavior/vector.zig +++ b/test/behavior/vector.zig @@ -910,7 +910,6 @@ test "mask parameter of @shuffle is comptime scope" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const __v4hi = @Vector(4, i16); var v4_a = __v4hi{ 0, 0, 0, 0 }; @@ -1322,7 +1321,6 @@ test "array operands to shuffle are coerced to vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; const mask = [5]i32{ -1, 0, 1, 2, 3 };