From d1fa5692c685b804181d4658afce1e53ca74ec19 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 26 Apr 2016 11:35:56 -0700 Subject: [PATCH] add array bounds checking in debug mode closes #27 --- src/all_types.hpp | 3 ++ src/analyze.cpp | 14 +++++ src/codegen.cpp | 130 ++++++++++++++++++++++++++++++++++----------- std/builtin.zig | 2 + test/run_tests.cpp | 122 ++++++++++++++++++++++++++++++++++-------- 5 files changed, 216 insertions(+), 55 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 3bd51002c5..cc944bc35a 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1051,6 +1051,7 @@ struct FnTableEntry { bool is_extern; bool is_test; bool is_pure; + bool safety_off; BlockContext *parent_block_context; FnAnalState anal_state; @@ -1315,6 +1316,8 @@ struct BlockContext { // if this is true, then this code will not be generated bool codegen_excluded; + + bool safety_off; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 737fff0a7b..5be9e10392 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -993,6 +993,18 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t add_node_error(g, directive_node, buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); } + } else if (buf_eql_str(name, "debug_safety")) { + if (fn_table_entry->is_extern) { + add_node_error(g, directive_node, + buf_sprintf("#debug_safety invalid on extern functions")); + } else { + bool enable; + bool ok = resolve_const_expr_bool(g, import, import->block_context, + &directive_node->data.directive.expr, &enable); + if (ok && !enable) { + fn_table_entry->safety_off = true; + } + } } else if (buf_eql_str(name, "condition")) { if (fn_proto->top_level_decl.visib_mod == VisibModExport) { bool include; @@ -2102,11 +2114,13 @@ BlockContext *new_block_context(AstNode *node, BlockContext *parent) { context->parent_loop_node = parent->parent_loop_node; context->c_import_buf = parent->c_import_buf; context->codegen_excluded = parent->codegen_excluded; + context->safety_off = parent->safety_off; } if (node && node->type == NodeTypeFnDef) { AstNode *fn_proto_node = node->data.fn_def.fn_proto; context->fn_entry = fn_proto_node->data.fn_proto.fn_table_entry; + context->safety_off = context->fn_entry->safety_off; } else if (parent) { context->fn_entry = parent->fn_entry; } diff --git a/src/codegen.cpp b/src/codegen.cpp index aa3b4d1529..dda78a0653 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -330,6 +330,46 @@ static LLVMValueRef get_handle_value(CodeGen *g, AstNode *source_node, LLVMValue } } +static bool want_debug_safety(CodeGen *g, AstNode *node) { + return !g->is_release_build && !node->block_context->safety_off; +} + +static void add_bounds_check(CodeGen *g, AstNode *source_node, LLVMValueRef target_val, + LLVMIntPredicate lower_pred, LLVMValueRef lower_value, + LLVMIntPredicate upper_pred, LLVMValueRef upper_value) +{ + if (!lower_value && !upper_value) { + return; + } + if (upper_value && !lower_value) { + lower_value = upper_value; + lower_pred = upper_pred; + upper_value = nullptr; + } + + add_debug_source_node(g, source_node); + + LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk"); + LLVMBasicBlockRef lower_ok_block = upper_value ? + LLVMAppendBasicBlock(g->cur_fn->fn_value, "FirstBoundsCheckOk") : ok_block; + + LLVMValueRef lower_ok_val = LLVMBuildICmp(g->builder, lower_pred, target_val, lower_value, ""); + LLVMBuildCondBr(g->builder, lower_ok_val, lower_ok_block, bounds_check_fail_block); + + LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block); + LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); + LLVMBuildUnreachable(g->builder); + + if (upper_value) { + LLVMPositionBuilderAtEnd(g->builder, lower_ok_block); + LLVMValueRef upper_ok_val = LLVMBuildICmp(g->builder, upper_pred, target_val, upper_value, ""); + LLVMBuildCondBr(g->builder, upper_ok_val, ok_block, bounds_check_fail_block); + } + + LLVMPositionBuilderAtEnd(g->builder, ok_block); +} + static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); assert(g->generate_error_name_table); @@ -344,25 +384,10 @@ static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) { LLVMValueRef err_val = gen_expr(g, err_val_node); add_debug_source_node(g, node); - if (!g->is_release_build) { - LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail"); - LLVMBasicBlockRef lower_ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "LowerBoundsCheckOk"); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk"); - + if (want_debug_safety(g, node)) { LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(err_val)); - LLVMValueRef is_zero_val = LLVMBuildICmp(g->builder, LLVMIntEQ, err_val, zero, ""); - LLVMBuildCondBr(g->builder, is_zero_val, bounds_check_fail_block, lower_ok_block); - - LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block); - LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); - LLVMBuildUnreachable(g->builder); - - LLVMPositionBuilderAtEnd(g->builder, lower_ok_block); LLVMValueRef end_val = LLVMConstInt(LLVMTypeOf(err_val), g->error_decls.length, false); - LLVMValueRef is_too_big_val = LLVMBuildICmp(g->builder, LLVMIntUGE, err_val, end_val, ""); - LLVMBuildCondBr(g->builder, is_too_big_val, bounds_check_fail_block, ok_block); - - LLVMPositionBuilderAtEnd(g->builder, ok_block); + add_bounds_check(g, node, err_val, LLVMIntNE, zero, LLVMIntULT, end_val); } LLVMValueRef indices[] = { @@ -869,6 +894,11 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal } if (array_type->id == TypeTableEntryIdArray) { + if (want_debug_safety(g, source_node)) { + LLVMValueRef end = LLVMConstInt(g->builtin_types.entry_isize->type_ref, + array_type->data.array.len, false); + add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, end); + } LLVMValueRef indices[] = { LLVMConstNull(g->builtin_types.entry_isize->type_ref), subscript_value @@ -887,6 +917,15 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind); assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind); + if (want_debug_safety(g, source_node)) { + add_debug_source_node(g, source_node); + int len_index = array_type->data.structure.fields[1].gen_index; + assert(len_index >= 0); + LLVMValueRef len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); + LLVMValueRef len = LLVMBuildLoad(g->builder, len_ptr, ""); + add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, len); + } + add_debug_source_node(g, source_node); int ptr_index = array_type->data.structure.fields[0].gen_index; assert(ptr_index >= 0); @@ -907,7 +946,6 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) { LLVMValueRef array_ptr = gen_array_base_ptr(g, array_expr_node); LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript); - return gen_array_elem_ptr(g, node, array_ptr, array_type, subscript_value); } @@ -969,6 +1007,15 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { end_val = LLVMConstInt(g->builtin_types.entry_isize->type_ref, array_type->data.array.len, false); } + if (want_debug_safety(g, node)) { + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + if (node->data.slice_expr.end) { + LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_isize->type_ref, + array_type->data.array.len, false); + add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end); + } + } + add_debug_source_node(g, node); LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); LLVMValueRef indices[] = { @@ -987,6 +1034,10 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start); LLVMValueRef end_val = gen_expr(g, node->data.slice_expr.end); + if (want_debug_safety(g, node)) { + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + } + add_debug_source_node(g, node); LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &start_val, 1, ""); @@ -1002,22 +1053,33 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind); assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind); + int ptr_index = array_type->data.structure.fields[0].gen_index; + assert(ptr_index >= 0); + int len_index = array_type->data.structure.fields[1].gen_index; + assert(len_index >= 0); + + LLVMValueRef prev_end = nullptr; + if (!node->data.slice_expr.end || want_debug_safety(g, node)) { + add_debug_source_node(g, node); + LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); + prev_end = LLVMBuildLoad(g->builder, src_len_ptr, ""); + } + LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start); LLVMValueRef end_val; if (node->data.slice_expr.end) { end_val = gen_expr(g, node->data.slice_expr.end); } else { - add_debug_source_node(g, node); - int len_index = array_type->data.structure.fields[1].gen_index; - assert(len_index >= 0); - LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); - end_val = LLVMBuildLoad(g->builder, src_len_ptr, ""); + end_val = prev_end; } - int ptr_index = array_type->data.structure.fields[0].gen_index; - assert(ptr_index >= 0); - int len_index = array_type->data.structure.fields[1].gen_index; - assert(len_index >= 0); + if (want_debug_safety(g, node)) { + assert(prev_end); + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + if (node->data.slice_expr.end) { + add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end); + } + } add_debug_source_node(g, node); LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, ptr_index, ""); @@ -1225,7 +1287,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(expr_type->id == TypeTableEntryIdErrorUnion); TypeTableEntry *child_type = expr_type->data.error.child_type; - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMValueRef err_val; if (type_has_bits(child_type)) { add_debug_source_node(g, node); @@ -1263,7 +1325,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(expr_type->id == TypeTableEntryIdMaybe); TypeTableEntry *child_type = expr_type->data.maybe.child_type; - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { add_debug_source_node(g, node); LLVMValueRef cond_val; if (child_type->id == TypeTableEntryIdPointer || @@ -2261,7 +2323,7 @@ static LLVMValueRef gen_container_init_expr(CodeGen *g, AstNode *node) { } else if (type_entry->id == TypeTableEntryIdUnreachable) { assert(node->data.container_init_expr.entries.length == 0); add_debug_source_node(g, node); - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); } LLVMBuildUnreachable(g->builder); @@ -2575,7 +2637,7 @@ static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVa } } } - if (!ignore_uninit && !g->is_release_build) { + if (!ignore_uninit && want_debug_safety(g, source_node)) { TypeTableEntry *isize = g->builtin_types.entry_isize; uint64_t size_bytes = LLVMStoreSizeOfType(g->target_data_ref, variable->type->type_ref); uint64_t align_bytes = get_memcpy_align(g, variable->type); @@ -2790,7 +2852,7 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { if (!else_prong) { LLVMPositionBuilderAtEnd(g->builder, else_block); add_debug_source_node(g, node); - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); } LLVMBuildUnreachable(g->builder); @@ -3383,6 +3445,10 @@ static void do_code_gen(CodeGen *g) { // Generate the list of test function pointers. if (g->is_test_build) { + if (g->test_fn_count == 0) { + fprintf(stderr, "No tests to run.\n"); + exit(0); + } assert(g->test_fn_count > 0); assert(next_test_index == g->test_fn_count); diff --git a/std/builtin.zig b/std/builtin.zig index 1c636cca27..0a712c7755 100644 --- a/std/builtin.zig +++ b/std/builtin.zig @@ -1,6 +1,7 @@ // These functions are provided when not linking against libc because LLVM // sometimes generates code that calls them. +#debug_safety(false) export fn memset(dest: &u8, c: u8, n: isize) -> &u8 { var index : @typeof(n) = 0; while (index != n) { @@ -10,6 +11,7 @@ export fn memset(dest: &u8, c: u8, n: isize) -> &u8 { return dest; } +#debug_safety(false) export fn memcpy(noalias dest: &u8, noalias src: &const u8, n: isize) -> &u8 { var index : @typeof(n) = 0; while (index != n) { diff --git a/test/run_tests.cpp b/test/run_tests.cpp index b0d3006c8d..fefec2637f 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -27,6 +27,7 @@ struct TestCase { ZigList program_args; bool is_parseh; bool is_self_hosted; + bool is_debug_safety; }; static ZigList test_cases = {0}; @@ -122,6 +123,55 @@ static TestCase *add_compile_fail_case(const char *case_name, const char *source return test_case; } +static void add_debug_safety_case(const char *case_name, const char *source) { + { + TestCase *test_case = allocate(1); + test_case->is_debug_safety = true; + test_case->case_name = buf_ptr(buf_sprintf("%s (debug)", case_name)); + test_case->source_files.resize(1); + test_case->source_files.at(0).relative_path = tmp_source_path; + test_case->source_files.at(0).source_code = source; + + test_case->compiler_args.append("build"); + test_case->compiler_args.append(tmp_source_path); + + test_case->compiler_args.append("--name"); + test_case->compiler_args.append("test"); + + test_case->compiler_args.append("--export"); + test_case->compiler_args.append("exe"); + + test_case->compiler_args.append("--output"); + test_case->compiler_args.append(tmp_exe_path); + + test_cases.append(test_case); + } + { + TestCase *test_case = allocate(1); + test_case->case_name = buf_ptr(buf_sprintf("%s (release)", case_name)); + test_case->source_files.resize(1); + test_case->source_files.at(0).relative_path = tmp_source_path; + test_case->source_files.at(0).source_code = source; + test_case->output = ""; + + test_case->compiler_args.append("build"); + test_case->compiler_args.append(tmp_source_path); + + test_case->compiler_args.append("--name"); + test_case->compiler_args.append("test"); + + test_case->compiler_args.append("--export"); + test_case->compiler_args.append("exe"); + + test_case->compiler_args.append("--output"); + test_case->compiler_args.append(tmp_exe_path); + + test_case->compiler_args.append("--release"); + + test_cases.append(test_case); + } +} + static TestCase *add_parseh_case(const char *case_name, const char *source, int count, ...) { va_list ap; va_start(ap, count); @@ -1247,6 +1297,22 @@ fn bar() -> i32 { 2 } )SOURCE", 1, ".tmp_source.zig:3:15: error: unable to infer expression type"); } +static void add_debug_safety_test_cases(void) { + add_debug_safety_case("out of bounds slice access", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + const a = []i32{1, 2, 3, 4}; + baz(bar(a)); +} +#static_eval_enable(false) +fn bar(a: []i32) -> i32 { + a[4] +} +#static_eval_enable(false) +fn baz(a: i32) {} + )SOURCE"); + +} + ////////////////////////////////////////////////////////////////////////////// static void add_parseh_test_cases(void) { @@ -1455,6 +1521,14 @@ static void print_compiler_invocation(TestCase *test_case) { printf("\n"); } +static void print_exe_invocation(TestCase *test_case) { + printf("%s", tmp_exe_path); + for (int i = 0; i < test_case->program_args.length; i += 1) { + printf(" %s", test_case->program_args.at(i)); + } + printf("\n"); +} + static void run_test(TestCase *test_case) { if (test_case->is_self_hosted) { return run_self_hosted_test(); @@ -1531,32 +1605,33 @@ static void run_test(TestCase *test_case) { Buf program_stdout = BUF_INIT; os_exec_process(tmp_exe_path, test_case->program_args, &return_code, &program_stderr, &program_stdout); - if (return_code != 0) { - printf("\nProgram exited with return code %d:\n", return_code); - print_compiler_invocation(test_case); - printf("%s", tmp_exe_path); - for (int i = 0; i < test_case->program_args.length; i += 1) { - printf(" %s", test_case->program_args.at(i)); + if (test_case->is_debug_safety) { + if (return_code == 0) { + printf("\nProgram expected to hit debug trap but exited with return code 0\n"); + print_compiler_invocation(test_case); + print_exe_invocation(test_case); + exit(1); + } + } else { + if (return_code != 0) { + printf("\nProgram exited with return code %d:\n", return_code); + print_compiler_invocation(test_case); + print_exe_invocation(test_case); + printf("%s\n", buf_ptr(&program_stderr)); + exit(1); } - printf("\n"); - printf("%s\n", buf_ptr(&program_stderr)); - exit(1); - } - if (!buf_eql_str(&program_stdout, test_case->output)) { - printf("\n"); - print_compiler_invocation(test_case); - printf("%s", tmp_exe_path); - for (int i = 0; i < test_case->program_args.length; i += 1) { - printf(" %s", test_case->program_args.at(i)); + if (!buf_eql_str(&program_stdout, test_case->output)) { + printf("\n"); + print_compiler_invocation(test_case); + print_exe_invocation(test_case); + printf("==== Test failed. Expected output: ====\n"); + printf("%s\n", test_case->output); + printf("========= Actual output: ==============\n"); + printf("%s\n", buf_ptr(&program_stdout)); + printf("=======================================\n"); + exit(1); } - printf("\n"); - printf("==== Test failed. Expected output: ====\n"); - printf("%s\n", test_case->output); - printf("========= Actual output: ==============\n"); - printf("%s\n", buf_ptr(&program_stdout)); - printf("=======================================\n"); - exit(1); } } @@ -1606,6 +1681,7 @@ int main(int argc, char **argv) { } } add_compiling_test_cases(); + add_debug_safety_test_cases(); add_compile_failure_test_cases(); add_parseh_test_cases(); add_self_hosted_tests();