From 12d5efcbe621b98b5a99c5f84a9d5b605e4acc40 Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Wed, 23 Mar 2022 21:41:35 +0100 Subject: [PATCH 1/3] stage2: implement `@select` --- src/Air.zig | 12 ++++- src/Liveness.zig | 4 ++ src/Sema.zig | 101 ++++++++++++++++++++++++++++++++++- src/arch/aarch64/CodeGen.zig | 8 +++ src/arch/arm/CodeGen.zig | 8 +++ src/arch/riscv64/CodeGen.zig | 8 +++ src/arch/wasm/CodeGen.zig | 11 ++++ src/arch/x86_64/CodeGen.zig | 8 +++ src/codegen/c.zig | 16 ++++++ src/codegen/llvm.zig | 13 +++++ src/print_air.zig | 13 +++++ test/behavior/select.zig | 42 ++++++++------- 12 files changed, 223 insertions(+), 21 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index 404ee8f9b7..4938cb8135 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -344,7 +344,7 @@ pub const Inst = struct { /// to the storage for the variable. The local may be a const or a var. /// Result type is always void. /// Uses `pl_op`. The payload index is the variable name. It points to the extra - /// array, reinterpreting the bytes there as a null-terminated string. + /// array, reinterpreting the bytes there as a null-terminated string. dbg_var_ptr, /// Same as `dbg_var_ptr` except the local is a const, not a var, and the /// operand is the local's value. @@ -553,6 +553,9 @@ pub const Inst = struct { /// Constructs a vector by selecting elements from `a` and `b` based on `mask`. /// Uses the `ty_pl` field with payload `Shuffle`. shuffle, + /// Constructs a vector element-wise from `a` or `b` based on `pred`. + /// Uses the `ty_pl` field with payload `Select`. + select, /// Given dest ptr, value, and len, set all elements at dest to value. /// Result type is always void. @@ -785,6 +788,12 @@ pub const Shuffle = struct { mask_len: u32, }; +pub const Select = struct { + pred: Inst.Ref, + a: Inst.Ref, + b: Inst.Ref, +}; + pub const VectorCmp = struct { lhs: Inst.Ref, rhs: Inst.Ref, @@ -956,6 +965,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .cmpxchg_weak, .cmpxchg_strong, .slice, + .select, .shuffle, .aggregate_init, .union_init, diff --git a/src/Liveness.zig b/src/Liveness.zig index 79521e7a94..9e15567d73 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -433,6 +433,10 @@ fn analyzeInst( } return extra_tombs.finish(); }, + .select => { + const extra = a.air.extraData(Air.Select, inst_datas[inst].ty_pl.payload).data; + return trackOperands(a, new_set, inst, main_tomb, .{ extra.pred, extra.a, extra.b }); + }, .shuffle => { const extra = a.air.extraData(Air.Shuffle, inst_datas[inst].ty_pl.payload).data; return trackOperands(a, new_set, inst, main_tomb, .{ extra.a, extra.b, .none }); diff --git a/src/Sema.zig b/src/Sema.zig index a42a4caf38..e4d719f346 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -14804,8 +14804,105 @@ fn analyzeShuffle( fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].pl_node; - const src = inst_data.src(); - return sema.fail(block, src, "TODO: Sema.zirSelect", .{}); + const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data; + + const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node }; + const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; + const a_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node }; + const b_src: LazySrcLoc = .{ .node_offset_builtin_call_arg3 = inst_data.src_node }; + + const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type); + try sema.checkVectorElemType(block, elem_ty_src, elem_ty); + const pred = sema.resolveInst(extra.pred); + const a = sema.resolveInst(extra.a); + const b = sema.resolveInst(extra.b); + const target = sema.mod.getTarget(); + + const pred_ty = sema.typeOf(pred); + switch (try pred_ty.zigTypeTagOrPoison()) { + .Vector => { + const scalar_ty = pred_ty.childType(); + if (!scalar_ty.eql(Type.bool, target)) { + const bool_vec_ty = try Type.vector(sema.arena, pred_ty.vectorLen(), Type.bool); + return sema.fail(block, pred_src, "Expected '{}', found '{}'", .{ bool_vec_ty.fmt(target), pred_ty.fmt(target) }); + } + }, + else => return sema.fail(block, pred_src, "Expected vector type, found '{}'", .{pred_ty.fmt(target)}), + } + + const vec_len = pred_ty.vectorLen(); + const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty); + + const a_ty = sema.typeOf(a); + if (!a_ty.eql(vec_ty, target)) { + return sema.fail(block, a_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), a_ty.fmt(target) }); + } + + const b_ty = sema.typeOf(b); + if (!b_ty.eql(vec_ty, target)) { + return sema.fail(block, b_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), b_ty.fmt(target) }); + } + + const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred); + const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a); + const maybe_b = try sema.resolveMaybeUndefVal(block, b_src, b); + + const runtime_src = if (maybe_pred) |pred_val| rs: { + if (pred_val.isUndef()) return sema.addConstUndef(vec_ty); + + if (maybe_a) |a_val| { + if (a_val.isUndef()) return sema.addConstUndef(vec_ty); + + if (maybe_b) |b_val| { + if (b_val.isUndef()) return sema.addConstUndef(vec_ty); + + var buf: Value.ElemValueBuffer = undefined; + const elems = try sema.gpa.alloc(Value, vec_len); + for (elems) |*elem, i| { + const pred_elem_val = pred_val.elemValueBuffer(i, &buf); + const should_choose_a = pred_elem_val.toBool(); + if (should_choose_a) { + elem.* = a_val.elemValueBuffer(i, &buf); + } else { + elem.* = b_val.elemValueBuffer(i, &buf); + } + } + + return sema.addConstant( + vec_ty, + try Value.Tag.aggregate.create(sema.arena, elems), + ); + } else { + break :rs b_src; + } + } else { + if (maybe_b) |b_val| { + if (b_val.isUndef()) return sema.addConstUndef(vec_ty); + } + break :rs a_src; + } + } else rs: { + if (maybe_a) |a_val| { + if (a_val.isUndef()) return sema.addConstUndef(vec_ty); + } + if (maybe_b) |b_val| { + if (b_val.isUndef()) return sema.addConstUndef(vec_ty); + } + break :rs pred_src; + }; + + try sema.requireRuntimeBlock(block, runtime_src); + return block.addInst(.{ + .tag = .select, + .data = .{ .ty_pl = .{ + .ty = try block.sema.addType(vec_ty), + .payload = try block.sema.addExtra(Air.Select{ + .pred = pred, + .a = a, + .b = b, + }), + } }, + }); } fn zirAtomicLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index 8e502e8943..178dfe2c29 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -640,6 +640,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .tag_name => try self.airTagName(inst), .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), + .select => try self.airSelect(inst), .shuffle => try self.airShuffle(inst), .reduce => try self.airReduce(inst), .aggregate_init => try self.airAggregateInit(inst), @@ -3746,6 +3747,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, result, .{ ty_op.operand, .none, .none }); } +fn airSelect(self: *Self, inst: Air.Inst.Index) !void { + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for {}", .{self.target.cpu.arch}); + return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); +} + fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for {}", .{self.target.cpu.arch}); diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 5bbacb8999..7fa116edd1 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -630,6 +630,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .tag_name => try self.airTagName(inst), .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), + .select => try self.airSelect(inst), .shuffle => try self.airShuffle(inst), .reduce => try self.airReduce(inst), .aggregate_init => try self.airAggregateInit(inst), @@ -4323,6 +4324,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, result, .{ ty_op.operand, .none, .none }); } +fn airSelect(self: *Self, inst: Air.Inst.Index) !void { + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for arm", .{}); + return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); +} + fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for arm", .{}); diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 8cbd26b7a6..65b3b4904d 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -600,6 +600,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .tag_name => try self.airTagName(inst), .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), + .select => try self.airSelect(inst), .shuffle => try self.airShuffle(inst), .reduce => try self.airReduce(inst), .aggregate_init => try self.airAggregateInit(inst), @@ -2396,6 +2397,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, result, .{ ty_op.operand, .none, .none }); } +fn airSelect(self: *Self, inst: Air.Inst.Index) !void { + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for riscv64", .{}); + return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); +} + fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for riscv64", .{}); diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index f2979d96b1..de5e698eca 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1371,6 +1371,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .ret_ptr => self.airRetPtr(inst), .ret_load => self.airRetLoad(inst), .splat => self.airSplat(inst), + .select => self.airSelect(inst), .shuffle => self.airShuffle(inst), .reduce => self.airReduce(inst), .aggregate_init => self.airAggregateInit(inst), @@ -3265,6 +3266,16 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) InnerError!WValue { return self.fail("TODO: Implement wasm airSplat", .{}); } +fn airSelect(self: *Self, inst: Air.Inst.Index) InnerError!WValue { + if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; + + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const ty = try self.resolveInst(ty_pl.ty); + + _ = ty; + return self.fail("TODO: Implement wasm airSelect", .{}); +} + fn airShuffle(self: *Self, inst: Air.Inst.Index) InnerError!WValue { if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 1e79119997..8a9aaebb22 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -721,6 +721,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .tag_name => try self.airTagName(inst), .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), + .select => try self.airSelect(inst), .shuffle => try self.airShuffle(inst), .reduce => try self.airReduce(inst), .aggregate_init => try self.airAggregateInit(inst), @@ -5678,6 +5679,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { return self.finishAir(inst, result, .{ ty_op.operand, .none, .none }); } +fn airSelect(self: *Self, inst: Air.Inst.Index) !void { + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for x86_64", .{}); + return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); +} + fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for x86_64", .{}); diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 16558c6e04..84bb3a211b 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -1825,6 +1825,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO .tag_name => try airTagName(f, inst), .error_name => try airErrorName(f, inst), .splat => try airSplat(f, inst), + .select => try airSelect(f, inst), .shuffle => try airShuffle(f, inst), .reduce => try airReduce(f, inst), .aggregate_init => try airAggregateInit(f, inst), @@ -3794,6 +3795,21 @@ fn airSplat(f: *Function, inst: Air.Inst.Index) !CValue { return f.fail("TODO: C backend: implement airSplat", .{}); } +fn airSelect(f: *Function, inst: Air.Inst.Index) !CValue { + if (f.liveness.isUnused(inst)) return CValue.none; + + const inst_ty = f.air.typeOfIndex(inst); + const ty_pl = f.air.instructions.items(.data)[inst].ty_pl; + + const writer = f.object.writer(); + const local = try f.allocLocal(inst_ty, .Const); + try writer.writeAll(" = "); + + _ = local; + _ = ty_pl; + return f.fail("TODO: C backend: implement airSelect", .{}); +} + fn airShuffle(f: *Function, inst: Air.Inst.Index) !CValue { if (f.liveness.isUnused(inst)) return CValue.none; diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index c8ca540f59..0f481e0a6d 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -3444,6 +3444,7 @@ pub const FuncGen = struct { .tag_name => try self.airTagName(inst), .error_name => try self.airErrorName(inst), .splat => try self.airSplat(inst), + .select => try self.airSelect(inst), .shuffle => try self.airShuffle(inst), .reduce => try self.airReduce(inst), .aggregate_init => try self.airAggregateInit(inst), @@ -6355,6 +6356,18 @@ pub const FuncGen = struct { return self.builder.buildShuffleVector(op_vector, undef_vector, mask_llvm_ty.constNull(), ""); } + fn airSelect(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { + if (self.liveness.isUnused(inst)) return null; + + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const pred = try self.resolveInst(extra.pred); + const a = try self.resolveInst(extra.a); + const b = try self.resolveInst(extra.b); + + return self.builder.buildSelect(pred, a, b, ""); + } + fn airShuffle(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { if (self.liveness.isUnused(inst)) return null; diff --git a/src/print_air.zig b/src/print_air.zig index 5e0179e3cc..a2fe461887 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -264,6 +264,7 @@ const Writer = struct { .wasm_memory_size => try w.writeWasmMemorySize(s, inst), .wasm_memory_grow => try w.writeWasmMemoryGrow(s, inst), .mul_add => try w.writeMulAdd(s, inst), + .select => try w.writeSelect(s, inst), .shuffle => try w.writeShuffle(s, inst), .reduce => try w.writeReduce(s, inst), .cmp_vector => try w.writeCmpVector(s, inst), @@ -396,6 +397,18 @@ const Writer = struct { try s.print(", mask {d}, len {d}", .{ extra.mask, extra.mask_len }); } + fn writeSelect(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { + const ty_pl = w.air.instructions.items(.data)[inst].ty_pl; + const extra = w.air.extraData(Air.Select, ty_pl.payload).data; + + try s.print("{}, ", .{w.air.getRefType(ty_pl.ty).fmtDebug()}); + try w.writeOperand(s, inst, 0, extra.pred); + try s.writeAll(", "); + try w.writeOperand(s, inst, 1, extra.a); + try s.writeAll(", "); + try w.writeOperand(s, inst, 2, extra.b); + } + fn writeReduce(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { const reduce = w.air.instructions.items(.data)[inst].reduce; diff --git a/test/behavior/select.zig b/test/behavior/select.zig index 8b4cba49bd..a1fcfb761a 100644 --- a/test/behavior/select.zig +++ b/test/behavior/select.zig @@ -4,23 +4,29 @@ const mem = std.mem; const expect = std.testing.expect; test "@select" { - if (@import("builtin").zig_backend != .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - const S = struct { - fn doTheTest() !void { - var a: @Vector(4, bool) = [4]bool{ true, false, true, false }; - var b: @Vector(4, i32) = [4]i32{ -1, 4, 999, -31 }; - var c: @Vector(4, i32) = [4]i32{ -5, 1, 0, 1234 }; - var abc = @select(i32, a, b, c); - try expect(mem.eql(i32, &@as([4]i32, abc), &[4]i32{ -1, 1, 999, 1234 })); - - var x: @Vector(4, bool) = [4]bool{ false, false, false, true }; - var y: @Vector(4, f32) = [4]f32{ 0.001, 33.4, 836, -3381.233 }; - var z: @Vector(4, f32) = [4]f32{ 0.0, 312.1, -145.9, 9993.55 }; - var xyz = @select(f32, x, y, z); - try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); - } - }; - try S.doTheTest(); - comptime try S.doTheTest(); + try doTheTest(); + comptime try doTheTest(); +} + +fn doTheTest() !void { + var a = @Vector(4, bool){ true, false, true, false }; + var b = @Vector(4, i32){ -1, 4, 999, -31 }; + var c = @Vector(4, i32){ -5, 1, 0, 1234 }; + var abc = @select(i32, a, b, c); + try expect(abc[0] == -1); + try expect(abc[1] == 1); + try expect(abc[2] == 999); + try expect(abc[3] == 1234); + + var x = @Vector(4, bool){ false, false, false, true }; + var y = @Vector(4, f32){ 0.001, 33.4, 836, -3381.233 }; + var z = @Vector(4, f32){ 0.0, 312.1, -145.9, 9993.55 }; + var xyz = @select(f32, x, y, z); + try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); } From f47db0a0dbf64f1fef0ba86bd9f684a5df29bc3d Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Thu, 24 Mar 2022 23:02:06 +0100 Subject: [PATCH 2/3] sema: use `pl_op` for `@select` --- src/Air.zig | 13 +++++-------- src/Liveness.zig | 5 +++-- src/Sema.zig | 11 +++++------ src/arch/aarch64/CodeGen.zig | 6 +++--- src/arch/arm/CodeGen.zig | 6 +++--- src/arch/riscv64/CodeGen.zig | 6 +++--- src/arch/wasm/CodeGen.zig | 6 +++--- src/arch/x86_64/CodeGen.zig | 6 +++--- src/codegen/llvm.zig | 10 +++++----- src/print_air.zig | 13 +++++++------ 10 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index 4938cb8135..5120e0fd67 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -554,7 +554,7 @@ pub const Inst = struct { /// Uses the `ty_pl` field with payload `Shuffle`. shuffle, /// Constructs a vector element-wise from `a` or `b` based on `pred`. - /// Uses the `ty_pl` field with payload `Select`. + /// Uses the `pl_op` field with `pred` as operand, and payload `Bin`. select, /// Given dest ptr, value, and len, set all elements at dest to value. @@ -788,12 +788,6 @@ pub const Shuffle = struct { mask_len: u32, }; -pub const Select = struct { - pred: Inst.Ref, - a: Inst.Ref, - b: Inst.Ref, -}; - pub const VectorCmp = struct { lhs: Inst.Ref, rhs: Inst.Ref, @@ -965,7 +959,6 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .cmpxchg_weak, .cmpxchg_strong, .slice, - .select, .shuffle, .aggregate_init, .union_init, @@ -1077,6 +1070,10 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .reduce => return air.typeOf(datas[inst].reduce.operand).childType(), .mul_add => return air.typeOf(datas[inst].pl_op.operand), + .select => { + const extra = air.extraData(Air.Bin, datas[inst].pl_op.payload).data; + return air.typeOf(extra.lhs); + }, .add_with_overflow, .sub_with_overflow, diff --git a/src/Liveness.zig b/src/Liveness.zig index 9e15567d73..b9f4e6b33a 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -434,8 +434,9 @@ fn analyzeInst( return extra_tombs.finish(); }, .select => { - const extra = a.air.extraData(Air.Select, inst_datas[inst].ty_pl.payload).data; - return trackOperands(a, new_set, inst, main_tomb, .{ extra.pred, extra.a, extra.b }); + const pl_op = inst_datas[inst].pl_op; + const extra = a.air.extraData(Air.Bin, pl_op.payload).data; + return trackOperands(a, new_set, inst, main_tomb, .{ pl_op.operand, extra.lhs, extra.rhs }); }, .shuffle => { const extra = a.air.extraData(Air.Shuffle, inst_datas[inst].ty_pl.payload).data; diff --git a/src/Sema.zig b/src/Sema.zig index e4d719f346..6971be64bf 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -14894,12 +14894,11 @@ fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air. try sema.requireRuntimeBlock(block, runtime_src); return block.addInst(.{ .tag = .select, - .data = .{ .ty_pl = .{ - .ty = try block.sema.addType(vec_ty), - .payload = try block.sema.addExtra(Air.Select{ - .pred = pred, - .a = a, - .b = b, + .data = .{ .pl_op = .{ + .operand = pred, + .payload = try block.sema.addExtra(Air.Bin{ + .lhs = a, + .rhs = b, }), } }, }); diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index 178dfe2c29..4d931da0b2 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -3748,10 +3748,10 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { } fn airSelect(self: *Self, inst: Air.Inst.Index) !void { - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for {}", .{self.target.cpu.arch}); - return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); + return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs }); } fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 7fa116edd1..15d0e47d4a 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -4325,10 +4325,10 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { } fn airSelect(self: *Self, inst: Air.Inst.Index) !void { - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for arm", .{}); - return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); + return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs }); } fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 65b3b4904d..8b697677c1 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -2398,10 +2398,10 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { } fn airSelect(self: *Self, inst: Air.Inst.Index) !void { - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for riscv64", .{}); - return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); + return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs }); } fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index de5e698eca..7ab4a33be9 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -3269,10 +3269,10 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) InnerError!WValue { fn airSelect(self: *Self, inst: Air.Inst.Index) InnerError!WValue { if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const ty = try self.resolveInst(ty_pl.ty); + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const operand = try self.resolveInst(pl_op.operand); - _ = ty; + _ = operand; return self.fail("TODO: Implement wasm airSelect", .{}); } diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 8a9aaebb22..4efb836c3a 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -5680,10 +5680,10 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void { } fn airSelect(self: *Self, inst: Air.Inst.Index) !void { - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const extra = self.air.extraData(Air.Select, ty_pl.payload).data; + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for x86_64", .{}); - return self.finishAir(inst, result, .{ extra.pred, extra.a, extra.b }); + return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs }); } fn airShuffle(self: *Self, inst: Air.Inst.Index) !void { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 0f481e0a6d..00074d69d1 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -6359,11 +6359,11 @@ pub const FuncGen = struct { fn airSelect(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { if (self.liveness.isUnused(inst)) return null; - const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; - const extra = self.air.extraData(Air.Select, ty_pl.payload).data; - const pred = try self.resolveInst(extra.pred); - const a = try self.resolveInst(extra.a); - const b = try self.resolveInst(extra.b); + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const extra = self.air.extraData(Air.Bin, pl_op.payload).data; + const pred = try self.resolveInst(pl_op.operand); + const a = try self.resolveInst(extra.lhs); + const b = try self.resolveInst(extra.rhs); return self.builder.buildSelect(pred, a, b, ""); } diff --git a/src/print_air.zig b/src/print_air.zig index a2fe461887..f1e51150a6 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -398,15 +398,16 @@ const Writer = struct { } fn writeSelect(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { - const ty_pl = w.air.instructions.items(.data)[inst].ty_pl; - const extra = w.air.extraData(Air.Select, ty_pl.payload).data; + const pl_op = w.air.instructions.items(.data)[inst].pl_op; + const extra = w.air.extraData(Air.Bin, pl_op.payload).data; - try s.print("{}, ", .{w.air.getRefType(ty_pl.ty).fmtDebug()}); - try w.writeOperand(s, inst, 0, extra.pred); + const elem_ty = w.air.typeOfIndex(inst).childType(); + try s.print("{}, ", .{elem_ty.fmtDebug()}); + try w.writeOperand(s, inst, 0, pl_op.operand); try s.writeAll(", "); - try w.writeOperand(s, inst, 1, extra.a); + try w.writeOperand(s, inst, 1, extra.lhs); try s.writeAll(", "); - try w.writeOperand(s, inst, 2, extra.b); + try w.writeOperand(s, inst, 2, extra.rhs); } fn writeReduce(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { From cd46daf7d047eeceb7690e2739af5952d60c3884 Mon Sep 17 00:00:00 2001 From: John Schmidt Date: Thu, 24 Mar 2022 23:27:23 +0100 Subject: [PATCH 3/3] sema: coerce inputs to vectors in zirSelect --- src/Sema.zig | 39 +++++++++++++-------------------------- test/behavior/select.zig | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 6971be64bf..27f12485a4 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -14805,6 +14805,7 @@ fn analyzeShuffle( fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].pl_node; const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data; + const target = sema.mod.getTarget(); const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node }; const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node }; @@ -14813,35 +14814,21 @@ fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air. const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type); try sema.checkVectorElemType(block, elem_ty_src, elem_ty); - const pred = sema.resolveInst(extra.pred); - const a = sema.resolveInst(extra.a); - const b = sema.resolveInst(extra.b); - const target = sema.mod.getTarget(); + const pred_uncoerced = sema.resolveInst(extra.pred); + const pred_ty = sema.typeOf(pred_uncoerced); - const pred_ty = sema.typeOf(pred); - switch (try pred_ty.zigTypeTagOrPoison()) { - .Vector => { - const scalar_ty = pred_ty.childType(); - if (!scalar_ty.eql(Type.bool, target)) { - const bool_vec_ty = try Type.vector(sema.arena, pred_ty.vectorLen(), Type.bool); - return sema.fail(block, pred_src, "Expected '{}', found '{}'", .{ bool_vec_ty.fmt(target), pred_ty.fmt(target) }); - } - }, - else => return sema.fail(block, pred_src, "Expected vector type, found '{}'", .{pred_ty.fmt(target)}), - } + const vec_len_u64 = switch (try pred_ty.zigTypeTagOrPoison()) { + .Vector, .Array => pred_ty.arrayLen(), + else => return sema.fail(block, pred_src, "expected vector or array, found '{}'", .{pred_ty.fmt(target)}), + }; + const vec_len = try sema.usizeCast(block, pred_src, vec_len_u64); + + const bool_vec_ty = try Type.vector(sema.arena, vec_len, Type.bool); + const pred = try sema.coerce(block, bool_vec_ty, pred_uncoerced, pred_src); - const vec_len = pred_ty.vectorLen(); const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty); - - const a_ty = sema.typeOf(a); - if (!a_ty.eql(vec_ty, target)) { - return sema.fail(block, a_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), a_ty.fmt(target) }); - } - - const b_ty = sema.typeOf(b); - if (!b_ty.eql(vec_ty, target)) { - return sema.fail(block, b_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), b_ty.fmt(target) }); - } + const a = try sema.coerce(block, vec_ty, sema.resolveInst(extra.a), a_src); + const b = try sema.coerce(block, vec_ty, sema.resolveInst(extra.b), b_src); const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred); const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a); diff --git a/test/behavior/select.zig b/test/behavior/select.zig index a1fcfb761a..f731ded09e 100644 --- a/test/behavior/select.zig +++ b/test/behavior/select.zig @@ -3,18 +3,18 @@ const builtin = @import("builtin"); const mem = std.mem; const expect = std.testing.expect; -test "@select" { +test "@select vectors" { if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO - try doTheTest(); - comptime try doTheTest(); + comptime try selectVectors(); + try selectVectors(); } -fn doTheTest() !void { +fn selectVectors() !void { var a = @Vector(4, bool){ true, false, true, false }; var b = @Vector(4, i32){ -1, 4, 999, -31 }; var c = @Vector(4, i32){ -5, 1, 0, 1234 }; @@ -30,3 +30,32 @@ fn doTheTest() !void { var xyz = @select(f32, x, y, z); try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); } + +test "@select arrays" { + if (builtin.zig_backend == .stage1) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + comptime try selectArrays(); + try selectArrays(); +} + +fn selectArrays() !void { + var a = [4]bool{ false, true, false, true }; + var b = [4]usize{ 0, 1, 2, 3 }; + var c = [4]usize{ 4, 5, 6, 7 }; + var abc = @select(usize, a, b, c); + try expect(abc[0] == 4); + try expect(abc[1] == 1); + try expect(abc[2] == 6); + try expect(abc[3] == 3); + + var x = [4]bool{ false, false, false, true }; + var y = [4]f32{ 0.001, 33.4, 836, -3381.233 }; + var z = [4]f32{ 0.0, 312.1, -145.9, 9993.55 }; + var xyz = @select(f32, x, y, z); + try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); +}