From cd46daf7d047eeceb7690e2739af5952d60c3884 Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Thu, 24 Mar 2022 23:27:23 +0100 Subject: [PATCH] sema: coerce inputs to vectors in zirSelect --- src/Sema.zig | 39 +++++++++++++-------------------------- test/behavior/select.zig | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 6971be64bf..27f12485a4 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -14805,6 +14805,7 @@ fn analyzeShuffle( fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].pl_node; const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data; + const target = sema.mod.getTarget(); const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node }; const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; @@ -14813,35 +14814,21 @@ fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air. const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type); try sema.checkVectorElemType(block, elem_ty_src, elem_ty); - const pred = sema.resolveInst(extra.pred); - const a = sema.resolveInst(extra.a); - const b = sema.resolveInst(extra.b); - const target = sema.mod.getTarget(); + const pred_uncoerced = sema.resolveInst(extra.pred); + const pred_ty = sema.typeOf(pred_uncoerced); - const pred_ty = sema.typeOf(pred); - switch (try pred_ty.zigTypeTagOrPoison()) { - .Vector => { - const scalar_ty = pred_ty.childType(); - if (!scalar_ty.eql(Type.bool, target)) { - const bool_vec_ty = try Type.vector(sema.arena, pred_ty.vectorLen(), Type.bool); - return sema.fail(block, pred_src, "Expected '{}', found '{}'", .{ bool_vec_ty.fmt(target), pred_ty.fmt(target) }); - } - }, - else => return sema.fail(block, pred_src, "Expected vector type, found '{}'", .{pred_ty.fmt(target)}), - } + const vec_len_u64 = switch (try pred_ty.zigTypeTagOrPoison()) { + .Vector, .Array => pred_ty.arrayLen(), + else => return sema.fail(block, pred_src, "expected vector or array, found '{}'", .{pred_ty.fmt(target)}), + }; + const vec_len = try sema.usizeCast(block, pred_src, vec_len_u64); + + const bool_vec_ty = try Type.vector(sema.arena, vec_len, Type.bool); + const pred = try sema.coerce(block, bool_vec_ty, pred_uncoerced, pred_src); - const vec_len = pred_ty.vectorLen(); const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty); - - const a_ty = sema.typeOf(a); - if (!a_ty.eql(vec_ty, target)) { - return sema.fail(block, a_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), a_ty.fmt(target) }); - } - - const b_ty = sema.typeOf(b); - if (!b_ty.eql(vec_ty, target)) { - return sema.fail(block, b_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), b_ty.fmt(target) }); - } + const a = try sema.coerce(block, vec_ty, sema.resolveInst(extra.a), a_src); + const b = try sema.coerce(block, vec_ty, sema.resolveInst(extra.b), b_src); const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred); const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a); diff --git a/test/behavior/select.zig b/test/behavior/select.zig index a1fcfb761a..f731ded09e 100644 --- a/test/behavior/select.zig +++ b/test/behavior/select.zig @@ -3,18 +3,18 @@ const builtin = @import("builtin"); const mem = std.mem; const expect = std.testing.expect; -test "@select" { +test "@select vectors" { if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - try doTheTest(); - comptime try doTheTest(); + comptime try selectVectors(); + try selectVectors(); } -fn doTheTest() !void { +fn selectVectors() !void { var a = @Vector(4, bool){ true, false, true, false }; var b = @Vector(4, i32){ -1, 4, 999, -31 }; var c = @Vector(4, i32){ -5, 1, 0, 1234 }; @@ -30,3 +30,32 @@ fn doTheTest() !void { var xyz = @select(f32, x, y, z); try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); } + +test "@select arrays" { + if (builtin.zig_backend == .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + comptime try selectArrays(); + try selectArrays(); +} + +fn selectArrays() !void { + var a = [4]bool{ false, true, false, true }; + var b = [4]usize{ 0, 1, 2, 3 }; + var c = [4]usize{ 4, 5, 6, 7 }; + var abc = @select(usize, a, b, c); + try expect(abc[0] == 4); + try expect(abc[1] == 1); + try expect(abc[2] == 6); + try expect(abc[3] == 3); + + var x = [4]bool{ false, false, false, true }; + var y = [4]f32{ 0.001, 33.4, 836, -3381.233 }; + var z = [4]f32{ 0.0, 312.1, -145.9, 9993.55 }; + var xyz = @select(f32, x, y, z); + try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); +}