stage1: saturating shl operates using LHS type

Saturating shift left (`<<|`) previously used the `ir_analyze_bin_op_math`
codepath rather than the `ir_analyze_bit_shift` codepath, leading to it
doing peer type resolution (incorrect) instead of using the LHS type as
the number of bits to do the saturating against.

This required implementing SIMD vector support for `@truncate`.

Additionall, this commit adds a compile error for saturating shift left
on a comptime_int.

stage2 does not pass these new behavior tests yet.

closes #10298
This commit is contained in:
Andrew Kelley 2021-12-08 15:19:13 -07:00
parent 64e2bfaa23
commit 38b2d62092
4 changed files with 153 additions and 54 deletions

View File

@ -9900,6 +9900,100 @@ static Stage1AirInst *ir_analyze_math_op(IrAnalyze *ira, Scope *scope, AstNode *
return ir_implicit_cast(ira, result_instruction, type_entry);
}
static Stage1AirInst *ir_analyze_truncate(IrAnalyze *ira, Scope *scope, AstNode *source_node,
ZigType *dest_scalar_type, AstNode *dest_type_node,
Stage1AirInst *operand, AstNode *operand_node)
{
if (dest_scalar_type->id != ZigTypeIdInt &&
dest_scalar_type->id != ZigTypeIdComptimeInt)
{
ir_add_error_node(ira, dest_type_node,
buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
}
ZigType *src_type = operand->value->type;
bool is_vector = (src_type->id == ZigTypeIdVector);
ZigType *src_scalar_type = is_vector ?
src_type->data.vector.elem_type : src_type;
ZigType *dest_type = is_vector ?
get_vector_type(ira->codegen, src_type->data.vector.len, dest_scalar_type) :
dest_scalar_type;
if (src_scalar_type->id != ZigTypeIdInt && src_scalar_type->id != ZigTypeIdComptimeInt) {
ir_add_error_node(ira, operand_node,
buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
}
if (dest_scalar_type->id == ZigTypeIdComptimeInt) {
return ir_implicit_cast2(ira, scope, operand_node, operand, dest_type);
}
if (src_scalar_type->id != ZigTypeIdComptimeInt) {
if (src_scalar_type->data.integral.is_signed != dest_scalar_type->data.integral.is_signed) {
const char *sign_str = dest_scalar_type->data.integral.is_signed ? "signed" : "unsigned";
ir_add_error_node(ira, operand_node, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
} else if (src_scalar_type->data.integral.bit_count > 0 && src_scalar_type->data.integral.bit_count < dest_scalar_type->data.integral.bit_count) {
ir_add_error_node(ira, operand_node, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
buf_ptr(&src_scalar_type->name), buf_ptr(&dest_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
}
}
if (instr_is_comptime(operand)) {
ZigValue *val = ir_resolve_const(ira, operand, UndefBad);
if (val == nullptr)
return ira->codegen->invalid_inst_gen;
if (!is_vector) {
Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
dest_scalar_type->data.integral.bit_count,
dest_scalar_type->data.integral.is_signed);
return result;
}
Stage1AirInst *result_instruction = ir_const(ira, scope, source_node, dest_type);
ZigValue *out_val = result_instruction->value;
expand_undef_array(ira->codegen, operand->value);
out_val->special = ConstValSpecialUndef;
expand_undef_array(ira->codegen, out_val);
size_t len = dest_type->data.vector.len;
for (size_t i = 0; i < len; i += 1) {
ZigValue *scalar_operand_val = &operand->value->data.x_array.data.s_none.elements[i];
ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
assert(scalar_operand_val->type == dest_scalar_type);
assert(scalar_out_val->type == dest_scalar_type);
bigint_truncate(&scalar_out_val->data.x_bigint,
&scalar_operand_val->data.x_bigint,
dest_scalar_type->data.integral.bit_count,
dest_scalar_type->data.integral.is_signed);
scalar_out_val->type = dest_scalar_type;
scalar_out_val->special = ConstValSpecialStatic;
}
out_val->type = dest_type;
out_val->special = ConstValSpecialStatic;
return result_instruction;
}
if (src_scalar_type->data.integral.bit_count == 0 ||
dest_scalar_type->data.integral.bit_count == 0)
{
Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
if (!is_vector) {
bigint_init_unsigned(&result->value->data.x_bigint, 0);
}
return result;
}
return ir_build_truncate_gen(ira, scope, source_node, dest_type, operand);
}
static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *bin_op_instruction) {
Stage1AirInst *op1 = bin_op_instruction->op1->child;
if (type_is_invalid(op1->value->type))
@ -9951,6 +10045,12 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
// comptime_int has no finite bit width
casted_op2 = op2;
if (op_id == IrBinOpShlSat) {
ir_add_error_node(ira, bin_op_instruction->base.source_node,
buf_sprintf("saturating shift on a comptime_int which has unlimited bits"));
return ira->codegen->invalid_inst_gen;
}
if (op_id == IrBinOpBitShiftLeftLossy) {
op_id = IrBinOpBitShiftLeftExact;
}
@ -9972,6 +10072,13 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
return ira->codegen->invalid_inst_gen;
}
} else if (op_id == IrBinOpShlSat) {
casted_op2 = ir_analyze_truncate(ira,
bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
op1_scalar_type, bin_op_instruction->op1->source_node,
op2, bin_op_instruction->op2->source_node);
if (type_is_invalid(casted_op2->value->type))
return ira->codegen->invalid_inst_gen;
} else {
const unsigned bit_count = op1_scalar_type->data.integral.bit_count;
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
@ -10030,8 +10137,9 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
return ir_analyze_math_op(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1_type, op1_val, op_id, op2_val);
}
return ir_build_bin_op_gen(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1->value->type,
op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
return ir_build_bin_op_gen(ira,
bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
op1->value->type, op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
}
static bool ok_float_op(IrBinOp op) {
@ -11035,6 +11143,7 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
case IrBinOpBitShiftLeftExact:
case IrBinOpBitShiftRightLossy:
case IrBinOpBitShiftRightExact:
case IrBinOpShlSat:
return ir_analyze_bit_shift(ira, bin_op_instruction);
case IrBinOpBinOr:
case IrBinOpBinXor:
@ -11057,7 +11166,6 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
case IrBinOpAddSat:
case IrBinOpSubSat:
case IrBinOpMultSat:
case IrBinOpShlSat:
return ir_analyze_bin_op_math(ira, bin_op_instruction);
case IrBinOpArrayCat:
return ir_analyze_array_cat(ira, bin_op_instruction);
@ -20017,59 +20125,13 @@ static Stage1AirInst *ir_analyze_instruction_truncate(IrAnalyze *ira, Stage1ZirI
if (type_is_invalid(dest_type))
return ira->codegen->invalid_inst_gen;
if (dest_type->id != ZigTypeIdInt &&
dest_type->id != ZigTypeIdComptimeInt)
{
ir_add_error(ira, dest_type_value, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name)));
return ira->codegen->invalid_inst_gen;
}
Stage1AirInst *target = instruction->target->child;
ZigType *src_type = target->value->type;
if (type_is_invalid(src_type))
Stage1AirInst *operand = instruction->target->child;
if (type_is_invalid(operand->value->type))
return ira->codegen->invalid_inst_gen;
if (src_type->id != ZigTypeIdInt &&
src_type->id != ZigTypeIdComptimeInt)
{
ir_add_error(ira, target, buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_type->name)));
return ira->codegen->invalid_inst_gen;
}
if (dest_type->id == ZigTypeIdComptimeInt) {
return ir_implicit_cast2(ira, instruction->target->scope, instruction->target->source_node, target, dest_type);
}
if (src_type->id != ZigTypeIdComptimeInt) {
if (src_type->data.integral.is_signed != dest_type->data.integral.is_signed) {
const char *sign_str = dest_type->data.integral.is_signed ? "signed" : "unsigned";
ir_add_error(ira, target, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_type->name)));
return ira->codegen->invalid_inst_gen;
} else if (src_type->data.integral.bit_count > 0 && src_type->data.integral.bit_count < dest_type->data.integral.bit_count) {
ir_add_error(ira, target, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
buf_ptr(&src_type->name), buf_ptr(&dest_type->name)));
return ira->codegen->invalid_inst_gen;
}
}
if (instr_is_comptime(target)) {
ZigValue *val = ir_resolve_const(ira, target, UndefBad);
if (val == nullptr)
return ira->codegen->invalid_inst_gen;
Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
dest_type->data.integral.bit_count, dest_type->data.integral.is_signed);
return result;
}
if (src_type->data.integral.bit_count == 0 || dest_type->data.integral.bit_count == 0) {
Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
bigint_init_unsigned(&result->value->data.x_bigint, 0);
return result;
}
return ir_build_truncate_gen(ira, instruction->base.scope, instruction->base.source_node, dest_type, target);
return ir_analyze_truncate(ira, instruction->base.scope, instruction->base.source_node,
dest_type, instruction->dest_type->source_node,
operand, instruction->target->source_node);
}
static Stage1AirInst *ir_analyze_int_cast(IrAnalyze *ira, Scope *scope, AstNode *source_node,

View File

@ -171,6 +171,7 @@ test {
_ = @import("behavior/popcount_stage1.zig");
_ = @import("behavior/ptrcast_stage1.zig");
_ = @import("behavior/reflection.zig");
_ = @import("behavior/saturating_arithmetic_stage1.zig");
_ = @import("behavior/select.zig");
_ = @import("behavior/shuffle.zig");
_ = @import("behavior/sizeof_and_typeof_stage1.zig");
@ -181,6 +182,7 @@ test {
_ = @import("behavior/switch_prong_err_enum.zig");
_ = @import("behavior/switch_prong_implicit_cast.zig");
_ = @import("behavior/switch_stage1.zig");
_ = @import("behavior/truncate_stage1.zig");
_ = @import("behavior/try.zig");
_ = @import("behavior/tuple.zig");
_ = @import("behavior/type.zig");

View File

@ -0,0 +1,22 @@
const std = @import("std");
const expect = std.testing.expect;
test "saturating shl uses the LHS type" {
const lhs_const: u8 = 1;
var lhs_var: u8 = 1;
const rhs_const: usize = 8;
var rhs_var: usize = 8;
try expect((lhs_const <<| 8) == 255);
try expect((lhs_const <<| rhs_const) == 255);
try expect((lhs_const <<| rhs_var) == 255);
try expect((lhs_var <<| 8) == 255);
try expect((lhs_var <<| rhs_const) == 255);
try expect((lhs_var <<| rhs_var) == 255);
try expect((@as(u8, 1) <<| 8) == 255);
try expect((@as(u8, 1) <<| rhs_const) == 255);
try expect((@as(u8, 1) <<| rhs_var) == 255);
}

View File

@ -0,0 +1,13 @@
const std = @import("std");
const expect = std.testing.expect;
test "truncate on vectors" {
const S = struct {
fn doTheTest() !void {
var v1: @Vector(4, u16) = .{ 0xaabb, 0xccdd, 0xeeff, 0x1122 };
var v2 = @truncate(u8, v1);
try expect(std.mem.eql(u8, &@as([4]u8, v2), &[4]u8{ 0xbb, 0xdd, 0xff, 0x22 }));
}
};
try S.doTheTest();
}