compiler: add intcast_safe AIR instruction

This instruction is like `intcast`, but includes two safety checks:

* Checks that the int is in range of the destination type
* If the destination type is an exhaustive enum, checks that the int
  is a named enum value

This instruction is locked behind the `safety_checked_instructions`
backend feature; if unsupported, Sema will emit a fallback, as with
other safety-checked instructions.

This instruction is used to add a missing safety check for `@enumFromInt`
truncating bits. This check also has a fallback for backends which do
not yet support `safety_checked_instructions`.

Resolves: #21946
This commit is contained in:
mlugg 2025-01-29 18:45:08 +00:00 committed by Matthew Lugg
parent c5e34df555
commit b01d6b156c
17 changed files with 249 additions and 29 deletions

View File

@ -574,6 +574,12 @@ pub const Inst = struct {
/// See `trunc` for integer truncation.
/// Uses the `ty_op` field.
intcast,
/// Like `intcast`, but includes two safety checks:
/// * triggers a safety panic if the cast truncates bits
/// * triggers a safety panic if the destination type is an exhaustive enum
/// and the operand is not a valid value of this type; i.e. equivalent to
/// a safety check based on `.is_named_enum_value`
intcast_safe,
/// Truncate higher bits from an integer, resulting in an integer with the same
/// sign but an equal or smaller number of bits.
/// Uses the `ty_op` field.
@ -1463,6 +1469,7 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool)
.fpext,
.fptrunc,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,
@ -1712,6 +1719,7 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> true,
.add,

View File

