Merge pull request #11605 from Luukdegram/wasm-mul-overflow

stage2: wasm - Improve `@mulWithOverflow` implementation
This commit is contained in:
Jakub Konka 2022-05-07 23:30:08 +02:00 committed by GitHub
commit f161d3875a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 186 additions and 130 deletions

View File

@ -1424,7 +1424,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.add_with_overflow => self.airBinOpOverflow(inst, .add),
.sub_with_overflow => self.airBinOpOverflow(inst, .sub),
.shl_with_overflow => self.airBinOpOverflow(inst, .shl),
.mul_with_overflow => self.airBinOpOverflow(inst, .mul),
.mul_with_overflow => self.airMulWithOverflow(inst),
.clz => self.airClz(inst),
.ctz => self.airCtz(inst),
@ -1822,7 +1822,7 @@ fn store(self: *Self, lhs: WValue, rhs: WValue, ty: Type, offset: u32) InnerErro
const opcode = buildOpcode(.{
.valtype1 = valtype,
.width = abi_size * 8, // use bitsize instead of byte size
.width = abi_size * 8,
.op = .store,
});
@ -1852,21 +1852,13 @@ fn airLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue {
// load local's value from memory by its stack position
try self.emitWValue(operand);
// Build the opcode with the right bitsize
const signedness: std.builtin.Signedness = if (ty.isUnsignedInt() or
ty.zigTypeTag() == .ErrorSet or
ty.zigTypeTag() == .Bool)
.unsigned
else
.signed;
const abi_size = @intCast(u8, ty.abiSize(self.target));
const opcode = buildOpcode(.{
.valtype1 = typeToValtype(ty, self.target),
.width = abi_size * 8, // use bitsize instead of byte size
.width = abi_size * 8,
.op = .load,
.signedness = signedness,
.signedness = .unsigned,
});
try self.addMemArg(
@ -1935,7 +1927,14 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
return self.wrapBinOp(lhs, rhs, self.air.typeOf(bin_op.lhs), op);
const ty = self.air.typeOf(bin_op.lhs);
if (ty.zigTypeTag() == .Vector) {
return self.fail("TODO: Implement wrapping arithmetic for vectors", .{});
} else if (ty.abiSize(self.target) > 8) {
return self.fail("TODO: Implement wrapping arithmetic for bitsize > 64", .{});
}
return self.wrapBinOp(lhs, rhs, ty, op);
}
fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
@ -1948,38 +1947,28 @@ fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
});
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
const int_info = ty.intInfo(self.target);
const bitsize = int_info.bits;
const is_signed = int_info.signedness == .signed;
// if target type bitsize is x < 32 and 32 > x < 64, we perform
// result & ((1<<N)-1) where N = bitsize or bitsize -1 incase of signed.
if (bitsize != 32 and bitsize < 64) {
// first check if we can use a single instruction,
// wasm provides those if the integers are signed and 8/16-bit.
// For arbitrary integer sizes, we use the algorithm mentioned above.
if (is_signed and bitsize == 8) {
try self.addTag(.i32_extend8_s);
} else if (is_signed and bitsize == 16) {
try self.addTag(.i32_extend16_s);
} else {
const result = (@as(u64, 1) << @intCast(u6, bitsize - @boolToInt(is_signed))) - 1;
if (bitsize < 32) {
try self.addImm32(@bitCast(i32, @intCast(u32, result)));
try self.addTag(.i32_and);
} else {
try self.addImm64(result);
try self.addTag(.i64_and);
}
}
} else if (int_info.bits > 64) {
return self.fail("TODO wasm: Integer wrapping for bitsizes larger than 64", .{});
}
// save the result in a temporary
const bin_local = try self.allocLocal(ty);
try self.addLabel(.local_set, bin_local.local);
return bin_local;
return self.wrapOperand(bin_local, ty);
}
/// Wraps an operand based on a given type's bitsize.
/// Asserts `Type` is <= 64bits.
fn wrapOperand(self: *Self, operand: WValue, ty: Type) InnerError!WValue {
assert(ty.abiSize(self.target) <= 8);
const result_local = try self.allocLocal(ty);
const bitsize = ty.intInfo(self.target).bits;
const result = @intCast(u64, (@as(u65, 1) << @intCast(u7, bitsize)) - 1);
try self.emitWValue(operand);
if (bitsize <= 32) {
try self.addImm32(@bitCast(i32, @intCast(u32, result)));
try self.addTag(.i32_and);
} else {
try self.addImm64(result);
try self.addTag(.i64_and);
}
try self.addLabel(.local_set, result_local.local);
return result_local;
}
fn lowerParentPtr(self: *Self, ptr_val: Value, ptr_child_ty: Type) InnerError!WValue {
@ -2098,6 +2087,22 @@ fn lowerDeclRefValue(self: *Self, tv: TypedValue, decl_index: Module.Decl.Index)
} else return WValue{ .memory = target_sym_index };
}
/// Converts a signed integer to its 2's complement form and returns
/// an unsigned integer instead.
/// Asserts bitsize <= 64
fn toTwosComplement(value: anytype, bits: u7) std.meta.Int(.unsigned, @typeInfo(@TypeOf(value)).Int.bits) {
const T = @TypeOf(value);
comptime assert(@typeInfo(T) == .Int);
comptime assert(@typeInfo(T).Int.signedness == .signed);
assert(bits <= 64);
const WantedT = std.meta.Int(.unsigned, @typeInfo(T).Int.bits);
if (value >= 0) return @bitCast(WantedT, value);
const max_value = @intCast(u64, (@as(u65, 1) << bits) - 1);
const flipped = (~-value) + 1;
const result = @bitCast(WantedT, flipped) & max_value;
return @intCast(WantedT, result);
}
fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
if (val.isUndefDeep()) return self.emitUndefined(ty);
if (val.castTag(.decl_ref)) |decl_ref| {
@ -2114,10 +2119,12 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
switch (ty.zigTypeTag()) {
.Int => {
const int_info = ty.intInfo(self.target);
// write constant
switch (int_info.signedness) {
.signed => switch (int_info.bits) {
0...32 => return WValue{ .imm32 = @bitCast(u32, @intCast(i32, val.toSignedInt())) },
0...32 => return WValue{ .imm32 = @intCast(u32, toTwosComplement(
val.toSignedInt(),
@intCast(u6, int_info.bits),
)) },
33...64 => return WValue{ .imm64 = @bitCast(u64, val.toSignedInt()) },
else => unreachable,
},
@ -2832,30 +2839,38 @@ fn airIntcast(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const ty = self.air.getRefType(ty_op.ty);
const operand = try self.resolveInst(ty_op.operand);
const ref_ty = self.air.typeOf(ty_op.operand);
const ref_info = ref_ty.intInfo(self.target);
const wanted_info = ty.intInfo(self.target);
const operand_ty = self.air.typeOf(ty_op.operand);
if (ty.abiSize(self.target) > 8 or operand_ty.abiSize(self.target) > 8) {
return self.fail("todo Wasm intcast for bitsize > 64", .{});
}
const op_bits = toWasmBits(ref_info.bits) orelse
return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{ref_info.bits});
const wanted_bits = toWasmBits(wanted_info.bits) orelse
return self.fail("TODO: Wasm intcast integer types of bitsize: {d}", .{wanted_info.bits});
return self.intcast(operand, operand_ty, ty);
}
// hot path
/// Upcasts or downcasts an integer based on the given and wanted types,
/// and stores the result in a new operand.
/// Asserts type's bitsize <= 64
fn intcast(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue {
const given_info = given.intInfo(self.target);
const wanted_info = wanted.intInfo(self.target);
assert(given_info.bits <= 64);
assert(wanted_info.bits <= 64);
const op_bits = toWasmBits(given_info.bits).?;
const wanted_bits = toWasmBits(wanted_info.bits).?;
if (op_bits == wanted_bits) return operand;
try self.emitWValue(operand);
if (op_bits > 32 and wanted_bits == 32) {
try self.emitWValue(operand);
try self.addTag(.i32_wrap_i64);
} else if (op_bits == 32 and wanted_bits > 32) {
try self.emitWValue(operand);
try self.addTag(switch (ref_info.signedness) {
try self.addTag(switch (wanted_info.signedness) {
.signed => .i64_extend_i32_s,
.unsigned => .i64_extend_i32_u,
});
} else unreachable;
const result = try self.allocLocal(ty);
const result = try self.allocLocal(wanted);
try self.addLabel(.local_set, result.local);
return result;
}
@ -3072,63 +3087,17 @@ fn airSlicePtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
}
fn airTrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
if (self.liveness.isUnused(inst)) return WValue.none;
if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const operand = try self.resolveInst(ty_op.operand);
const op_ty = self.air.typeOf(ty_op.operand);
const int_info = self.air.getRefType(ty_op.ty).intInfo(self.target);
const wanted_ty = self.air.getRefType(ty_op.ty);
const int_info = wanted_ty.intInfo(self.target);
const wanted_bits = int_info.bits;
const result = try self.allocLocal(self.air.getRefType(ty_op.ty));
const op_bits = op_ty.intInfo(self.target).bits;
const wasm_bits = toWasmBits(wanted_bits) orelse
_ = toWasmBits(wanted_bits) orelse {
return self.fail("TODO: Implement wasm integer truncation for integer bitsize: {d}", .{wanted_bits});
// Use wasm's instruction to wrap from 64bit to 32bit integer when possible
if (op_bits == 64 and wanted_bits == 32) {
try self.emitWValue(operand);
try self.addTag(.i32_wrap_i64);
try self.addLabel(.local_set, result.local);
return result;
}
// Any other truncation must be done manually
if (int_info.signedness == .unsigned) {
const mask = (@as(u65, 1) << @intCast(u7, wanted_bits)) - 1;
try self.emitWValue(operand);
switch (wasm_bits) {
32 => {
try self.addImm32(@bitCast(i32, @intCast(u32, mask)));
try self.addTag(.i32_and);
},
64 => {
try self.addImm64(@intCast(u64, mask));
try self.addTag(.i64_and);
},
else => unreachable,
}
} else {
const shift_bits = wasm_bits - wanted_bits;
try self.emitWValue(operand);
switch (wasm_bits) {
32 => {
try self.addImm32(@bitCast(i16, shift_bits));
try self.addTag(.i32_shl);
try self.addImm32(@bitCast(i16, shift_bits));
try self.addTag(.i32_shr_s);
},
64 => {
try self.addImm64(shift_bits);
try self.addTag(.i64_shl);
try self.addImm64(shift_bits);
try self.addTag(.i64_shr_s);
},
else => unreachable,
}
}
try self.addLabel(.local_set, result.local);
return result;
};
return self.wrapOperand(operand, wanted_ty);
}
fn airBoolToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@ -3418,7 +3387,8 @@ fn airFloatToInt(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const result = try self.allocLocal(dest_ty);
try self.addLabel(.local_set, result.local);
return result;
return self.wrapOperand(result, dest_ty);
}
fn airIntToFloat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@ -3922,6 +3892,10 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue
const rhs = try self.resolveInst(extra.rhs);
const lhs_ty = self.air.typeOf(extra.lhs);
if (lhs_ty.zigTypeTag() == .Vector) {
return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
}
// We store the bit if it's overflowed or not in this. As it's zero-initialized
// we only need to update it if an overflow (or underflow) occured.
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
@ -4008,23 +3982,6 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue
}
try self.addLabel(.local_set, tmp_val.local);
break :blk tmp_val;
} else if (op == .mul) blk: {
const bin_op = try self.wrapBinOp(lhs, rhs, lhs_ty, op);
try self.startBlock(.block, wasm.block_empty);
// check if 0. true => Break out of block as cannot over -or underflow.
try self.emitWValue(lhs);
switch (wasm_bits) {
32 => try self.addTag(.i32_eqz),
64 => try self.addTag(.i64_eqz),
else => unreachable,
}
try self.addLabel(.br_if, 0);
const div = try self.binOp(bin_op, lhs, lhs_ty, .div);
const cmp_res = try self.cmp(div, rhs, lhs_ty, .neq);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
try self.endBlock();
break :blk bin_op;
} else try self.wrapBinOp(lhs, rhs, lhs_ty, op);
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
@ -4035,6 +3992,99 @@ fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue
return result_ptr;
}
fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
const lhs = try self.resolveInst(extra.lhs);
const rhs = try self.resolveInst(extra.rhs);
const lhs_ty = self.air.typeOf(extra.lhs);
if (lhs_ty.zigTypeTag() == .Vector) {
return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
}
// We store the bit if it's overflowed or not in this. As it's zero-initialized
// we only need to update it if an overflow (or underflow) occured.
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
const int_info = lhs_ty.intInfo(self.target);
const wasm_bits = toWasmBits(int_info.bits) orelse {
return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits});
};
if (wasm_bits == 64) {
return self.fail("TODO: Implement `@mulWithOverflow` for integer bitsize: {d}", .{int_info.bits});
}
const zero = switch (wasm_bits) {
32 => WValue{ .imm32 = 0 },
64 => WValue{ .imm64 = 0 },
else => unreachable,
};
// for 32 bit integers we upcast it to a 64bit integer
const bin_op = if (int_info.bits == 32) blk: {
const new_ty = if (int_info.signedness == .signed) Type.i64 else Type.u64;
const lhs_upcast = try self.intcast(lhs, lhs_ty, new_ty);
const rhs_upcast = try self.intcast(rhs, lhs_ty, new_ty);
const bin_op = try self.binOp(lhs_upcast, rhs_upcast, new_ty, .mul);
if (int_info.signedness == .unsigned) {
const shr = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
const wrap = try self.intcast(shr, new_ty, lhs_ty);
const cmp_res = try self.cmp(wrap, zero, lhs_ty, .neq);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.intcast(bin_op, new_ty, lhs_ty);
} else {
const down_cast = try self.intcast(bin_op, new_ty, lhs_ty);
const shr = try self.binOp(down_cast, .{ .imm32 = int_info.bits - 1 }, lhs_ty, .shr);
const shr_res = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
const down_shr_res = try self.intcast(shr_res, new_ty, lhs_ty);
const cmp_res = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
try self.emitWValue(cmp_res);
try self.addLabel(.local_set, overflow_bit.local);
break :blk down_cast;
}
} else if (int_info.signedness == .signed) blk: {
const shift_imm = if (wasm_bits == 32)
WValue{ .imm32 = wasm_bits - int_info.bits }
else
WValue{ .imm64 = wasm_bits - int_info.bits };
const lhs_shl = try self.binOp(lhs, shift_imm, lhs_ty, .shl);
const lhs_shr = try self.binOp(lhs_shl, shift_imm, lhs_ty, .shr);
const rhs_shl = try self.binOp(rhs, shift_imm, lhs_ty, .shl);
const rhs_shr = try self.binOp(rhs_shl, shift_imm, lhs_ty, .shr);
const bin_op = try self.binOp(lhs_shr, rhs_shr, lhs_ty, .mul);
const shl = try self.binOp(bin_op, shift_imm, lhs_ty, .shl);
const shr = try self.binOp(shl, shift_imm, lhs_ty, .shr);
const cmp_op = try self.cmp(shr, bin_op, lhs_ty, .neq);
try self.emitWValue(cmp_op);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.wrapOperand(bin_op, lhs_ty);
} else blk: {
const bin_op = try self.binOp(lhs, rhs, lhs_ty, .mul);
const shift_imm = if (wasm_bits == 32)
WValue{ .imm32 = int_info.bits }
else
WValue{ .imm64 = int_info.bits };
const shr = try self.binOp(bin_op, shift_imm, lhs_ty, .shr);
const cmp_op = try self.cmp(shr, zero, lhs_ty, .neq);
try self.emitWValue(cmp_op);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.wrapOperand(bin_op, lhs_ty);
};
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
try self.store(result_ptr, bin_op, lhs_ty, 0);
const offset = @intCast(u32, lhs_ty.abiSize(self.target));
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
return result_ptr;
}
fn airMaxMin(self: *Self, inst: Air.Inst.Index, op: enum { max, min }) InnerError!WValue {
if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
const bin_op = self.air.instructions.items(.data)[inst].bin_op;

View File

@ -6005,6 +6005,7 @@ pub const Type = extern union {
pub const @"u64" = initTag(.u64);
pub const @"i32" = initTag(.i32);
pub const @"i64" = initTag(.i64);
pub const @"f16" = initTag(.f16);
pub const @"f32" = initTag(.f32);

View File

@ -687,7 +687,6 @@ test "basic @mulWithOverflow" {
test "extensive @mulWithOverflow" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
{
var a: u5 = 3;
@ -833,6 +832,12 @@ test "extensive @mulWithOverflow" {
try expect(@mulWithOverflow(i32, a, b, &res));
try expect(res == 0x7fffffff);
}
}
test "@mulWithOverflow bitsize > 32" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
{
var a: u62 = 3;

View File

@ -1,6 +1,6 @@
pub fn main() void {
var i: i4 = 3;
if (i *% 3 != 1) unreachable;
if (i *% 3 != -7) unreachable;
return;
}

View File

@ -1,6 +1,6 @@
pub fn main() void {
var i: i4 = 7;
if (i +% 1 != 0) unreachable;
if (i +% 1 != -8) unreachable;
return;
}