From ce5d934f5f713a0dbc8787d9ffe58b4962042f8f Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Wed, 15 Jun 2022 22:03:18 +0200 Subject: [PATCH] wasm: saturating add and sub for signed integers --- src/arch/wasm/CodeGen.zig | 85 ++++++++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 18 deletions(-) diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 2e1d88083c..0a93db7fb5 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -4885,8 +4885,8 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { const bin_op = self.air.instructions.items(.data)[inst].bin_op; const ty = self.air.typeOfIndex(inst); - const lhs_operand = try self.resolveInst(bin_op.lhs); - const rhs_operand = try self.resolveInst(bin_op.rhs); + const lhs = try self.resolveInst(bin_op.lhs); + const rhs = try self.resolveInst(bin_op.rhs); const int_info = ty.intInfo(self.target); const is_signed = int_info.signedness == .signed; @@ -4895,22 +4895,12 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { return self.fail("TODO: saturating arithmetic for integers with bitsize '{d}'", .{int_info.bits}); } + if (is_signed) { + return signedSat(self, lhs, rhs, ty, op); + } + const wasm_bits = toWasmBits(int_info.bits).?; - - const lhs = if (is_signed) blk: { - break :blk try self.signAbsValue(lhs_operand, ty); - } else lhs_operand; - const rhs = if (is_signed) blk: { - break :blk try self.signAbsValue(rhs_operand, ty); - } else rhs_operand; - - const opcode = buildOpcode(.{ .op = op, .valtype1 = typeToValtype(ty, self.target) }); - try self.emitWValue(lhs); - try self.emitWValue(rhs); - try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); - const bin_result = try self.allocLocal(ty); - try self.addLabel(.local_set, bin_result.local); - + const bin_result = try self.binOp(lhs, rhs, ty, op); if (wasm_bits != int_info.bits and op == .add) { const val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits)) - 1); const imm_val = switch (wasm_bits) { @@ -4919,7 +4909,7 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { else => unreachable, }; - const cmp_result = try self.cmp(bin_result, imm_val, ty, if (op == .add) .lt else .gt); + const cmp_result = try self.cmp(bin_result, imm_val, ty, .lt); try self.emitWValue(bin_result); try self.emitWValue(imm_val); try self.emitWValue(cmp_result); @@ -4939,3 +4929,62 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { try self.addLabel(.local_set, result.local); return result; } + +fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op: Op) InnerError!WValue { + const int_info = ty.intInfo(self.target); + const wasm_bits = toWasmBits(int_info.bits).?; + const is_wasm_bits = wasm_bits == int_info.bits; + + const lhs = if (!is_wasm_bits) try self.signAbsValue(lhs_operand, ty) else lhs_operand; + const rhs = if (!is_wasm_bits) try self.signAbsValue(rhs_operand, ty) else rhs_operand; + + const max_val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits - 1)) - 1); + const min_val = @intCast(i64, ~@intCast(u63, max_val)); + const max_wvalue = switch (wasm_bits) { + 32 => WValue{ .imm32 = @intCast(u32, max_val) }, + 64 => WValue{ .imm64 = max_val }, + else => unreachable, + }; + const min_wvalue = switch (wasm_bits) { + 32 => WValue{ .imm32 = @bitCast(u32, @truncate(i32, min_val)) }, + 64 => WValue{ .imm64 = @bitCast(u64, min_val) }, + else => unreachable, + }; + + const bin_result = try self.binOp(lhs, rhs, ty, op); + if (!is_wasm_bits) { + const cmp_result_lt = try self.cmp(bin_result, max_wvalue, ty, .lt); + try self.emitWValue(bin_result); + try self.emitWValue(max_wvalue); + try self.emitWValue(cmp_result_lt); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + + const cmp_result_gt = try self.cmp(bin_result, min_wvalue, ty, .gt); + try self.emitWValue(bin_result); + try self.emitWValue(min_wvalue); + try self.emitWValue(cmp_result_gt); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + return self.wrapOperand(bin_result, ty); + } else { + const zero = switch (wasm_bits) { + 32 => WValue{ .imm32 = 0 }, + 64 => WValue{ .imm64 = 0 }, + else => unreachable, + }; + const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt); + const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt); + const xor = try self.binOp(cmp_zero_result, cmp_bin_result, ty, .xor); + const cmp_bin_zero_result = try self.cmp(bin_result, zero, ty, .lt); + try self.emitWValue(max_wvalue); + try self.emitWValue(min_wvalue); + try self.emitWValue(cmp_bin_zero_result); + try self.addTag(.select); + try self.emitWValue(bin_result); + try self.emitWValue(xor); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + return bin_result; + } +}