stage2: add runtime safety for invalid enum values

This commit is contained in:
Veikka Tuominen 2022-08-05 18:15:31 +03:00
parent 19d5ffc710
commit f46d7304b1
15 changed files with 131 additions and 7 deletions

View File

@ -660,6 +660,10 @@ pub const Inst = struct {
/// Uses the `pl_op` field with payload `AtomicRmw`. Operand is `ptr`.
atomic_rmw,
/// Returns true if enum tag value has a name.
/// Uses the `un_op` field.
is_named_enum_value,
/// Given an enum tag value, returns the tag name. The enum type may be non-exhaustive.
/// Result type is always `[:0]const u8`.
/// Uses the `un_op` field.
@ -1057,6 +1061,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
.is_non_err,
.is_err_ptr,
.is_non_err_ptr,
.is_named_enum_value,
=> return Type.bool,
.const_ty => return Type.type,

View File

@ -291,6 +291,7 @@ pub fn categorizeOperand(
.is_non_err_ptr,
.ptrtoint,
.bool_to_int,
.is_named_enum_value,
.tag_name,
.error_name,
.sqrt,
@ -858,6 +859,7 @@ fn analyzeInst(
.bool_to_int,
.ret,
.ret_load,
.is_named_enum_value,
.tag_name,
.error_name,
.sqrt,

View File

@ -6933,8 +6933,12 @@ fn zirIntToEnum(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
}
try sema.requireRuntimeBlock(block, src, operand_src);
// TODO insert safety check to make sure the value matches an enum value
return block.addTyOp(.intcast, dest_ty, operand);
const result = try block.addTyOp(.intcast, dest_ty, operand);
if (block.wantSafety() and !dest_ty.isNonexhaustiveEnum() and sema.mod.comp.bin_file.options.use_llvm) {
const ok = try block.addUnOp(.is_named_enum_value, result);
try sema.addSafetyCheck(block, ok, .invalid_enum_value);
}
return result;
}
/// Pointer in, pointer out.
@ -15887,6 +15891,11 @@ fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
const field_name = enum_ty.enumFieldName(field_index);
return sema.addStrLit(block, field_name);
}
try sema.requireRuntimeBlock(block, src, operand_src);
if (block.wantSafety() and sema.mod.comp.bin_file.options.use_llvm) {
const ok = try block.addUnOp(.is_named_enum_value, casted_operand);
try sema.addSafetyCheck(block, ok, .invalid_enum_value);
}
// In case the value is runtime-known, we have an AIR instruction for this instead
// of trying to lower it in Sema because an optimization pass may result in the operand
// being comptime-known, which would let us elide the `tag_name` AIR instruction.
@ -20019,6 +20028,7 @@ pub const PanicId = enum {
integer_part_out_of_bounds,
corrupt_switch,
shift_rhs_too_big,
invalid_enum_value,
};
fn addSafetyCheck(
@ -20316,6 +20326,7 @@ fn safetyPanic(
.integer_part_out_of_bounds => "integer part of floating point value out of bounds",
.corrupt_switch => "switch on corrupt value",
.shift_rhs_too_big => "shift amount is greater than the type size",
.invalid_enum_value => "invalid enum value",
};
const msg_inst = msg_inst: {

View File

@ -753,6 +753,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.float_to_int_optimized,
=> return self.fail("TODO implement optimized float mode", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
.wasm_memory_size => unreachable,
.wasm_memory_grow => unreachable,
// zig fmt: on

View File

@ -768,6 +768,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.float_to_int_optimized,
=> return self.fail("TODO implement optimized float mode", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
.wasm_memory_size => unreachable,
.wasm_memory_grow => unreachable,
// zig fmt: on

View File

@ -693,6 +693,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.float_to_int_optimized,
=> return self.fail("TODO implement optimized float mode", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
.wasm_memory_size => unreachable,
.wasm_memory_grow => unreachable,
// zig fmt: on

View File

@ -705,6 +705,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.float_to_int_optimized,
=> @panic("TODO implement optimized float mode"),
.is_named_enum_value => @panic("TODO implement is_named_enum_value"),
.wasm_memory_size => unreachable,
.wasm_memory_grow => unreachable,
// zig fmt: on

View File

@ -1621,6 +1621,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.tag_name,
.err_return_trace,
.set_err_return_trace,
.is_named_enum_value,
=> |tag| return self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
.add_optimized,

View File

@ -775,6 +775,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.float_to_int_optimized,
=> return self.fail("TODO implement optimized float mode", .{}),
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
.wasm_memory_size => unreachable,
.wasm_memory_grow => unreachable,
// zig fmt: on

View File

@ -1952,6 +1952,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
.reduce_optimized,
.float_to_int_optimized,
=> return f.fail("TODO implement optimized float mode", .{}),
.is_named_enum_value => return f.fail("TODO: C backend: implement is_named_enum_value", .{}),
// zig fmt: on
};
switch (result_value) {

View File

@ -201,6 +201,8 @@ pub const Object = struct {
/// * it works for functions not all globals.
/// Therefore, this table keeps track of the mapping.
decl_map: std.AutoHashMapUnmanaged(Module.Decl.Index, *const llvm.Value),
/// Serves the same purpose as `decl_map` but only used for the `is_named_enum_value` instruction.
named_enum_map: std.AutoHashMapUnmanaged(Module.Decl.Index, *const llvm.Value),
/// Maps Zig types to LLVM types. The table memory itself is backed by the GPA of
/// the compiler, but the Type/Value memory here is backed by `type_map_arena`.
/// TODO we need to remove entries from this map in response to incremental compilation
@ -377,6 +379,7 @@ pub const Object = struct {
.target_data = target_data,
.target = options.target,
.decl_map = .{},
.named_enum_map = .{},
.type_map = .{},
.type_map_arena = std.heap.ArenaAllocator.init(gpa),
.di_type_map = .{},
@ -396,6 +399,7 @@ pub const Object = struct {
self.llvm_module.dispose();
self.context.dispose();
self.decl_map.deinit(gpa);
self.named_enum_map.deinit(gpa);
self.type_map.deinit(gpa);
self.type_map_arena.deinit();
self.extern_collisions.deinit(gpa);
@ -4180,6 +4184,8 @@ pub const FuncGen = struct {
.union_init => try self.airUnionInit(inst),
.prefetch => try self.airPrefetch(inst),
.is_named_enum_value => try self.airIsNamedEnumValue(inst),
.reduce => try self.airReduce(inst, false),
.reduce_optimized => try self.airReduce(inst, true),
@ -7882,6 +7888,87 @@ pub const FuncGen = struct {
}
}
fn airIsNamedEnumValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst)) return null;
const un_op = self.air.instructions.items(.data)[inst].un_op;
const operand = try self.resolveInst(un_op);
const enum_ty = self.air.typeOf(un_op);
const llvm_fn = try self.getIsNamedEnumValueFunction(enum_ty);
const params = [_]*const llvm.Value{operand};
return self.builder.buildCall(llvm_fn, &params, params.len, .Fast, .Auto, "");
}
fn getIsNamedEnumValueFunction(self: *FuncGen, enum_ty: Type) !*const llvm.Value {
const enum_decl = enum_ty.getOwnerDecl();
// TODO: detect when the type changes and re-emit this function.
const gop = try self.dg.object.named_enum_map.getOrPut(self.dg.gpa, enum_decl);
if (gop.found_existing) return gop.value_ptr.*;
errdefer assert(self.dg.object.named_enum_map.remove(enum_decl));
var arena_allocator = std.heap.ArenaAllocator.init(self.gpa);
defer arena_allocator.deinit();
const arena = arena_allocator.allocator();
const mod = self.dg.module;
const llvm_fn_name = try std.fmt.allocPrintZ(arena, "__zig_is_named_enum_value_{s}", .{
try mod.declPtr(enum_decl).getFullyQualifiedName(mod),
});
var int_tag_type_buffer: Type.Payload.Bits = undefined;
const int_tag_ty = enum_ty.intTagType(&int_tag_type_buffer);
const param_types = [_]*const llvm.Type{try self.dg.lowerType(int_tag_ty)};
const llvm_ret_ty = try self.dg.lowerType(Type.bool);
const fn_type = llvm.functionType(llvm_ret_ty, &param_types, param_types.len, .False);
const fn_val = self.dg.object.llvm_module.addFunction(llvm_fn_name, fn_type);
fn_val.setLinkage(.Internal);
fn_val.setFunctionCallConv(.Fast);
self.dg.addCommonFnAttributes(fn_val);
gop.value_ptr.* = fn_val;
const prev_block = self.builder.getInsertBlock();
const prev_debug_location = self.builder.getCurrentDebugLocation2();
defer {
self.builder.positionBuilderAtEnd(prev_block);
if (self.di_scope != null) {
self.builder.setCurrentDebugLocation2(prev_debug_location);
}
}
const entry_block = self.dg.context.appendBasicBlock(fn_val, "Entry");
self.builder.positionBuilderAtEnd(entry_block);
self.builder.clearCurrentDebugLocation();
const fields = enum_ty.enumFields();
const named_block = self.dg.context.appendBasicBlock(fn_val, "Named");
const unnamed_block = self.dg.context.appendBasicBlock(fn_val, "Unnamed");
const tag_int_value = fn_val.getParam(0);
const switch_instr = self.builder.buildSwitch(tag_int_value, unnamed_block, @intCast(c_uint, fields.count()));
for (fields.keys()) |_, field_index| {
const this_tag_int_value = int: {
var tag_val_payload: Value.Payload.U32 = .{
.base = .{ .tag = .enum_field_index },
.data = @intCast(u32, field_index),
};
break :int try self.dg.lowerValue(.{
.ty = enum_ty,
.val = Value.initPayload(&tag_val_payload.base),
});
};
switch_instr.addCase(this_tag_int_value, named_block);
}
self.builder.positionBuilderAtEnd(named_block);
_ = self.builder.buildRet(self.dg.context.intType(1).constInt(1, .False));
self.builder.positionBuilderAtEnd(unnamed_block);
_ = self.builder.buildRet(self.dg.context.intType(1).constInt(0, .False));
return fn_val;
}
fn airTagName(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst)) return null;

View File

@ -170,6 +170,7 @@ const Writer = struct {
.bool_to_int,
.ret,
.ret_load,
.is_named_enum_value,
.tag_name,
.error_name,
.sqrt,

View File

@ -1,9 +1,11 @@
const std = @import("std");
pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
_ = message;
_ = stack_trace;
std.process.exit(0);
if (std.mem.eql(u8, message, "invalid enum value")) {
std.process.exit(0);
}
std.process.exit(1);
}
const Foo = enum {
A,
@ -18,6 +20,7 @@ fn bar(a: u2) Foo {
return @intToEnum(Foo, a);
}
fn baz(_: Foo) void {}
// run
// backend=stage1
// backend=llvm
// target=native

View File

@ -10,6 +10,7 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
const E = enum(u32) {
X = 1,
Y = 2,
};
pub fn main() !void {
@ -21,5 +22,5 @@ pub fn main() !void {
}
// run
// backend=stage1
// backend=llvm
// target=native

View File

@ -10,6 +10,7 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
const U = union(enum(u32)) {
X: u8,
Y: i8,
};
pub fn main() !void {
@ -22,5 +23,5 @@ pub fn main() !void {
}
// run
// backend=stage1
// backend=llvm
// target=native