Legalize: implement scalarization of @select

This commit is contained in:
Jacob Young 2025-05-30 18:04:30 -04:00 committed by mlugg
parent 32a57bfeaa
commit b48d6ff619
No known key found for this signature in database
GPG Key ID: 3F5B7DCCBF4AF02E
5 changed files with 102 additions and 21 deletions

View File

@ -368,9 +368,6 @@ pub fn countElementsWithValue(vec: anytype, value: std.meta.Child(@TypeOf(vec)))
} }
test "vector searching" { test "vector searching" {
if (builtin.zig_backend == .stage2_x86_64 and
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) return error.SkipZigTest;
const base = @Vector(8, u32){ 6, 4, 7, 4, 4, 2, 3, 7 }; const base = @Vector(8, u32){ 6, 4, 7, 4, 4, 2, 3, 7 };
try std.testing.expectEqual(@as(?u3, 1), firstIndexOfValue(base, 4)); try std.testing.expectEqual(@as(?u3, 1), firstIndexOfValue(base, 4));

View File

@ -74,6 +74,7 @@ pub const Feature = enum {
scalarize_int_from_float, scalarize_int_from_float,
scalarize_int_from_float_optimized, scalarize_int_from_float_optimized,
scalarize_float_from_int, scalarize_float_from_int,
scalarize_select,
scalarize_mul_add, scalarize_mul_add,
/// Legalize (shift lhs, (splat rhs)) -> (shift lhs, rhs) /// Legalize (shift lhs, (splat rhs)) -> (shift lhs, rhs)
@ -167,6 +168,7 @@ pub const Feature = enum {
.int_from_float => .scalarize_int_from_float, .int_from_float => .scalarize_int_from_float,
.int_from_float_optimized => .scalarize_int_from_float_optimized, .int_from_float_optimized => .scalarize_int_from_float_optimized,
.float_from_int => .scalarize_float_from_int, .float_from_int => .scalarize_float_from_int,
.select => .scalarize_select,
.mul_add => .scalarize_mul_add, .mul_add => .scalarize_mul_add,
}; };
} }
@ -520,7 +522,9 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
}, },
.splat, .splat,
.shuffle, .shuffle,
=> {},
.select, .select,
=> if (l.features.contains(.scalarize_select)) continue :inst try l.scalarize(inst, .select_pl_op_bin),
.memset, .memset,
.memset_safe, .memset_safe,
.memcpy, .memcpy,
@ -568,7 +572,7 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
} }
} }
const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin }; const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin, select_pl_op_bin };
inline fn scalarize(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_tag: ScalarizeDataTag) Error!Air.Inst.Tag { inline fn scalarize(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_tag: ScalarizeDataTag) Error!Air.Inst.Tag {
return l.replaceInst(orig_inst, .block, try l.scalarizeBlockPayload(orig_inst, data_tag)); return l.replaceInst(orig_inst, .block, try l.scalarizeBlockPayload(orig_inst, data_tag));
} }
@ -584,6 +588,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
.un_op, .ty_op => 1, .un_op, .ty_op => 1,
.bin_op, .ty_pl_vector_cmp => 2, .bin_op, .ty_pl_vector_cmp => 2,
.pl_op_bin => 3, .pl_op_bin => 3,
.select_pl_op_bin => 6,
} + 9 } + 9
]Air.Inst.Index = undefined; ]Air.Inst.Index = undefined;
try l.air_instructions.ensureUnusedCapacity(zcu.gpa, inst_buf.len); try l.air_instructions.ensureUnusedCapacity(zcu.gpa, inst_buf.len);
@ -722,23 +727,67 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
} }, } },
}); });
}, },
.select_pl_op_bin => {
const extra = l.extraData(Air.Bin, orig.data.pl_op.payload).data;
var res_elem: Result = .init(l, l.typeOf(extra.lhs).scalarType(zcu), &loop.block);
res_elem.block = .init(loop.block.stealCapacity(6));
{
var select_cond_br: CondBr = .init(l, res_elem.block.add(l, .{
.tag = .array_elem_val,
.data = .{ .bin_op = .{
.lhs = orig.data.pl_op.operand,
.rhs = cur_index_inst.toRef(),
} },
}).toRef(), &res_elem.block, .{});
select_cond_br.then_block = .init(res_elem.block.stealRemainingCapacity());
{
_ = select_cond_br.then_block.add(l, .{
.tag = .br,
.data = .{ .br = .{
.block_inst = res_elem.inst,
.operand = select_cond_br.then_block.add(l, .{
.tag = .array_elem_val,
.data = .{ .bin_op = .{
.lhs = extra.lhs,
.rhs = cur_index_inst.toRef(),
} },
}).toRef(),
} },
});
}
select_cond_br.else_block = .init(select_cond_br.then_block.stealRemainingCapacity());
{
_ = select_cond_br.else_block.add(l, .{
.tag = .br,
.data = .{ .br = .{
.block_inst = res_elem.inst,
.operand = select_cond_br.else_block.add(l, .{
.tag = .array_elem_val,
.data = .{ .bin_op = .{
.lhs = extra.rhs,
.rhs = cur_index_inst.toRef(),
} },
}).toRef(),
} },
});
}
try select_cond_br.finish(l);
}
try res_elem.finish(l);
break :res_elem res_elem.inst;
},
}.toRef(), }.toRef(),
}), }),
} }, } },
}); });
var loop_cond_br: CondBr = .init( var loop_cond_br: CondBr = .init(l, (try loop.block.addCmp(
l,
(try loop.block.addCmp(
l, l,
.lt, .lt,
cur_index_inst.toRef(), cur_index_inst.toRef(),
try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1), try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1),
.{}, .{},
)).toRef(), )).toRef(), &loop.block, .{});
&loop.block,
.{},
);
loop_cond_br.then_block = .init(loop.block.stealRemainingCapacity()); loop_cond_br.then_block = .init(loop.block.stealRemainingCapacity());
{ {
_ = loop_cond_br.then_block.add(l, .{ _ = loop_cond_br.then_block.add(l, .{
@ -1138,9 +1187,21 @@ const Block = struct {
/// This is useful when you've provided a buffer big enough for all your instructions, but you are /// This is useful when you've provided a buffer big enough for all your instructions, but you are
/// now starting a new block and some of them need to live there instead. /// now starting a new block and some of them need to live there instead.
fn stealRemainingCapacity(b: *Block) []Air.Inst.Index { fn stealRemainingCapacity(b: *Block) []Air.Inst.Index {
const remaining = b.instructions[b.len..]; return b.stealFrom(b.len);
b.instructions = b.instructions[0..b.len]; }
return remaining;
/// Returns `len` elements taken from the unused capacity of `b.instructions`, and shrinks
/// `b.instructions` down to not include them anymore.
/// This is useful when you've provided a buffer big enough for all your instructions, but you are
/// now starting a new block and some of them need to live there instead.
fn stealCapacity(b: *Block, len: usize) []Air.Inst.Index {
return b.stealFrom(b.instructions.len - len);
}
fn stealFrom(b: *Block, start: usize) []Air.Inst.Index {
assert(start >= b.len);
defer b.instructions.len = start;
return b.instructions[start..];
} }
fn body(b: *const Block) []const Air.Inst.Index { fn body(b: *const Block) []const Air.Inst.Index {
@ -1149,6 +1210,31 @@ const Block = struct {
} }
}; };
const Result = struct {
inst: Air.Inst.Index,
block: Block,
/// The return value has `block` initialized to `undefined`; it is the caller's reponsibility
/// to initialize it.
fn init(l: *Legalize, ty: Type, parent_block: *Block) Result {
return .{
.inst = parent_block.add(l, .{
.tag = .block,
.data = .{ .ty_pl = .{
.ty = Air.internedToRef(ty.toIntern()),
.payload = undefined,
} },
}),
.block = undefined,
};
}
fn finish(res: Result, l: *Legalize) Error!void {
const data = &l.air_instructions.items(.data)[@intFromEnum(res.inst)];
data.ty_pl.payload = try l.addBlockBody(res.block.body());
}
};
const Loop = struct { const Loop = struct {
inst: Air.Inst.Index, inst: Air.Inst.Index,
block: Block, block: Block,

View File

@ -2529,6 +2529,7 @@ pub fn destroy(comp: *Compilation) void {
pub fn clearMiscFailures(comp: *Compilation) void { pub fn clearMiscFailures(comp: *Compilation) void {
comp.alloc_failure_occurred = false; comp.alloc_failure_occurred = false;
comp.link_diags.flags = .{};
for (comp.misc_failures.values()) |*value| { for (comp.misc_failures.values()) |*value| {
value.deinit(comp.gpa); value.deinit(comp.gpa);
} }
@ -2795,7 +2796,6 @@ pub fn update(comp: *Compilation, main_progress_node: std.Progress.Node) !void {
if (anyErrors(comp)) { if (anyErrors(comp)) {
// Skip flushing and keep source files loaded for error reporting. // Skip flushing and keep source files loaded for error reporting.
comp.link_diags.flags = .{};
return; return;
} }

View File

@ -84,6 +84,7 @@ pub fn legalizeFeatures(target: *const std.Target) *const Air.Legalize.Features
.scalarize_int_from_float = use_old, .scalarize_int_from_float = use_old,
.scalarize_int_from_float_optimized = use_old, .scalarize_int_from_float_optimized = use_old,
.scalarize_float_from_int = use_old, .scalarize_float_from_int = use_old,
.scalarize_select = true,
.scalarize_mul_add = use_old, .scalarize_mul_add = use_old,
.unsplat_shift_rhs = false, .unsplat_shift_rhs = false,

View File

@ -41,8 +41,6 @@ test "@select arrays" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_x86_64 and
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) return error.SkipZigTest;
try comptime selectArrays(); try comptime selectArrays();
try selectArrays(); try selectArrays();
@ -70,7 +68,6 @@ fn selectArrays() !void {
test "@select compare result" { test "@select compare result" {
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .hexagon) return error.SkipZigTest; if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .hexagon) return error.SkipZigTest;
const S = struct { const S = struct {