From 15f111a085dcd2da7b91da8e29b89def24cb3c6a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 3 Jul 2022 13:07:23 -0700 Subject: [PATCH] LLVM: update lowering of saturating shift-left LLVM 14 makes it so that a RHS of saturating shift left produces a poison value if the value is greater than the number of bits of the LHS. Zig now emits code that will check if this is the case and select a saturated LHS value in such case, matching Zig semantics. --- src/codegen/llvm.zig | 32 +++++++++++++++++++++++++++----- src/stage1/codegen.cpp | 33 +++++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index b755ee7d2d..902e4797b8 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -6817,15 +6817,37 @@ pub const FuncGen = struct { const rhs_ty = self.air.typeOf(bin_op.rhs); const lhs_scalar_ty = lhs_ty.scalarType(); const rhs_scalar_ty = rhs_ty.scalarType(); - const tg = self.dg.module.getTarget(); + const lhs_bits = lhs_scalar_ty.bitSize(tg); - const casted_rhs = if (rhs_scalar_ty.bitSize(tg) < lhs_scalar_ty.bitSize(tg)) - self.builder.buildZExt(rhs, try self.dg.lowerType(lhs_ty), "") + const casted_rhs = if (rhs_scalar_ty.bitSize(tg) < lhs_bits) + self.builder.buildZExt(rhs, lhs.typeOf(), "") else rhs; - if (lhs_scalar_ty.isSignedInt()) return self.builder.buildSShlSat(lhs, casted_rhs, ""); - return self.builder.buildUShlSat(lhs, casted_rhs, ""); + + const result = if (lhs_scalar_ty.isSignedInt()) + self.builder.buildSShlSat(lhs, casted_rhs, "") + else + self.builder.buildUShlSat(lhs, casted_rhs, ""); + + // LLVM langref says "If b is (statically or dynamically) equal to or + // larger than the integer bit width of the arguments, the result is a + // poison value." + // However Zig semantics says that saturating shift left can never produce + // undefined; instead it saturates. + const lhs_scalar_llvm_ty = try self.dg.lowerType(lhs_scalar_ty); + const bits = lhs_scalar_llvm_ty.constInt(lhs_bits, .False); + const lhs_max = lhs_scalar_llvm_ty.constAllOnes(); + if (rhs_ty.zigTypeTag() == .Vector) { + const vec_len = rhs_ty.vectorLen(); + const bits_vec = self.builder.buildVectorSplat(vec_len, bits, ""); + const lhs_max_vec = self.builder.buildVectorSplat(vec_len, lhs_max, ""); + const in_range = self.builder.buildICmp(.ULT, rhs, bits_vec, ""); + return self.builder.buildSelect(in_range, result, lhs_max_vec, ""); + } else { + const in_range = self.builder.buildICmp(.ULT, rhs, bits, ""); + return self.builder.buildSelect(in_range, result, lhs_max, ""); + } } fn airShr(self: *FuncGen, inst: Air.Inst.Index, is_exact: bool) !?*const llvm.Value { diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index 697f002dfc..bb0cbd9b9f 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -3868,16 +3868,33 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, Stage1Air *executable, } else { zig_unreachable(); } - case IrBinOpShlSat: - if (scalar_type->id == ZigTypeIdInt) { - if (scalar_type->data.integral.is_signed) { - return ZigLLVMBuildSShlSat(g->builder, op1_value, op2_value, ""); - } else { - return ZigLLVMBuildUShlSat(g->builder, op1_value, op2_value, ""); - } - } else { + case IrBinOpShlSat: { + if (scalar_type->id != ZigTypeIdInt) { zig_unreachable(); } + LLVMValueRef result = scalar_type->data.integral.is_signed ? + ZigLLVMBuildSShlSat(g->builder, op1_value, op2_value, "") : + ZigLLVMBuildUShlSat(g->builder, op1_value, op2_value, ""); + // LLVM langref says "If b is (statically or dynamically) equal to or + // larger than the integer bit width of the arguments, the result is a + // poison value." + // However Zig semantics says that saturating shift left can never produce + // undefined; instead it saturates. + LLVMTypeRef lhs_scalar_llvm_ty = get_llvm_type(g, scalar_type); + LLVMValueRef bits = LLVMConstInt(lhs_scalar_llvm_ty, + scalar_type->data.integral.bit_count, false); + LLVMValueRef lhs_max = LLVMConstAllOnes(lhs_scalar_llvm_ty); + if (operand_type->id == ZigTypeIdVector) { + uint64_t vec_len = operand_type->data.vector.len; + LLVMValueRef bits_vec = LLVMBuildVectorSplat(g->builder, vec_len, bits, ""); + LLVMValueRef lhs_max_vec = LLVMBuildVectorSplat(g->builder, vec_len, lhs_max, ""); + LLVMValueRef in_range = LLVMBuildICmp(g->builder, LLVMIntULT, op2_value, bits_vec, ""); + return LLVMBuildSelect(g->builder, in_range, result, lhs_max_vec, ""); + } else { + LLVMValueRef in_range = LLVMBuildICmp(g->builder, LLVMIntULT, op2_value, bits, ""); + return LLVMBuildSelect(g->builder, in_range, result, lhs_max, ""); + } + } } zig_unreachable(); }