diff --git a/doc/langref.html.in b/doc/langref.html.in index 9e98bb0114..a805f9a2ac 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -7125,6 +7125,7 @@ fn func(y: *i32) void { an integer or an enum.

{#header_close#} + {#header_open|@bitCast#}
{#syntax#}@bitCast(comptime DestType: type, value: anytype) DestType{#endsyntax#}

@@ -8177,6 +8178,15 @@ test "@wasmMemoryGrow" { a calling function, the returned address will apply to the calling function.

{#header_close#} + + {#header_open|@select#} +
{#syntax#}@select(comptime T: type, pred: std.meta.Vector(len, bool), a: std.meta.Vector(len, T), b: std.meta.Vector(len, T)) std.meta.Vector(len, T){#endsyntax#}
+

+ Selects values element-wise from {#syntax#}a{#endsyntax#} or {#syntax#}b{#endsyntax#} based on {#syntax#}pred{#endsyntax#}. If {#syntax#}pred[i]{#endsyntax#} is {#syntax#}true{#endsyntax#}, the corresponding element in the result will be {#syntax#}a[i]{#endsyntax#} and otherwise {#syntax#}b[i]{#endsyntax#}. +

+ {#see_also|SIMD|Vectors#} + {#header_close#} + {#header_open|@setAlignStack#}
{#syntax#}@setAlignStack(comptime alignment: u29){#endsyntax#}

diff --git a/src/AstGen.zig b/src/AstGen.zig index a00b43288b..656b960c2f 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2090,6 +2090,7 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: ast.Node.Index) Inner .splat, .reduce, .shuffle, + .select, .atomic_load, .atomic_rmw, .atomic_store, @@ -7375,6 +7376,15 @@ fn builtinCall( }); return rvalue(gz, rl, result, node); }, + .select => { + const result = try gz.addPlNode(.select, node, Zir.Inst.Select{ + .elem_type = try typeExpr(gz, scope, params[0]), + .pred = try expr(gz, scope, .none, params[1]), + .a = try expr(gz, scope, .none, params[2]), + .b = try expr(gz, scope, .none, params[3]), + }); + return rvalue(gz, rl, result, node); + }, .async_call => { const result = try gz.addPlNode(.builtin_async_call, node, Zir.Inst.AsyncCall{ .frame_buffer = try expr(gz, scope, .none, params[0]), diff --git a/src/BuiltinFn.zig b/src/BuiltinFn.zig index 3964512838..07371b3192 100644 --- a/src/BuiltinFn.zig +++ b/src/BuiltinFn.zig @@ -69,6 +69,7 @@ pub const Tag = enum { ptr_to_int, rem, return_address, + select, set_align_stack, set_cold, set_eval_branch_quota, @@ -601,6 +602,13 @@ pub const list = list: { .param_count = 0, }, }, + .{ + "@select", + .{ + .tag = .select, + .param_count = 4, + }, + }, .{ "@setAlignStack", .{ diff --git a/src/Sema.zig b/src/Sema.zig index 3df512596f..6f8975d086 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -338,6 +338,7 @@ pub fn analyzeBody( .splat => try sema.zirSplat(block, inst), .reduce => try sema.zirReduce(block, inst), .shuffle => try sema.zirShuffle(block, inst), + .select => try sema.zirSelect(block, inst), .atomic_load => try sema.zirAtomicLoad(block, inst), .atomic_rmw => try sema.zirAtomicRmw(block, inst), .atomic_store => try sema.zirAtomicStore(block, inst), @@ -6099,6 +6100,12 @@ fn zirShuffle(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileErr return sema.mod.fail(&block.base, src, "TODO: Sema.zirShuffle", .{}); } +fn zirSelect(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { + const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const src = inst_data.src(); + return sema.mod.fail(&block.base, src, "TODO: Sema.zirSelect", .{}); +} + fn zirAtomicLoad(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].pl_node; const src = inst_data.src(); diff --git a/src/Zir.zig b/src/Zir.zig index 00f66ef507..003b43d9e0 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -890,6 +890,9 @@ pub const Inst = struct { /// Implements the `@shuffle` builtin. /// Uses the `pl_node` union field with payload `Shuffle`. shuffle, + /// Implements the `@select` builtin. + /// Uses the `pl_node` union field with payload `Select`. + select, /// Implements the `@atomicLoad` builtin. /// Uses the `pl_node` union field with payload `Bin`. atomic_load, @@ -1181,6 +1184,7 @@ pub const Inst = struct { .splat, .reduce, .shuffle, + .select, .atomic_load, .atomic_rmw, .atomic_store, @@ -1451,6 +1455,7 @@ pub const Inst = struct { .splat = .pl_node, .reduce = .pl_node, .shuffle = .pl_node, + .select = .pl_node, .atomic_load = .pl_node, .atomic_rmw = .pl_node, .atomic_store = .pl_node, @@ -2725,6 +2730,13 @@ pub const Inst = struct { mask: Ref, }; + pub const Select = struct { + elem_type: Ref, + pred: Ref, + a: Ref, + b: Ref, + }; + pub const AsyncCall = struct { frame_buffer: Ref, result_ptr: Ref, @@ -2935,6 +2947,7 @@ const Writer = struct { .cmpxchg_strong, .cmpxchg_weak, .shuffle, + .select, .atomic_rmw, .atomic_store, .mul_add, diff --git a/src/stage1/all_types.hpp b/src/stage1/all_types.hpp index ee06c15cf0..b4b2740dec 100644 --- a/src/stage1/all_types.hpp +++ b/src/stage1/all_types.hpp @@ -1755,6 +1755,7 @@ enum BuiltinFnId { BuiltinFnIdIntToEnum, BuiltinFnIdVectorType, BuiltinFnIdShuffle, + BuiltinFnIdSelect, BuiltinFnIdSplat, BuiltinFnIdSetCold, BuiltinFnIdSetRuntimeSafety, @@ -2544,6 +2545,7 @@ enum Stage1ZirInstId : uint8_t { Stage1ZirInstIdBoolToInt, Stage1ZirInstIdVectorType, Stage1ZirInstIdShuffleVector, + Stage1ZirInstIdSelect, Stage1ZirInstIdSplat, Stage1ZirInstIdBoolNot, Stage1ZirInstIdMemset, @@ -2664,6 +2666,7 @@ enum Stage1AirInstId : uint8_t { Stage1AirInstIdReduce, Stage1AirInstIdTruncate, Stage1AirInstIdShuffleVector, + Stage1AirInstIdSelect, Stage1AirInstIdSplat, Stage1AirInstIdBoolNot, Stage1AirInstIdMemset, @@ -4295,6 +4298,23 @@ struct Stage1AirInstShuffleVector { Stage1AirInst *mask; // This is in zig-format, not llvm format }; +struct Stage1ZirInstSelect { + Stage1ZirInst base; + + Stage1ZirInst *scalar_type; + Stage1ZirInst *pred; // This is in zig-format, not llvm format + Stage1ZirInst *a; + Stage1ZirInst *b; +}; + +struct Stage1AirInstSelect { + Stage1AirInst base; + + Stage1AirInst *pred; // This is in zig-format, not llvm format + Stage1AirInst *a; + Stage1AirInst *b; +}; + struct Stage1ZirInstSplat { Stage1ZirInst base; diff --git a/src/stage1/astgen.cpp b/src/stage1/astgen.cpp index cc8ccefad7..e8f68c82bb 100644 --- a/src/stage1/astgen.cpp +++ b/src/stage1/astgen.cpp @@ -196,6 +196,8 @@ void destroy_instruction_src(Stage1ZirInst *inst) { return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1ZirInstIdShuffleVector: return heap::c_allocator.destroy(reinterpret_cast(inst)); + case Stage1ZirInstIdSelect: + return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1ZirInstIdSplat: return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1ZirInstIdBoolNot: @@ -651,6 +653,10 @@ static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstShuffleVector *) { return Stage1ZirInstIdShuffleVector; } +static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstSelect *) { + return Stage1ZirInstIdSelect; +} + static constexpr Stage1ZirInstId ir_inst_id(Stage1ZirInstSplat *) { return Stage1ZirInstIdSplat; } @@ -2037,6 +2043,22 @@ static Stage1ZirInst *ir_build_shuffle_vector(Stage1AstGen *ag, Scope *scope, As return &instruction->base; } +static Stage1ZirInst *ir_build_select(Stage1AstGen *ag, Scope *scope, AstNode *source_node, + Stage1ZirInst *scalar_type, Stage1ZirInst *pred, Stage1ZirInst *a, Stage1ZirInst *b) +{ + Stage1ZirInstSelect *instruction = ir_build_instruction(ag, scope, source_node); + instruction->scalar_type = scalar_type; + instruction->pred = pred; + instruction->a = a; + instruction->b = b; + + ir_ref_instruction(pred, ag->current_basic_block); + ir_ref_instruction(a, ag->current_basic_block); + ir_ref_instruction(b, ag->current_basic_block); + + return &instruction->base; +} + static Stage1ZirInst *ir_build_splat_src(Stage1AstGen *ag, Scope *scope, AstNode *source_node, Stage1ZirInst *len, Stage1ZirInst *scalar) { @@ -4619,6 +4641,35 @@ static Stage1ZirInst *astgen_builtin_fn_call(Stage1AstGen *ag, Scope *scope, Ast arg0_value, arg1_value, arg2_value, arg3_value); return ir_lval_wrap(ag, scope, shuffle_vector, lval, result_loc); } + case BuiltinFnIdSelect: + { + // Used for the type expr + Scope *comptime_scope = create_comptime_scope(ag->codegen, node, scope); + + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, comptime_scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg1_value; + + AstNode *arg2_node = node->data.fn_call_expr.params.at(2); + Stage1ZirInst *arg2_value = astgen_node(ag, arg2_node, scope); + if (arg1_value == ag->codegen->invalid_inst_src) + return arg2_value; + + AstNode *arg3_node = node->data.fn_call_expr.params.at(3); + Stage1ZirInst *arg3_value = astgen_node(ag, arg3_node, scope); + if (arg2_value == ag->codegen->invalid_inst_src) + return arg3_value; + + Stage1ZirInst *select = ir_build_select(ag, scope, node, + arg0_value, arg1_value, arg2_value, arg3_value); + return ir_lval_wrap(ag, scope, select, lval, result_loc); + } case BuiltinFnIdSplat: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index 562327d500..fc2651c8f7 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -5162,6 +5162,13 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, Stage1Air *executable, llvm_mask_value, ""); } +static LLVMValueRef ir_render_select(CodeGen *g, Stage1Air *executable, Stage1AirInstSelect *instruction) { + LLVMValueRef pred = ir_llvm_value(g, instruction->pred); + LLVMValueRef a = ir_llvm_value(g, instruction->a); + LLVMValueRef b = ir_llvm_value(g, instruction->b); + return LLVMBuildSelect(g->builder, pred, a, b, ""); +} + static LLVMValueRef ir_render_splat(CodeGen *g, Stage1Air *executable, Stage1AirInstSplat *instruction) { ZigType *result_type = instruction->base.value->type; ir_assert(result_type->id == ZigTypeIdVector, &instruction->base); @@ -7015,6 +7022,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, Stage1Air *executable, Sta return ir_render_spill_end(g, executable, (Stage1AirInstSpillEnd *)instruction); case Stage1AirInstIdShuffleVector: return ir_render_shuffle_vector(g, executable, (Stage1AirInstShuffleVector *) instruction); + case Stage1AirInstIdSelect: + return ir_render_select(g, executable, (Stage1AirInstSelect *) instruction); case Stage1AirInstIdSplat: return ir_render_splat(g, executable, (Stage1AirInstSplat *) instruction); case Stage1AirInstIdVectorExtractElem: @@ -8920,6 +8929,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdCompileLog, "compileLog", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdVectorType, "Vector", 2); create_builtin_fn(g, BuiltinFnIdShuffle, "shuffle", 4); + create_builtin_fn(g, BuiltinFnIdSelect, "select", 4); create_builtin_fn(g, BuiltinFnIdSplat, "splat", 2); create_builtin_fn(g, BuiltinFnIdSetCold, "setCold", 1); create_builtin_fn(g, BuiltinFnIdSetRuntimeSafety, "setRuntimeSafety", 1); diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index a6d1ff3007..9a2e50fef7 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -355,6 +355,8 @@ void destroy_instruction_gen(Stage1AirInst *inst) { return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1AirInstIdShuffleVector: return heap::c_allocator.destroy(reinterpret_cast(inst)); + case Stage1AirInstIdSelect: + return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1AirInstIdSplat: return heap::c_allocator.destroy(reinterpret_cast(inst)); case Stage1AirInstIdBoolNot: @@ -901,6 +903,10 @@ static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstShuffleVector *) { return Stage1AirInstIdShuffleVector; } +static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstSelect *) { + return Stage1AirInstIdSelect; +} + static constexpr Stage1AirInstId ir_inst_id(Stage1AirInstSplat *) { return Stage1AirInstIdSplat; } @@ -1756,6 +1762,22 @@ static Stage1AirInst *ir_build_shuffle_vector_gen(IrAnalyze *ira, Scope *scope, return &inst->base; } +static Stage1AirInst *ir_build_select_gen(IrAnalyze *ira, Scope *scope, AstNode *source_node, + ZigType *result_type, Stage1AirInst *pred, Stage1AirInst *a, Stage1AirInst *b) +{ + Stage1AirInstSelect *inst = ir_build_inst_gen(&ira->new_irb, scope, source_node); + inst->base.value->type = result_type; + inst->pred = pred; + inst->a = a; + inst->b = b; + + ir_ref_inst_gen(pred); + ir_ref_inst_gen(a); + ir_ref_inst_gen(b); + + return &inst->base; +} + static Stage1AirInst *ir_build_splat_gen(IrAnalyze *ira, Scope *scope, AstNode *source_node, ZigType *result_type, Stage1AirInst *scalar) { @@ -20318,6 +20340,100 @@ static Stage1AirInst *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, Stag return ir_analyze_shuffle_vector(ira, instruction->base.scope, instruction->base.source_node, scalar_type, a, b, mask); } +static Stage1AirInst *ir_analyze_instruction_select(IrAnalyze *ira, Stage1ZirInstSelect *instruction) { + Error err; + + ZigType *scalar_type = ir_resolve_vector_elem_type(ira, instruction->scalar_type->child); + if (type_is_invalid(scalar_type)) + return ira->codegen->invalid_inst_gen; + + if ((err = ir_validate_vector_elem_type(ira, instruction->base.source_node, scalar_type))) + return ira->codegen->invalid_inst_gen; + + Stage1AirInst *pred = instruction->pred->child; + if (type_is_invalid(pred->value->type)) + return ira->codegen->invalid_inst_gen; + + Stage1AirInst *a = instruction->a->child; + if (type_is_invalid(a->value->type)) + return ira->codegen->invalid_inst_gen; + + Stage1AirInst *b = instruction->b->child; + if (type_is_invalid(b->value->type)) + return ira->codegen->invalid_inst_gen; + + if (pred->value->type->id != ZigTypeIdVector) { + ir_add_error(ira, pred, + buf_sprintf("expected vector type, found '%s'", + buf_ptr(&pred->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + + uint32_t pred_len = pred->value->type->data.vector.len; + pred = ir_implicit_cast(ira, pred, get_vector_type(ira->codegen, pred_len, + ira->codegen->builtin_types.entry_bool)); + if (type_is_invalid(pred->value->type)) + return ira->codegen->invalid_inst_gen; + + if (a->value->type->id != ZigTypeIdVector) { + ir_add_error(ira, a, + buf_sprintf("expected vector type, found '%s'", + buf_ptr(&a->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + + if (b->value->type->id != ZigTypeIdVector) { + ir_add_error(ira, b, + buf_sprintf("expected vector type, found '%s'", + buf_ptr(&b->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + + ZigType *result_type = get_vector_type(ira->codegen, pred_len, scalar_type); + + a = ir_implicit_cast(ira, a, result_type); + if (type_is_invalid(a->value->type)) + return ira->codegen->invalid_inst_gen; + + b = ir_implicit_cast(ira, b, result_type); + if (type_is_invalid(a->value->type)) + return ira->codegen->invalid_inst_gen; + + if (instr_is_comptime(pred) && instr_is_comptime(a) && instr_is_comptime(b)) { + ZigValue *pred_val = ir_resolve_const(ira, pred, UndefBad); + if (pred_val == nullptr) + return ira->codegen->invalid_inst_gen; + + ZigValue *a_val = ir_resolve_const(ira, a, UndefBad); + if (a_val == nullptr) + return ira->codegen->invalid_inst_gen; + + ZigValue *b_val = ir_resolve_const(ira, b, UndefBad); + if (b_val == nullptr) + return ira->codegen->invalid_inst_gen; + + expand_undef_array(ira->codegen, a_val); + expand_undef_array(ira->codegen, b_val); + + Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, result_type); + result->value->data.x_array.data.s_none.elements = ira->codegen->pass1_arena->allocate(pred_len); + + for (uint64_t i = 0; i < pred_len; i += 1) { + ZigValue *dst_elem_val = &result->value->data.x_array.data.s_none.elements[i]; + ZigValue *pred_elem_val = &pred_val->data.x_array.data.s_none.elements[i]; + ZigValue *a_elem_val = &a_val->data.x_array.data.s_none.elements[i]; + ZigValue *b_elem_val = &b_val->data.x_array.data.s_none.elements[i]; + ZigValue *result_elem_val = pred_elem_val->data.x_bool ? a_elem_val : b_elem_val; + copy_const_val(ira->codegen, dst_elem_val, result_elem_val); + } + + result->value->special = ConstValSpecialStatic; + return result; + } + + return ir_build_select_gen(ira, instruction->base.scope, instruction->base.source_node, result_type, pred, a, b); +} + static Stage1AirInst *ir_analyze_instruction_splat(IrAnalyze *ira, Stage1ZirInstSplat *instruction) { Error err; @@ -24595,7 +24711,9 @@ static Stage1AirInst *ir_analyze_instruction_base(IrAnalyze *ira, Stage1ZirInst return ir_analyze_instruction_vector_type(ira, (Stage1ZirInstVectorType *)instruction); case Stage1ZirInstIdShuffleVector: return ir_analyze_instruction_shuffle_vector(ira, (Stage1ZirInstShuffleVector *)instruction); - case Stage1ZirInstIdSplat: + case Stage1ZirInstIdSelect: + return ir_analyze_instruction_select(ira, (Stage1ZirInstSelect *)instruction); + case Stage1ZirInstIdSplat: return ir_analyze_instruction_splat(ira, (Stage1ZirInstSplat *)instruction); case Stage1ZirInstIdBoolNot: return ir_analyze_instruction_bool_not(ira, (Stage1ZirInstBoolNot *)instruction); @@ -24931,6 +25049,7 @@ bool ir_inst_gen_has_side_effects(Stage1AirInst *instruction) { case Stage1AirInstIdUnionTag: case Stage1AirInstIdTruncate: case Stage1AirInstIdShuffleVector: + case Stage1AirInstIdSelect: case Stage1AirInstIdSplat: case Stage1AirInstIdBoolNot: case Stage1AirInstIdReturnAddress: @@ -25084,6 +25203,7 @@ bool ir_inst_src_has_side_effects(Stage1ZirInst *instruction) { case Stage1ZirInstIdTruncate: case Stage1ZirInstIdVectorType: case Stage1ZirInstIdShuffleVector: + case Stage1ZirInstIdSelect: case Stage1ZirInstIdSplat: case Stage1ZirInstIdBoolNot: case Stage1ZirInstIdSlice: @@ -25751,7 +25871,7 @@ static Error ir_resolve_lazy_recurse_array(AstNode *source_node, ZigValue *val, static Error ir_resolve_lazy_recurse(AstNode *source_node, ZigValue *val) { Error err; - if ((err = ir_resolve_lazy_raw(source_node, val))) + if ((err = ir_resolve_lazy_raw(source_node, val))) return err; assert(val->special != ConstValSpecialRuntime); assert(val->special != ConstValSpecialLazy); diff --git a/src/stage1/ir_print.cpp b/src/stage1/ir_print.cpp index bec9521bc2..e83da8565c 100644 --- a/src/stage1/ir_print.cpp +++ b/src/stage1/ir_print.cpp @@ -93,6 +93,8 @@ const char* ir_inst_src_type_str(Stage1ZirInstId id) { return "SrcInvalid"; case Stage1ZirInstIdShuffleVector: return "SrcShuffle"; + case Stage1ZirInstIdSelect: + return "SrcSelect"; case Stage1ZirInstIdSplat: return "SrcSplat"; case Stage1ZirInstIdDeclVar: @@ -379,6 +381,8 @@ const char* ir_inst_gen_type_str(Stage1AirInstId id) { return "GenInvalid"; case Stage1AirInstIdShuffleVector: return "GenShuffle"; + case Stage1AirInstIdSelect: + return "GenSelect"; case Stage1AirInstIdSplat: return "GenSplat"; case Stage1AirInstIdDeclVar: @@ -1722,6 +1726,28 @@ static void ir_print_shuffle_vector(IrPrintGen *irp, Stage1AirInstShuffleVector fprintf(irp->f, ")"); } +static void ir_print_select(IrPrintSrc *irp, Stage1ZirInstSelect *instruction) { + fprintf(irp->f, "@select("); + ir_print_other_inst_src(irp, instruction->scalar_type); + fprintf(irp->f, ", "); + ir_print_other_inst_src(irp, instruction->pred); + fprintf(irp->f, ", "); + ir_print_other_inst_src(irp, instruction->a); + fprintf(irp->f, ", "); + ir_print_other_inst_src(irp, instruction->b); + fprintf(irp->f, ")"); +} + +static void ir_print_select(IrPrintGen *irp, Stage1AirInstSelect *instruction) { + fprintf(irp->f, "@select("); + ir_print_other_inst_gen(irp, instruction->pred); + fprintf(irp->f, ", "); + ir_print_other_inst_gen(irp, instruction->a); + fprintf(irp->f, ", "); + ir_print_other_inst_gen(irp, instruction->b); + fprintf(irp->f, ")"); +} + static void ir_print_splat_src(IrPrintSrc *irp, Stage1ZirInstSplat *instruction) { fprintf(irp->f, "@splat("); ir_print_other_inst_src(irp, instruction->len); @@ -2836,6 +2862,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, Stage1ZirInst *instruction, bool case Stage1ZirInstIdShuffleVector: ir_print_shuffle_vector(irp, (Stage1ZirInstShuffleVector *)instruction); break; + case Stage1ZirInstIdSelect: + ir_print_select(irp, (Stage1ZirInstSelect *)instruction); + break; case Stage1ZirInstIdSplat: ir_print_splat_src(irp, (Stage1ZirInstSplat *)instruction); break; @@ -3178,6 +3207,9 @@ static void ir_print_inst_gen(IrPrintGen *irp, Stage1AirInst *instruction, bool case Stage1AirInstIdShuffleVector: ir_print_shuffle_vector(irp, (Stage1AirInstShuffleVector *)instruction); break; + case Stage1AirInstIdSelect: + ir_print_select(irp, (Stage1AirInstSelect *)instruction); + break; case Stage1AirInstIdSplat: ir_print_splat_gen(irp, (Stage1AirInstSplat *)instruction); break; diff --git a/test/behavior.zig b/test/behavior.zig index 0055638335..8459e499d7 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -118,6 +118,7 @@ test { _ = @import("behavior/ref_var_in_if_after_if_2nd_switch_prong.zig"); _ = @import("behavior/reflection.zig"); _ = @import("behavior/shuffle.zig"); + _ = @import("behavior/select.zig"); _ = @import("behavior/sizeof_and_typeof.zig"); _ = @import("behavior/slice.zig"); _ = @import("behavior/slice_sentinel_comptime.zig"); diff --git a/test/behavior/select.zig b/test/behavior/select.zig new file mode 100644 index 0000000000..5c69094413 --- /dev/null +++ b/test/behavior/select.zig @@ -0,0 +1,25 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const mem = std.mem; +const expect = std.testing.expect; +const Vector = std.meta.Vector; + +test "@select" { + const S = struct { + fn doTheTest() !void { + var a: Vector(4, bool) = [4]bool{ true, false, true, false }; + var b: Vector(4, i32) = [4]i32{ -1, 4, 999, -31 }; + var c: Vector(4, i32) = [4]i32{ -5, 1, 0, 1234 }; + var abc = @select(i32, a, b, c); + try expect(mem.eql(i32, &@as([4]i32, abc), &[4]i32{ -1, 1, 999, 1234 })); + + var x: Vector(4, bool) = [4]bool{ false, false, false, true }; + var y: Vector(4, f32) = [4]f32{ 0.001, 33.4, 836, -3381.233 }; + var z: Vector(4, f32) = [4]f32{ 0.0, 312.1, -145.9, 9993.55 }; + var xyz = @select(f32, x, y, z); + try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 })); + } + }; + try S.doTheTest(); + comptime try S.doTheTest(); +}