From 404defd99b76ebf2cfe46ea26248a8813e40136f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 7 May 2016 20:53:16 -0700 Subject: [PATCH] add div_exact builtin fn closes #149 --- src/all_types.hpp | 1 + src/analyze.cpp | 33 ++++++++++++++++++++++++++++ src/codegen.cpp | 52 +++++++++++++++++++++++++++++++++++++++++--- src/eval.cpp | 42 +++++++++++++++++++++++++++++++++++ src/zig_llvm.cpp | 7 ++++++ src/zig_llvm.hpp | 2 ++ test/run_tests.cpp | 10 +++++++++ test/self_hosted.zig | 9 ++++++++ 8 files changed, 153 insertions(+), 3 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 82603f6077..daf37e5956 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1128,6 +1128,7 @@ enum BuiltinFnId { BuiltinFnIdEmbedFile, BuiltinFnIdCmpExchange, BuiltinFnIdFence, + BuiltinFnIdDivExact, }; struct BuiltinFnEntry { diff --git a/src/analyze.cpp b/src/analyze.cpp index 5832998900..4bc3022765 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4657,6 +4657,37 @@ static TypeTableEntry *analyze_fence(CodeGen *g, ImportTableEntry *import, return g->builtin_types.entry_void; } +static TypeTableEntry *analyze_div_exact(CodeGen *g, ImportTableEntry *import, + BlockContext *context, AstNode *node) +{ + assert(node->type == NodeTypeFnCallExpr); + + AstNode **op1 = &node->data.fn_call_expr.params.at(0); + AstNode **op2 = &node->data.fn_call_expr.params.at(1); + + TypeTableEntry *op1_type = analyze_expression(g, import, context, nullptr, *op1); + TypeTableEntry *op2_type = analyze_expression(g, import, context, nullptr, *op2); + + AstNode *op_nodes[] = {*op1, *op2}; + TypeTableEntry *op_types[] = {op1_type, op2_type}; + TypeTableEntry *result_type = resolve_peer_type_compatibility(g, import, context, node, + op_nodes, op_types, 2); + + if (result_type->id == TypeTableEntryIdInvalid) { + return g->builtin_types.entry_invalid; + } else if (result_type->id == TypeTableEntryIdInt) { + return result_type; + } else if (result_type->id == TypeTableEntryIdNumLitInt) { + // check for division by zero + // check for non exact division + zig_panic("TODO"); + } else { + add_node_error(g, node, + buf_sprintf("expected integer type, got '%s'", buf_ptr(&result_type->name))); + return g->builtin_types.entry_invalid; + } +} + static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -4997,6 +5028,8 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry return analyze_cmpxchg(g, import, context, node); case BuiltinFnIdFence: return analyze_fence(g, import, context, node); + case BuiltinFnIdDivExact: + return analyze_div_exact(g, import, context, node); } zig_unreachable(); } diff --git a/src/codegen.cpp b/src/codegen.cpp index 79601a3865..11eebcd996 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -217,6 +217,8 @@ static LLVMValueRef gen_assign_raw(CodeGen *g, AstNode *source_node, BinOpType b LLVMValueRef target_ref, LLVMValueRef value, TypeTableEntry *op1_type, TypeTableEntry *op2_type); static LLVMValueRef gen_unwrap_maybe(CodeGen *g, AstNode *node, LLVMValueRef maybe_struct_ref); +static LLVMValueRef gen_div(CodeGen *g, AstNode *source_node, LLVMValueRef val1, LLVMValueRef val2, + TypeTableEntry *type_entry, bool exact); static TypeTableEntry *get_type_for_type_node(AstNode *node) { Expr *expr = get_resolved_expr(node); @@ -459,6 +461,18 @@ static LLVMValueRef gen_fence(CodeGen *g, AstNode *node) { return nullptr; } +static LLVMValueRef gen_div_exact(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeFnCallExpr); + + AstNode *op1_node = node->data.fn_call_expr.params.at(0); + AstNode *op2_node = node->data.fn_call_expr.params.at(1); + + LLVMValueRef op1_val = gen_expr(g, op1_node); + LLVMValueRef op2_val = gen_expr(g, op2_node); + + return gen_div(g, node, op1_val, op2_val, get_expr_type(op1_node), true); +} + static LLVMValueRef gen_shl_with_overflow(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); @@ -645,6 +659,8 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { return gen_cmp_exchange(g, node); case BuiltinFnIdFence: return gen_fence(g, node); + case BuiltinFnIdDivExact: + return gen_div_exact(g, node); } zig_unreachable(); } @@ -1566,7 +1582,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { } static LLVMValueRef gen_div(CodeGen *g, AstNode *source_node, LLVMValueRef val1, LLVMValueRef val2, - TypeTableEntry *type_entry) + TypeTableEntry *type_entry, bool exact) { set_debug_source_node(g, source_node); @@ -1591,9 +1607,38 @@ static LLVMValueRef gen_div(CodeGen *g, AstNode *source_node, LLVMValueRef val1, } if (type_entry->id == TypeTableEntryIdFloat) { + assert(!exact); return LLVMBuildFDiv(g->builder, val1, val2, ""); + } + + assert(type_entry->id == TypeTableEntryIdInt); + + if (exact) { + if (want_debug_safety(g, source_node)) { + LLVMValueRef remainder_val; + if (type_entry->data.integral.is_signed) { + remainder_val = LLVMBuildSRem(g->builder, val1, val2, ""); + } else { + remainder_val = LLVMBuildURem(g->builder, val1, val2, ""); + } + LLVMValueRef zero = LLVMConstNull(type_entry->type_ref); + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); + + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "DivExactOk"); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "DivExactFail"); + LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_debug_safety_crash(g); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + if (type_entry->data.integral.is_signed) { + return LLVMBuildExactSDiv(g->builder, val1, val2, ""); + } else { + return ZigLLVMBuildExactUDiv(g->builder, val1, val2, ""); + } } else { - assert(type_entry->id == TypeTableEntryIdInt); if (type_entry->data.integral.is_signed) { return LLVMBuildSDiv(g->builder, val1, val2, ""); } else { @@ -1702,7 +1747,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node, } case BinOpTypeDiv: case BinOpTypeAssignDiv: - return gen_div(g, source_node, val1, val2, op1_type); + return gen_div(g, source_node, val1, val2, op1_type, false); case BinOpTypeMod: case BinOpTypeAssignMod: set_debug_source_node(g, source_node); @@ -4492,6 +4537,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn_with_arg_count(g, BuiltinFnIdEmbedFile, "embed_file", 1); create_builtin_fn_with_arg_count(g, BuiltinFnIdCmpExchange, "cmpxchg", 5); create_builtin_fn_with_arg_count(g, BuiltinFnIdFence, "fence", 1); + create_builtin_fn_with_arg_count(g, BuiltinFnIdDivExact, "div_exact", 2); } static void init(CodeGen *g, Buf *source_path) { diff --git a/src/eval.cpp b/src/eval.cpp index 56bf6a8577..b31cc47daa 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -709,6 +709,46 @@ static bool eval_min_max(EvalFn *ef, AstNode *node, ConstExprValue *out_val, boo return false; } +static bool eval_div_exact(EvalFn *ef, AstNode *node, ConstExprValue *out_val) { + assert(node->type == NodeTypeFnCallExpr); + AstNode *op1_node = node->data.fn_call_expr.params.at(0); + AstNode *op2_node = node->data.fn_call_expr.params.at(1); + + TypeTableEntry *type_entry = get_resolved_expr(op1_node)->type_entry; + assert(type_entry->id == TypeTableEntryIdInt); + + ConstExprValue op1_val = {0}; + if (eval_expr(ef, op1_node, &op1_val)) return true; + + ConstExprValue op2_val = {0}; + if (eval_expr(ef, op2_node, &op2_val)) return true; + + if (op2_val.data.x_bignum.data.x_uint == 0) { + ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node, + buf_sprintf("function evaluation caused division by zero")); + add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here")); + add_error_note(ef->root->codegen, msg, node, buf_sprintf("division by zero here")); + return true; + } + + bignum_div(&out_val->data.x_bignum, &op1_val.data.x_bignum, &op2_val.data.x_bignum); + + BigNum orig_bn; + bignum_mul(&orig_bn, &out_val->data.x_bignum, &op2_val.data.x_bignum); + + if (bignum_cmp_neq(&orig_bn, &op1_val.data.x_bignum)) { + ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node, + buf_sprintf("function evaluation violated exact division")); + add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here")); + add_error_note(ef->root->codegen, msg, node, buf_sprintf("exact division violation here")); + return true; + } + + out_val->ok = true; + out_val->depends_on_compile_var = op1_val.depends_on_compile_var || op2_val.depends_on_compile_var; + return false; +} + static bool eval_fn_with_overflow(EvalFn *ef, AstNode *node, ConstExprValue *out_val, bool (*bignum_fn)(BigNum *dest, BigNum *op1, BigNum *op2)) { @@ -767,6 +807,8 @@ static bool eval_fn_call_builtin(EvalFn *ef, AstNode *node, ConstExprValue *out_ return eval_fn_with_overflow(ef, node, out_val, bignum_shl); case BuiltinFnIdFence: return false; + case BuiltinFnIdDivExact: + return eval_div_exact(ef, node, out_val); case BuiltinFnIdMemcpy: case BuiltinFnIdMemset: case BuiltinFnIdSizeof: diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index bf59763cb2..3e828db144 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -680,6 +680,13 @@ LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMVa return wrap(unwrap(builder)->CreateShl(unwrap(LHS), unwrap(RHS), name, false, true)); } +LLVMValueRef ZigLLVMBuildExactUDiv(LLVMBuilderRef B, LLVMValueRef LHS, + LLVMValueRef RHS, const char *Name) +{ + return wrap(unwrap(B)->CreateExactUDiv(unwrap(LHS), unwrap(RHS), Name)); +} + + //------------------------------------ diff --git a/src/zig_llvm.hpp b/src/zig_llvm.hpp index 6e824fc57c..444081ce98 100644 --- a/src/zig_llvm.hpp +++ b/src/zig_llvm.hpp @@ -47,6 +47,8 @@ LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMVa const char *name); LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char *name); +LLVMValueRef ZigLLVMBuildExactUDiv(LLVMBuilderRef B, LLVMValueRef LHS, + LLVMValueRef RHS, const char *Name); // 0 is return value, 1 is first arg void LLVMZigAddNonNullAttr(LLVMValueRef fn, unsigned i); diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 44b2c93e83..74926d5a38 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1456,6 +1456,16 @@ fn div0(a: i32, b: i32) -> i32 { } )SOURCE"); + add_debug_safety_case("exact division failure", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + div_exact(10, 3); +} +#static_eval_enable(false) +fn div_exact(a: i32, b: i32) -> i32 { + @div_exact(a, b) +} + )SOURCE"); + } ////////////////////////////////////////////////////////////////////////////// diff --git a/test/self_hosted.zig b/test/self_hosted.zig index 34577e570e..6d8ccc9882 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -1603,3 +1603,12 @@ fn float_division() { fn fdiv32(a: f32, b: f32) -> f32 { a / b } + +#attribute("test") +fn exact_division() { + assert(div_exact(55, 11) == 5); +} +#static_eval_enable(false) +fn div_exact(a: u32, b: u32) -> u32 { + @div_exact(a, b) +}