From b8d17b11a7eba696200ab9b5819121f48ad123d1 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 5 May 2016 18:07:04 -0700 Subject: [PATCH] add tests for integer overflow crashing see #46 --- src/codegen.cpp | 70 ++++++++++++++++++++++++++-------------------- test/run_tests.cpp | 40 ++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/src/codegen.cpp b/src/codegen.cpp index ab45109312..8107b20845 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1097,7 +1097,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr); LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, ""); - LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, ""); + LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, ""); LLVMBuildStore(g->builder, len_value, len_field_ptr); return tmp_struct_ptr; @@ -1115,7 +1115,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr); LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, ""); - LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, ""); + LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, ""); LLVMBuildStore(g->builder, len_value, len_field_ptr); return tmp_struct_ptr; @@ -1160,7 +1160,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr); LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, len_index, ""); - LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, ""); + LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, ""); LLVMBuildStore(g->builder, len_value, len_field_ptr); return tmp_struct_ptr; @@ -1287,6 +1287,28 @@ static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node, return target_ref; } +static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddSubMul op, + LLVMValueRef val1, LLVMValueRef val2) +{ + LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op); + LLVMValueRef params[] = { + val1, + val2, + }; + LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, ""); + LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, ""); + LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk"); + LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_debug_safety_crash(g); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + return result; +} + static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypePrefixOpExpr); assert(node->data.prefix_op_expr.primary_expr); @@ -1300,12 +1322,20 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { case PrefixOpNegation: { LLVMValueRef expr = gen_expr(g, expr_node); - if (expr_type->id == TypeTableEntryIdInt) { - set_debug_source_node(g, node); - return LLVMBuildNeg(g->builder, expr, ""); - } else if (expr_type->id == TypeTableEntryIdFloat) { - set_debug_source_node(g, node); + set_debug_source_node(g, node); + if (expr_type->id == TypeTableEntryIdFloat) { return LLVMBuildFNeg(g->builder, expr, ""); + } else if (expr_type->id == TypeTableEntryIdInt) { + if (expr_type->data.integral.is_wrapping) { + return LLVMBuildNeg(g->builder, expr, ""); + } else if (want_debug_safety(g, expr_node)) { + LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(expr)); + return gen_overflow_op(g, expr_type, AddSubMulSub, zero, expr); + } else if (expr_type->data.integral.is_signed) { + return LLVMBuildNSWNeg(g->builder, expr, ""); + } else { + return LLVMBuildNUWNeg(g->builder, expr, ""); + } } else { zig_unreachable(); } @@ -1431,28 +1461,6 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } -static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddSubMul op, - LLVMValueRef val1, LLVMValueRef val2) -{ - LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op); - LLVMValueRef params[] = { - val1, - val2, - }; - LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, ""); - LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, ""); - LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); - LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail"); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk"); - LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block); - - LLVMPositionBuilderAtEnd(g->builder, fail_block); - gen_debug_safety_crash(g); - - LLVMPositionBuilderAtEnd(g->builder, ok_block); - return result; -} - static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node, LLVMValueRef val1, LLVMValueRef val2, TypeTableEntry *op1_type, TypeTableEntry *op2_type, @@ -2727,7 +2735,7 @@ static LLVMValueRef gen_for_expr(CodeGen *g, AstNode *node) { LLVMPositionBuilderAtEnd(g->builder, continue_block); set_debug_source_node(g, node); - LLVMValueRef new_index_val = LLVMBuildAdd(g->builder, index_val, one_const, ""); + LLVMValueRef new_index_val = LLVMBuildNSWAdd(g->builder, index_val, one_const, ""); LLVMBuildStore(g->builder, new_index_val, index_ptr); LLVMBuildBr(g->builder, cond_block); diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 9d5861d9e5..6c898533cb 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1323,6 +1323,46 @@ fn bar(a: []i32) -> i32 { fn baz(a: i32) {} )SOURCE"); + add_debug_safety_case("integer addition overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + add(65530, 10); +} +#static_eval_enable(false) +fn add(a: u16, b: u16) -> u16 { + a + b +} + )SOURCE"); + + add_debug_safety_case("integer subtraction overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + sub(10, 20); +} +#static_eval_enable(false) +fn sub(a: u16, b: u16) -> u16 { + a - b +} + )SOURCE"); + + add_debug_safety_case("integer multiplication overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + mul(300, 6000); +} +#static_eval_enable(false) +fn mul(a: u16, b: u16) -> u16 { + a * b +} + )SOURCE"); + + add_debug_safety_case("integer negation overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + neg(-32768); +} +#static_eval_enable(false) +fn neg(a: i16) -> i16 { + -a +} + )SOURCE"); + } //////////////////////////////////////////////////////////////////////////////