From d2d97e55ccd2d7c992d01bd05ea52a52fe36776e Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Sat, 14 Mar 2020 20:01:28 +0100 Subject: [PATCH] ir: Support shift left/right on vectors --- src/codegen.cpp | 50 ++++++++++---- src/ir.cpp | 118 +++++++++++++++++++++++--------- test/stage1/behavior/vector.zig | 65 ++++++++++++++++++ 3 files changed, 184 insertions(+), 49 deletions(-) diff --git a/src/codegen.cpp b/src/codegen.cpp index 0fa181b32c..97d960b523 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -155,6 +155,7 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr, LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type, LLVMValueRef result_loc, bool non_async); static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix); +static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val); static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) { unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name)); @@ -2535,19 +2536,21 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir return nullptr; } -static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, - LLVMValueRef val1, LLVMValueRef val2) +static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, + LLVMValueRef val1, LLVMValueRef val2) { // for unsigned left shifting, we do the lossy shift, then logically shift // right the same number of bits // if the values don't match, we have an overflow // for signed left shifting we do the same except arithmetic shift right + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; - assert(type_entry->id == ZigTypeIdInt); + assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, ""); LLVMValueRef orig_val; - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { orig_val = LLVMBuildAShr(g->builder, result, val2, ""); } else { orig_val = LLVMBuildLShr(g->builder, result, val2, ""); @@ -2556,6 +2559,9 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2565,13 +2571,16 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, return result; } -static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, - LLVMValueRef val1, LLVMValueRef val2) +static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type, + LLVMValueRef val1, LLVMValueRef val2) { - assert(type_entry->id == ZigTypeIdInt); + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; + + assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef result; - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { result = LLVMBuildAShr(g->builder, val1, val2, ""); } else { result = LLVMBuildLShr(g->builder, val1, val2, ""); @@ -2581,6 +2590,9 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2897,11 +2909,17 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type // otherwise the check is useful as the allowed values are limited by the // operand type itself if (!is_power_of_2(lhs_type->data.integral.bit_count)) { - LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type), - lhs_type->data.integral.bit_count, false); - LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); + BigInt bit_count_bi = {0}; + bigint_init_unsigned(&bit_count_bi, lhs_type->data.integral.bit_count); + LLVMValueRef bit_count_value = bigint_to_llvm_const(get_llvm_type(g, rhs_type), + &bit_count_bi); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail"); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk"); + LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); + if (rhs_type->id == ZigTypeIdVector) { + less_than_bit = scalarize_cmp_result(g, less_than_bit); + } LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -3018,7 +3036,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, case IrBinOpBitShiftLeftExact: { assert(scalar_type->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value, + LLVMTypeOf(op1_value), "");//gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); if (want_runtime_safety) { gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); @@ -3028,7 +3047,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, if (is_sloppy) { return LLVMBuildShl(g->builder, op1_value, op2_casted, ""); } else if (want_runtime_safety) { - return gen_overflow_shl_op(g, scalar_type, op1_value, op2_casted); + return gen_overflow_shl_op(g, operand_type, op1_value, op2_casted); } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildNSWShl(g->builder, op1_value, op2_casted, ""); } else { @@ -3039,7 +3058,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, case IrBinOpBitShiftRightExact: { assert(scalar_type->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value, + LLVMTypeOf(op1_value), "");//gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); if (want_runtime_safety) { gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); @@ -3053,7 +3073,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, return LLVMBuildLShr(g->builder, op1_value, op2_casted, ""); } } else if (want_runtime_safety) { - return gen_overflow_shr_op(g, scalar_type, op1_value, op2_casted); + return gen_overflow_shr_op(g, operand_type, op1_value, op2_casted); } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildAShrExact(g->builder, op1_value, op2_casted, ""); } else { diff --git a/src/ir.cpp b/src/ir.cpp index 436db592f2..6fed044c6c 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -283,6 +283,8 @@ static IrInstGen *ir_analyze_union_init(IrAnalyze *ira, IrInst* source_instructi IrInstGen *result_loc); static IrInstGen *ir_analyze_struct_value_field_value(IrAnalyze *ira, IrInst* source_instr, IrInstGen *struct_operand, TypeStructField *field); +static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right); +static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right); static void destroy_instruction_src(IrInstSrc *inst) { switch (inst->id) { @@ -16803,7 +16805,6 @@ static IrInstGen *ir_analyze_math_op(IrAnalyze *ira, IrInst* source_instr, ZigValue *scalar_op2_val = &op2_val->data.x_array.data.s_none.elements[i]; ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i]; assert(scalar_op1_val->type == scalar_type); - assert(scalar_op2_val->type == scalar_type); assert(scalar_out_val->type == scalar_type); ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type, scalar_op1_val, op_id, scalar_op2_val, scalar_out_val); @@ -16828,27 +16829,49 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (type_is_invalid(op1->value->type)) return ira->codegen->invalid_inst_gen; - if (op1->value->type->id != ZigTypeIdInt && op1->value->type->id != ZigTypeIdComptimeInt) { - ir_add_error(ira, &bin_op_instruction->op1->base, - buf_sprintf("bit shifting operation expected integer type, found '%s'", - buf_ptr(&op1->value->type->name))); - return ira->codegen->invalid_inst_gen; - } - IrInstGen *op2 = bin_op_instruction->op2->child; if (type_is_invalid(op2->value->type)) return ira->codegen->invalid_inst_gen; - if (op2->value->type->id != ZigTypeIdInt && op2->value->type->id != ZigTypeIdComptimeInt) { + ZigType *op1_type = op1->value->type; + ZigType *op2_type = op2->value->type; + + if (op1_type->id == ZigTypeIdVector && op2_type->id != ZigTypeIdVector) { + ir_add_error(ira, &bin_op_instruction->op1->base, + buf_sprintf("bit shifting operation expected vector type, found '%s'", + buf_ptr(&op2_type->name))); + return ira->codegen->invalid_inst_gen; + } + + if (op1_type->id != ZigTypeIdVector && op2_type->id == ZigTypeIdVector) { + ir_add_error(ira, &bin_op_instruction->op1->base, + buf_sprintf("bit shifting operation expected vector type, found '%s'", + buf_ptr(&op1_type->name))); + return ira->codegen->invalid_inst_gen; + } + + ZigType *op1_scalar_type = (op1_type->id == ZigTypeIdVector) ? + op1_type->data.vector.elem_type : op1_type; + ZigType *op2_scalar_type = (op2_type->id == ZigTypeIdVector) ? + op2_type->data.vector.elem_type : op2_type; + + if (op1_scalar_type->id != ZigTypeIdInt && op1_scalar_type->id != ZigTypeIdComptimeInt) { + ir_add_error(ira, &bin_op_instruction->op1->base, + buf_sprintf("bit shifting operation expected integer type, found '%s'", + buf_ptr(&op1_scalar_type->name))); + return ira->codegen->invalid_inst_gen; + } + + if (op2_scalar_type->id != ZigTypeIdInt && op2_scalar_type->id != ZigTypeIdComptimeInt) { ir_add_error(ira, &bin_op_instruction->op2->base, buf_sprintf("shift amount has to be an integer type, but found '%s'", - buf_ptr(&op2->value->type->name))); + buf_ptr(&op2_scalar_type->name))); return ira->codegen->invalid_inst_gen; } IrInstGen *casted_op2; IrBinOp op_id = bin_op_instruction->op_id; - if (op1->value->type->id == ZigTypeIdComptimeInt) { + if (op1_scalar_type->id == ZigTypeIdComptimeInt) { // comptime_int has no finite bit width casted_op2 = op2; @@ -16874,10 +16897,15 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in return ira->codegen->invalid_inst_gen; } } else { - const unsigned bit_count = op1->value->type->data.integral.bit_count; + const unsigned bit_count = op1_scalar_type->data.integral.bit_count; ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen, bit_count > 0 ? bit_count - 1 : 0); + if (op1_type->id == ZigTypeIdVector) { + shift_amt_type = get_vector_type(ira->codegen, op1_type->data.vector.len, + shift_amt_type); + } + casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type); if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; @@ -16888,10 +16916,10 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - BigInt bit_count_value = {0}; - bigint_init_unsigned(&bit_count_value, bit_count); + ZigValue bit_count_value; + init_const_usize(ira->codegen, &bit_count_value, bit_count); - if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) { + if (!value_cmp_numeric_val_all(op2_val, CmpLT, &bit_count_value)) { ErrorMsg* msg = ir_add_error(ira, &bin_op_instruction->base.base, buf_sprintf("RHS of shift is too large for LHS type")); @@ -16910,7 +16938,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ) + if (value_cmp_numeric_val_all(op2_val, CmpEQ, nullptr)) return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1); } @@ -16923,7 +16951,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val); + return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1_type, op1_val, op_id, op2_val); } return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type, @@ -16991,31 +17019,53 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) { zig_unreachable(); } -static bool value_cmp_zero_any(ZigValue *value, Cmp predicate) { - assert(value->special == ConstValSpecialStatic); +static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) { + assert(left->special == ConstValSpecialStatic); + assert(right == nullptr || right->special == ConstValSpecialStatic); - switch (value->type->id) { + switch (left->type->id) { case ZigTypeIdComptimeInt: - case ZigTypeIdInt: - return bigint_cmp_zero(&value->data.x_bigint) == predicate; + case ZigTypeIdInt: { + const Cmp result = right ? + bigint_cmp(&left->data.x_bigint, &right->data.x_bigint) : + bigint_cmp_zero(&left->data.x_bigint); + return result == predicate; + } case ZigTypeIdComptimeFloat: - case ZigTypeIdFloat: - if (float_is_nan(value)) + case ZigTypeIdFloat: { + if (float_is_nan(left)) return false; - return float_cmp_zero(value) == predicate; + if (right != nullptr && float_is_nan(right)) + return false; + + const Cmp result = right ? float_cmp(left, right) : float_cmp_zero(left); + return result == predicate; + } case ZigTypeIdVector: { - for (size_t i = 0; i < value->type->data.vector.len; i++) { - ZigValue *scalar_val = &value->data.x_array.data.s_none.elements[i]; - if (!value_cmp_zero_any(scalar_val, predicate)) - return true; + for (size_t i = 0; i < left->type->data.vector.len; i++) { + ZigValue *scalar_val = &left->data.x_array.data.s_none.elements[i]; + const bool result = value_cmp_numeric_val(scalar_val, predicate, right, any); + + if (any && result) + return true; // This element satisfies the predicate + else if (!any && !result) + return false; // This element doesn't satisfy the predicate } - return false; + return any ? false : true; } default: zig_unreachable(); } } +static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right) { + return value_cmp_numeric_val(left, predicate, right, true); +} + +static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right) { + return value_cmp_numeric_val(left, predicate, right, false); +} + static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruction) { Error err; @@ -17165,8 +17215,8 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc return ira->codegen->invalid_inst_gen; // Promote division with negative numbers to signed - bool is_signed_div = value_cmp_zero_any(op1_val, CmpLT) || - value_cmp_zero_any(op2_val, CmpLT); + bool is_signed_div = value_cmp_numeric_val_any(op1_val, CmpLT, nullptr) || + value_cmp_numeric_val_any(op2_val, CmpLT, nullptr); if (op_id == IrBinOpDivUnspecified && is_int) { // Default to truncating division and check if it's valid for the @@ -17176,7 +17226,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc if (is_signed_div) { bool ok = false; - if (value_cmp_zero_any(op2_val, CmpEQ)) { + if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) { // the division by zero error will be caught later, but we don't have a // division function ambiguity problem. ok = true; @@ -17215,7 +17265,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc if (is_signed_div) { bool ok = false; - if (value_cmp_zero_any(op2_val, CmpEQ)) { + if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) { // the division by zero error will be caught later, but we don't have a // division function ambiguity problem. ok = true; diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index f242aa0fbf..f3bc334b84 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -1,5 +1,6 @@ const std = @import("std"); const mem = std.mem; +const math = std.math; const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; @@ -376,3 +377,67 @@ test "vector bitwise not operator" { S.doTheTest(); comptime S.doTheTest(); } + +test "vector shift operators" { + const S = struct { + fn doTheTestShift(x: var, y: var) void { + const N = @typeInfo(@TypeOf(x)).Array.len; + const TX = @typeInfo(@TypeOf(x)).Array.child; + const TY = @typeInfo(@TypeOf(y)).Array.child; + + var xv = @as(@Vector(N, TX), x); + var yv = @as(@Vector(N, TY), y); + + var z0 = xv >> yv; + for (@as([N]TX, z0)) |v, i| { + expectEqual(x[i] >> y[i], v); + } + var z1 = xv << yv; + for (@as([N]TX, z1)) |v, i| { + expectEqual(x[i] << y[i], v); + } + } + fn doTheTestShiftExact(x: var, y: var, dir: enum { Left, Right }) void { + const N = @typeInfo(@TypeOf(x)).Array.len; + const TX = @typeInfo(@TypeOf(x)).Array.child; + const TY = @typeInfo(@TypeOf(y)).Array.child; + + var xv = @as(@Vector(N, TX), x); + var yv = @as(@Vector(N, TY), y); + + var z = if (dir == .Left) @shlExact(xv, yv) else @shrExact(xv, yv); + for (@as([N]TX, z)) |v, i| { + const check = if (dir == .Left) x[i] << y[i] else x[i] >> y[i]; + expectEqual(check, v); + } + } + fn doTheTest() void { + doTheTestShift([_]u8{ 0, 2, 4, math.maxInt(u8) }, [_]u3{ 2, 0, 2, 7 }); + doTheTestShift([_]u16{ 0, 2, 4, math.maxInt(u16) }, [_]u4{ 2, 0, 2, 15 }); + doTheTestShift([_]u24{ 0, 2, 4, math.maxInt(u24) }, [_]u5{ 2, 0, 2, 23 }); + doTheTestShift([_]u32{ 0, 2, 4, math.maxInt(u32) }, [_]u5{ 2, 0, 2, 31 }); + doTheTestShift([_]u64{ 0xfe, math.maxInt(u64) }, [_]u6{ 0, 63 }); + + doTheTestShift([_]i8{ 0, 2, 4, math.maxInt(i8) }, [_]u3{ 2, 0, 2, 7 }); + doTheTestShift([_]i16{ 0, 2, 4, math.maxInt(i16) }, [_]u4{ 2, 0, 2, 7 }); + doTheTestShift([_]i24{ 0, 2, 4, math.maxInt(i24) }, [_]u5{ 2, 0, 2, 7 }); + doTheTestShift([_]i32{ 0, 2, 4, math.maxInt(i32) }, [_]u5{ 2, 0, 2, 7 }); + doTheTestShift([_]i64{ 0xfe, math.maxInt(i64) }, [_]u6{ 0, 63 }); + + doTheTestShiftExact([_]u8{ 0, 1, 1 << 7, math.maxInt(u8) ^ 1 }, [_]u3{ 4, 0, 7, 1 }, .Right); + doTheTestShiftExact([_]u16{ 0, 1, 1 << 15, math.maxInt(u16) ^ 1 }, [_]u4{ 4, 0, 15, 1 }, .Right); + doTheTestShiftExact([_]u24{ 0, 1, 1 << 23, math.maxInt(u24) ^ 1 }, [_]u5{ 4, 0, 23, 1 }, .Right); + doTheTestShiftExact([_]u32{ 0, 1, 1 << 31, math.maxInt(u32) ^ 1 }, [_]u5{ 4, 0, 31, 1 }, .Right); + doTheTestShiftExact([_]u64{ 1 << 63, 1 }, [_]u6{ 63, 0 }, .Right); + + doTheTestShiftExact([_]u8{ 0, 1, 1, math.maxInt(u8) ^ (1 << 7) }, [_]u3{ 4, 0, 7, 1 }, .Left); + doTheTestShiftExact([_]u16{ 0, 1, 1, math.maxInt(u16) ^ (1 << 15) }, [_]u4{ 4, 0, 15, 1 }, .Left); + doTheTestShiftExact([_]u24{ 0, 1, 1, math.maxInt(u24) ^ (1 << 23) }, [_]u5{ 4, 0, 23, 1 }, .Left); + doTheTestShiftExact([_]u32{ 0, 1, 1, math.maxInt(u32) ^ (1 << 31) }, [_]u5{ 4, 0, 31, 1 }, .Left); + doTheTestShiftExact([_]u64{ 1 << 63, 1 }, [_]u6{ 0, 63 }, .Left); + } + }; + + S.doTheTest(); + comptime S.doTheTest(); +}