diff --git a/doc/langref.html.in b/doc/langref.html.in
index 9e98bb0114..a805f9a2ac 100644
--- a/doc/langref.html.in
+++ b/doc/langref.html.in
@@ -7125,6 +7125,7 @@ fn func(y: *i32) void {
an integer or an enum.
{#header_close#}
+
{#header_open|@bitCast#}
{#syntax#}@bitCast(comptime DestType: type, value: anytype) DestType{#endsyntax#}
@@ -8177,6 +8178,15 @@ test "@wasmMemoryGrow" {
a calling function, the returned address will apply to the calling function.
{#header_close#}
+
+ {#header_open|@select#}
+ {#syntax#}@select(comptime T: type, pred: std.meta.Vector(len, bool), a: std.meta.Vector(len, T), b: std.meta.Vector(len, T)) std.meta.Vector(len, T){#endsyntax#}
+
+ Selects values element-wise from {#syntax#}a{#endsyntax#} or {#syntax#}b{#endsyntax#} based on {#syntax#}pred{#endsyntax#}. If {#syntax#}pred[i]{#endsyntax#} is {#syntax#}true{#endsyntax#}, the corresponding element in the result will be {#syntax#}a[i]{#endsyntax#} and otherwise {#syntax#}b[i]{#endsyntax#}.
+
+ {#see_also|SIMD|Vectors#}
+ {#header_close#}
+
{#header_open|@setAlignStack#}
{#syntax#}@setAlignStack(comptime alignment: u29){#endsyntax#}
diff --git a/src/AstGen.zig b/src/AstGen.zig
index a00b43288b..656b960c2f 100644
--- a/src/AstGen.zig
+++ b/src/AstGen.zig
@@ -2090,6 +2090,7 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: ast.Node.Index) Inner
.splat,
.reduce,
.shuffle,
+ .select,
.atomic_load,
.atomic_rmw,
.atomic_store,
@@ -7375,6 +7376,15 @@ fn builtinCall(
});
return rvalue(gz, rl, result, node);
},
+ .select => {
+ const result = try gz.addPlNode(.select, node, Zir.Inst.Select{
+ .elem_type = try typeExpr(gz, scope, params[0]),
+ .pred = try expr(gz, scope, .none, params[1]),
+ .a = try expr(gz, scope, .none, params[2]),
+ .b = try expr(gz, scope, .none, params[3]),
+ });
+ return rvalue(gz, rl, result, node);
+ },
.async_call => {
const result = try gz.addPlNode(.builtin_async_call, node, Zir.Inst.AsyncCall{
.frame_buffer = try expr(gz, scope, .none, params[0]),
diff --git a/src/BuiltinFn.zig b/src/BuiltinFn.zig
index 3964512838..07371b3192 100644
--- a/src/BuiltinFn.zig
+++ b/src/BuiltinFn.zig
@@ -69,6 +69,7 @@ pub const Tag = enum {
ptr_to_int,
rem,
return_address,
+ select,
set_align_stack,
set_cold,
set_eval_branch_quota,
@@ -601,6 +602,13 @@ pub const list = list: {
.param_count = 0,
},
},
+ .{
+ "@select",
+ .{
+ .tag = .select,
+ .param_count = 4,
+ },
+ },
.{
"@setAlignStack",
.{
diff --git a/src/Sema.zig b/src/Sema.zig
index 3df512596f..6f8975d086 100644
--- a/src/Sema.zig
+++ b/src/Sema.zig
@@ -338,6 +338,7 @@ pub fn analyzeBody(
.splat => try sema.zirSplat(block, inst),
.reduce => try sema.zirReduce(block, inst),
.shuffle => try sema.zirShuffle(block, inst),
+ .select => try sema.zirSelect(block, inst),
.atomic_load => try sema.zirAtomicLoad(block, inst),
.atomic_rmw => try sema.zirAtomicRmw(block, inst),
.atomic_store => try sema.zirAtomicStore(block, inst),
@@ -6099,6 +6100,12 @@ fn zirShuffle(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileErr
return sema.mod.fail(&block.base, src, "TODO: Sema.zirShuffle", .{});
}
+fn zirSelect(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+ const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
+ const src = inst_data.src();
+ return sema.mod.fail(&block.base, src, "TODO: Sema.zirSelect", .{});
+}
+
fn zirAtomicLoad(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
const src = inst_data.src();
diff --git a/src/Zir.zig b/src/Zir.zig
index 00f66ef507..003b43d9e0 100644
--- a/src/Zir.zig
+++ b/src/Zir.zig
@@ -890,6 +890,9 @@ pub const Inst = struct {
/// Implements the `@shuffle` builtin.
/// Uses the `pl_node` union field with payload `Shuffle`.
shuffle,
+ /// Implements the `@select` builtin.
+ /// Uses the `pl_node` union field with payload `Select`.
+ select,
/// Implements the `@atomicLoad` builtin.
/// Uses the `pl_node` union field with payload `Bin`.
atomic_load,
@@ -1181,6 +1184,7 @@ pub const Inst = struct {
.splat,
.reduce,
.shuffle,
+ .select,
.atomic_load,
.atomic_rmw,
.atomic_store,
@@ -1451,6 +1455,7 @@ pub const Inst = struct {
.splat = .pl_node,
.reduce = .pl_node,
.shuffle = .pl_node,
+ .select = .pl_node,
.atomic_load = .pl_node,
.atomic_rmw = .pl_node,
.atomic_store = .pl_node,
@@ -2725,6 +2730,13 @@ pub const Inst = struct {
mask: Ref,
};
+ pub const Select = struct {
+ elem_type: Ref,
+ pred: Ref,
+ a: Ref,
+ b: Ref,
+ };
+
pub const AsyncCall = struct {
frame_buffer: Ref,
result_ptr: Ref,
@@ -2935,6 +2947,7 @@ const Writer = struct {
.cmpxchg_strong,
.cmpxchg_weak,
.shuffle,
+ .select,
.atomic_rmw,
.atomic_store,
.mul_add,
diff --git a/src/stage1/all_types.hpp b/src/stage1/all_types.hpp
index ee06c15cf0..b4b2740dec 100644
--- a/src/stage1/all_types.hpp
+++ b/src/stage1/all_types.hpp
@@ -1755,6 +1755,7 @@ enum BuiltinFnId {
BuiltinFnIdIntToEnum,
BuiltinFnIdVectorType,
BuiltinFnIdShuffle,
+ BuiltinFnIdSelect,
BuiltinFnIdSplat,
BuiltinFnIdSetCold,
BuiltinFnIdSetRuntimeSafety,
@@ -2544,6 +2545,7 @@ enum Stage1ZirInstId : uint8_t {
Stage1ZirInstIdBoolToInt,
Stage1ZirInstIdVectorType,
Stage1ZirInstIdShuffleVector,
+ Stage1ZirInstIdSelect,
Stage1ZirInstIdSplat,
Stage1ZirInstIdBoolNot,
Stage1ZirInstIdMemset,
@@ -2664,6 +2666,7 @@ enum Stage1AirInstId : uint8_t {
Stage1AirInstIdReduce,
Stage1AirInstIdTruncate,
Stage1AirInstIdShuffleVector,
+ Stage1AirInstIdSelect,
Stage1AirInstIdSplat,
Stage1AirInstIdBoolNot,
Stage1AirInstIdMemset,
@@ -4295,6 +4298,23 @@ struct Stage1AirInstShuffleVector {
Stage1AirInst *mask; // This is in zig-format, not llvm format
};
+struct Stage1ZirInstSelect {
+ Stage1ZirInst base;
+
+ Stage1ZirInst *scalar_type;
+ Stage1ZirInst *pred; // This is in zig-format, not llvm format
+ Stage1ZirInst *a;
+ Stage1ZirInst *b;
+};
+
+struct Stage1AirInstSelect {
+ Stage1AirInst base;
+
+ Stage1AirInst *pred; // This is in zig-format, not llvm format
+ Stage1AirInst *a;
+ Stage1AirInst *b;
+};
+
struct Stage1ZirInstSplat {
Stage1ZirInst base;
diff --git a/src/stage1/astgen.cpp b/src/stage1/astgen.cpp
index cc8ccefad7..e8f68c82bb 100644
--- a/src/stage1/astgen.cpp
+++ b/src/stage1/astgen.cpp
@@ -196,6 +196,8 @@ void destroy_instruction_src(Stage1ZirInst *inst) {
return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1ZirInstIdShuffleVector:
return heap::c_allocator.destroy(reinterpret_cast(inst));
+ case Stage1ZirInstIdSelect:
+ return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1ZirInstIdSplat:
return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1ZirInstIdBoolNot:
@@ -651,6 +653,10 @@ static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstShuffleVector *) {
return Stage1ZirInstIdShuffleVector;
}
+static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstSelect *) {
+ return Stage1ZirInstIdSelect;
+}
+
static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstSplat *) {
return Stage1ZirInstIdSplat;
}
@@ -2037,6 +2043,22 @@ static Stage1ZirInst *ir_build_shuffle_vector(Stage1AstGen *ag, Scope *scope, As
return &instruction->base;
}
+static Stage1ZirInst *ir_build_select(Stage1AstGen *ag, Scope *scope, AstNode *source_node,
+ Stage1ZirInst *scalar_type, Stage1ZirInst *pred, Stage1ZirInst *a, Stage1ZirInst *b)
+{
+ Stage1ZirInstSelect *instruction = ir_build_instruction(ag, scope, source_node);
+ instruction->scalar_type = scalar_type;
+ instruction->pred = pred;
+ instruction->a = a;
+ instruction->b = b;
+
+ ir_ref_instruction(pred, ag->current_basic_block);
+ ir_ref_instruction(a, ag->current_basic_block);
+ ir_ref_instruction(b, ag->current_basic_block);
+
+ return &instruction->base;
+}
+
static Stage1ZirInst *ir_build_splat_src(Stage1AstGen *ag, Scope *scope, AstNode *source_node,
Stage1ZirInst *len, Stage1ZirInst *scalar)
{
@@ -4619,6 +4641,35 @@ static Stage1ZirInst *astgen_builtin_fn_call(Stage1AstGen *ag, Scope *scope, Ast
arg0_value, arg1_value, arg2_value, arg3_value);
return ir_lval_wrap(ag, scope, shuffle_vector, lval, result_loc);
}
+ case BuiltinFnIdSelect:
+ {
+ // Used for the type expr
+ Scope *comptime_scope = create_comptime_scope(ag->codegen, node, scope);
+
+ AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+ Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, comptime_scope);
+ if (arg0_value == ag->codegen->invalid_inst_src)
+ return arg0_value;
+
+ AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+ Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope);
+ if (arg0_value == ag->codegen->invalid_inst_src)
+ return arg1_value;
+
+ AstNode *arg2_node = node->data.fn_call_expr.params.at(2);
+ Stage1ZirInst *arg2_value = astgen_node(ag, arg2_node, scope);
+ if (arg1_value == ag->codegen->invalid_inst_src)
+ return arg2_value;
+
+ AstNode *arg3_node = node->data.fn_call_expr.params.at(3);
+ Stage1ZirInst *arg3_value = astgen_node(ag, arg3_node, scope);
+ if (arg2_value == ag->codegen->invalid_inst_src)
+ return arg3_value;
+
+ Stage1ZirInst *select = ir_build_select(ag, scope, node,
+ arg0_value, arg1_value, arg2_value, arg3_value);
+ return ir_lval_wrap(ag, scope, select, lval, result_loc);
+ }
case BuiltinFnIdSplat:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp
index 562327d500..fc2651c8f7 100644
--- a/src/stage1/codegen.cpp
+++ b/src/stage1/codegen.cpp
@@ -5162,6 +5162,13 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, Stage1Air *executable,
llvm_mask_value, "");
}
+static LLVMValueRef ir_render_select(CodeGen *g, Stage1Air *executable, Stage1AirInstSelect *instruction) {
+ LLVMValueRef pred = ir_llvm_value(g, instruction->pred);
+ LLVMValueRef a = ir_llvm_value(g, instruction->a);
+ LLVMValueRef b = ir_llvm_value(g, instruction->b);
+ return LLVMBuildSelect(g->builder, pred, a, b, "");
+}
+
static LLVMValueRef ir_render_splat(CodeGen *g, Stage1Air *executable, Stage1AirInstSplat *instruction) {
ZigType *result_type = instruction->base.value->type;
ir_assert(result_type->id == ZigTypeIdVector, &instruction->base);
@@ -7015,6 +7022,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, Stage1Air *executable, Sta
return ir_render_spill_end(g, executable, (Stage1AirInstSpillEnd *)instruction);
case Stage1AirInstIdShuffleVector:
return ir_render_shuffle_vector(g, executable, (Stage1AirInstShuffleVector *) instruction);
+ case Stage1AirInstIdSelect:
+ return ir_render_select(g, executable, (Stage1AirInstSelect *) instruction);
case Stage1AirInstIdSplat:
return ir_render_splat(g, executable, (Stage1AirInstSplat *) instruction);
case Stage1AirInstIdVectorExtractElem:
@@ -8920,6 +8929,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdCompileLog, "compileLog", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdVectorType, "Vector", 2);
create_builtin_fn(g, BuiltinFnIdShuffle, "shuffle", 4);
+ create_builtin_fn(g, BuiltinFnIdSelect, "select", 4);
create_builtin_fn(g, BuiltinFnIdSplat, "splat", 2);
create_builtin_fn(g, BuiltinFnIdSetCold, "setCold", 1);
create_builtin_fn(g, BuiltinFnIdSetRuntimeSafety, "setRuntimeSafety", 1);
diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp
index a6d1ff3007..9a2e50fef7 100644
--- a/src/stage1/ir.cpp
+++ b/src/stage1/ir.cpp
@@ -355,6 +355,8 @@ void destroy_instruction_gen(Stage1AirInst *inst) {
return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1AirInstIdShuffleVector:
return heap::c_allocator.destroy(reinterpret_cast(inst));
+ case Stage1AirInstIdSelect:
+ return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1AirInstIdSplat:
return heap::c_allocator.destroy(reinterpret_cast(inst));
case Stage1AirInstIdBoolNot:
@@ -901,6 +903,10 @@ static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstShuffleVector *) {
return Stage1AirInstIdShuffleVector;
}
+static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstSelect *) {
+ return Stage1AirInstIdSelect;
+}
+
static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstSplat *) {
return Stage1AirInstIdSplat;
}
@@ -1756,6 +1762,22 @@ static Stage1AirInst *ir_build_shuffle_vector_gen(IrAnalyze *ira, Scope *scope,
return &inst->base;
}
+static Stage1AirInst *ir_build_select_gen(IrAnalyze *ira, Scope *scope, AstNode *source_node,
+ ZigType *result_type, Stage1AirInst *pred, Stage1AirInst *a, Stage1AirInst *b)
+{
+ Stage1AirInstSelect *inst = ir_build_inst_gen(&ira->new_irb, scope, source_node);
+ inst->base.value->type = result_type;
+ inst->pred = pred;
+ inst->a = a;
+ inst->b = b;
+
+ ir_ref_inst_gen(pred);
+ ir_ref_inst_gen(a);
+ ir_ref_inst_gen(b);
+
+ return &inst->base;
+}
+
static Stage1AirInst *ir_build_splat_gen(IrAnalyze *ira, Scope *scope, AstNode *source_node, ZigType *result_type,
Stage1AirInst *scalar)
{
@@ -20318,6 +20340,100 @@ static Stage1AirInst *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, Stag
return ir_analyze_shuffle_vector(ira, instruction->base.scope, instruction->base.source_node, scalar_type, a, b, mask);
}
+static Stage1AirInst *ir_analyze_instruction_select(IrAnalyze *ira, Stage1ZirInstSelect *instruction) {
+ Error err;
+
+ ZigType *scalar_type = ir_resolve_vector_elem_type(ira, instruction->scalar_type->child);
+ if (type_is_invalid(scalar_type))
+ return ira->codegen->invalid_inst_gen;
+
+ if ((err = ir_validate_vector_elem_type(ira, instruction->base.source_node, scalar_type)))
+ return ira->codegen->invalid_inst_gen;
+
+ Stage1AirInst *pred = instruction->pred->child;
+ if (type_is_invalid(pred->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ Stage1AirInst *a = instruction->a->child;
+ if (type_is_invalid(a->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ Stage1AirInst *b = instruction->b->child;
+ if (type_is_invalid(b->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ if (pred->value->type->id != ZigTypeIdVector) {
+ ir_add_error(ira, pred,
+ buf_sprintf("expected vector type, found '%s'",
+ buf_ptr(&pred->value->type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ uint32_t pred_len = pred->value->type->data.vector.len;
+ pred = ir_implicit_cast(ira, pred, get_vector_type(ira->codegen, pred_len,
+ ira->codegen->builtin_types.entry_bool));
+ if (type_is_invalid(pred->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ if (a->value->type->id != ZigTypeIdVector) {
+ ir_add_error(ira, a,
+ buf_sprintf("expected vector type, found '%s'",
+ buf_ptr(&a->value->type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ if (b->value->type->id != ZigTypeIdVector) {
+ ir_add_error(ira, b,
+ buf_sprintf("expected vector type, found '%s'",
+ buf_ptr(&b->value->type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ ZigType *result_type = get_vector_type(ira->codegen, pred_len, scalar_type);
+
+ a = ir_implicit_cast(ira, a, result_type);
+ if (type_is_invalid(a->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ b = ir_implicit_cast(ira, b, result_type);
+ if (type_is_invalid(a->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ if (instr_is_comptime(pred) && instr_is_comptime(a) && instr_is_comptime(b)) {
+ ZigValue *pred_val = ir_resolve_const(ira, pred, UndefBad);
+ if (pred_val == nullptr)
+ return ira->codegen->invalid_inst_gen;
+
+ ZigValue *a_val = ir_resolve_const(ira, a, UndefBad);
+ if (a_val == nullptr)
+ return ira->codegen->invalid_inst_gen;
+
+ ZigValue *b_val = ir_resolve_const(ira, b, UndefBad);
+ if (b_val == nullptr)
+ return ira->codegen->invalid_inst_gen;
+
+ expand_undef_array(ira->codegen, a_val);
+ expand_undef_array(ira->codegen, b_val);
+
+ Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, result_type);
+ result->value->data.x_array.data.s_none.elements = ira->codegen->pass1_arena->allocate(pred_len);
+
+ for (uint64_t i = 0; i < pred_len; i += 1) {
+ ZigValue *dst_elem_val = &result->value->data.x_array.data.s_none.elements[i];
+ ZigValue *pred_elem_val = &pred_val->data.x_array.data.s_none.elements[i];
+ ZigValue *a_elem_val = &a_val->data.x_array.data.s_none.elements[i];
+ ZigValue *b_elem_val = &b_val->data.x_array.data.s_none.elements[i];
+ ZigValue *result_elem_val = pred_elem_val->data.x_bool ? a_elem_val : b_elem_val;
+ copy_const_val(ira->codegen, dst_elem_val, result_elem_val);
+ }
+
+ result->value->special = ConstValSpecialStatic;
+ return result;
+ }
+
+ return ir_build_select_gen(ira, instruction->base.scope, instruction->base.source_node, result_type, pred, a, b);
+}
+
static Stage1AirInst *ir_analyze_instruction_splat(IrAnalyze *ira, Stage1ZirInstSplat *instruction) {
Error err;
@@ -24595,7 +24711,9 @@ static Stage1AirInst *ir_analyze_instruction_base(IrAnalyze *ira, Stage1ZirInst
return ir_analyze_instruction_vector_type(ira, (Stage1ZirInstVectorType *)instruction);
case Stage1ZirInstIdShuffleVector:
return ir_analyze_instruction_shuffle_vector(ira, (Stage1ZirInstShuffleVector *)instruction);
- case Stage1ZirInstIdSplat:
+ case Stage1ZirInstIdSelect:
+ return ir_analyze_instruction_select(ira, (Stage1ZirInstSelect *)instruction);
+ case Stage1ZirInstIdSplat:
return ir_analyze_instruction_splat(ira, (Stage1ZirInstSplat *)instruction);
case Stage1ZirInstIdBoolNot:
return ir_analyze_instruction_bool_not(ira, (Stage1ZirInstBoolNot *)instruction);
@@ -24931,6 +25049,7 @@ bool ir_inst_gen_has_side_effects(Stage1AirInst *instruction) {
case Stage1AirInstIdUnionTag:
case Stage1AirInstIdTruncate:
case Stage1AirInstIdShuffleVector:
+ case Stage1AirInstIdSelect:
case Stage1AirInstIdSplat:
case Stage1AirInstIdBoolNot:
case Stage1AirInstIdReturnAddress:
@@ -25084,6 +25203,7 @@ bool ir_inst_src_has_side_effects(Stage1ZirInst *instruction) {
case Stage1ZirInstIdTruncate:
case Stage1ZirInstIdVectorType:
case Stage1ZirInstIdShuffleVector:
+ case Stage1ZirInstIdSelect:
case Stage1ZirInstIdSplat:
case Stage1ZirInstIdBoolNot:
case Stage1ZirInstIdSlice:
@@ -25751,7 +25871,7 @@ static Error ir_resolve_lazy_recurse_array(AstNode *source_node, ZigValue *val,
static Error ir_resolve_lazy_recurse(AstNode *source_node, ZigValue *val) {
Error err;
- if ((err = ir_resolve_lazy_raw(source_node, val)))
+ if ((err = ir_resolve_lazy_raw(source_node, val)))
return err;
assert(val->special != ConstValSpecialRuntime);
assert(val->special != ConstValSpecialLazy);
diff --git a/src/stage1/ir_print.cpp b/src/stage1/ir_print.cpp
index bec9521bc2..e83da8565c 100644
--- a/src/stage1/ir_print.cpp
+++ b/src/stage1/ir_print.cpp
@@ -93,6 +93,8 @@ const char* ir_inst_src_type_str(Stage1ZirInstId id) {
return "SrcInvalid";
case Stage1ZirInstIdShuffleVector:
return "SrcShuffle";
+ case Stage1ZirInstIdSelect:
+ return "SrcSelect";
case Stage1ZirInstIdSplat:
return "SrcSplat";
case Stage1ZirInstIdDeclVar:
@@ -379,6 +381,8 @@ const char* ir_inst_gen_type_str(Stage1AirInstId id) {
return "GenInvalid";
case Stage1AirInstIdShuffleVector:
return "GenShuffle";
+ case Stage1AirInstIdSelect:
+ return "GenSelect";
case Stage1AirInstIdSplat:
return "GenSplat";
case Stage1AirInstIdDeclVar:
@@ -1722,6 +1726,28 @@ static void ir_print_shuffle_vector(IrPrintGen *irp, Stage1AirInstShuffleVector
fprintf(irp->f, ")");
}
+static void ir_print_select(IrPrintSrc *irp, Stage1ZirInstSelect *instruction) {
+ fprintf(irp->f, "@select(");
+ ir_print_other_inst_src(irp, instruction->scalar_type);
+ fprintf(irp->f, ", ");
+ ir_print_other_inst_src(irp, instruction->pred);
+ fprintf(irp->f, ", ");
+ ir_print_other_inst_src(irp, instruction->a);
+ fprintf(irp->f, ", ");
+ ir_print_other_inst_src(irp, instruction->b);
+ fprintf(irp->f, ")");
+}
+
+static void ir_print_select(IrPrintGen *irp, Stage1AirInstSelect *instruction) {
+ fprintf(irp->f, "@select(");
+ ir_print_other_inst_gen(irp, instruction->pred);
+ fprintf(irp->f, ", ");
+ ir_print_other_inst_gen(irp, instruction->a);
+ fprintf(irp->f, ", ");
+ ir_print_other_inst_gen(irp, instruction->b);
+ fprintf(irp->f, ")");
+}
+
static void ir_print_splat_src(IrPrintSrc *irp, Stage1ZirInstSplat *instruction) {
fprintf(irp->f, "@splat(");
ir_print_other_inst_src(irp, instruction->len);
@@ -2836,6 +2862,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, Stage1ZirInst *instruction, bool
case Stage1ZirInstIdShuffleVector:
ir_print_shuffle_vector(irp, (Stage1ZirInstShuffleVector *)instruction);
break;
+ case Stage1ZirInstIdSelect:
+ ir_print_select(irp, (Stage1ZirInstSelect *)instruction);
+ break;
case Stage1ZirInstIdSplat:
ir_print_splat_src(irp, (Stage1ZirInstSplat *)instruction);
break;
@@ -3178,6 +3207,9 @@ static void ir_print_inst_gen(IrPrintGen *irp, Stage1AirInst *instruction, bool
case Stage1AirInstIdShuffleVector:
ir_print_shuffle_vector(irp, (Stage1AirInstShuffleVector *)instruction);
break;
+ case Stage1AirInstIdSelect:
+ ir_print_select(irp, (Stage1AirInstSelect *)instruction);
+ break;
case Stage1AirInstIdSplat:
ir_print_splat_gen(irp, (Stage1AirInstSplat *)instruction);
break;
diff --git a/test/behavior.zig b/test/behavior.zig
index 0055638335..8459e499d7 100644
--- a/test/behavior.zig
+++ b/test/behavior.zig
@@ -118,6 +118,7 @@ test {
_ = @import("behavior/ref_var_in_if_after_if_2nd_switch_prong.zig");
_ = @import("behavior/reflection.zig");
_ = @import("behavior/shuffle.zig");
+ _ = @import("behavior/select.zig");
_ = @import("behavior/sizeof_and_typeof.zig");
_ = @import("behavior/slice.zig");
_ = @import("behavior/slice_sentinel_comptime.zig");
diff --git a/test/behavior/select.zig b/test/behavior/select.zig
new file mode 100644
index 0000000000..5c69094413
--- /dev/null
+++ b/test/behavior/select.zig
@@ -0,0 +1,25 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const mem = std.mem;
+const expect = std.testing.expect;
+const Vector = std.meta.Vector;
+
+test "@select" {
+ const S = struct {
+ fn doTheTest() !void {
+ var a: Vector(4, bool) = [4]bool{ true, false, true, false };
+ var b: Vector(4, i32) = [4]i32{ -1, 4, 999, -31 };
+ var c: Vector(4, i32) = [4]i32{ -5, 1, 0, 1234 };
+ var abc = @select(i32, a, b, c);
+ try expect(mem.eql(i32, &@as([4]i32, abc), &[4]i32{ -1, 1, 999, 1234 }));
+
+ var x: Vector(4, bool) = [4]bool{ false, false, false, true };
+ var y: Vector(4, f32) = [4]f32{ 0.001, 33.4, 836, -3381.233 };
+ var z: Vector(4, f32) = [4]f32{ 0.0, 312.1, -145.9, 9993.55 };
+ var xyz = @select(f32, x, y, z);
+ try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
+ }
+ };
+ try S.doTheTest();
+ comptime try S.doTheTest();
+}