From 63d37b7cff0907cdf2361f1d61f19410fd6cc626 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 14 Feb 2017 01:08:30 -0500 Subject: [PATCH] add runtime debug safety for dividing integer min value by -1 closes #260 --- src/analyze.cpp | 2 +- src/analyze.hpp | 1 + src/codegen.cpp | 26 +++++++++++++++++++++----- test/run_tests.cpp | 17 ++++++++++++++++- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/analyze.cpp b/src/analyze.cpp index 1fa3d063bb..f567502c29 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3447,7 +3447,7 @@ static int64_t max_signed_val(TypeTableEntry *type_entry) { } } -static int64_t min_signed_val(TypeTableEntry *type_entry) { +int64_t min_signed_val(TypeTableEntry *type_entry) { assert(type_entry->id == TypeTableEntryIdInt); if (type_entry->data.integral.bit_count == 64) { return INT64_MIN; diff --git a/src/analyze.hpp b/src/analyze.hpp index 012e6d19e7..3ef0696be3 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -81,6 +81,7 @@ void complete_enum(CodeGen *g, TypeTableEntry *enum_type); bool ir_get_var_is_comptime(VariableTableEntry *var); bool const_values_equal(ConstExprValue *a, ConstExprValue *b); void eval_min_max_value(CodeGen *g, TypeTableEntry *type_entry, ConstExprValue *const_val, bool is_max); +int64_t min_signed_val(TypeTableEntry *type_entry); void render_const_value(Buf *buf, ConstExprValue *const_val); void define_local_param_variables(CodeGen *g, FnTableEntry *fn_table_entry, VariableTableEntry **arg_vars); diff --git a/src/codegen.cpp b/src/codegen.cpp index b98d549e84..df4b6fcf7c 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -846,14 +846,30 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_debug_safety, LLVMValueRef val } else { zig_unreachable(); } - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk"); - LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail"); - LLVMBuildCondBr(g->builder, is_zero_bit, fail_block, ok_block); + LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk"); + LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail"); + LLVMBuildCondBr(g->builder, is_zero_bit, div_zero_fail_block, div_zero_ok_block); - LLVMPositionBuilderAtEnd(g->builder, fail_block); + LLVMPositionBuilderAtEnd(g->builder, div_zero_fail_block); gen_debug_safety_crash(g, PanicMsgIdDivisionByZero); - LLVMPositionBuilderAtEnd(g->builder, ok_block); + LLVMPositionBuilderAtEnd(g->builder, div_zero_ok_block); + + if (type_entry->id == TypeTableEntryIdInt && type_entry->data.integral.is_signed) { + LLVMValueRef neg_1_value = LLVMConstInt(type_entry->type_ref, -1, true); + LLVMValueRef int_min_value = LLVMConstInt(type_entry->type_ref, min_signed_val(type_entry), true); + LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk"); + LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowFail"); + LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, ""); + LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, ""); + LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, ""); + LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block); + + LLVMPositionBuilderAtEnd(g->builder, overflow_fail_block); + gen_debug_safety_crash(g, PanicMsgIdIntegerOverflow); + + LLVMPositionBuilderAtEnd(g->builder, overflow_ok_block); + } } if (type_entry->id == TypeTableEntryIdFloat) { diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 1614c1cbdc..5fd49d9105 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1700,13 +1700,28 @@ pub fn panic(message: []const u8) -> unreachable { error Whatever; pub fn main(args: [][]u8) -> %void { const x = neg(-32768); - if (x == 0) return error.Whatever; + if (x == 32767) return error.Whatever; } fn neg(a: i16) -> i16 { -a } )SOURCE"); + add_debug_safety_case("signed integer division overflow", R"SOURCE( +pub fn panic(message: []const u8) -> unreachable { + @breakpoint(); + while (true) {} +} +error Whatever; +pub fn main(args: [][]u8) -> %void { + const x = div(-32768, -1); + if (x == 32767) return error.Whatever; +} +fn div(a: i16, b: i16) -> i16 { + a / b +} + )SOURCE"); + add_debug_safety_case("signed shift left overflow", R"SOURCE( pub fn panic(message: []const u8) -> unreachable { @breakpoint();