ir: Fix error checking for vector ops

The extra logic that's needed was lost during a refactoring, now it
should be fine.
This commit is contained in:
LemonBoy 2020-04-05 10:40:41 +02:00 committed by Andrew Kelley
parent 0f964e1910
commit f6cdc94a50
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9

View File

@ -155,7 +155,6 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr,
LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type,
LLVMValueRef result_loc, bool non_async);
static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix);
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val);
static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) {
unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name));
@ -2536,6 +2535,36 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}
enum class ScalarizePredicate {
// Returns true iff all the elements in the vector are 1.
// Equivalent to folding all the bits with `and`.
All,
// Returns true iff there's at least one element in the vector that is 1.
// Equivalent to folding all the bits with `or`.
Any,
};
// Collapses a <N x i1> vector into a single i1 according to the given predicate
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
switch (predicate) {
case ScalarizePredicate::Any: {
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
}
case ScalarizePredicate::All: {
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
}
zig_unreachable();
}
static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMValueRef val1, LLVMValueRef val2)
{
@ -2560,7 +2589,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2591,7 +2620,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2647,16 +2676,6 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) {
}
}
// Collapses a <N x i1> vector into a single i1 whose value is 1 iff all the
// vector elements are 1
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math,
LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind)
{
@ -2678,7 +2697,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
}
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
@ -2703,7 +2722,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit);
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
@ -2728,7 +2747,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2745,7 +2764,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ltz = scalarize_cmp_result(g, ltz);
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
@ -2797,7 +2816,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit);
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -2861,7 +2880,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
}
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
@ -2918,7 +2937,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
if (rhs_type->id == ZigTypeIdVector) {
less_than_bit = scalarize_cmp_result(g, less_than_bit);
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);