mirror of
https://github.com/ziglang/zig.git
synced 2025-12-06 06:13:07 +00:00
Merge pull request #13552 from hryx/comparus-tautologicus
Sema: elide integer comparisons with guaranteed outcomes
This commit is contained in:
commit
901c3e9636
@ -1437,6 +1437,19 @@ pub const CompareOperator = enum {
|
||||
gt,
|
||||
/// Not equal (`!=`)
|
||||
neq,
|
||||
|
||||
/// Reverse the direction of the comparison.
|
||||
/// Use when swapping the left and right hand operands.
|
||||
pub fn reverse(op: CompareOperator) CompareOperator {
|
||||
return switch (op) {
|
||||
.lt => .gt,
|
||||
.lte => .gte,
|
||||
.gt => .lt,
|
||||
.gte => .lte,
|
||||
.eq => .eq,
|
||||
.neq => .neq,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// This function does the same thing as comparison operators, however the
|
||||
@ -1496,6 +1509,15 @@ test "order.compare" {
|
||||
try testing.expect(order(1, 0).compare(.neq));
|
||||
}
|
||||
|
||||
test "compare.reverse" {
|
||||
inline for (@typeInfo(CompareOperator).Enum.fields) |op_field| {
|
||||
const op = @intToEnum(CompareOperator, op_field.value);
|
||||
try testing.expect(compare(2, op, 3) == compare(3, op.reverse(), 2));
|
||||
try testing.expect(compare(3, op, 3) == compare(3, op.reverse(), 3));
|
||||
try testing.expect(compare(4, op, 3) == compare(3, op.reverse(), 4));
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a mask of all ones if value is true,
|
||||
/// and a mask of all zeroes if value is false.
|
||||
/// Compiles to one instruction for register sized integers.
|
||||
|
||||
130
src/Sema.zig
130
src/Sema.zig
@ -28497,6 +28497,19 @@ fn cmpNumeric(
|
||||
const runtime_src: LazySrcLoc = src: {
|
||||
if (try sema.resolveMaybeUndefVal(lhs)) |lhs_val| {
|
||||
if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| {
|
||||
// Compare ints: const vs. undefined (or vice versa)
|
||||
if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt() and rhs_val.isUndef()) {
|
||||
try sema.resolveLazyValue(lhs_val);
|
||||
if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| {
|
||||
return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
|
||||
}
|
||||
} else if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt() and lhs_val.isUndef()) {
|
||||
try sema.resolveLazyValue(rhs_val);
|
||||
if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| {
|
||||
return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
|
||||
if (lhs_val.isUndef() or rhs_val.isUndef()) {
|
||||
return sema.addConstUndef(Type.bool);
|
||||
}
|
||||
@ -28513,9 +28526,25 @@ fn cmpNumeric(
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
} else {
|
||||
if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt()) {
|
||||
// Compare ints: const vs. var
|
||||
try sema.resolveLazyValue(lhs_val);
|
||||
if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| {
|
||||
return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
break :src rhs_src;
|
||||
}
|
||||
} else {
|
||||
if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| {
|
||||
if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt()) {
|
||||
// Compare ints: var vs. const
|
||||
try sema.resolveLazyValue(rhs_val);
|
||||
if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| {
|
||||
return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break :src lhs_src;
|
||||
}
|
||||
};
|
||||
@ -28704,6 +28733,107 @@ fn cmpNumeric(
|
||||
return block.addBinOp(Air.Inst.Tag.fromCmpOp(op, block.float_mode == .Optimized), casted_lhs, casted_rhs);
|
||||
}
|
||||
|
||||
/// Asserts that LHS value is an int or comptime int and not undefined, and that RHS type is an int.
|
||||
/// Given a const LHS and an unknown RHS, attempt to determine whether `op` has a guaranteed result.
|
||||
/// If it cannot be determined, returns null.
|
||||
/// Otherwise returns a bool for the guaranteed comparison operation.
|
||||
fn compareIntsOnlyPossibleResult(sema: *Sema, target: std.Target, lhs_val: Value, op: std.math.CompareOperator, rhs_ty: Type) ?bool {
|
||||
const rhs_info = rhs_ty.intInfo(target);
|
||||
const vs_zero = lhs_val.orderAgainstZeroAdvanced(sema) catch unreachable;
|
||||
const is_zero = vs_zero == .eq;
|
||||
const is_negative = vs_zero == .lt;
|
||||
const is_positive = vs_zero == .gt;
|
||||
|
||||
// Anything vs. zero-sized type has guaranteed outcome.
|
||||
if (rhs_info.bits == 0) return switch (op) {
|
||||
.eq, .lte, .gte => is_zero,
|
||||
.neq, .lt, .gt => !is_zero,
|
||||
};
|
||||
|
||||
// Special case for i1, which can only be 0 or -1.
|
||||
// Zero and positive ints have guaranteed outcome.
|
||||
if (rhs_info.bits == 1 and rhs_info.signedness == .signed) {
|
||||
if (is_positive) return switch (op) {
|
||||
.gt, .gte, .neq => true,
|
||||
.lt, .lte, .eq => false,
|
||||
};
|
||||
if (is_zero) return switch (op) {
|
||||
.gte => true,
|
||||
.lt => false,
|
||||
.gt, .lte, .eq, .neq => null,
|
||||
};
|
||||
}
|
||||
|
||||
// Negative vs. unsigned has guaranteed outcome.
|
||||
if (rhs_info.signedness == .unsigned and is_negative) return switch (op) {
|
||||
.eq, .gt, .gte => false,
|
||||
.neq, .lt, .lte => true,
|
||||
};
|
||||
|
||||
const sign_adj = @boolToInt(!is_negative and rhs_info.signedness == .signed);
|
||||
const req_bits = lhs_val.intBitCountTwosComp(target) + sign_adj;
|
||||
|
||||
// No sized type can have more than 65535 bits.
|
||||
// The RHS type operand is either a runtime value or sized (but undefined) constant.
|
||||
if (req_bits > 65535) return switch (op) {
|
||||
.lt, .lte => is_negative,
|
||||
.gt, .gte => is_positive,
|
||||
.eq => false,
|
||||
.neq => true,
|
||||
};
|
||||
const fits = req_bits <= rhs_info.bits;
|
||||
|
||||
// Oversized int has guaranteed outcome.
|
||||
switch (op) {
|
||||
.eq => return if (!fits) false else null,
|
||||
.neq => return if (!fits) true else null,
|
||||
.lt, .lte => if (!fits) return is_negative,
|
||||
.gt, .gte => if (!fits) return !is_negative,
|
||||
}
|
||||
|
||||
// For any other comparison, we need to know if the LHS value is
|
||||
// equal to the maximum or minimum possible value of the RHS type.
|
||||
const edge: struct { min: bool, max: bool } = edge: {
|
||||
if (is_zero and rhs_info.signedness == .unsigned) break :edge .{
|
||||
.min = true,
|
||||
.max = false,
|
||||
};
|
||||
|
||||
if (req_bits != rhs_info.bits) break :edge .{
|
||||
.min = false,
|
||||
.max = false,
|
||||
};
|
||||
|
||||
var ty_buffer: Type.Payload.Bits = .{
|
||||
.base = .{ .tag = if (is_negative) .int_signed else .int_unsigned },
|
||||
.data = @intCast(u16, req_bits),
|
||||
};
|
||||
const ty = Type.initPayload(&ty_buffer.base);
|
||||
const pop_count = lhs_val.popCount(ty, target);
|
||||
|
||||
if (is_negative) {
|
||||
break :edge .{
|
||||
.min = pop_count == 1,
|
||||
.max = false,
|
||||
};
|
||||
} else {
|
||||
break :edge .{
|
||||
.min = false,
|
||||
.max = pop_count == req_bits - sign_adj,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
assert(fits);
|
||||
return switch (op) {
|
||||
.lt => if (edge.max) false else null,
|
||||
.lte => if (edge.min) true else null,
|
||||
.gt => if (edge.min) false else null,
|
||||
.gte => if (edge.max) true else null,
|
||||
.eq, .neq => unreachable,
|
||||
};
|
||||
}
|
||||
|
||||
/// Asserts that lhs and rhs types are both vectors.
|
||||
fn cmpVector(
|
||||
sema: *Sema,
|
||||
|
||||
@ -1723,7 +1723,8 @@ pub const TestContext = struct {
|
||||
(case_msg.src.column == std.math.maxInt(u32) or
|
||||
actual_msg.column == case_msg.src.column) and
|
||||
std.mem.eql(u8, expected_msg, actual_msg.msg) and
|
||||
case_msg.src.kind == .note)
|
||||
case_msg.src.kind == .note and
|
||||
actual_msg.count == case_msg.src.count)
|
||||
{
|
||||
handled_errors[i] = true;
|
||||
break;
|
||||
@ -1733,7 +1734,8 @@ pub const TestContext = struct {
|
||||
if (ex_tag != .plain) continue;
|
||||
|
||||
if (std.mem.eql(u8, case_msg.plain.msg, plain.msg) and
|
||||
case_msg.plain.kind == .note)
|
||||
case_msg.plain.kind == .note and
|
||||
case_msg.plain.count == plain.count)
|
||||
{
|
||||
handled_errors[i] = true;
|
||||
break;
|
||||
|
||||
@ -1756,17 +1756,8 @@ pub const Value = extern union {
|
||||
const info = ty.intInfo(target);
|
||||
|
||||
var buffer: Value.BigIntSpace = undefined;
|
||||
const operand_bigint = val.toBigInt(&buffer, target);
|
||||
|
||||
var limbs_buffer: [4]std.math.big.Limb = undefined;
|
||||
var result_bigint = BigIntMutable{
|
||||
.limbs = &limbs_buffer,
|
||||
.positive = undefined,
|
||||
.len = undefined,
|
||||
};
|
||||
result_bigint.popCount(operand_bigint, info.bits);
|
||||
|
||||
return result_bigint.toConst().to(u64) catch unreachable;
|
||||
const int = val.toBigInt(&buffer, target);
|
||||
return @intCast(u64, int.popCount(info.bits));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -160,7 +160,7 @@ test {
|
||||
_ = @import("behavior/incomplete_struct_param_tld.zig");
|
||||
_ = @import("behavior/inline_switch.zig");
|
||||
_ = @import("behavior/int128.zig");
|
||||
_ = @import("behavior/int_div.zig");
|
||||
_ = @import("behavior/int_comparison_elision.zig");
|
||||
_ = @import("behavior/inttoptr.zig");
|
||||
_ = @import("behavior/ir_block_deps.zig");
|
||||
_ = @import("behavior/lower_strlit_to_vector.zig");
|
||||
|
||||
108
test/behavior/int_comparison_elision.zig
Normal file
108
test/behavior/int_comparison_elision.zig
Normal file
@ -0,0 +1,108 @@
|
||||
const std = @import("std");
|
||||
const minInt = std.math.minInt;
|
||||
const maxInt = std.math.maxInt;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
test "int comparison elision" {
|
||||
testIntEdges(u0);
|
||||
testIntEdges(i0);
|
||||
testIntEdges(u1);
|
||||
testIntEdges(i1);
|
||||
testIntEdges(u4);
|
||||
testIntEdges(i4);
|
||||
|
||||
// TODO: support int types > 128 bits wide in other backends
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
|
||||
// TODO: panic: integer overflow with int types > 65528 bits wide
|
||||
// TODO: LLVM generates too many parameters for wasmtime when splitting up int > 64000 bits wide
|
||||
testIntEdges(u64000);
|
||||
testIntEdges(i64000);
|
||||
}
|
||||
|
||||
// All comparisons in this test have a guaranteed result,
|
||||
// so one branch of each 'if' should never be analyzed.
|
||||
fn testIntEdges(comptime T: type) void {
|
||||
const min = minInt(T);
|
||||
const max = maxInt(T);
|
||||
|
||||
var runtime_val: T = undefined;
|
||||
|
||||
if (min > runtime_val) @compileError("analyzed impossible branch");
|
||||
if (min <= runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val < min) @compileError("analyzed impossible branch");
|
||||
if (runtime_val >= min) {} else @compileError("analyzed impossible branch");
|
||||
|
||||
if (min - 1 > runtime_val) @compileError("analyzed impossible branch");
|
||||
if (min - 1 >= runtime_val) @compileError("analyzed impossible branch");
|
||||
if (min - 1 < runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (min - 1 <= runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (min - 1 == runtime_val) @compileError("analyzed impossible branch");
|
||||
if (min - 1 != runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val < min - 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val <= min - 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val > min - 1) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val >= min - 1) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val == min - 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val != min - 1) {} else @compileError("analyzed impossible branch");
|
||||
|
||||
if (max >= runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (max < runtime_val) @compileError("analyzed impossible branch");
|
||||
if (runtime_val <= max) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val > max) @compileError("analyzed impossible branch");
|
||||
|
||||
if (max + 1 > runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (max + 1 >= runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (max + 1 < runtime_val) @compileError("analyzed impossible branch");
|
||||
if (max + 1 <= runtime_val) @compileError("analyzed impossible branch");
|
||||
if (max + 1 == runtime_val) @compileError("analyzed impossible branch");
|
||||
if (max + 1 != runtime_val) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val < max + 1) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val <= max + 1) {} else @compileError("analyzed impossible branch");
|
||||
if (runtime_val > max + 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val >= max + 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val == max + 1) @compileError("analyzed impossible branch");
|
||||
if (runtime_val != max + 1) {} else @compileError("analyzed impossible branch");
|
||||
|
||||
const undef_const: T = undefined;
|
||||
|
||||
if (min > undef_const) @compileError("analyzed impossible branch");
|
||||
if (min <= undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const < min) @compileError("analyzed impossible branch");
|
||||
if (undef_const >= min) {} else @compileError("analyzed impossible branch");
|
||||
|
||||
if (min - 1 > undef_const) @compileError("analyzed impossible branch");
|
||||
if (min - 1 >= undef_const) @compileError("analyzed impossible branch");
|
||||
if (min - 1 < undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (min - 1 <= undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (min - 1 == undef_const) @compileError("analyzed impossible branch");
|
||||
if (min - 1 != undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const < min - 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const <= min - 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const > min - 1) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const >= min - 1) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const == min - 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const != min - 1) {} else @compileError("analyzed impossible branch");
|
||||
|
||||
if (max >= undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (max < undef_const) @compileError("analyzed impossible branch");
|
||||
if (undef_const <= max) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const > max) @compileError("analyzed impossible branch");
|
||||
|
||||
if (max + 1 > undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (max + 1 >= undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (max + 1 < undef_const) @compileError("analyzed impossible branch");
|
||||
if (max + 1 <= undef_const) @compileError("analyzed impossible branch");
|
||||
if (max + 1 == undef_const) @compileError("analyzed impossible branch");
|
||||
if (max + 1 != undef_const) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const < max + 1) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const <= max + 1) {} else @compileError("analyzed impossible branch");
|
||||
if (undef_const > max + 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const >= max + 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const == max + 1) @compileError("analyzed impossible branch");
|
||||
if (undef_const != max + 1) {} else @compileError("analyzed impossible branch");
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user