From e57e835904d54b236cbaf0eedf5b8cea90a94542 Mon Sep 17 00:00:00 2001 From: Stevie Hryciw Date: Mon, 14 Nov 2022 15:56:59 -0800 Subject: [PATCH] Sema: elide integer comparisons with guaranteed outcomes --- src/Sema.zig | 126 +++++++++++++++++++++++ src/value.zig | 13 +-- test/behavior.zig | 2 +- test/behavior/int_comparison_elision.zig | 108 +++++++++++++++++++ 4 files changed, 237 insertions(+), 12 deletions(-) create mode 100644 test/behavior/int_comparison_elision.zig diff --git a/src/Sema.zig b/src/Sema.zig index 8c7c8b0dd7..8f00cbb425 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -28460,6 +28460,17 @@ fn cmpNumeric( const runtime_src: LazySrcLoc = src: { if (try sema.resolveMaybeUndefVal(lhs)) |lhs_val| { if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| { + // Compare ints: const vs. undefined (or vice versa) + if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt() and rhs_val.isUndef()) { + if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| { + return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false; + } + } else if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt() and lhs_val.isUndef()) { + if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| { + return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false; + } + } + if (lhs_val.isUndef() or rhs_val.isUndef()) { return sema.addConstUndef(Type.bool); } @@ -28476,9 +28487,23 @@ fn cmpNumeric( return Air.Inst.Ref.bool_false; } } else { + if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt()) { + // Compare ints: const vs. var + if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| { + return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false; + } + } break :src rhs_src; } } else { + if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| { + if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt()) { + // Compare ints: var vs. const + if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| { + return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false; + } + } + } break :src lhs_src; } }; @@ -28667,6 +28692,107 @@ fn cmpNumeric( return block.addBinOp(Air.Inst.Tag.fromCmpOp(op, block.float_mode == .Optimized), casted_lhs, casted_rhs); } +/// Asserts that LHS value is an int or comptime int and not undefined, and that RHS type is an int. +/// Given a const LHS and an unknown RHS, attempt to determine whether `op` has a guaranteed result. +/// If it cannot be determined, returns null. +/// Otherwise returns a bool for the guaranteed comparison operation. +fn compareIntsOnlyPossibleResult(sema: *Sema, target: std.Target, lhs_val: Value, op: std.math.CompareOperator, rhs_ty: Type) ?bool { + const rhs_info = rhs_ty.intInfo(target); + const vs_zero = lhs_val.orderAgainstZeroAdvanced(sema) catch unreachable; + const is_zero = vs_zero == .eq; + const is_negative = vs_zero == .lt; + const is_positive = vs_zero == .gt; + + // Anything vs. zero-sized type has guaranteed outcome. + if (rhs_info.bits == 0) return switch (op) { + .eq, .lte, .gte => is_zero, + .neq, .lt, .gt => !is_zero, + }; + + // Special case for i1, which can only be 0 or -1. + // Zero and positive ints have guaranteed outcome. + if (rhs_info.bits == 1 and rhs_info.signedness == .signed) { + if (is_positive) return switch (op) { + .gt, .gte, .neq => true, + .lt, .lte, .eq => false, + }; + if (is_zero) return switch (op) { + .gte => true, + .lt => false, + .gt, .lte, .eq, .neq => null, + }; + } + + // Negative vs. unsigned has guaranteed outcome. + if (rhs_info.signedness == .unsigned and is_negative) return switch (op) { + .eq, .gt, .gte => false, + .neq, .lt, .lte => true, + }; + + const sign_adj = @boolToInt(!is_negative and rhs_info.signedness == .signed); + const req_bits = lhs_val.intBitCountTwosComp(target) + sign_adj; + + // No sized type can have more than 65535 bits. + // The RHS type operand is either a runtime value or sized (but undefined) constant. + if (req_bits > 65535) return switch (op) { + .lt, .lte => is_negative, + .gt, .gte => is_positive, + .eq => false, + .neq => true, + }; + const fits = req_bits <= rhs_info.bits; + + // Oversized int has guaranteed outcome. + switch (op) { + .eq => return if (!fits) false else null, + .neq => return if (!fits) true else null, + .lt, .lte => if (!fits) return is_negative, + .gt, .gte => if (!fits) return !is_negative, + } + + // For any other comparison, we need to know if the LHS value is + // equal to the maximum or minimum possible value of the RHS type. + const edge: struct { min: bool, max: bool } = edge: { + if (is_zero and rhs_info.signedness == .unsigned) break :edge .{ + .min = true, + .max = false, + }; + + if (req_bits != rhs_info.bits) break :edge .{ + .min = false, + .max = false, + }; + + var ty_buffer: Type.Payload.Bits = .{ + .base = .{ .tag = if (is_negative) .int_signed else .int_unsigned }, + .data = @intCast(u16, req_bits), + }; + const ty = Type.initPayload(&ty_buffer.base); + const pop_count = lhs_val.popCount(ty, target); + + if (is_negative) { + break :edge .{ + .min = pop_count == 1, + .max = false, + }; + } else { + break :edge .{ + .min = false, + .max = pop_count == req_bits - sign_adj, + }; + } + }; + + assert(fits); + return switch (op) { + .lt => if (edge.max) false else null, + .lte => if (edge.min) true else null, + .gt => if (edge.min) false else null, + .gte => if (edge.max) true else null, + .eq, .neq => unreachable, + }; +} + /// Asserts that lhs and rhs types are both vectors. fn cmpVector( sema: *Sema, diff --git a/src/value.zig b/src/value.zig index 3d5636ee34..b0558487f9 100644 --- a/src/value.zig +++ b/src/value.zig @@ -1756,17 +1756,8 @@ pub const Value = extern union { const info = ty.intInfo(target); var buffer: Value.BigIntSpace = undefined; - const operand_bigint = val.toBigInt(&buffer, target); - - var limbs_buffer: [4]std.math.big.Limb = undefined; - var result_bigint = BigIntMutable{ - .limbs = &limbs_buffer, - .positive = undefined, - .len = undefined, - }; - result_bigint.popCount(operand_bigint, info.bits); - - return result_bigint.toConst().to(u64) catch unreachable; + const int = val.toBigInt(&buffer, target); + return @intCast(u64, int.popCount(info.bits)); }, } } diff --git a/test/behavior.zig b/test/behavior.zig index ebd1e1afb7..1f739508ec 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -158,7 +158,7 @@ test { _ = @import("behavior/incomplete_struct_param_tld.zig"); _ = @import("behavior/inline_switch.zig"); _ = @import("behavior/int128.zig"); - _ = @import("behavior/int_div.zig"); + _ = @import("behavior/int_comparison_elision.zig"); _ = @import("behavior/inttoptr.zig"); _ = @import("behavior/ir_block_deps.zig"); _ = @import("behavior/math.zig"); diff --git a/test/behavior/int_comparison_elision.zig b/test/behavior/int_comparison_elision.zig new file mode 100644 index 0000000000..5e13e00e83 --- /dev/null +++ b/test/behavior/int_comparison_elision.zig @@ -0,0 +1,108 @@ +const std = @import("std"); +const minInt = std.math.minInt; +const maxInt = std.math.maxInt; +const builtin = @import("builtin"); + +test "int comparison elision" { + testIntEdges(u0); + testIntEdges(i0); + testIntEdges(u1); + testIntEdges(i1); + testIntEdges(u4); + testIntEdges(i4); + + // TODO: support int types > 128 bits wide in other backends + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + + // TODO: panic: integer overflow with int types > 65528 bits wide + // TODO: LLVM generates too many parameters for wasmtime when splitting up int > 64000 bits wide + testIntEdges(u64000); + testIntEdges(i64000); +} + +// All comparisons in this test have a guaranteed result, +// so one branch of each 'if' should never be analyzed. +fn testIntEdges(comptime T: type) void { + const min = minInt(T); + const max = maxInt(T); + + var runtime_val: T = undefined; + + if (min > runtime_val) @compileError("analyzed impossible branch"); + if (min <= runtime_val) {} else @compileError("analyzed impossible branch"); + if (runtime_val < min) @compileError("analyzed impossible branch"); + if (runtime_val >= min) {} else @compileError("analyzed impossible branch"); + + if (min - 1 > runtime_val) @compileError("analyzed impossible branch"); + if (min - 1 >= runtime_val) @compileError("analyzed impossible branch"); + if (min - 1 < runtime_val) {} else @compileError("analyzed impossible branch"); + if (min - 1 <= runtime_val) {} else @compileError("analyzed impossible branch"); + if (min - 1 == runtime_val) @compileError("analyzed impossible branch"); + if (min - 1 != runtime_val) {} else @compileError("analyzed impossible branch"); + if (runtime_val < min - 1) @compileError("analyzed impossible branch"); + if (runtime_val <= min - 1) @compileError("analyzed impossible branch"); + if (runtime_val > min - 1) {} else @compileError("analyzed impossible branch"); + if (runtime_val >= min - 1) {} else @compileError("analyzed impossible branch"); + if (runtime_val == min - 1) @compileError("analyzed impossible branch"); + if (runtime_val != min - 1) {} else @compileError("analyzed impossible branch"); + + if (max >= runtime_val) {} else @compileError("analyzed impossible branch"); + if (max < runtime_val) @compileError("analyzed impossible branch"); + if (runtime_val <= max) {} else @compileError("analyzed impossible branch"); + if (runtime_val > max) @compileError("analyzed impossible branch"); + + if (max + 1 > runtime_val) {} else @compileError("analyzed impossible branch"); + if (max + 1 >= runtime_val) {} else @compileError("analyzed impossible branch"); + if (max + 1 < runtime_val) @compileError("analyzed impossible branch"); + if (max + 1 <= runtime_val) @compileError("analyzed impossible branch"); + if (max + 1 == runtime_val) @compileError("analyzed impossible branch"); + if (max + 1 != runtime_val) {} else @compileError("analyzed impossible branch"); + if (runtime_val < max + 1) {} else @compileError("analyzed impossible branch"); + if (runtime_val <= max + 1) {} else @compileError("analyzed impossible branch"); + if (runtime_val > max + 1) @compileError("analyzed impossible branch"); + if (runtime_val >= max + 1) @compileError("analyzed impossible branch"); + if (runtime_val == max + 1) @compileError("analyzed impossible branch"); + if (runtime_val != max + 1) {} else @compileError("analyzed impossible branch"); + + const undef_const: T = undefined; + + if (min > undef_const) @compileError("analyzed impossible branch"); + if (min <= undef_const) {} else @compileError("analyzed impossible branch"); + if (undef_const < min) @compileError("analyzed impossible branch"); + if (undef_const >= min) {} else @compileError("analyzed impossible branch"); + + if (min - 1 > undef_const) @compileError("analyzed impossible branch"); + if (min - 1 >= undef_const) @compileError("analyzed impossible branch"); + if (min - 1 < undef_const) {} else @compileError("analyzed impossible branch"); + if (min - 1 <= undef_const) {} else @compileError("analyzed impossible branch"); + if (min - 1 == undef_const) @compileError("analyzed impossible branch"); + if (min - 1 != undef_const) {} else @compileError("analyzed impossible branch"); + if (undef_const < min - 1) @compileError("analyzed impossible branch"); + if (undef_const <= min - 1) @compileError("analyzed impossible branch"); + if (undef_const > min - 1) {} else @compileError("analyzed impossible branch"); + if (undef_const >= min - 1) {} else @compileError("analyzed impossible branch"); + if (undef_const == min - 1) @compileError("analyzed impossible branch"); + if (undef_const != min - 1) {} else @compileError("analyzed impossible branch"); + + if (max >= undef_const) {} else @compileError("analyzed impossible branch"); + if (max < undef_const) @compileError("analyzed impossible branch"); + if (undef_const <= max) {} else @compileError("analyzed impossible branch"); + if (undef_const > max) @compileError("analyzed impossible branch"); + + if (max + 1 > undef_const) {} else @compileError("analyzed impossible branch"); + if (max + 1 >= undef_const) {} else @compileError("analyzed impossible branch"); + if (max + 1 < undef_const) @compileError("analyzed impossible branch"); + if (max + 1 <= undef_const) @compileError("analyzed impossible branch"); + if (max + 1 == undef_const) @compileError("analyzed impossible branch"); + if (max + 1 != undef_const) {} else @compileError("analyzed impossible branch"); + if (undef_const < max + 1) {} else @compileError("analyzed impossible branch"); + if (undef_const <= max + 1) {} else @compileError("analyzed impossible branch"); + if (undef_const > max + 1) @compileError("analyzed impossible branch"); + if (undef_const >= max + 1) @compileError("analyzed impossible branch"); + if (undef_const == max + 1) @compileError("analyzed impossible branch"); + if (undef_const != max + 1) {} else @compileError("analyzed impossible branch"); +}