diff --git a/src/codegen.cpp b/src/codegen.cpp index 8107b20845..b3046e7473 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -440,7 +440,7 @@ static LLVMValueRef gen_cmp_exchange(CodeGen *g, AstNode *node) { LLVMAtomicOrdering failure_order = to_LLVMAtomicOrdering((AtomicOrder)failure_order_val->data.x_enum.tag); LLVMValueRef result_val = ZigLLVMBuildCmpXchg(g->builder, ptr_val, cmp_val, new_val, - success_order, failure_order, ""); + success_order, failure_order); return LLVMBuildExtractValue(g->builder, result_val, 1, ""); } @@ -1309,6 +1309,36 @@ static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddS return result; } +static LLVMValueRef gen_overflow_shl_op(CodeGen *g, TypeTableEntry *type_entry, + LLVMValueRef val1, LLVMValueRef val2) +{ + // for unsigned left shifting, we do the wrapping shift, then logically shift + // right the same number of bits + // if the values don't match, we have an overflow + // for signed left shifting we do the same except arithmetic shift right + + assert(type_entry->id == TypeTableEntryIdInt); + + LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, ""); + LLVMValueRef orig_val; + if (type_entry->data.integral.is_signed) { + orig_val = LLVMBuildAShr(g->builder, result, val2, ""); + } else { + orig_val = LLVMBuildLShr(g->builder, result, val2, ""); + } + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, orig_val, ""); + + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk"); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail"); + LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); + + LLVMPositionBuilderAtEnd(g->builder, fail_block); + gen_debug_safety_crash(g); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + return result; +} + static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypePrefixOpExpr); assert(node->data.prefix_op_expr.primary_expr); @@ -1484,7 +1514,16 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node, case BinOpTypeBitShiftLeft: case BinOpTypeAssignBitShiftLeft: set_debug_source_node(g, source_node); - return LLVMBuildShl(g->builder, val1, val2, ""); + assert(op1_type->id == TypeTableEntryIdInt); + if (op1_type->data.integral.is_wrapping) { + return LLVMBuildShl(g->builder, val1, val2, ""); + } else if (want_debug_safety(g, source_node)) { + return gen_overflow_shl_op(g, op1_type, val1, val2); + } else if (op1_type->data.integral.is_signed) { + return ZigLLVMBuildNSWShl(g->builder, val1, val2, ""); + } else { + return ZigLLVMBuildNUWShl(g->builder, val1, val2, ""); + } case BinOpTypeBitShiftRight: case BinOpTypeAssignBitShiftRight: assert(op1_type->id == TypeTableEntryIdInt); diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 4ef11f4ca2..bf59763cb2 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -661,14 +661,25 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) { LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp, LLVMValueRef new_val, LLVMAtomicOrdering success_ordering, - LLVMAtomicOrdering failure_ordering, - const char *name) + LLVMAtomicOrdering failure_ordering) { return wrap(unwrap(builder)->CreateAtomicCmpXchg(unwrap(ptr), unwrap(cmp), unwrap(new_val), mapFromLLVMOrdering(success_ordering), mapFromLLVMOrdering(failure_ordering), CrossThread)); } +LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, + const char *name) +{ + return wrap(unwrap(builder)->CreateShl(unwrap(LHS), unwrap(RHS), name, false, true)); +} + +LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, + const char *name) +{ + return wrap(unwrap(builder)->CreateShl(unwrap(LHS), unwrap(RHS), name, false, true)); +} + //------------------------------------ diff --git a/src/zig_llvm.hpp b/src/zig_llvm.hpp index f223bac117..6e824fc57c 100644 --- a/src/zig_llvm.hpp +++ b/src/zig_llvm.hpp @@ -41,7 +41,11 @@ LLVMValueRef LLVMZigBuildCall(LLVMBuilderRef B, LLVMValueRef Fn, LLVMValueRef *A LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp, LLVMValueRef new_val, LLVMAtomicOrdering success_ordering, - LLVMAtomicOrdering failure_ordering, + LLVMAtomicOrdering failure_ordering); + +LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, + const char *name); +LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char *name); // 0 is return value, 1 is first arg diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 6c898533cb..81911671b6 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1363,6 +1363,26 @@ fn neg(a: i16) -> i16 { } )SOURCE"); + add_debug_safety_case("signed shift left overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + shl(-16385, 1); +} +#static_eval_enable(false) +fn shl(a: i16, b: i16) -> i16 { + a << b +} + )SOURCE"); + + add_debug_safety_case("unsigned shift left overflow", R"SOURCE( +pub fn main(args: [][]u8) -> %void { + shl(0b0010111111111111, 3); +} +#static_eval_enable(false) +fn shl(a: u16, b: u16) -> u16 { + a << b +} + )SOURCE"); + } //////////////////////////////////////////////////////////////////////////////