wasm: keep result of cmp on the stack

By keeping the result on the stack, we prevent codegen
from generating unneccesary locals when we have subsequent instructions
that do not have to be re-used.
This commit is contained in:
Luuk de Gram 2022-07-25 07:10:56 +02:00
parent cc6f2b67c6
commit 305b113a53
No known key found for this signature in database
GPG Key ID: A8CFE58E4DC7D664

View File

@ -2659,10 +2659,14 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const operand_ty = self.air.typeOf(bin_op.lhs);
return self.cmp(lhs, rhs, operand_ty, op);
return (try self.cmp(lhs, rhs, operand_ty, op)).toLocal(self, Type.u32); // comparison result is always 32 bits
}
/// Compares two operands.
/// Asserts rhs is not a stack value when the lhs isn't a stack value either
/// NOTE: This leaves the result on top of the stack, rather than a new local.
fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOperator) InnerError!WValue {
assert(!(lhs != .stack and rhs == .stack));
if (ty.zigTypeTag() == .Optional and !ty.optionalReprIsPayload()) {
var buf: Type.Payload.ElemType = undefined;
const payload_ty = ty.optionalChild(&buf);
@ -2704,9 +2708,7 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
});
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
const cmp_tmp = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
try self.addLabel(.local_set, cmp_tmp.local);
return cmp_tmp;
return WValue{ .stack = {} };
}
fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue {
@ -2729,9 +2731,7 @@ fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperato
try self.emitWValue(ext_rhs);
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
try self.addLabel(.local_set, result.local);
return result;
return WValue{ .stack = {} };
}
fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@ -3982,13 +3982,16 @@ fn cmpOptionals(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std
try self.addImm32(0);
try self.addTag(if (op == .eq) .i32_ne else .i32_eq);
try self.addLabel(.local_set, result.local);
return result;
try self.emitWValue(result);
return WValue{ .stack = {} };
}
/// Compares big integers by checking both its high bits and low bits.
/// NOTE: Leaves the result of the comparison on top of the stack.
/// TODO: Lower this to compiler_rt call when bitsize > 128
fn cmpBigInt(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std.math.CompareOperator) InnerError!WValue {
assert(operand_ty.abiSize(self.target) >= 16);
assert(!(lhs != .stack and rhs == .stack));
if (operand_ty.intInfo(self.target).bits > 128) {
return self.fail("TODO: Support cmpBigInt for integer bitsize: '{d}'", .{operand_ty.intInfo(self.target).bits});
}
@ -4012,20 +4015,15 @@ fn cmpBigInt(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std.ma
},
else => {
const ty = if (operand_ty.isSignedInt()) Type.i64 else Type.u64;
const high_bit_eql = try self.cmp(lhs_high_bit, rhs_high_bit, ty, .eq);
const high_bit_cmp = try self.cmp(lhs_high_bit, rhs_high_bit, ty, op);
const low_bit_cmp = try self.cmp(lhs_low_bit, rhs_low_bit, ty, op);
try self.emitWValue(low_bit_cmp);
try self.emitWValue(high_bit_cmp);
try self.emitWValue(high_bit_eql);
// leave those value on top of the stack for '.select'
_ = try self.cmp(lhs_low_bit, rhs_low_bit, ty, op);
_ = try self.cmp(lhs_high_bit, rhs_high_bit, ty, op);
_ = try self.cmp(lhs_high_bit, rhs_high_bit, ty, .eq);
try self.addTag(.select);
},
}
const result = try self.allocLocal(Type.initTag(.i32));
try self.addLabel(.local_set, result.local);
return result;
return WValue{ .stack = {} };
}
fn airSetUnionTag(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@ -4350,7 +4348,7 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
if (wasm_bits == int_info.bits) {
const cmp_zero = try self.cmp(rhs, zero, lhs_ty, cmp_op);
const lt = try self.cmp(bin_op, lhs, lhs_ty, .lt);
break :blk try (try self.binOp(cmp_zero, lt, Type.u32, .xor)).toLocal(self, Type.u32); // result of cmp_zero and lt is always 32bit
break :blk try self.binOp(cmp_zero, lt, Type.u32, .xor);
}
const abs = try self.signAbsValue(bin_op, lhs_ty);
break :blk try self.cmp(abs, bin_op, lhs_ty, .neq);
@ -4358,11 +4356,12 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
try self.cmp(bin_op, lhs, lhs_ty, cmp_op)
else
try self.cmp(bin_op, result, lhs_ty, .neq);
const overflow_local = try overflow_bit.toLocal(self, Type.u32);
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
try self.store(result_ptr, result, lhs_ty, 0);
const offset = @intCast(u32, lhs_ty.abiSize(self.target));
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
try self.store(result_ptr, overflow_local, Type.initTag(.u1), offset);
return result_ptr;
}
@ -4384,9 +4383,9 @@ fn airAddSubWithOverflowBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type,
const high_op_res = try (try self.binOp(lhs_high_bit, rhs_high_bit, Type.u64, op)).toLocal(self, Type.u64);
const lt = if (op == .add) blk: {
break :blk try self.cmp(high_op_res, lhs_high_bit, Type.u64, .lt);
break :blk try (try self.cmp(high_op_res, lhs_high_bit, Type.u64, .lt)).toLocal(self, Type.u32);
} else if (op == .sub) blk: {
break :blk try self.cmp(lhs_high_bit, rhs_high_bit, Type.u64, .lt);
break :blk try (try self.cmp(lhs_high_bit, rhs_high_bit, Type.u64, .lt)).toLocal(self, Type.u32);
} else unreachable;
const tmp = try self.intcast(lt, Type.u32, Type.u64);
const tmp_op = try (try self.binOp(low_op_res, tmp, Type.u64, op)).toLocal(self, Type.u64);
@ -4400,27 +4399,23 @@ fn airAddSubWithOverflowBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type,
const wrap = try self.binOp(to_wrap, xor_op, Type.u64, .@"and");
break :blk try self.cmp(wrap, .{ .imm64 = 0 }, Type.i64, .lt); // i64 because signed
} else blk: {
const eq = try self.cmp(tmp_op, lhs_low_bit, Type.u64, .eq);
const op_eq = try self.cmp(tmp_op, lhs_low_bit, Type.u64, if (op == .add) .lt else .gt);
const first_arg = if (op == .sub) arg: {
break :arg try self.cmp(high_op_res, lhs_high_bit, Type.u64, .gt);
} else lt;
try self.emitWValue(first_arg);
try self.emitWValue(op_eq);
try self.emitWValue(eq);
_ = try self.cmp(tmp_op, lhs_low_bit, Type.u64, if (op == .add) .lt else .gt);
_ = try self.cmp(tmp_op, lhs_low_bit, Type.u64, .eq);
try self.addTag(.select);
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
try self.addLabel(.local_set, overflow_bit.local);
break :blk overflow_bit;
break :blk WValue{ .stack = {} };
};
const overflow_local = try overflow_bit.toLocal(self, Type.initTag(.u1));
const result_ptr = try self.allocStack(result_ty);
try self.store(result_ptr, high_op_res, Type.u64, 0);
try self.store(result_ptr, tmp_op, Type.u64, 8);
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), 16);
try self.store(result_ptr, overflow_local, Type.initTag(.u1), 16);
return result_ptr;
}
@ -4455,11 +4450,12 @@ fn airShlWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const shr = try (try self.binOp(result, rhs, lhs_ty, .shr)).toLocal(self, lhs_ty);
break :blk try self.cmp(lhs, shr, lhs_ty, .neq);
};
const overflow_local = try overflow_bit.toLocal(self, Type.initTag(.u1));
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
try self.store(result_ptr, result, lhs_ty, 0);
const offset = @intCast(u32, lhs_ty.abiSize(self.target));
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
try self.store(result_ptr, overflow_local, Type.initTag(.u1), offset);
return result_ptr;
}
@ -4502,8 +4498,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
if (int_info.signedness == .unsigned) {
const shr = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
const wrap = try self.intcast(shr, new_ty, lhs_ty);
const cmp_res = try self.cmp(wrap, zero, lhs_ty, .neq);
try self.emitWValue(cmp_res);
_ = try self.cmp(wrap, zero, lhs_ty, .neq);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.intcast(bin_op, new_ty, lhs_ty);
} else {
@ -4512,8 +4507,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const shr_res = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
const down_shr_res = try self.intcast(shr_res, new_ty, lhs_ty);
const cmp_res = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
try self.emitWValue(cmp_res);
_ = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
try self.addLabel(.local_set, overflow_bit.local);
break :blk down_cast;
}
@ -4522,8 +4516,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const rhs_abs = try self.signAbsValue(rhs, lhs_ty);
const bin_op = try (try self.binOp(lhs_abs, rhs_abs, lhs_ty, .mul)).toLocal(self, lhs_ty);
const mul_abs = try self.signAbsValue(bin_op, lhs_ty);
const cmp_op = try self.cmp(mul_abs, bin_op, lhs_ty, .neq);
try self.emitWValue(cmp_op);
_ = try self.cmp(mul_abs, bin_op, lhs_ty, .neq);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.wrapOperand(bin_op, lhs_ty);
} else blk: {
@ -4533,8 +4526,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
else
WValue{ .imm64 = int_info.bits };
const shr = try self.binOp(bin_op, shift_imm, lhs_ty, .shr);
const cmp_op = try self.cmp(shr, zero, lhs_ty, .neq);
try self.emitWValue(cmp_op);
_ = try self.cmp(shr, zero, lhs_ty, .neq);
try self.addLabel(.local_set, overflow_bit.local);
break :blk try self.wrapOperand(bin_op, lhs_ty);
};
@ -4562,12 +4554,10 @@ fn airMaxMin(self: *Self, inst: Air.Inst.Index, op: enum { max, min }) InnerErro
const lhs = try self.resolveInst(bin_op.lhs);
const rhs = try self.resolveInst(bin_op.rhs);
const cmp_result = try self.cmp(lhs, rhs, ty, if (op == .max) .gt else .lt);
// operands to select from
try self.lowerToStack(lhs);
try self.lowerToStack(rhs);
try self.emitWValue(cmp_result);
_ = try self.cmp(lhs, rhs, ty, if (op == .max) .gt else .lt);
// based on the result from comparison, return operand 0 or 1.
try self.addTag(.select);
@ -4638,7 +4628,6 @@ fn airClz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
128 => {
const msb = try self.load(operand, Type.u64, 0);
const lsb = try self.load(operand, Type.u64, 8);
const neq = try self.cmp(lsb, .{ .imm64 = 0 }, Type.u64, .neq);
try self.emitWValue(lsb);
try self.addTag(.i64_clz);
@ -4646,7 +4635,7 @@ fn airClz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
try self.addTag(.i64_clz);
try self.emitWValue(.{ .imm64 = 64 });
try self.addTag(.i64_add);
try self.emitWValue(neq);
_ = try self.cmp(lsb, .{ .imm64 = 0 }, Type.u64, .neq);
try self.addTag(.select);
try self.addTag(.i32_wrap_i64);
},
@ -4700,7 +4689,6 @@ fn airCtz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
128 => {
const msb = try self.load(operand, Type.u64, 0);
const lsb = try self.load(operand, Type.u64, 8);
const neq = try self.cmp(msb, .{ .imm64 = 0 }, Type.u64, .neq);
try self.emitWValue(msb);
try self.addTag(.i64_ctz);
@ -4716,7 +4704,7 @@ fn airCtz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
} else {
try self.addTag(.i64_add);
}
try self.emitWValue(neq);
_ = try self.cmp(msb, .{ .imm64 = 0 }, Type.u64, .neq);
try self.addTag(.select);
try self.addTag(.i32_wrap_i64);
},
@ -4952,15 +4940,13 @@ fn airDivFloor(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
64 => WValue{ .imm64 = 0 },
else => unreachable,
};
const lhs_less_than_zero = try self.cmp(lhs_res, zero, ty, .lt);
const rhs_less_than_zero = try self.cmp(rhs_res, zero, ty, .lt);
const div_result = try self.allocLocal(ty);
// leave on stack
_ = try self.binOp(lhs_res, rhs_res, ty, .div);
try self.addLabel(.local_tee, div_result.local);
try self.emitWValue(lhs_less_than_zero);
try self.emitWValue(rhs_less_than_zero);
_ = try self.cmp(lhs_res, zero, ty, .lt);
_ = try self.cmp(rhs_res, zero, ty, .lt);
switch (wasm_bits) {
32 => {
try self.addTag(.i32_xor);
@ -5140,19 +5126,17 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
else => unreachable,
};
const cmp_result = try self.cmp(bin_result, imm_val, ty, .lt);
try self.emitWValue(bin_result);
try self.emitWValue(imm_val);
try self.emitWValue(cmp_result);
_ = try self.cmp(bin_result, imm_val, ty, .lt);
} else {
const cmp_result = try self.cmp(bin_result, lhs, ty, if (op == .add) .lt else .gt);
switch (wasm_bits) {
32 => try self.addImm32(if (op == .add) @as(i32, -1) else 0),
64 => try self.addImm64(if (op == .add) @bitCast(u64, @as(i64, -1)) else 0),
else => unreachable,
}
try self.emitWValue(bin_result);
try self.emitWValue(cmp_result);
_ = try self.cmp(bin_result, lhs, ty, if (op == .add) .lt else .gt);
}
try self.addTag(.select);
@ -5184,17 +5168,15 @@ fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op
const bin_result = try (try self.binOp(lhs, rhs, ty, op)).toLocal(self, ty);
if (!is_wasm_bits) {
const cmp_result_lt = try self.cmp(bin_result, max_wvalue, ty, .lt);
try self.emitWValue(bin_result);
try self.emitWValue(max_wvalue);
try self.emitWValue(cmp_result_lt);
_ = try self.cmp(bin_result, max_wvalue, ty, .lt);
try self.addTag(.select);
try self.addLabel(.local_set, bin_result.local); // re-use local
const cmp_result_gt = try self.cmp(bin_result, min_wvalue, ty, .gt);
try self.emitWValue(bin_result);
try self.emitWValue(min_wvalue);
try self.emitWValue(cmp_result_gt);
_ = try self.cmp(bin_result, min_wvalue, ty, .gt);
try self.addTag(.select);
try self.addLabel(.local_set, bin_result.local); // re-use local
return self.wrapOperand(bin_result, ty);
@ -5204,15 +5186,14 @@ fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op
64 => WValue{ .imm64 = 0 },
else => unreachable,
};
const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt);
const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt);
const cmp_bin_zero_result = try self.cmp(bin_result, zero, ty, .lt);
try self.emitWValue(max_wvalue);
try self.emitWValue(min_wvalue);
try self.emitWValue(cmp_bin_zero_result);
_ = try self.cmp(bin_result, zero, ty, .lt);
try self.addTag(.select);
try self.emitWValue(bin_result);
// leave on stack
const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt);
const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt);
_ = try self.binOp(cmp_zero_result, cmp_bin_result, Type.u32, .xor); // comparisons always return i32, so provide u32 as type to xor.
try self.addTag(.select);
try self.addLabel(.local_set, bin_result.local); // re-use local
@ -5239,7 +5220,6 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
if (wasm_bits == int_info.bits) {
const shl = try (try self.binOp(lhs, rhs, ty, .shl)).toLocal(self, ty);
const shr = try (try self.binOp(shl, rhs, ty, .shr)).toLocal(self, ty);
const cmp_result = try self.cmp(lhs, shr, ty, .neq);
switch (wasm_bits) {
32 => blk: {
@ -5247,10 +5227,9 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
try self.addImm32(-1);
break :blk;
}
const less_than_zero = try self.cmp(lhs, .{ .imm32 = 0 }, ty, .lt);
try self.addImm32(std.math.minInt(i32));
try self.addImm32(std.math.maxInt(i32));
try self.emitWValue(less_than_zero);
_ = try self.cmp(lhs, .{ .imm32 = 0 }, ty, .lt);
try self.addTag(.select);
},
64 => blk: {
@ -5258,16 +5237,15 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
try self.addImm64(@bitCast(u64, @as(i64, -1)));
break :blk;
}
const less_than_zero = try self.cmp(lhs, .{ .imm64 = 0 }, ty, .lt);
try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64))));
try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64))));
try self.emitWValue(less_than_zero);
_ = try self.cmp(lhs, .{ .imm64 = 0 }, ty, .lt);
try self.addTag(.select);
},
else => unreachable,
}
try self.emitWValue(shl);
try self.emitWValue(cmp_result);
_ = try self.cmp(lhs, shr, ty, .neq);
try self.addTag(.select);
try self.addLabel(.local_set, result.local);
return result;
@ -5282,7 +5260,6 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
const shl_res = try (try self.binOp(lhs, shift_value, ty, .shl)).toLocal(self, ty);
const shl = try (try self.binOp(shl_res, rhs, ty, .shl)).toLocal(self, ty);
const shr = try (try self.binOp(shl, rhs, ty, .shr)).toLocal(self, ty);
const cmp_result = try self.cmp(shl_res, shr, ty, .neq);
switch (wasm_bits) {
32 => blk: {
@ -5291,10 +5268,9 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
break :blk;
}
const less_than_zero = try self.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt);
try self.addImm32(std.math.minInt(i32));
try self.addImm32(std.math.maxInt(i32));
try self.emitWValue(less_than_zero);
_ = try self.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt);
try self.addTag(.select);
},
64 => blk: {
@ -5303,16 +5279,15 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
break :blk;
}
const less_than_zero = try self.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt);
try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64))));
try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64))));
try self.emitWValue(less_than_zero);
_ = try self.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt);
try self.addTag(.select);
},
else => unreachable,
}
try self.emitWValue(shl);
try self.emitWValue(cmp_result);
_ = try self.cmp(shl_res, shr, ty, .neq);
try self.addTag(.select);
try self.addLabel(.local_set, result.local);
const shift_result = try (try self.binOp(result, shift_value, ty, .shr)).toLocal(self, ty);