@ -104,6 +104,7 @@ fn checkBody(air: Air, body: []const Air.Inst.Index, zcu: *Zcu) bool {
.fptrunc,
.fpext,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,

View File

@ -345,6 +345,7 @@ pub fn categorizeOperand(
.fpext,
.fptrunc,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,
@ -977,6 +978,7 @@ fn analyzeInst(
.fpext,
.fptrunc,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,

View File

@ -81,6 +81,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void {
.fpext,
.fptrunc,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,

View File

@ -8798,11 +8798,12 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0);
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@enumFromInt");
const operand = try sema.resolveInst(extra.rhs);
const operand_ty = sema.typeOf(operand);
if (dest_ty.zigTypeTag(zcu) != .@"enum") {
return sema.fail(block, src, "expected enum, found '{}'", .{dest_ty.fmt(pt)});
}
_ = try sema.checkIntType(block, operand_src, sema.typeOf(operand));
_ = try sema.checkIntType(block, operand_src, operand_ty);
if (try sema.resolveValue(operand)) |int_val| {
if (dest_ty.isNonexhaustiveEnum(zcu)) {
@ -8830,23 +8831,39 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
}
if (try sema.typeHasOnePossibleValue(dest_ty)) |opv| {
const result = Air.internedToRef(opv.toIntern());
// The operand is runtime-known but the result is comptime-known. In
// this case we still need a safety check.
// TODO add a safety check here. we can't use is_named_enum_value -
// it needs to convert the enum back to int and make sure it equals the operand int.
return result;
if (block.wantSafety()) {
// The operand is runtime-known but the result is comptime-known. In
// this case we still need a safety check.
const expect_int_val = switch (zcu.intern_pool.indexToKey(opv.toIntern())) {
.enum_tag => |enum_tag| enum_tag.int,
else => unreachable,
};
const expect_int_coerced = try pt.getCoerced(.fromInterned(expect_int_val), operand_ty);
const ok = try block.addBinOp(.cmp_eq, operand, Air.internedToRef(expect_int_coerced.toIntern()));
try sema.addSafetyCheck(block, src, ok, .invalid_enum_value);
}
return Air.internedToRef(opv.toIntern());
}
try sema.requireRuntimeBlock(block, src, operand_src);
const result = try block.addTyOp(.intcast, dest_ty, operand);
if (block.wantSafety() and !dest_ty.isNonexhaustiveEnum(zcu) and
zcu.backendSupportsFeature(.is_named_enum_value))
{
const ok = try block.addUnOp(.is_named_enum_value, result);
try sema.addSafetyCheck(block, src, ok, .invalid_enum_value);
if (block.wantSafety()) {
if (zcu.backendSupportsFeature(.safety_checked_instructions)) {
_ = try sema.preparePanicId(src, .invalid_enum_value);
return block.addTyOp(.intcast_safe, dest_ty, operand);
} else {
// Slightly silly fallback case...
const int_tag_ty = dest_ty.intTagType(zcu);
// Use `intCast`, since it'll set up the Sema-emitted safety checks for us!
const int_val = try sema.intCast(block, src, int_tag_ty, src, operand, src, true, true);
const result = try block.addBitCast(dest_ty, int_val);
if (zcu.backendSupportsFeature(.is_named_enum_value)) {
const ok = try block.addUnOp(.is_named_enum_value, result);
try sema.addSafetyCheck(block, src, ok, .invalid_enum_value);
}
return result;
}
}
return result;
return block.addTyOp(.intcast, dest_ty, operand);
}
/// Pointer in, pointer out.
@ -10192,7 +10209,7 @@ fn zirIntCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@intCast");
const operand = try sema.resolveInst(extra.rhs);
return sema.intCast(block, block.nodeOffset(inst_data.src_node), dest_ty, src, operand, operand_src, true);
return sema.intCast(block, block.nodeOffset(inst_data.src_node), dest_ty, src, operand, operand_src, true, false);
}
fn intCast(
@ -10204,6 +10221,7 @@ fn intCast(
operand: Air.Inst.Ref,
operand_src: LazySrcLoc,
runtime_safety: bool,
safety_panics_are_enum: bool,
) CompileError!Air.Inst.Ref {
const pt = sema.pt;
const zcu = pt.zcu;
@ -10242,7 +10260,7 @@ fn intCast(
const is_in_range = try block.addBinOp(.cmp_lte, operand, zero_inst);
break :ok is_in_range;
};
try sema.addSafetyCheck(block, src, ok, .cast_truncated_data);
try sema.addSafetyCheck(block, src, ok, if (safety_panics_are_enum) .invalid_enum_value else .cast_truncated_data);
}
}
@ -10251,6 +10269,11 @@ fn intCast(
try sema.requireRuntimeBlock(block, src, operand_src);
if (runtime_safety and block.wantSafety()) {
if (zcu.backendSupportsFeature(.safety_checked_instructions)) {
_ = try sema.preparePanicId(src, .negative_to_unsigned);
_ = try sema.preparePanicId(src, .cast_truncated_data);
return block.addTyOp(.intcast_safe, dest_ty, operand);
}
const actual_info = operand_scalar_ty.intInfo(zcu);
const wanted_info = dest_scalar_ty.intInfo(zcu);
const actual_bits = actual_info.bits;
@ -10305,7 +10328,7 @@ fn intCast(
break :ok is_in_range;
};
// TODO negative_to_unsigned?
try sema.addSafetyCheck(block, src, ok, .cast_truncated_data);
try sema.addSafetyCheck(block, src, ok, if (safety_panics_are_enum) .invalid_enum_value else .cast_truncated_data);
} else {
const ok = if (is_vector) ok: {
const is_in_range = try block.addCmpVector(operand, dest_max, .lte);
@ -10321,7 +10344,7 @@ fn intCast(
const is_in_range = try block.addBinOp(.cmp_lte, operand, dest_max);
break :ok is_in_range;
};
try sema.addSafetyCheck(block, src, ok, .cast_truncated_data);
try sema.addSafetyCheck(block, src, ok, if (safety_panics_are_enum) .invalid_enum_value else .cast_truncated_data);
}
} else if (actual_info.signedness == .signed and wanted_info.signedness == .unsigned) {
// no shrinkage, yes sign loss
@ -10344,7 +10367,7 @@ fn intCast(
const is_in_range = try block.addBinOp(.cmp_gte, operand, zero_inst);
break :ok is_in_range;
};
try sema.addSafetyCheck(block, src, ok, .negative_to_unsigned);
try sema.addSafetyCheck(block, src, ok, if (safety_panics_are_enum) .invalid_enum_value else .negative_to_unsigned);
}
}
return block.addTyOp(.intcast, dest_ty, operand);
@ -14149,7 +14172,7 @@ fn zirShl(
{
const max_int = Air.internedToRef((try lhs_ty.maxInt(pt, lhs_ty)).toIntern());
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, .min, &.{ rhs, max_int }, &.{ rhs_src, rhs_src });
break :rhs try sema.intCast(block, src, lhs_ty, rhs_src, rhs_limited, rhs_src, false);
break :rhs try sema.intCast(block, src, lhs_ty, rhs_src, rhs_limited, rhs_src, false, false);
} else {
break :rhs rhs;
}

View File

@ -3313,7 +3313,7 @@ pub fn addGlobalAssembly(zcu: *Zcu, unit: AnalUnit, source: []const u8) !void {
pub const Feature = enum {
/// When this feature is enabled, Sema will emit calls to
/// `std.builtin.Panic` functions for things like safety checks and
/// `std.builtin.panic` functions for things like safety checks and
/// unreachables. Otherwise traps will be emitted.
panic_fn,
/// When this feature is enabled, Sema will insert tracer functions for gathering a stack
@ -3329,6 +3329,7 @@ pub const Feature = enum {
/// * `Air.Inst.Tag.add_safe`
/// * `Air.Inst.Tag.sub_safe`
/// * `Air.Inst.Tag.mul_safe`
/// * `Air.Inst.Tag.intcast_safe`
/// The motivation for this feature is that it makes AIR smaller, and makes it easier
/// to generate better machine code in the backends. All backends should migrate to
/// enabling this feature.

View File

@ -871,6 +871,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return self.fail("TODO implement safety_checked_instructions", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),

View File

@ -860,6 +860,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return self.fail("TODO implement safety_checked_instructions", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),

View File

@ -1517,6 +1517,7 @@ fn genBody(func: *Func, body: []const Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return func.fail("TODO implement safety_checked_instructions", .{}),
.cmp_lt,

View File

@ -714,6 +714,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> @panic("TODO implement safety_checked_instructions"),
.is_named_enum_value => @panic("TODO implement is_named_enum_value"),

View File

@ -2084,6 +2084,7 @@ fn genInst(cg: *CodeGen, inst: Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return cg.fail("TODO implement safety_checked_instructions", .{}),
.work_item_id,

View File

@ -2549,6 +2549,7 @@ fn genBody(cg: *CodeGen, body: []const Air.Inst.Index) InnerError!void {
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return cg.fail("TODO implement safety_checked_instructions", .{}),
.add_optimized => try cg.airBinOp(inst, .add),

View File

@ -3436,6 +3436,7 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail,
.add_safe,
.sub_safe,
.mul_safe,
.intcast_safe,
=> return f.fail("TODO implement safety_checked_instructions", .{}),
.is_named_enum_value => return f.fail("TODO: C backend: implement is_named_enum_value", .{}),

View File

@ -5162,7 +5162,8 @@ pub const FuncGen = struct {
.try_cold => try self.airTry(body[i..], true),
.try_ptr => try self.airTryPtr(inst, false),
.try_ptr_cold => try self.airTryPtr(inst, true),
.intcast => try self.airIntCast(inst),
.intcast => try self.airIntCast(inst, false),
.intcast_safe => try self.airIntCast(inst, true),
.trunc => try self.airTrunc(inst),
.fptrunc => try self.airFptrunc(inst),
.fpext => try self.airFpext(inst),
@ -9246,20 +9247,110 @@ pub const FuncGen = struct {
}
}
fn airIntCast(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value {
const o = self.ng.object;
fn airIntCast(fg: *FuncGen, inst: Air.Inst.Index, safety: bool) !Builder.Value {
const o = fg.ng.object;
const zcu = o.pt.zcu;
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const dest_ty = self.typeOfIndex(inst);
const ty_op = fg.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const dest_ty = fg.typeOfIndex(inst);
const dest_llvm_ty = try o.lowerType(dest_ty);
const operand = try self.resolveInst(ty_op.operand);
const operand_ty = self.typeOf(ty_op.operand);
const operand = try fg.resolveInst(ty_op.operand);
const operand_ty = fg.typeOf(ty_op.operand);
const operand_info = operand_ty.intInfo(zcu);
return self.wip.conv(switch (operand_info.signedness) {
const dest_is_enum = dest_ty.zigTypeTag(zcu) == .@"enum";
safety: {
if (!safety) break :safety;
const dest_scalar = dest_ty.scalarType(zcu);
const operand_scalar = operand_ty.scalarType(zcu);
const dest_info = dest_ty.intInfo(zcu);
const have_min_check, const have_max_check = c: {
const dest_pos_bits = dest_info.bits - @intFromBool(dest_info.signedness == .signed);
const operand_pos_bits = operand_info.bits - @intFromBool(operand_info.signedness == .signed);
const dest_allows_neg = dest_info.signedness == .signed and dest_info.bits > 0;
const operand_maybe_neg = operand_info.signedness == .signed and operand_info.bits > 0;
break :c .{
operand_maybe_neg and (!dest_allows_neg or dest_info.bits < operand_info.bits),
dest_pos_bits < operand_pos_bits,
};
};
if (!have_min_check and !have_max_check) break :safety;
const operand_llvm_ty = try o.lowerType(operand_ty);
const operand_scalar_llvm_ty = try o.lowerType(operand_scalar);
const is_vector = operand_ty.zigTypeTag(zcu) == .vector;
assert(is_vector == (dest_ty.zigTypeTag(zcu) == .vector));
const min_panic_id: Zcu.SimplePanicId, const max_panic_id: Zcu.SimplePanicId = id: {
if (dest_is_enum) break :id .{ .invalid_enum_value, .invalid_enum_value };
if (dest_info.signedness == .unsigned) break :id .{ .negative_to_unsigned, .cast_truncated_data };
break :id .{ .cast_truncated_data, .cast_truncated_data };
};
if (have_min_check) {
const min_const_scalar = try minIntConst(&o.builder, dest_scalar, operand_scalar_llvm_ty, zcu);
const min_val = if (is_vector) try o.builder.splatValue(operand_llvm_ty, min_const_scalar) else min_const_scalar.toValue();
const ok_maybe_vec = try fg.cmp(.normal, .gte, operand_ty, operand, min_val);
const ok = if (is_vector) ok: {
const vec_ty = ok_maybe_vec.typeOfWip(&fg.wip);
break :ok try fg.wip.callIntrinsic(.normal, .none, .@"vector.reduce.and", &.{vec_ty}, &.{ok_maybe_vec}, "");
} else ok_maybe_vec;
const fail_block = try fg.wip.block(1, "IntMinFail");
const ok_block = try fg.wip.block(1, "IntMinOk");
_ = try fg.wip.brCond(ok, ok_block, fail_block, .none);
fg.wip.cursor = .{ .block = fail_block };
try fg.buildSimplePanic(min_panic_id);
fg.wip.cursor = .{ .block = ok_block };
}
if (have_max_check) {
const max_const_scalar = try maxIntConst(&o.builder, dest_scalar, operand_scalar_llvm_ty, zcu);
const max_val = if (is_vector) try o.builder.splatValue(operand_llvm_ty, max_const_scalar) else max_const_scalar.toValue();
const ok_maybe_vec = try fg.cmp(.normal, .lte, operand_ty, operand, max_val);
const ok = if (is_vector) ok: {
const vec_ty = ok_maybe_vec.typeOfWip(&fg.wip);
break :ok try fg.wip.callIntrinsic(.normal, .none, .@"vector.reduce.and", &.{vec_ty}, &.{ok_maybe_vec}, "");
} else ok_maybe_vec;
const fail_block = try fg.wip.block(1, "IntMaxFail");
const ok_block = try fg.wip.block(1, "IntMaxOk");
_ = try fg.wip.brCond(ok, ok_block, fail_block, .none);
fg.wip.cursor = .{ .block = fail_block };
try fg.buildSimplePanic(max_panic_id);
fg.wip.cursor = .{ .block = ok_block };
}
}
const result = try fg.wip.conv(switch (operand_info.signedness) {
.signed => .signed,
.unsigned => .unsigned,
}, operand, dest_llvm_ty, "");
if (safety and dest_is_enum and !dest_ty.isNonexhaustiveEnum(zcu)) {
const llvm_fn = try fg.getIsNamedEnumValueFunction(dest_ty);
const is_valid_enum_val = try fg.wip.call(
.normal,
.fastcc,
.none,
llvm_fn.typeOf(&o.builder),
llvm_fn.toValue(&o.builder),
&.{result},
"",
);
const fail_block = try fg.wip.block(1, "ValidEnumFail");
const ok_block = try fg.wip.block(1, "ValidEnumOk");
_ = try fg.wip.brCond(is_valid_enum_val, ok_block, fail_block, .none);
fg.wip.cursor = .{ .block = fail_block };
try fg.buildSimplePanic(.invalid_enum_value);
fg.wip.cursor = .{ .block = ok_block };
}
return result;
}
fn airTrunc(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value {
@ -12953,3 +13044,42 @@ pub fn initializeLLVMTarget(arch: std.Target.Cpu.Arch) void {
=> unreachable,
}
}
fn minIntConst(b: *Builder, min_ty: Type, as_ty: Builder.Type, zcu: *const Zcu) Allocator.Error!Builder.Constant {
const info = min_ty.intInfo(zcu);
if (info.signedness == .unsigned or info.bits == 0) {
return b.intConst(as_ty, 0);
}
if (std.math.cast(u6, info.bits - 1)) |shift| {
const min_val: i64 = @as(i64, std.math.minInt(i64)) >> (63 - shift);
return b.intConst(as_ty, min_val);
}
var res: std.math.big.int.Managed = try .init(zcu.gpa);
defer res.deinit();
try res.setTwosCompIntLimit(.min, info.signedness, info.bits);
return b.bigIntConst(as_ty, res.toConst());
}
fn maxIntConst(b: *Builder, max_ty: Type, as_ty: Builder.Type, zcu: *const Zcu) Allocator.Error!Builder.Constant {
const info = max_ty.intInfo(zcu);
switch (info.bits) {
0 => return b.intConst(as_ty, 0),
1 => switch (info.signedness) {
.signed => return b.intConst(as_ty, 0),
.unsigned => return b.intConst(as_ty, 1),
},
else => {},
}
const unsigned_bits = switch (info.signedness) {
.unsigned => info.bits,
.signed => info.bits - 1,
};
if (std.math.cast(u6, unsigned_bits)) |shift| {
const max_val: u64 = (@as(u64, 1) << shift) - 1;
return b.intConst(as_ty, max_val);
}
var res: std.math.big.int.Managed = try .init(zcu.gpa);
defer res.deinit();
try res.setTwosCompIntLimit(.max, info.signedness, info.bits);
return b.bigIntConst(as_ty, res.toConst());
}

View File

@ -223,6 +223,7 @@ const Writer = struct {
.fptrunc,
.fpext,
.intcast,
.intcast_safe,
.trunc,
.optional_payload,
.optional_payload_ptr,

View File

@ -0,0 +1,23 @@
const std = @import("std");
pub fn panic(message: []const u8, _: ?*std.builtin.StackTrace, _: ?usize) noreturn {
if (std.mem.eql(u8, message, "invalid enum value")) {
std.process.exit(0);
}
std.process.exit(1);
}
pub fn main() u8 {
var num: i5 = undefined;
num = 14;
const E = enum(u3) { a, b, c, d, e, f, g, h };
const invalid: E = @enumFromInt(num);
_ = invalid;
return 1;
}
// run
// backend=llvm
// target=native

View File

@ -0,0 +1,23 @@
const std = @import("std");
pub fn panic(message: []const u8, _: ?*std.builtin.StackTrace, _: ?usize) noreturn {
if (std.mem.eql(u8, message, "invalid enum value")) {
std.process.exit(0);
}
std.process.exit(1);
}
pub fn main() u8 {
var num: u8 = undefined;
num = 250;
const E = enum(u6) { _ };
const invalid: E = @enumFromInt(num);
_ = invalid;
return 1;
}
// run
// backend=llvm
// target=native