From 76d5696434095e39d9aaae92c1533b2d016c1a31 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 21 Jan 2024 22:24:53 +0100 Subject: [PATCH] spirv: air abs --- src/codegen/spirv.zig | 71 +++++++++++++++++++++++++++++++++++++++ test/behavior/abs.zig | 7 ++-- test/behavior/cast.zig | 1 - test/behavior/floatop.zig | 3 -- test/behavior/math.zig | 1 - 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 28f2c1677c..9e7c49d1a5 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -678,6 +678,18 @@ const DeclGen = struct { } } + /// Emits a float constant + fn constFloat(self: *DeclGen, ty_ref: CacheRef, value: f128) !IdRef { + const ty = self.spv.cache.lookup(ty_ref).float_type; + return switch (ty.bits) { + 16 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float16 = @floatCast(value) } } }), + 32 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float32 = @floatCast(value) } } }), + 64 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float64 = @floatCast(value) } } }), + 80, 128 => unreachable, // TODO + else => unreachable, + }; + } + /// Construct a struct at runtime. /// ty must be a struct type. /// Constituents should be in `indirect` representation (as the elements of a struct should be). @@ -2164,6 +2176,8 @@ const DeclGen = struct { .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub), .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul), + .abs => try self.airAbs(inst), + .div_float, .div_float_optimized, // TODO: Check that this is the right operation. @@ -2562,6 +2576,63 @@ const DeclGen = struct { return try wip.finalize(); } + fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const mod = self.module; + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const operand_id = try self.resolve(ty_op.operand); + // Note: operand_ty may be signed, while ty is always unsigned! + const operand_ty = self.typeOf(ty_op.operand); + const ty = self.typeOfIndex(inst); + const info = self.arithmeticTypeInfo(ty); + const operand_scalar_ty = operand_ty.scalarType(mod); + const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct); + + var wip = try self.elementWise(ty); + defer wip.deinit(); + + const zero_id = switch (info.class) { + .float => try self.constFloat(operand_scalar_ty_ref, 0), + .integer, .strange_integer => try self.constInt(operand_scalar_ty_ref, 0), + .composite_integer => unreachable, // TODO + .bool => unreachable, + }; + for (wip.results, 0..) |*result_id, i| { + const elem_id = try wip.elementAt(operand_ty, operand_id, i); + // Idk why spir-v doesn't have a dedicated abs() instruction in the base + // instruction set. For now we're just going to negate and check to avoid + // importing the extinst. + const neg_id = self.spv.allocId(); + const args = .{ + .id_result_type = self.typeId(operand_scalar_ty_ref), + .id_result = neg_id, + .operand_1 = zero_id, + .operand_2 = elem_id, + }; + switch (info.class) { + .float => try self.func.body.emit(self.spv.gpa, .OpFSub, args), + .integer, .strange_integer => try self.func.body.emit(self.spv.gpa, .OpISub, args), + .composite_integer => unreachable, // TODO + .bool => unreachable, + } + const neg_norm_id = try self.normalize(wip.scalar_ty_ref, neg_id, info); + + const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id); + const abs_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpSelect, .{ + .id_result_type = self.typeId(operand_scalar_ty_ref), + .id_result = abs_id, + .condition = gt_zero_id, + .object_1 = elem_id, + .object_2 = neg_norm_id, + }); + // For Shader, we may need to cast from signed to unsigned here. + result_id.* = try self.bitCast(wip.scalar_ty, operand_scalar_ty, abs_id); + } + return try wip.finalize(); + } + fn airAddSubOverflow( self: *DeclGen, inst: Air.Inst.Index, diff --git a/test/behavior/abs.zig b/test/behavior/abs.zig index d8666405a0..abea715ea0 100644 --- a/test/behavior/abs.zig +++ b/test/behavior/abs.zig @@ -7,7 +7,6 @@ test "@abs integers" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsIntegers(); try testAbsIntegers(); @@ -95,7 +94,6 @@ test "@abs floats" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest; try comptime testAbsFloats(f16); @@ -105,9 +103,9 @@ test "@abs floats" { try comptime testAbsFloats(f64); try testAbsFloats(f64); try comptime testAbsFloats(f80); - if (builtin.zig_backend != .stage2_wasm) try testAbsFloats(f80); + if (builtin.zig_backend != .stage2_wasm and builtin.zig_backend != .stage2_spirv64) try testAbsFloats(f80); try comptime testAbsFloats(f128); - if (builtin.zig_backend != .stage2_wasm) try testAbsFloats(f128); + if (builtin.zig_backend != .stage2_wasm and builtin.zig_backend != .stage2_spirv64) try testAbsFloats(f128); } fn testAbsFloats(comptime T: type) !void { @@ -155,7 +153,6 @@ test "@abs int vectors" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try comptime testAbsIntVectors(1); try testAbsIntVectors(1); diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig index 48feb86ef1..c59a9803c0 100644 --- a/test/behavior/cast.zig +++ b/test/behavior/cast.zig @@ -2465,7 +2465,6 @@ test "@as does not corrupt values with incompatible representations" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest; const x: f32 = @as(f16, blk: { diff --git a/test/behavior/floatop.zig b/test/behavior/floatop.zig index 568fe6deef..43654bc2b9 100644 --- a/test/behavior/floatop.zig +++ b/test/behavior/floatop.zig @@ -969,7 +969,6 @@ test "@abs f16" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest; - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testFabs(f16); try comptime testFabs(f16); @@ -979,7 +978,6 @@ test "@abs f32/f64" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testFabs(f32); try comptime testFabs(f32); @@ -1070,7 +1068,6 @@ fn testFabs(comptime T: type) !void { test "@abs with vectors" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO try testFabsWithVectors(); diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 3aa65dddbb..dc4d5f894a 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -1687,7 +1687,6 @@ test "absFloat" { if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testAbsFloat(); try comptime testAbsFloat();