error sets: runtime safety for int-to-err and err set cast

This commit is contained in:
Andrew Kelley 2018-02-08 21:54:44 -05:00
parent 8fc6e31567
commit 54c06bf715
5 changed files with 139 additions and 42 deletions

4
TODO
View File

@ -17,10 +17,6 @@ you can get the compiler to tell you the possible errors for an inferred error s
foo() catch |err| switch (err) {};
// TODO this is an explicit cast and should actually coerce the type
erorr set casting
// add a runtime safety check
test err should be comptime if error set has 0 members

View File

@ -1958,6 +1958,54 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
zig_unreachable();
}
static void add_error_range_check(CodeGen *g, TypeTableEntry *err_set_type, TypeTableEntry *int_type, LLVMValueRef target_val) {
assert(err_set_type->id == TypeTableEntryIdErrorSet);
if (type_is_global_error_set(err_set_type)) {
LLVMValueRef zero = LLVMConstNull(int_type->type_ref);
LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
LLVMValueRef ok_bit;
BigInt biggest_possible_err_val = {0};
eval_min_max_value_int(g, int_type, &biggest_possible_err_val, true);
if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
{
ok_bit = neq_zero_bit;
} else {
LLVMValueRef error_value_count = LLVMConstInt(int_type->type_ref, g->errors_by_index.length, false);
LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
}
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
LLVMPositionBuilderAtEnd(g->builder, ok_block);
} else {
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
uint32_t err_count = err_set_type->data.error_set.err_count;
LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, target_val, fail_block, err_count);
for (uint32_t i = 0; i < err_count; i += 1) {
LLVMValueRef case_value = LLVMConstInt(g->err_tag_type->type_ref, err_set_type->data.error_set.errors[i]->value, false);
LLVMAddCase(switch_instr, case_value, ok_block);
}
LLVMPositionBuilderAtEnd(g->builder, fail_block);
gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
LLVMPositionBuilderAtEnd(g->builder, ok_block);
}
}
static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
IrInstructionCast *cast_instruction)
{
@ -2082,7 +2130,9 @@ static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
assert(actual_type->id == TypeTableEntryIdBool);
return LLVMBuildZExt(g->builder, expr_val, wanted_type->type_ref, "");
case CastOpErrSet:
// TODO runtime safety for error casting
if (ir_want_runtime_safety(g, &cast_instruction->base)) {
add_error_range_check(g, wanted_type, g->err_tag_type, expr_val);
}
return expr_val;
}
zig_unreachable();
@ -2154,32 +2204,7 @@ static LLVMValueRef ir_render_int_to_err(CodeGen *g, IrExecutable *executable, I
LLVMValueRef target_val = ir_llvm_value(g, instruction->target);
if (ir_want_runtime_safety(g, &instruction->base)) {
LLVMValueRef zero = LLVMConstNull(actual_type->type_ref);
LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
LLVMValueRef ok_bit;
BigInt biggest_possible_err_val = {0};
eval_min_max_value_int(g, actual_type, &biggest_possible_err_val, true);
if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
{
ok_bit = neq_zero_bit;
} else {
LLVMValueRef error_value_count = LLVMConstInt(actual_type->type_ref, g->errors_by_index.length, false);
LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
}
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
LLVMPositionBuilderAtEnd(g->builder, ok_block);
add_error_range_check(g, wanted_type, actual_type, target_val);
}
return gen_widen_or_shorten(g, false, actual_type, g->err_tag_type, target_val);

View File

@ -8505,19 +8505,49 @@ static IrInstruction *ir_analyze_int_to_err(IrAnalyze *ira, IrInstruction *sourc
IrInstruction *result = ir_create_const(&ira->new_irb, source_instr->scope,
source_instr->source_node, wanted_type);
BigInt err_count;
bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
Buf *val_buf = buf_alloc();
bigint_append_buf(val_buf, &val->data.x_bigint, 10);
ir_add_error(ira, source_instr,
buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
if (!resolve_inferred_error_set(ira, wanted_type, source_instr->source_node)) {
return ira->codegen->invalid_instruction;
}
size_t index = bigint_as_unsigned(&val->data.x_bigint);
result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
return result;
if (type_is_global_error_set(wanted_type)) {
BigInt err_count;
bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
Buf *val_buf = buf_alloc();
bigint_append_buf(val_buf, &val->data.x_bigint, 10);
ir_add_error(ira, source_instr,
buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
return ira->codegen->invalid_instruction;
}
size_t index = bigint_as_unsigned(&val->data.x_bigint);
result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
return result;
} else {
ErrorTableEntry *err = nullptr;
BigInt err_int;
for (uint32_t i = 0, count = wanted_type->data.error_set.err_count; i < count; i += 1) {
ErrorTableEntry *this_err = wanted_type->data.error_set.errors[i];
bigint_init_unsigned(&err_int, this_err->value);
if (bigint_cmp(&val->data.x_bigint, &err_int) == CmpEQ) {
err = this_err;
break;
}
}
if (err == nullptr) {
Buf *val_buf = buf_alloc();
bigint_append_buf(val_buf, &val->data.x_bigint, 10);
ir_add_error(ira, source_instr,
buf_sprintf("integer value %s represents no error in '%s'", buf_ptr(val_buf), buf_ptr(&wanted_type->name)));
return ira->codegen->invalid_instruction;
}
result->value.data.x_err_set = err;
return result;
}
}
IrInstruction *result = ir_build_int_to_err(&ira->new_irb, source_instr->scope, source_instr->source_node, target);

View File

@ -1,6 +1,38 @@
const tests = @import("tests.zig");
pub fn addCases(cases: &tests.CompileErrorContext) void {
cases.add("implicit cast of error set not a subset",
\\const Set1 = error{A, B};
\\const Set2 = error{A, C};
\\export fn entry() void {
\\ foo(Set1.B);
\\}
\\fn foo(set1: Set1) void {
\\ var x: Set2 = set1;
\\}
,
".tmp_source.zig:7:19: error: expected 'Set2', found 'Set1'",
".tmp_source.zig:1:23: note: 'error.B' not a member of destination error set");
cases.add("int to err global invalid number",
\\const Set1 = error{A, B};
\\comptime {
\\ var x: usize = 3;
\\ var y = error(x);
\\}
,
".tmp_source.zig:4:18: error: integer value 3 represents no error");
cases.add("int to err non global invalid number",
\\const Set1 = error{A, B};
\\const Set2 = error{A, C};
\\comptime {
\\ var x = usize(Set1.B);
\\ var y = Set2(x);
\\}
,
".tmp_source.zig:5:17: error: integer value 2 represents no error in 'Set2'");
cases.add("@memberCount of error",
\\comptime {
\\ _ = @memberCount(error);

View File

@ -220,7 +220,7 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
\\}
);
cases.addRuntimeSafety("cast integer to error and no code matches",
cases.addRuntimeSafety("cast integer to global error and no code matches",
\\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
\\}
@ -232,6 +232,20 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
\\}
);
cases.addRuntimeSafety("cast integer to non-global error set and no match",
\\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
\\}
\\const Set1 = error{A, B};
\\const Set2 = error{A, C};
\\pub fn main() void {
\\ _ = foo(Set1.B);
\\}
\\fn foo(set1: Set1) Set2 {
\\ return Set2(set1);
\\}
);
cases.addRuntimeSafety("@alignCast misaligned",
\\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);