sema: coerce inputs to vectors in zirSelect

This commit is contained in:
John Schmidt 2022-03-24 23:27:23 +01:00
parent f47db0a0db
commit cd46daf7d0
2 changed files with 46 additions and 30 deletions

View File

@ -14805,6 +14805,7 @@ fn analyzeShuffle(
fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data;
const target = sema.mod.getTarget();
const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
@ -14813,35 +14814,21 @@ fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type);
try sema.checkVectorElemType(block, elem_ty_src, elem_ty);
const pred = sema.resolveInst(extra.pred);
const a = sema.resolveInst(extra.a);
const b = sema.resolveInst(extra.b);
const target = sema.mod.getTarget();
const pred_uncoerced = sema.resolveInst(extra.pred);
const pred_ty = sema.typeOf(pred_uncoerced);
const pred_ty = sema.typeOf(pred);
switch (try pred_ty.zigTypeTagOrPoison()) {
.Vector => {
const scalar_ty = pred_ty.childType();
if (!scalar_ty.eql(Type.bool, target)) {
const bool_vec_ty = try Type.vector(sema.arena, pred_ty.vectorLen(), Type.bool);
return sema.fail(block, pred_src, "Expected '{}', found '{}'", .{ bool_vec_ty.fmt(target), pred_ty.fmt(target) });
}
},
else => return sema.fail(block, pred_src, "Expected vector type, found '{}'", .{pred_ty.fmt(target)}),
}
const vec_len_u64 = switch (try pred_ty.zigTypeTagOrPoison()) {
.Vector, .Array => pred_ty.arrayLen(),
else => return sema.fail(block, pred_src, "expected vector or array, found '{}'", .{pred_ty.fmt(target)}),
};
const vec_len = try sema.usizeCast(block, pred_src, vec_len_u64);
const bool_vec_ty = try Type.vector(sema.arena, vec_len, Type.bool);
const pred = try sema.coerce(block, bool_vec_ty, pred_uncoerced, pred_src);
const vec_len = pred_ty.vectorLen();
const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty);
const a_ty = sema.typeOf(a);
if (!a_ty.eql(vec_ty, target)) {
return sema.fail(block, a_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), a_ty.fmt(target) });
}
const b_ty = sema.typeOf(b);
if (!b_ty.eql(vec_ty, target)) {
return sema.fail(block, b_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), b_ty.fmt(target) });
}
const a = try sema.coerce(block, vec_ty, sema.resolveInst(extra.a), a_src);
const b = try sema.coerce(block, vec_ty, sema.resolveInst(extra.b), b_src);
const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred);
const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a);

View File

@ -3,18 +3,18 @@ const builtin = @import("builtin");
const mem = std.mem;
const expect = std.testing.expect;
test "@select" {
test "@select vectors" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
try doTheTest();
comptime try doTheTest();
comptime try selectVectors();
try selectVectors();
}
fn doTheTest() !void {
fn selectVectors() !void {
var a = @Vector(4, bool){ true, false, true, false };
var b = @Vector(4, i32){ -1, 4, 999, -31 };
var c = @Vector(4, i32){ -5, 1, 0, 1234 };
@ -30,3 +30,32 @@ fn doTheTest() !void {
var xyz = @select(f32, x, y, z);
try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
}
test "@select arrays" {
if (builtin.zig_backend == .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
comptime try selectArrays();
try selectArrays();
}
fn selectArrays() !void {
var a = [4]bool{ false, true, false, true };
var b = [4]usize{ 0, 1, 2, 3 };
var c = [4]usize{ 4, 5, 6, 7 };
var abc = @select(usize, a, b, c);
try expect(abc[0] == 4);
try expect(abc[1] == 1);
try expect(abc[2] == 6);
try expect(abc[3] == 3);
var x = [4]bool{ false, false, false, true };
var y = [4]f32{ 0.001, 33.4, 836, -3381.233 };
var z = [4]f32{ 0.0, 312.1, -145.9, 9993.55 };
var xyz = @select(f32, x, y, z);
try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
}