Sema: consider type bounds when refining result type of @min/@max

I achieved this through a major refactor of the logic of analyzeMinMax.
This change should be compatible with vectors of comptime_int, which
Andrew said are supposed to work (but which currently do not).
This commit is contained in:
mlugg 2023-06-14 00:51:31 +01:00 committed by Andrew Kelley
parent 5d9e8f27d0
commit c4cc796695
3 changed files with 239 additions and 112 deletions

View File

@ -22984,104 +22984,127 @@ fn analyzeMinMax(
else => @compileError("unreachable"),
};
// First, find all comptime-known arguments, and get their min/max
// The set of runtime-known operands. Set up in the loop below.
var runtime_known = try std.DynamicBitSet.initFull(sema.arena, operands.len);
// The current minmax value - initially this will always be comptime-known, then we'll add
// runtime values into the mix later.
var cur_minmax: ?Air.Inst.Ref = null;
var cur_minmax_src: LazySrcLoc = undefined; // defined if cur_minmax not null
// The current known scalar bounds of the value.
var bounds_status: enum {
unknown, // We've only seen undef comptime_ints so far, so do not know the bounds.
defined, // We've seen only integers, so the bounds are defined.
non_integral, // There are floats in the mix, so the bounds aren't defined.
} = .unknown;
var cur_min_scalar: Value = undefined;
var cur_max_scalar: Value = undefined;
// First, find all comptime-known arguments, and get their min/max
for (operands, operand_srcs, 0..) |operand, operand_src, operand_idx| {
// Resolve the value now to avoid redundant calls to `checkSimdBinOp` - we'll have to call
// it in the runtime path anyway since the result type may have been refined
const uncasted_operand_val = (try sema.resolveMaybeUndefVal(operand)) orelse continue;
if (cur_minmax) |cur| {
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
const cur_val = simd_op.lhs_val.?; // cur_minmax is comptime-known
const operand_val = simd_op.rhs_val.?; // we checked the operand was resolvable above
const unresolved_uncoerced_val = try sema.resolveMaybeUndefVal(operand) orelse continue;
const uncoerced_val = try sema.resolveLazyValue(unresolved_uncoerced_val);
runtime_known.unset(operand_idx);
runtime_known.unset(operand_idx);
if (cur_val.isUndef(mod)) continue; // result is also undef
if (operand_val.isUndef(mod)) {
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
continue;
}
const resolved_cur_val = try sema.resolveLazyValue(cur_val);
const resolved_operand_val = try sema.resolveLazyValue(operand_val);
const vec_len = simd_op.len orelse {
const result_val = opFunc(resolved_cur_val, resolved_operand_val, mod);
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
continue;
};
const elems = try sema.arena.alloc(InternPool.Index, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = try resolved_cur_val.elemValue(mod, i);
const rhs_elem_val = try resolved_operand_val.elemValue(mod, i);
elem.* = try opFunc(lhs_elem_val, rhs_elem_val, mod).intern(simd_op.scalar_ty, mod);
}
cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
.ty = simd_op.result_ty.toIntern(),
.storage = .{ .elems = elems },
} })).toValue());
} else {
runtime_known.unset(operand_idx);
cur_minmax = try sema.addConstant(sema.typeOf(operand), uncasted_operand_val);
cur_minmax_src = operand_src;
switch (bounds_status) {
.unknown, .defined => refine_bounds: {
const ty = sema.typeOf(operand);
if (!ty.scalarType(mod).isInt(mod) and !ty.scalarType(mod).eql(Type.comptime_int, mod)) {
bounds_status = .non_integral;
break :refine_bounds;
}
const scalar_bounds: ?[2]Value = bounds: {
if (!ty.isVector(mod)) break :bounds try uncoerced_val.intValueBounds(mod);
var cur_bounds: [2]Value = try Value.intValueBounds(try uncoerced_val.elemValue(mod, 0), mod) orelse break :bounds null;
const len = try sema.usizeCast(block, src, ty.vectorLen(mod));
for (1..len) |i| {
const elem = try uncoerced_val.elemValue(mod, i);
const elem_bounds = try elem.intValueBounds(mod) orelse break :bounds null;
cur_bounds = .{
Value.numberMin(elem_bounds[0], cur_bounds[0], mod),
Value.numberMax(elem_bounds[1], cur_bounds[1], mod),
};
}
break :bounds cur_bounds;
};
if (scalar_bounds) |bounds| {
if (bounds_status == .unknown) {
cur_min_scalar = bounds[0];
cur_max_scalar = bounds[1];
bounds_status = .defined;
} else {
cur_min_scalar = opFunc(cur_min_scalar, bounds[0], mod);
cur_max_scalar = opFunc(cur_max_scalar, bounds[1], mod);
}
}
},
.non_integral => {},
}
const cur = cur_minmax orelse {
cur_minmax = operand;
cur_minmax_src = operand_src;
continue;
};
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
const cur_val = try sema.resolveLazyValue(simd_op.lhs_val.?); // cur_minmax is comptime-known
const operand_val = try sema.resolveLazyValue(simd_op.rhs_val.?); // we checked the operand was resolvable above
const vec_len = simd_op.len orelse {
const result_val = opFunc(cur_val, operand_val, mod);
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
continue;
};
const elems = try sema.arena.alloc(InternPool.Index, vec_len);
for (elems, 0..) |*elem, i| {
const lhs_elem_val = try cur_val.elemValue(mod, i);
const rhs_elem_val = try operand_val.elemValue(mod, i);
const uncoerced_elem = opFunc(lhs_elem_val, rhs_elem_val, mod);
elem.* = (try mod.getCoerced(uncoerced_elem, simd_op.scalar_ty)).toIntern();
}
cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
.ty = simd_op.result_ty.toIntern(),
.storage = .{ .elems = elems },
} })).toValue());
}
const opt_runtime_idx = runtime_known.findFirstSet();
const comptime_refined_ty: ?Type = if (cur_minmax) |ct_minmax_ref| refined: {
// Refine the comptime-known result type based on the operation
if (cur_minmax) |ct_minmax_ref| refine: {
// Refine the comptime-known result type based on the bounds. This isn't strictly necessary
// in the runtime case, since we'll refine the type again later, but keeping things as small
// as possible will allow us to emit more optimal AIR (if all the runtime operands have
// smaller types than the non-refined comptime type).
const val = (try sema.resolveMaybeUndefVal(ct_minmax_ref)).?;
const orig_ty = sema.typeOf(ct_minmax_ref);
if (opt_runtime_idx == null and orig_ty.eql(Type.comptime_int, mod)) {
if (opt_runtime_idx == null and orig_ty.scalarType(mod).eql(Type.comptime_int, mod)) {
// If all arguments were `comptime_int`, and there are no runtime args, we'll preserve that type
break :refined orig_ty;
break :refine;
}
const refined_ty = if (orig_ty.zigTypeTag(mod) == .Vector) blk: {
const elem_ty = orig_ty.childType(mod);
const len = orig_ty.vectorLen(mod);
// We can't refine float types
if (orig_ty.scalarType(mod).isAnyFloat()) break :refine;
if (len == 0) break :blk orig_ty;
if (elem_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
assert(bounds_status == .defined); // there was a non-comptime-int integral comptime-known arg
var cur_min: Value = try val.elemValue(mod, 0);
var cur_max: Value = cur_min;
for (1..len) |idx| {
const elem_val = try val.elemValue(mod, idx);
if (elem_val.isUndef(mod)) break :blk orig_ty; // can't refine undef
if (Value.order(elem_val, cur_min, mod).compare(.lt)) cur_min = elem_val;
if (Value.order(elem_val, cur_max, mod).compare(.gt)) cur_max = elem_val;
}
const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
const refined_ty = if (orig_ty.isVector(mod)) try mod.vectorType(.{
.len = orig_ty.vectorLen(mod),
.child = refined_scalar_ty.toIntern(),
}) else refined_scalar_ty;
const refined_elem_ty = try mod.intFittingRange(cur_min, cur_max);
break :blk try mod.vectorType(.{
.len = len,
.child = refined_elem_ty.toIntern(),
});
} else blk: {
if (orig_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
if (val.isUndef(mod)) break :blk orig_ty; // can't refine undef
break :blk try mod.intFittingRange(val, val);
};
// Apply the refined type to the current value - this isn't strictly necessary in the
// runtime case since we'll refine again afterwards, but keeping things as small as possible
// will allow us to emit more optimal AIR (if all the runtime operands have smaller types
// than the non-refined comptime type).
if (!refined_ty.eql(orig_ty, mod)) {
if (std.debug.runtime_safety) {
assert(try sema.intFitsInType(val, refined_ty, null));
}
cur_minmax = try sema.coerceInMemory(val, refined_ty);
// Apply the refined type to the current value
if (std.debug.runtime_safety) {
assert(try sema.intFitsInType(val, refined_ty, null));
}
break :refined refined_ty;
} else null;
cur_minmax = try sema.coerceInMemory(val, refined_ty);
}
const runtime_idx = opt_runtime_idx orelse return cur_minmax.?;
const runtime_src = operand_srcs[runtime_idx];
@ -23102,6 +23125,11 @@ fn analyzeMinMax(
cur_minmax = operands[0];
cur_minmax_src = runtime_src;
runtime_known.unset(0); // don't look at this operand in the loop below
const scalar_ty = sema.typeOf(cur_minmax.?).scalarType(mod);
if (scalar_ty.isInt(mod)) {
cur_min_scalar = try scalar_ty.minInt(mod, scalar_ty);
cur_max_scalar = try scalar_ty.maxInt(mod, scalar_ty);
}
}
var it = runtime_known.iterator(.{});
@ -23112,49 +23140,49 @@ fn analyzeMinMax(
const rhs_src = operand_srcs[idx];
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
if (known_undef) {
cur_minmax = try sema.addConstant(simd_op.result_ty, Value.undef);
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
} else {
cur_minmax = try block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
}
// Compute the bounds of this type
switch (bounds_status) {
.unknown, .defined => refine_bounds: {
const scalar_ty = sema.typeOf(rhs).scalarType(mod);
if (scalar_ty.isAnyFloat()) {
bounds_status = .non_integral;
break :refine_bounds;
}
const scalar_min = try scalar_ty.minInt(mod, scalar_ty);
const scalar_max = try scalar_ty.maxInt(mod, scalar_ty);
if (bounds_status == .unknown) {
cur_min_scalar = scalar_min;
cur_max_scalar = scalar_max;
bounds_status = .defined;
} else {
cur_min_scalar = opFunc(cur_min_scalar, scalar_min, mod);
cur_max_scalar = opFunc(cur_max_scalar, scalar_max, mod);
}
},
.non_integral => {},
}
}
if (comptime_refined_ty) |comptime_ty| refine: {
// Finally, refine the type based on the comptime-known bound.
if (known_undef) break :refine; // can't refine undef
const unrefined_ty = sema.typeOf(cur_minmax.?);
const is_vector = unrefined_ty.zigTypeTag(mod) == .Vector;
const comptime_elem_ty = if (is_vector) comptime_ty.childType(mod) else comptime_ty;
const unrefined_elem_ty = if (is_vector) unrefined_ty.childType(mod) else unrefined_ty;
// Finally, refine the type based on the known bounds.
const unrefined_ty = sema.typeOf(cur_minmax.?);
if (unrefined_ty.scalarType(mod).isAnyFloat()) {
// We can't refine floats, so we're done.
return cur_minmax.?;
}
assert(bounds_status == .defined); // there were integral runtime operands
const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
const refined_ty = if (unrefined_ty.isVector(mod)) try mod.vectorType(.{
.len = unrefined_ty.vectorLen(mod),
.child = refined_scalar_ty.toIntern(),
}) else refined_scalar_ty;
if (unrefined_elem_ty.isAnyFloat()) break :refine; // we can't refine floats
// Compute the final bounds based on the runtime type and the comptime-known bound type
const min_val = switch (air_tag) {
.min => try unrefined_elem_ty.minInt(mod, unrefined_elem_ty),
.max => try comptime_elem_ty.minInt(mod, comptime_elem_ty), // @max(ct, rt) >= ct
else => unreachable,
};
const max_val = switch (air_tag) {
.min => try comptime_elem_ty.maxInt(mod, comptime_elem_ty), // @min(ct, rt) <= ct
.max => try unrefined_elem_ty.maxInt(mod, unrefined_elem_ty),
else => unreachable,
};
// Find the smallest type which can contain these bounds
const final_elem_ty = try mod.intFittingRange(min_val, max_val);
const final_ty = if (is_vector)
try mod.vectorType(.{
.len = unrefined_ty.vectorLen(mod),
.child = final_elem_ty.toIntern(),
})
else
final_elem_ty;
if (!final_ty.eql(unrefined_ty, mod)) {
// We've reduced the type - cast the result down
return block.addTyOp(.intcast, final_ty, cur_minmax.?);
}
if (!refined_ty.eql(unrefined_ty, mod)) {
// We've reduced the type - cast the result down
return block.addTyOp(.intcast, refined_ty, cur_minmax.?);
}
return cur_minmax.?;

View File

@ -4146,6 +4146,20 @@ pub const Value = struct {
return val.toIntern() == .generic_poison;
}
/// For an integer (comptime or fixed-width) `val`, returns the comptime-known bounds of the value.
/// If `val` is not undef, the bounds are both `val`.
/// If `val` is undef and has a fixed-width type, the bounds are the bounds of the type.
/// If `val` is undef and is a `comptime_int`, returns null.
pub fn intValueBounds(val: Value, mod: *Module) !?[2]Value {
if (!val.isUndef(mod)) return .{ val, val };
const ty = mod.intern_pool.typeOf(val.toIntern());
if (ty == .comptime_int_type) return null;
return .{
try ty.toType().minInt(mod, ty.toType()),
try ty.toType().maxInt(mod, ty.toType()),
};
}
/// This type is not copyable since it may contain pointers to its inner data.
pub const Payload = struct {
tag: Tag,

View File

@ -1,6 +1,7 @@
const std = @import("std");
const builtin = @import("builtin");
const mem = std.mem;
const assert = std.debug.assert;
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;
@ -210,3 +211,87 @@ test "@min/@max on comptime_int" {
try expectEqual(-2, min);
try expectEqual(2, max);
}
test "@min/@max notices bounds from types" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var x: u16 = 123;
var y: u32 = 456;
var z: u8 = 10;
const min = @min(x, y, z);
const max = @max(x, y, z);
comptime assert(@TypeOf(min) == u8);
comptime assert(@TypeOf(max) == u32);
try expectEqual(z, min);
try expectEqual(y, max);
}
test "@min/@max notices bounds from vector types" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var x: @Vector(2, u16) = .{ 30, 67 };
var y: @Vector(2, u32) = .{ 20, 500 };
var z: @Vector(2, u8) = .{ 60, 15 };
const min = @min(x, y, z);
const max = @max(x, y, z);
comptime assert(@TypeOf(min) == @Vector(2, u8));
comptime assert(@TypeOf(max) == @Vector(2, u32));
try expectEqual(@Vector(2, u8){ 20, 15 }, min);
try expectEqual(@Vector(2, u32){ 60, 500 }, max);
}
test "@min/@max notices bounds from types when comptime-known value is undef" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var x: u32 = 1_000_000;
const y: u16 = undefined;
// y is comptime-known, but is undef, so bounds cannot be refined using its value
const min = @min(x, y);
const max = @max(x, y);
comptime assert(@TypeOf(min) == u16);
comptime assert(@TypeOf(max) == u32);
// Cannot assert values as one was undefined
}
test "@min/@max notices bounds from vector types when element of comptime-known vector is undef" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var x: @Vector(2, u32) = .{ 1_000_000, 12345 };
const y: @Vector(2, u16) = .{ 10, undefined };
// y is comptime-known, but an element is undef, so bounds cannot be refined using its value
const min = @min(x, y);
const max = @max(x, y);
comptime assert(@TypeOf(min) == @Vector(2, u16));
comptime assert(@TypeOf(max) == @Vector(2, u32));
try expectEqual(@as(u16, 10), min[0]);
try expectEqual(@as(u32, 1_000_000), max[0]);
// Cannot assert values at index 1 as one was undefined
}