diff --git a/src/all_types.hpp b/src/all_types.hpp index 7887c06158..464a1d6ba4 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1612,6 +1612,7 @@ enum BuiltinFnId { BuiltinFnIdIntType, BuiltinFnIdVectorType, BuiltinFnIdShuffle, + BuiltinFnIdSplat, BuiltinFnIdSetCold, BuiltinFnIdSetRuntimeSafety, BuiltinFnIdSetFloatMode, @@ -2431,6 +2432,7 @@ enum IrInstructionId { IrInstructionIdIntType, IrInstructionIdVectorType, IrInstructionIdShuffleVector, + IrInstructionIdSplat, IrInstructionIdBoolNot, IrInstructionIdMemset, IrInstructionIdMemcpy, @@ -3681,6 +3683,13 @@ struct IrInstructionShuffleVector { IrInstruction *mask; // This is in zig-format, not llvm format }; +struct IrInstructionSplat { + IrInstruction base; + + IrInstruction *len; + IrInstruction *scalar; +}; + struct IrInstructionAssertZero { IrInstruction base; diff --git a/src/codegen.cpp b/src/codegen.cpp index 54c02b288a..49681c20c1 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4619,6 +4619,20 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, IrExecutable *executabl llvm_mask_value, ""); } +static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplat *instruction) { + uint64_t len = bigint_as_u64(&instruction->len->value.data.x_bigint); + LLVMValueRef wrapped_scalar_undef = LLVMGetUndef(instruction->base.value.type->llvm_type); + LLVMValueRef wrapped_scalar = LLVMBuildInsertElement(g->builder, wrapped_scalar_undef, + ir_llvm_value(g, instruction->scalar), + LLVMConstInt(LLVMInt32Type(), 0, false), + ""); + return LLVMBuildShuffleVector(g->builder, + wrapped_scalar, + wrapped_scalar_undef, + LLVMConstNull(LLVMVectorType(g->builtin_types.entry_u32->llvm_type, (uint32_t)len)), + ""); +} + static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, IrInstructionPopCount *instruction) { ZigType *int_type = instruction->op->value.type; LLVMValueRef fn_val = get_int_builtin_fn(g, int_type, BuiltinFnIdPopCount); @@ -6146,6 +6160,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction); case IrInstructionIdShuffleVector: return ir_render_shuffle_vector(g, executable, (IrInstructionShuffleVector *) instruction); + case IrInstructionIdSplat: + return ir_render_splat(g, executable, (IrInstructionSplat *) instruction); } zig_unreachable(); } @@ -7837,6 +7853,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdIntType, "IntType", 2); // TODO rename to Int create_builtin_fn(g, BuiltinFnIdVectorType, "Vector", 2); create_builtin_fn(g, BuiltinFnIdShuffle, "shuffle", 4); + create_builtin_fn(g, BuiltinFnIdSplat, "splat", 2); create_builtin_fn(g, BuiltinFnIdSetCold, "setCold", 1); create_builtin_fn(g, BuiltinFnIdSetRuntimeSafety, "setRuntimeSafety", 1); create_builtin_fn(g, BuiltinFnIdSetFloatMode, "setFloatMode", 1); diff --git a/src/ir.cpp b/src/ir.cpp index 1eba53ef45..8fca50c6f7 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -721,6 +721,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionShuffleVector *) return IrInstructionIdShuffleVector; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionSplat *) { + return IrInstructionIdSplat; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionBoolNot *) { return IrInstructionIdBoolNot; } @@ -2300,6 +2304,19 @@ static IrInstruction *ir_build_shuffle_vector(IrBuilder *irb, Scope *scope, AstN return &instruction->base; } +static IrInstruction *ir_build_splat(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *len, IrInstruction *scalar) +{ + IrInstructionSplat *instruction = ir_build_instruction(irb, scope, source_node); + instruction->len = len; + instruction->scalar = scalar; + + ir_ref_instruction(len, irb->current_basic_block); + ir_ref_instruction(scalar, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_bool_not(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) { IrInstructionBoolNot *instruction = ir_build_instruction(irb, scope, source_node); instruction->value = value; @@ -4985,6 +5002,22 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo arg0_value, arg1_value, arg2_value, arg3_value); return ir_lval_wrap(irb, scope, shuffle_vector, lval, result_loc); } + case BuiltinFnIdSplat: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + IrInstruction *splat = ir_build_splat(irb, scope, node, + arg0_value, arg1_value); + return ir_lval_wrap(irb, scope, splat, lval, result_loc); + } case BuiltinFnIdMemcpy: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -22324,6 +22357,52 @@ static IrInstruction *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, IrIn return ir_analyze_shuffle_vector(ira, &instruction->base, scalar_type, a, b, mask); } +static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplat *instruction) { + IrInstruction *len = instruction->len->child; + if (type_is_invalid(len->value.type)) + return ira->codegen->invalid_instruction; + + IrInstruction *scalar = instruction->scalar->child; + if (type_is_invalid(scalar->value.type)) + return ira->codegen->invalid_instruction; + + uint64_t len_int; + if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_int)) { + ir_add_error(ira, len, + buf_sprintf("splat length must be comptime")); + return ira->codegen->invalid_instruction; + } + + if (!is_valid_vector_elem_type(scalar->value.type)) { + ir_add_error(ira, len, + buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid", + buf_ptr(&scalar->value.type->name))); + return ira->codegen->invalid_instruction; + } + + ZigType *return_type = get_vector_type(ira->codegen, len_int, scalar->value.type); + + if (instr_is_comptime(scalar)) { + IrInstruction *result = ir_const_undef(ira, scalar, return_type); + result->value.data.x_array.data.s_none.elements = + allocate(len_int); + for (uint32_t i = 0; i < len_int; i++) { + result->value.data.x_array.data.s_none.elements[i] = + scalar->value; + } + result->value.type = return_type; + result->value.special = ConstValSpecialStatic; + return result; + } + + IrInstruction *result = ir_build_splat(&ira->new_irb, + instruction->base.scope, instruction->base.source_node, + instruction->len->child, instruction->scalar->child); + result->value.type = return_type; + result->value.special = ConstValSpecialRuntime; + return result; +} + static IrInstruction *ir_analyze_instruction_bool_not(IrAnalyze *ira, IrInstructionBoolNot *instruction) { IrInstruction *value = instruction->value->child; if (type_is_invalid(value->value.type)) @@ -25908,6 +25987,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_vector_type(ira, (IrInstructionVectorType *)instruction); case IrInstructionIdShuffleVector: return ir_analyze_instruction_shuffle_vector(ira, (IrInstructionShuffleVector *)instruction); + case IrInstructionIdSplat: + return ir_analyze_instruction_splat(ira, (IrInstructionSplat *)instruction); case IrInstructionIdBoolNot: return ir_analyze_instruction_bool_not(ira, (IrInstructionBoolNot *)instruction); case IrInstructionIdMemset: @@ -26244,6 +26325,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdIntType: case IrInstructionIdVectorType: case IrInstructionIdShuffleVector: + case IrInstructionIdSplat: case IrInstructionIdBoolNot: case IrInstructionIdSliceSrc: case IrInstructionIdMemberCount: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 8561ed4508..0dee7d342a 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -44,6 +44,8 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) { return "Invalid"; case IrInstructionIdShuffleVector: return "Shuffle"; + case IrInstructionIdSplat: + return "Splat"; case IrInstructionIdDeclVarSrc: return "DeclVarSrc"; case IrInstructionIdDeclVarGen: @@ -1222,6 +1224,14 @@ static void ir_print_shuffle_vector(IrPrint *irp, IrInstructionShuffleVector *in fprintf(irp->f, ")"); } +static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) { + fprintf(irp->f, "@splat("); + ir_print_other_instruction(irp, instruction->len); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->scalar); + fprintf(irp->f, ")"); +} + static void ir_print_bool_not(IrPrint *irp, IrInstructionBoolNot *instruction) { fprintf(irp->f, "! "); ir_print_other_instruction(irp, instruction->value); @@ -2160,6 +2170,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool case IrInstructionIdShuffleVector: ir_print_shuffle_vector(irp, (IrInstructionShuffleVector *)instruction); break; + case IrInstructionIdSplat: + ir_print_splat(irp, (IrInstructionSplat *)instruction); + break; case IrInstructionIdBoolNot: ir_print_bool_not(irp, (IrInstructionBoolNot *)instruction); break; diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 1fe3fc58ab..2909bffc3b 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -6507,6 +6507,16 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { "tmp.zig:2:26: error: vector element type must be integer, float, bool, or pointer; '@Vector(4, u8)' is invalid", ); + cases.addTest( + "bad @splat type", + \\export fn entry() void { + \\ const c = 4; + \\ var v = @splat(4, c); + \\} + , + "tmp.zig:3:20: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid", + ); + cases.add("compileLog of tagged enum doesn't crash the compiler", \\const Bar = union(enum(u32)) { \\ X: i32 = 1 diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index 27277b5e52..88a332d87b 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -35,12 +35,12 @@ test "vector bin compares with mem.eql" { fn doTheTest() void { var v: @Vector(4, i32) = [4]i32{ 2147483647, -2, 30, 40 }; var x: @Vector(4, i32) = [4]i32{ 1, 2147483647, 30, 4 }; - expect(mem.eql(bool, ([4]bool)(v == x), [4]bool{ false, false, true, false})); - expect(mem.eql(bool, ([4]bool)(v != x), [4]bool{ true, true, false, true})); - expect(mem.eql(bool, ([4]bool)(v < x), [4]bool{ false, true, false, false})); - expect(mem.eql(bool, ([4]bool)(v > x), [4]bool{ true, false, false, true})); - expect(mem.eql(bool, ([4]bool)(v <= x), [4]bool{ false, true, true, false})); - expect(mem.eql(bool, ([4]bool)(v >= x), [4]bool{ true, false, true, true})); + expect(mem.eql(bool, ([4]bool)(v == x), [4]bool{ false, false, true, false })); + expect(mem.eql(bool, ([4]bool)(v != x), [4]bool{ true, true, false, true })); + expect(mem.eql(bool, ([4]bool)(v < x), [4]bool{ false, true, false, false })); + expect(mem.eql(bool, ([4]bool)(v > x), [4]bool{ true, false, false, true })); + expect(mem.eql(bool, ([4]bool)(v <= x), [4]bool{ false, true, true, false })); + expect(mem.eql(bool, ([4]bool)(v >= x), [4]bool{ true, false, true, true })); } }; S.doTheTest(); @@ -114,22 +114,22 @@ test "vector casts of sizes not divisable by 8" { const S = struct { fn doTheTest() void { { - var v: @Vector(4, u3) = [4]u3{ 5, 2, 3, 0}; + var v: @Vector(4, u3) = [4]u3{ 5, 2, 3, 0 }; var x: [4]u3 = v; expect(mem.eql(u3, x, ([4]u3)(v))); } { - var v: @Vector(4, u2) = [4]u2{ 1, 2, 3, 0}; + var v: @Vector(4, u2) = [4]u2{ 1, 2, 3, 0 }; var x: [4]u2 = v; expect(mem.eql(u2, x, ([4]u2)(v))); } { - var v: @Vector(4, u1) = [4]u1{ 1, 0, 1, 0}; + var v: @Vector(4, u1) = [4]u1{ 1, 0, 1, 0 }; var x: [4]u1 = v; expect(mem.eql(u1, x, ([4]u1)(v))); } { - var v: @Vector(4, bool) = [4]bool{ false, false, true, false}; + var v: @Vector(4, bool) = [4]bool{ false, false, true, false }; var x: [4]bool = v; expect(mem.eql(bool, x, ([4]bool)(v))); } @@ -138,3 +138,19 @@ test "vector casts of sizes not divisable by 8" { S.doTheTest(); comptime S.doTheTest(); } + +test "vector @splat" { + const S = struct { + fn doTheTest() void { + var v: u32 = 5; + var x = @splat(4, v); + expect(@typeOf(x) == @Vector(4, u32)); + expect(x[0] == 5); + expect(x[1] == 5); + expect(x[2] == 5); + expect(x[3] == 5); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +}