From 0a7bdc00771dbad1dfe5eb93a7cade89059d227a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 9 Feb 2019 14:44:33 -0500 Subject: [PATCH] implement vector addition with safety checking this would work if @llvm.sadd.with.overflow supported vectors, which it does in trunk. but it does not support them in llvm 7 or even in llvm 8 release branch. so the next commit after this will have to do a different strategy, but when llvm 9 comes out it may be worth coming back to this one. --- src/all_types.hpp | 3 + src/analyze.cpp | 6 +- src/codegen.cpp | 154 ++++++++++++++++++++++++++-------------------- 3 files changed, 95 insertions(+), 68 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 842c9ae904..908c0e327c 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1538,6 +1538,8 @@ enum ZigLLVMFnId { ZigLLVMFnIdBitReverse, }; +// There are a bunch of places in code that rely on these values being in +// exactly this order. enum AddSubMul { AddSubMulAdd = 0, AddSubMulSub = 1, @@ -1563,6 +1565,7 @@ struct ZigLLVMFnKey { struct { AddSubMul add_sub_mul; uint32_t bit_count; + uint32_t vector_len; // 0 means not a vector bool is_signed; } overflow_arithmetic; struct { diff --git a/src/analyze.cpp b/src/analyze.cpp index 83a576554a..0c493ebda1 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -6361,7 +6361,8 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { case ZigLLVMFnIdOverflowArithmetic: return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) + ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) + - ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 1062315172 : 314955820); + ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 1062315172 : 314955820) + + x.data.overflow_arithmetic.vector_len * 1435156945; } zig_unreachable(); } @@ -6387,7 +6388,8 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { case ZigLLVMFnIdOverflowArithmetic: return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) && (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) && - (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed); + (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed) && + (a.data.overflow_arithmetic.vector_len == b.data.overflow_arithmetic.vector_len); } zig_unreachable(); } diff --git a/src/codegen.cpp b/src/codegen.cpp index 3bfd7cdfc5..e45280b0d1 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -715,38 +715,59 @@ static void clear_debug_source_node(CodeGen *g) { ZigLLVMClearCurrentDebugLocation(g->builder); } -static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, ZigType *type_entry, +static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, ZigType *operand_type, const char *signed_name, const char *unsigned_name) { + ZigType *int_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type; char fn_name[64]; - assert(type_entry->id == ZigTypeIdInt); - const char *signed_str = type_entry->data.integral.is_signed ? signed_name : unsigned_name; - sprintf(fn_name, "llvm.%s.with.overflow.i%" PRIu32, signed_str, type_entry->data.integral.bit_count); + assert(int_type->id == ZigTypeIdInt); + const char *signed_str = int_type->data.integral.is_signed ? signed_name : unsigned_name; - LLVMTypeRef return_elem_types[] = { - type_entry->type_ref, - LLVMInt1Type(), - }; LLVMTypeRef param_types[] = { - type_entry->type_ref, - type_entry->type_ref, + operand_type->type_ref, + operand_type->type_ref, }; - LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false); - LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false); - LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); - assert(LLVMGetIntrinsicID(fn_val)); - return fn_val; + + if (operand_type->id == ZigTypeIdVector) { + sprintf(fn_name, "llvm.%s.with.overflow.v%" PRIu32 "i%" PRIu32, signed_str, + operand_type->data.vector.len, int_type->data.integral.bit_count); + + LLVMTypeRef return_elem_types[] = { + operand_type->type_ref, + LLVMVectorType(LLVMInt1Type(), operand_type->data.vector.len), + }; + LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false); + LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false); + LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); + assert(LLVMGetIntrinsicID(fn_val)); + return fn_val; + } else { + sprintf(fn_name, "llvm.%s.with.overflow.i%" PRIu32, signed_str, int_type->data.integral.bit_count); + + LLVMTypeRef return_elem_types[] = { + operand_type->type_ref, + LLVMInt1Type(), + }; + LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false); + LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false); + LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); + assert(LLVMGetIntrinsicID(fn_val)); + return fn_val; + } } -static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *type_entry, AddSubMul add_sub_mul) { - assert(type_entry->id == ZigTypeIdInt); +static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSubMul add_sub_mul) { + ZigType *int_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type; + assert(int_type->id == ZigTypeIdInt); ZigLLVMFnKey key = {}; key.id = ZigLLVMFnIdOverflowArithmetic; - key.data.overflow_arithmetic.is_signed = type_entry->data.integral.is_signed; + key.data.overflow_arithmetic.is_signed = int_type->data.integral.is_signed; key.data.overflow_arithmetic.add_sub_mul = add_sub_mul; - key.data.overflow_arithmetic.bit_count = (uint32_t)type_entry->data.integral.bit_count; + key.data.overflow_arithmetic.bit_count = (uint32_t)int_type->data.integral.bit_count; + key.data.overflow_arithmetic.vector_len = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.len : 0; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) @@ -755,13 +776,13 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *type_entry, AddSubM LLVMValueRef fn_val; switch (add_sub_mul) { case AddSubMulAdd: - fn_val = get_arithmetic_overflow_fn(g, type_entry, "sadd", "uadd"); + fn_val = get_arithmetic_overflow_fn(g, operand_type, "sadd", "uadd"); break; case AddSubMulSub: - fn_val = get_arithmetic_overflow_fn(g, type_entry, "ssub", "usub"); + fn_val = get_arithmetic_overflow_fn(g, operand_type, "ssub", "usub"); break; case AddSubMulMul: - fn_val = get_arithmetic_overflow_fn(g, type_entry, "smul", "umul"); + fn_val = get_arithmetic_overflow_fn(g, operand_type, "smul", "umul"); break; } @@ -1752,17 +1773,28 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z } } -static LLVMValueRef gen_overflow_op(CodeGen *g, ZigType *type_entry, AddSubMul op, +static LLVMValueRef gen_overflow_op(CodeGen *g, ZigType *operand_type, AddSubMul op, LLVMValueRef val1, LLVMValueRef val2) { - LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op); + LLVMValueRef fn_val = get_int_overflow_fn(g, operand_type, op); LLVMValueRef params[] = { val1, val2, }; LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, ""); LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, ""); - LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); + + LLVMValueRef overflow_bit; + if (operand_type->id == ZigTypeIdVector) { + LLVMValueRef overflow_vector = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); + LLVMTypeRef bigger_int_type_ref = LLVMIntType(operand_type->data.vector.len); + LLVMValueRef bitcasted_overflow = LLVMBuildBitCast(g->builder, overflow_vector, bigger_int_type_ref, ""); + LLVMValueRef zero = LLVMConstNull(bigger_int_type_ref); + overflow_bit = LLVMBuildICmp(g->builder, LLVMIntNE, bitcasted_overflow, zero, ""); + } else { + overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, ""); + } + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block); @@ -2608,7 +2640,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, (op_id == IrBinOpAdd || op_id == IrBinOpSub) && op1->value.type->data.pointer.ptr_len == PtrLenUnknown) ); - ZigType *type_entry = op1->value.type; + ZigType *operand_type = op1->value.type; + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type; bool want_runtime_safety = bin_op_instruction->safety_check_on && ir_want_runtime_safety(g, &bin_op_instruction->base); @@ -2634,17 +2667,17 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, case IrBinOpCmpGreaterThan: case IrBinOpCmpLessOrEq: case IrBinOpCmpGreaterOrEq: - if (type_entry->id == ZigTypeIdFloat) { + if (scalar_type->id == ZigTypeIdFloat) { ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &bin_op_instruction->base)); LLVMRealPredicate pred = cmp_op_to_real_predicate(op_id); return LLVMBuildFCmp(g->builder, pred, op1_value, op2_value, ""); - } else if (type_entry->id == ZigTypeIdInt) { - LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, type_entry->data.integral.is_signed); + } else if (scalar_type->id == ZigTypeIdInt) { + LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, scalar_type->data.integral.is_signed); return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, ""); - } else if (type_entry->id == ZigTypeIdEnum || - type_entry->id == ZigTypeIdErrorSet || - type_entry->id == ZigTypeIdBool || - get_codegen_ptr_type(type_entry) != nullptr) + } else if (scalar_type->id == ZigTypeIdEnum || + scalar_type->id == ZigTypeIdErrorSet || + scalar_type->id == ZigTypeIdBool || + get_codegen_ptr_type(scalar_type) != nullptr) { LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, false); return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, ""); @@ -2665,23 +2698,16 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, static const BuildBinOpFunc signed_op[3] = { LLVMBuildNSWAdd, LLVMBuildNSWSub, LLVMBuildNSWMul }; static const BuildBinOpFunc unsigned_op[3] = { LLVMBuildNUWAdd, LLVMBuildNUWSub, LLVMBuildNUWMul }; - bool is_vector = type_entry->id == ZigTypeIdVector; bool is_wrapping = (op_id == IrBinOpSubWrap || op_id == IrBinOpAddWrap || op_id == IrBinOpMultWrap); AddSubMul add_sub_mul = op_id == IrBinOpAdd || op_id == IrBinOpAddWrap ? AddSubMulAdd : op_id == IrBinOpSub || op_id == IrBinOpSubWrap ? AddSubMulSub : AddSubMulMul; - // The code that is generated for vectors and scalars are the same, - // so we can just set type_entry to the vectors elem_type an avoid - // a lot of repeated code. - if (is_vector) - type_entry = type_entry->data.vector.elem_type; - - if (type_entry->id == ZigTypeIdPointer) { - assert(type_entry->data.pointer.ptr_len == PtrLenUnknown); + if (scalar_type->id == ZigTypeIdPointer) { + assert(scalar_type->data.pointer.ptr_len == PtrLenUnknown); LLVMValueRef subscript_value; - if (is_vector) + if (operand_type->id == ZigTypeIdVector) zig_panic("TODO: Implement vector operations on pointers."); switch (add_sub_mul) { @@ -2697,17 +2723,15 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, // TODO runtime safety return LLVMBuildInBoundsGEP(g->builder, op1_value, &subscript_value, 1, ""); - } else if (type_entry->id == ZigTypeIdFloat) { + } else if (scalar_type->id == ZigTypeIdFloat) { ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &bin_op_instruction->base)); return float_op[add_sub_mul](g->builder, op1_value, op2_value, ""); - } else if (type_entry->id == ZigTypeIdInt) { + } else if (scalar_type->id == ZigTypeIdInt) { if (is_wrapping) { return wrap_op[add_sub_mul](g->builder, op1_value, op2_value, ""); } else if (want_runtime_safety) { - if (is_vector) - zig_panic("TODO: Implement runtime safety vector operations."); - return gen_overflow_op(g, type_entry, add_sub_mul, op1_value, op2_value); - } else if (type_entry->data.integral.is_signed) { + return gen_overflow_op(g, operand_type, add_sub_mul, op1_value, op2_value); + } else if (scalar_type->data.integral.is_signed) { return signed_op[add_sub_mul](g->builder, op1_value, op2_value, ""); } else { return unsigned_op[add_sub_mul](g->builder, op1_value, op2_value, ""); @@ -2725,15 +2749,14 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, case IrBinOpBitShiftLeftLossy: case IrBinOpBitShiftLeftExact: { - assert(type_entry->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, - type_entry, op2_value); + assert(scalar_type->id == ZigTypeIdInt); + LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, scalar_type, op2_value); bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy); if (is_sloppy) { return LLVMBuildShl(g->builder, op1_value, op2_casted, ""); } else if (want_runtime_safety) { - return gen_overflow_shl_op(g, type_entry, op1_value, op2_casted); - } else if (type_entry->data.integral.is_signed) { + return gen_overflow_shl_op(g, scalar_type, op1_value, op2_casted); + } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildNSWShl(g->builder, op1_value, op2_casted, ""); } else { return ZigLLVMBuildNUWShl(g->builder, op1_value, op2_casted, ""); @@ -2742,19 +2765,18 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, case IrBinOpBitShiftRightLossy: case IrBinOpBitShiftRightExact: { - assert(type_entry->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, - type_entry, op2_value); + assert(scalar_type->id == ZigTypeIdInt); + LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, scalar_type, op2_value); bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy); if (is_sloppy) { - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { return LLVMBuildAShr(g->builder, op1_value, op2_casted, ""); } else { return LLVMBuildLShr(g->builder, op1_value, op2_casted, ""); } } else if (want_runtime_safety) { - return gen_overflow_shr_op(g, type_entry, op1_value, op2_casted); - } else if (type_entry->data.integral.is_signed) { + return gen_overflow_shr_op(g, scalar_type, op1_value, op2_casted); + } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildAShrExact(g->builder, op1_value, op2_casted, ""); } else { return ZigLLVMBuildLShrExact(g->builder, op1_value, op2_casted, ""); @@ -2762,22 +2784,22 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, } case IrBinOpDivUnspecified: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, DivKindFloat); + op1_value, op2_value, scalar_type, DivKindFloat); case IrBinOpDivExact: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, DivKindExact); + op1_value, op2_value, scalar_type, DivKindExact); case IrBinOpDivTrunc: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, DivKindTrunc); + op1_value, op2_value, scalar_type, DivKindTrunc); case IrBinOpDivFloor: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, DivKindFloor); + op1_value, op2_value, scalar_type, DivKindFloor); case IrBinOpRemRem: return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, RemKindRem); + op1_value, op2_value, scalar_type, RemKindRem); case IrBinOpRemMod: return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, type_entry, RemKindMod); + op1_value, op2_value, scalar_type, RemKindMod); } zig_unreachable(); }