mirror of
https://github.com/ziglang/zig.git
synced 2026-01-21 06:45:24 +00:00
Sema: consider type bounds when refining result type of @min/@max
I achieved this through a major refactor of the logic of analyzeMinMax. This change should be compatible with vectors of comptime_int, which Andrew said are supposed to work (but which currently do not).
This commit is contained in:
parent
5d9e8f27d0
commit
c4cc796695
252
src/Sema.zig
252
src/Sema.zig
@ -22984,104 +22984,127 @@ fn analyzeMinMax(
|
||||
else => @compileError("unreachable"),
|
||||
};
|
||||
|
||||
// First, find all comptime-known arguments, and get their min/max
|
||||
// The set of runtime-known operands. Set up in the loop below.
|
||||
var runtime_known = try std.DynamicBitSet.initFull(sema.arena, operands.len);
|
||||
// The current minmax value - initially this will always be comptime-known, then we'll add
|
||||
// runtime values into the mix later.
|
||||
var cur_minmax: ?Air.Inst.Ref = null;
|
||||
var cur_minmax_src: LazySrcLoc = undefined; // defined if cur_minmax not null
|
||||
// The current known scalar bounds of the value.
|
||||
var bounds_status: enum {
|
||||
unknown, // We've only seen undef comptime_ints so far, so do not know the bounds.
|
||||
defined, // We've seen only integers, so the bounds are defined.
|
||||
non_integral, // There are floats in the mix, so the bounds aren't defined.
|
||||
} = .unknown;
|
||||
var cur_min_scalar: Value = undefined;
|
||||
var cur_max_scalar: Value = undefined;
|
||||
|
||||
// First, find all comptime-known arguments, and get their min/max
|
||||
|
||||
for (operands, operand_srcs, 0..) |operand, operand_src, operand_idx| {
|
||||
// Resolve the value now to avoid redundant calls to `checkSimdBinOp` - we'll have to call
|
||||
// it in the runtime path anyway since the result type may have been refined
|
||||
const uncasted_operand_val = (try sema.resolveMaybeUndefVal(operand)) orelse continue;
|
||||
if (cur_minmax) |cur| {
|
||||
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
|
||||
const cur_val = simd_op.lhs_val.?; // cur_minmax is comptime-known
|
||||
const operand_val = simd_op.rhs_val.?; // we checked the operand was resolvable above
|
||||
const unresolved_uncoerced_val = try sema.resolveMaybeUndefVal(operand) orelse continue;
|
||||
const uncoerced_val = try sema.resolveLazyValue(unresolved_uncoerced_val);
|
||||
|
||||
runtime_known.unset(operand_idx);
|
||||
runtime_known.unset(operand_idx);
|
||||
|
||||
if (cur_val.isUndef(mod)) continue; // result is also undef
|
||||
if (operand_val.isUndef(mod)) {
|
||||
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
|
||||
continue;
|
||||
}
|
||||
|
||||
const resolved_cur_val = try sema.resolveLazyValue(cur_val);
|
||||
const resolved_operand_val = try sema.resolveLazyValue(operand_val);
|
||||
|
||||
const vec_len = simd_op.len orelse {
|
||||
const result_val = opFunc(resolved_cur_val, resolved_operand_val, mod);
|
||||
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
|
||||
continue;
|
||||
};
|
||||
const elems = try sema.arena.alloc(InternPool.Index, vec_len);
|
||||
for (elems, 0..) |*elem, i| {
|
||||
const lhs_elem_val = try resolved_cur_val.elemValue(mod, i);
|
||||
const rhs_elem_val = try resolved_operand_val.elemValue(mod, i);
|
||||
elem.* = try opFunc(lhs_elem_val, rhs_elem_val, mod).intern(simd_op.scalar_ty, mod);
|
||||
}
|
||||
cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
|
||||
.ty = simd_op.result_ty.toIntern(),
|
||||
.storage = .{ .elems = elems },
|
||||
} })).toValue());
|
||||
} else {
|
||||
runtime_known.unset(operand_idx);
|
||||
cur_minmax = try sema.addConstant(sema.typeOf(operand), uncasted_operand_val);
|
||||
cur_minmax_src = operand_src;
|
||||
switch (bounds_status) {
|
||||
.unknown, .defined => refine_bounds: {
|
||||
const ty = sema.typeOf(operand);
|
||||
if (!ty.scalarType(mod).isInt(mod) and !ty.scalarType(mod).eql(Type.comptime_int, mod)) {
|
||||
bounds_status = .non_integral;
|
||||
break :refine_bounds;
|
||||
}
|
||||
const scalar_bounds: ?[2]Value = bounds: {
|
||||
if (!ty.isVector(mod)) break :bounds try uncoerced_val.intValueBounds(mod);
|
||||
var cur_bounds: [2]Value = try Value.intValueBounds(try uncoerced_val.elemValue(mod, 0), mod) orelse break :bounds null;
|
||||
const len = try sema.usizeCast(block, src, ty.vectorLen(mod));
|
||||
for (1..len) |i| {
|
||||
const elem = try uncoerced_val.elemValue(mod, i);
|
||||
const elem_bounds = try elem.intValueBounds(mod) orelse break :bounds null;
|
||||
cur_bounds = .{
|
||||
Value.numberMin(elem_bounds[0], cur_bounds[0], mod),
|
||||
Value.numberMax(elem_bounds[1], cur_bounds[1], mod),
|
||||
};
|
||||
}
|
||||
break :bounds cur_bounds;
|
||||
};
|
||||
if (scalar_bounds) |bounds| {
|
||||
if (bounds_status == .unknown) {
|
||||
cur_min_scalar = bounds[0];
|
||||
cur_max_scalar = bounds[1];
|
||||
bounds_status = .defined;
|
||||
} else {
|
||||
cur_min_scalar = opFunc(cur_min_scalar, bounds[0], mod);
|
||||
cur_max_scalar = opFunc(cur_max_scalar, bounds[1], mod);
|
||||
}
|
||||
}
|
||||
},
|
||||
.non_integral => {},
|
||||
}
|
||||
|
||||
const cur = cur_minmax orelse {
|
||||
cur_minmax = operand;
|
||||
cur_minmax_src = operand_src;
|
||||
continue;
|
||||
};
|
||||
|
||||
const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
|
||||
const cur_val = try sema.resolveLazyValue(simd_op.lhs_val.?); // cur_minmax is comptime-known
|
||||
const operand_val = try sema.resolveLazyValue(simd_op.rhs_val.?); // we checked the operand was resolvable above
|
||||
|
||||
const vec_len = simd_op.len orelse {
|
||||
const result_val = opFunc(cur_val, operand_val, mod);
|
||||
cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
|
||||
continue;
|
||||
};
|
||||
const elems = try sema.arena.alloc(InternPool.Index, vec_len);
|
||||
for (elems, 0..) |*elem, i| {
|
||||
const lhs_elem_val = try cur_val.elemValue(mod, i);
|
||||
const rhs_elem_val = try operand_val.elemValue(mod, i);
|
||||
const uncoerced_elem = opFunc(lhs_elem_val, rhs_elem_val, mod);
|
||||
elem.* = (try mod.getCoerced(uncoerced_elem, simd_op.scalar_ty)).toIntern();
|
||||
}
|
||||
cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
|
||||
.ty = simd_op.result_ty.toIntern(),
|
||||
.storage = .{ .elems = elems },
|
||||
} })).toValue());
|
||||
}
|
||||
|
||||
const opt_runtime_idx = runtime_known.findFirstSet();
|
||||
|
||||
const comptime_refined_ty: ?Type = if (cur_minmax) |ct_minmax_ref| refined: {
|
||||
// Refine the comptime-known result type based on the operation
|
||||
if (cur_minmax) |ct_minmax_ref| refine: {
|
||||
// Refine the comptime-known result type based on the bounds. This isn't strictly necessary
|
||||
// in the runtime case, since we'll refine the type again later, but keeping things as small
|
||||
// as possible will allow us to emit more optimal AIR (if all the runtime operands have
|
||||
// smaller types than the non-refined comptime type).
|
||||
|
||||
const val = (try sema.resolveMaybeUndefVal(ct_minmax_ref)).?;
|
||||
const orig_ty = sema.typeOf(ct_minmax_ref);
|
||||
|
||||
if (opt_runtime_idx == null and orig_ty.eql(Type.comptime_int, mod)) {
|
||||
if (opt_runtime_idx == null and orig_ty.scalarType(mod).eql(Type.comptime_int, mod)) {
|
||||
// If all arguments were `comptime_int`, and there are no runtime args, we'll preserve that type
|
||||
break :refined orig_ty;
|
||||
break :refine;
|
||||
}
|
||||
|
||||
const refined_ty = if (orig_ty.zigTypeTag(mod) == .Vector) blk: {
|
||||
const elem_ty = orig_ty.childType(mod);
|
||||
const len = orig_ty.vectorLen(mod);
|
||||
// We can't refine float types
|
||||
if (orig_ty.scalarType(mod).isAnyFloat()) break :refine;
|
||||
|
||||
if (len == 0) break :blk orig_ty;
|
||||
if (elem_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
|
||||
assert(bounds_status == .defined); // there was a non-comptime-int integral comptime-known arg
|
||||
|
||||
var cur_min: Value = try val.elemValue(mod, 0);
|
||||
var cur_max: Value = cur_min;
|
||||
for (1..len) |idx| {
|
||||
const elem_val = try val.elemValue(mod, idx);
|
||||
if (elem_val.isUndef(mod)) break :blk orig_ty; // can't refine undef
|
||||
if (Value.order(elem_val, cur_min, mod).compare(.lt)) cur_min = elem_val;
|
||||
if (Value.order(elem_val, cur_max, mod).compare(.gt)) cur_max = elem_val;
|
||||
}
|
||||
const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
|
||||
const refined_ty = if (orig_ty.isVector(mod)) try mod.vectorType(.{
|
||||
.len = orig_ty.vectorLen(mod),
|
||||
.child = refined_scalar_ty.toIntern(),
|
||||
}) else refined_scalar_ty;
|
||||
|
||||
const refined_elem_ty = try mod.intFittingRange(cur_min, cur_max);
|
||||
break :blk try mod.vectorType(.{
|
||||
.len = len,
|
||||
.child = refined_elem_ty.toIntern(),
|
||||
});
|
||||
} else blk: {
|
||||
if (orig_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
|
||||
if (val.isUndef(mod)) break :blk orig_ty; // can't refine undef
|
||||
break :blk try mod.intFittingRange(val, val);
|
||||
};
|
||||
|
||||
// Apply the refined type to the current value - this isn't strictly necessary in the
|
||||
// runtime case since we'll refine again afterwards, but keeping things as small as possible
|
||||
// will allow us to emit more optimal AIR (if all the runtime operands have smaller types
|
||||
// than the non-refined comptime type).
|
||||
if (!refined_ty.eql(orig_ty, mod)) {
|
||||
if (std.debug.runtime_safety) {
|
||||
assert(try sema.intFitsInType(val, refined_ty, null));
|
||||
}
|
||||
cur_minmax = try sema.coerceInMemory(val, refined_ty);
|
||||
// Apply the refined type to the current value
|
||||
if (std.debug.runtime_safety) {
|
||||
assert(try sema.intFitsInType(val, refined_ty, null));
|
||||
}
|
||||
|
||||
break :refined refined_ty;
|
||||
} else null;
|
||||
cur_minmax = try sema.coerceInMemory(val, refined_ty);
|
||||
}
|
||||
|
||||
const runtime_idx = opt_runtime_idx orelse return cur_minmax.?;
|
||||
const runtime_src = operand_srcs[runtime_idx];
|
||||
@ -23102,6 +23125,11 @@ fn analyzeMinMax(
|
||||
cur_minmax = operands[0];
|
||||
cur_minmax_src = runtime_src;
|
||||
runtime_known.unset(0); // don't look at this operand in the loop below
|
||||
const scalar_ty = sema.typeOf(cur_minmax.?).scalarType(mod);
|
||||
if (scalar_ty.isInt(mod)) {
|
||||
cur_min_scalar = try scalar_ty.minInt(mod, scalar_ty);
|
||||
cur_max_scalar = try scalar_ty.maxInt(mod, scalar_ty);
|
||||
}
|
||||
}
|
||||
|
||||
var it = runtime_known.iterator(.{});
|
||||
@ -23112,49 +23140,49 @@ fn analyzeMinMax(
|
||||
const rhs_src = operand_srcs[idx];
|
||||
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
|
||||
if (known_undef) {
|
||||
cur_minmax = try sema.addConstant(simd_op.result_ty, Value.undef);
|
||||
cur_minmax = try sema.addConstUndef(simd_op.result_ty);
|
||||
} else {
|
||||
cur_minmax = try block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
|
||||
}
|
||||
// Compute the bounds of this type
|
||||
switch (bounds_status) {
|
||||
.unknown, .defined => refine_bounds: {
|
||||
const scalar_ty = sema.typeOf(rhs).scalarType(mod);
|
||||
if (scalar_ty.isAnyFloat()) {
|
||||
bounds_status = .non_integral;
|
||||
break :refine_bounds;
|
||||
}
|
||||
const scalar_min = try scalar_ty.minInt(mod, scalar_ty);
|
||||
const scalar_max = try scalar_ty.maxInt(mod, scalar_ty);
|
||||
if (bounds_status == .unknown) {
|
||||
cur_min_scalar = scalar_min;
|
||||
cur_max_scalar = scalar_max;
|
||||
bounds_status = .defined;
|
||||
} else {
|
||||
cur_min_scalar = opFunc(cur_min_scalar, scalar_min, mod);
|
||||
cur_max_scalar = opFunc(cur_max_scalar, scalar_max, mod);
|
||||
}
|
||||
},
|
||||
.non_integral => {},
|
||||
}
|
||||
}
|
||||
|
||||
if (comptime_refined_ty) |comptime_ty| refine: {
|
||||
// Finally, refine the type based on the comptime-known bound.
|
||||
if (known_undef) break :refine; // can't refine undef
|
||||
const unrefined_ty = sema.typeOf(cur_minmax.?);
|
||||
const is_vector = unrefined_ty.zigTypeTag(mod) == .Vector;
|
||||
const comptime_elem_ty = if (is_vector) comptime_ty.childType(mod) else comptime_ty;
|
||||
const unrefined_elem_ty = if (is_vector) unrefined_ty.childType(mod) else unrefined_ty;
|
||||
// Finally, refine the type based on the known bounds.
|
||||
const unrefined_ty = sema.typeOf(cur_minmax.?);
|
||||
if (unrefined_ty.scalarType(mod).isAnyFloat()) {
|
||||
// We can't refine floats, so we're done.
|
||||
return cur_minmax.?;
|
||||
}
|
||||
assert(bounds_status == .defined); // there were integral runtime operands
|
||||
const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
|
||||
const refined_ty = if (unrefined_ty.isVector(mod)) try mod.vectorType(.{
|
||||
.len = unrefined_ty.vectorLen(mod),
|
||||
.child = refined_scalar_ty.toIntern(),
|
||||
}) else refined_scalar_ty;
|
||||
|
||||
if (unrefined_elem_ty.isAnyFloat()) break :refine; // we can't refine floats
|
||||
|
||||
// Compute the final bounds based on the runtime type and the comptime-known bound type
|
||||
const min_val = switch (air_tag) {
|
||||
.min => try unrefined_elem_ty.minInt(mod, unrefined_elem_ty),
|
||||
.max => try comptime_elem_ty.minInt(mod, comptime_elem_ty), // @max(ct, rt) >= ct
|
||||
else => unreachable,
|
||||
};
|
||||
const max_val = switch (air_tag) {
|
||||
.min => try comptime_elem_ty.maxInt(mod, comptime_elem_ty), // @min(ct, rt) <= ct
|
||||
.max => try unrefined_elem_ty.maxInt(mod, unrefined_elem_ty),
|
||||
else => unreachable,
|
||||
};
|
||||
|
||||
// Find the smallest type which can contain these bounds
|
||||
const final_elem_ty = try mod.intFittingRange(min_val, max_val);
|
||||
|
||||
const final_ty = if (is_vector)
|
||||
try mod.vectorType(.{
|
||||
.len = unrefined_ty.vectorLen(mod),
|
||||
.child = final_elem_ty.toIntern(),
|
||||
})
|
||||
else
|
||||
final_elem_ty;
|
||||
|
||||
if (!final_ty.eql(unrefined_ty, mod)) {
|
||||
// We've reduced the type - cast the result down
|
||||
return block.addTyOp(.intcast, final_ty, cur_minmax.?);
|
||||
}
|
||||
if (!refined_ty.eql(unrefined_ty, mod)) {
|
||||
// We've reduced the type - cast the result down
|
||||
return block.addTyOp(.intcast, refined_ty, cur_minmax.?);
|
||||
}
|
||||
|
||||
return cur_minmax.?;
|
||||
|
||||
@ -4146,6 +4146,20 @@ pub const Value = struct {
|
||||
return val.toIntern() == .generic_poison;
|
||||
}
|
||||
|
||||
/// For an integer (comptime or fixed-width) `val`, returns the comptime-known bounds of the value.
|
||||
/// If `val` is not undef, the bounds are both `val`.
|
||||
/// If `val` is undef and has a fixed-width type, the bounds are the bounds of the type.
|
||||
/// If `val` is undef and is a `comptime_int`, returns null.
|
||||
pub fn intValueBounds(val: Value, mod: *Module) !?[2]Value {
|
||||
if (!val.isUndef(mod)) return .{ val, val };
|
||||
const ty = mod.intern_pool.typeOf(val.toIntern());
|
||||
if (ty == .comptime_int_type) return null;
|
||||
return .{
|
||||
try ty.toType().minInt(mod, ty.toType()),
|
||||
try ty.toType().maxInt(mod, ty.toType()),
|
||||
};
|
||||
}
|
||||
|
||||
/// This type is not copyable since it may contain pointers to its inner data.
|
||||
pub const Payload = struct {
|
||||
tag: Tag,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
const mem = std.mem;
|
||||
const assert = std.debug.assert;
|
||||
const expect = std.testing.expect;
|
||||
const expectEqual = std.testing.expectEqual;
|
||||
|
||||
@ -210,3 +211,87 @@ test "@min/@max on comptime_int" {
|
||||
try expectEqual(-2, min);
|
||||
try expectEqual(2, max);
|
||||
}
|
||||
|
||||
test "@min/@max notices bounds from types" {
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
|
||||
var x: u16 = 123;
|
||||
var y: u32 = 456;
|
||||
var z: u8 = 10;
|
||||
|
||||
const min = @min(x, y, z);
|
||||
const max = @max(x, y, z);
|
||||
|
||||
comptime assert(@TypeOf(min) == u8);
|
||||
comptime assert(@TypeOf(max) == u32);
|
||||
|
||||
try expectEqual(z, min);
|
||||
try expectEqual(y, max);
|
||||
}
|
||||
|
||||
test "@min/@max notices bounds from vector types" {
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
|
||||
var x: @Vector(2, u16) = .{ 30, 67 };
|
||||
var y: @Vector(2, u32) = .{ 20, 500 };
|
||||
var z: @Vector(2, u8) = .{ 60, 15 };
|
||||
|
||||
const min = @min(x, y, z);
|
||||
const max = @max(x, y, z);
|
||||
|
||||
comptime assert(@TypeOf(min) == @Vector(2, u8));
|
||||
comptime assert(@TypeOf(max) == @Vector(2, u32));
|
||||
|
||||
try expectEqual(@Vector(2, u8){ 20, 15 }, min);
|
||||
try expectEqual(@Vector(2, u32){ 60, 500 }, max);
|
||||
}
|
||||
|
||||
test "@min/@max notices bounds from types when comptime-known value is undef" {
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
|
||||
var x: u32 = 1_000_000;
|
||||
const y: u16 = undefined;
|
||||
// y is comptime-known, but is undef, so bounds cannot be refined using its value
|
||||
|
||||
const min = @min(x, y);
|
||||
const max = @max(x, y);
|
||||
|
||||
comptime assert(@TypeOf(min) == u16);
|
||||
comptime assert(@TypeOf(max) == u32);
|
||||
|
||||
// Cannot assert values as one was undefined
|
||||
}
|
||||
|
||||
test "@min/@max notices bounds from vector types when element of comptime-known vector is undef" {
|
||||
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
|
||||
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
|
||||
|
||||
var x: @Vector(2, u32) = .{ 1_000_000, 12345 };
|
||||
const y: @Vector(2, u16) = .{ 10, undefined };
|
||||
// y is comptime-known, but an element is undef, so bounds cannot be refined using its value
|
||||
|
||||
const min = @min(x, y);
|
||||
const max = @max(x, y);
|
||||
|
||||
comptime assert(@TypeOf(min) == @Vector(2, u16));
|
||||
comptime assert(@TypeOf(max) == @Vector(2, u32));
|
||||
|
||||
try expectEqual(@as(u16, 10), min[0]);
|
||||
try expectEqual(@as(u32, 1_000_000), max[0]);
|
||||
// Cannot assert values at index 1 as one was undefined
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user