From c96f9a017a863e1f8cb610b7caba60ce93ab5616 Mon Sep 17 00:00:00 2001 From: mlugg Date: Tue, 8 Oct 2024 23:37:16 +0100 Subject: [PATCH] Sema: implement @splat for arrays Resolves: #20433 --- lib/std/zig/AstGen.zig | 4 +- lib/std/zig/Zir.zig | 12 +-- src/Sema.zig | 73 +++++++++++++++---- src/print_zir.zig | 2 +- test/behavior/array.zig | 67 +++++++++++++++++ .../compile_errors/splat_bad_result_type.zig | 7 ++ .../splat_result_type_non_vector.zig | 9 --- 7 files changed, 140 insertions(+), 34 deletions(-) create mode 100644 test/cases/compile_errors/splat_bad_result_type.zig delete mode 100644 test/cases/compile_errors/splat_result_type_non_vector.zig diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index cff9b4e37a..7e11f8d44b 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -2716,7 +2716,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .array_type_sentinel, .elem_type, .indexable_ptr_elem_type, - .vector_elem_type, + .vec_arr_elem_type, .vector_type, .indexable_ptr_len, .anyframe_type, @@ -9529,7 +9529,7 @@ fn builtinCall( .splat => { const result_type = try ri.rl.resultTypeForCast(gz, node, builtin_name); - const elem_type = try gz.addUnNode(.vector_elem_type, result_type, node); + const elem_type = try gz.addUnNode(.vec_arr_elem_type, result_type, node); const scalar = try expr(gz, scope, .{ .rl = .{ .ty = elem_type } }, params[0]); const result = try gz.addPlNode(.splat, node, Zir.Inst.Bin{ .lhs = result_type, diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index a644bd69e4..9c76f72663 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -247,9 +247,9 @@ pub const Inst = struct { /// element type. Emits a compile error if the type is not an indexable pointer. /// Uses the `un_node` field. indexable_ptr_elem_type, - /// Given a vector type, returns its element type. + /// Given a vector or array type, returns its element type. /// Uses the `un_node` field. - vector_elem_type, + vec_arr_elem_type, /// Given a pointer to an indexable object, returns the len property. This is /// used by for loops. This instruction also emits a for-loop specific compile /// error if the indexable object is not indexable. @@ -1065,7 +1065,7 @@ pub const Inst = struct { .vector_type, .elem_type, .indexable_ptr_elem_type, - .vector_elem_type, + .vec_arr_elem_type, .indexable_ptr_len, .anyframe_type, .as_node, @@ -1375,7 +1375,7 @@ pub const Inst = struct { .vector_type, .elem_type, .indexable_ptr_elem_type, - .vector_elem_type, + .vec_arr_elem_type, .indexable_ptr_len, .anyframe_type, .as_node, @@ -1607,7 +1607,7 @@ pub const Inst = struct { .vector_type = .pl_node, .elem_type = .un_node, .indexable_ptr_elem_type = .un_node, - .vector_elem_type = .un_node, + .vec_arr_elem_type = .un_node, .indexable_ptr_len = .un_node, .anyframe_type = .un_node, .as_node = .pl_node, @@ -3781,7 +3781,7 @@ fn findDeclsInner( .vector_type, .elem_type, .indexable_ptr_elem_type, - .vector_elem_type, + .vec_arr_elem_type, .indexable_ptr_len, .anyframe_type, .as_node, diff --git a/src/Sema.zig b/src/Sema.zig index 7060128e6d..1de0d73808 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1087,7 +1087,7 @@ fn analyzeBodyInner( .elem_val_imm => try sema.zirElemValImm(block, inst), .elem_type => try sema.zirElemType(block, inst), .indexable_ptr_elem_type => try sema.zirIndexablePtrElemType(block, inst), - .vector_elem_type => try sema.zirVectorElemType(block, inst), + .vec_arr_elem_type => try sema.zirVecArrElemType(block, inst), .enum_literal => try sema.zirEnumLiteral(block, inst), .decl_literal => try sema.zirDeclLiteral(block, inst, true), .decl_literal_no_coerce => try sema.zirDeclLiteral(block, inst, false), @@ -2046,7 +2046,7 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi const bin = sema.code.instructions.items(.data)[@intFromEnum(inst)].bin; cur = bin.lhs; }, - .indexable_ptr_elem_type, .vector_elem_type => { + .indexable_ptr_elem_type, .vec_arr_elem_type => { const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; cur = un_node.operand; }, @@ -8603,7 +8603,7 @@ fn zirIndexablePtrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Com return Air.internedToRef(elem_ty.toIntern()); } -fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { +fn zirVecArrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const pt = sema.pt; const zcu = pt.zcu; const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; @@ -8615,8 +8615,9 @@ fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr error.GenericPoison => return .generic_poison_type, else => |e| return e, }; - if (!vec_ty.isVector(zcu)) { - return sema.fail(block, block.nodeOffset(un_node.src_node), "expected vector type, found '{}'", .{vec_ty.fmt(pt)}); + switch (vec_ty.zigTypeTag(zcu)) { + .array, .vector => {}, + else => return sema.fail(block, block.nodeOffset(un_node.src_node), "expected array or vector type, found '{}'", .{vec_ty.fmt(pt)}), } return Air.internedToRef(vec_ty.childType(zcu).toIntern()); } @@ -24804,26 +24805,66 @@ fn zirSplat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.I const scalar_src = block.builtinCallArgSrc(inst_data.src_node, 0); const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@splat"); - if (!dest_ty.isVector(zcu)) return sema.fail(block, src, "expected vector type, found '{}'", .{dest_ty.fmt(pt)}); - - if (!dest_ty.hasRuntimeBits(zcu)) { - const empty_aggregate = try pt.intern(.{ .aggregate = .{ - .ty = dest_ty.toIntern(), - .storage = .{ .elems = &[_]InternPool.Index{} }, - } }); - return Air.internedToRef(empty_aggregate); + switch (dest_ty.zigTypeTag(zcu)) { + .array, .vector => {}, + else => return sema.fail(block, src, "expected array or vector type, found '{}'", .{dest_ty.fmt(pt)}), } const operand = try sema.resolveInst(extra.rhs); const scalar_ty = dest_ty.childType(zcu); const scalar = try sema.coerce(block, scalar_ty, operand, scalar_src); + + const len = try sema.usizeCast(block, src, dest_ty.arrayLen(zcu)); + + // `len == 0` because `[0:s]T` always has a comptime-known splat. + if (!dest_ty.hasRuntimeBits(zcu) or len == 0) { + const empty_aggregate = try pt.intern(.{ .aggregate = .{ + .ty = dest_ty.toIntern(), + .storage = .{ .elems = &.{} }, + } }); + return Air.internedToRef(empty_aggregate); + } + + const maybe_sentinel = dest_ty.sentinel(zcu); + if (try sema.resolveValue(scalar)) |scalar_val| { - if (scalar_val.isUndef(zcu)) return pt.undefRef(dest_ty); - return Air.internedToRef((try sema.splat(dest_ty, scalar_val)).toIntern()); + if (scalar_val.isUndef(zcu) and maybe_sentinel == null) { + return pt.undefRef(dest_ty); + } + // TODO: I didn't want to put `.aggregate` on a separate line here; `zig fmt` bugs have forced my hand + return Air.internedToRef(try pt.intern(.{ + .aggregate = .{ + .ty = dest_ty.toIntern(), + .storage = s: { + full: { + if (dest_ty.zigTypeTag(zcu) == .vector) break :full; + const sentinel = maybe_sentinel orelse break :full; + if (sentinel.toIntern() == scalar_val.toIntern()) break :full; + // This is a array with non-zero length and a sentinel which does not match the element. + // We have to use the full `elems` representation. + const elems = try sema.arena.alloc(InternPool.Index, len + 1); + @memset(elems[0..len], scalar_val.toIntern()); + elems[len] = sentinel.toIntern(); + break :s .{ .elems = elems }; + } + break :s .{ .repeated_elem = scalar_val.toIntern() }; + }, + }, + })); } try sema.requireRuntimeBlock(block, src, scalar_src); - return block.addTyOp(.splat, dest_ty, scalar); + + switch (dest_ty.zigTypeTag(zcu)) { + .array => { + const elems = try sema.arena.alloc(Air.Inst.Ref, len + @intFromBool(maybe_sentinel != null)); + @memset(elems[0..len], scalar); + if (maybe_sentinel) |s| elems[len] = Air.internedToRef(s.toIntern()); + return block.addAggregateInit(dest_ty, elems); + }, + .vector => return block.addTyOp(.splat, dest_ty, scalar), + else => unreachable, + } } fn zirReduce(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { diff --git a/src/print_zir.zig b/src/print_zir.zig index a16fcbb2fd..5467dcd27a 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -203,7 +203,7 @@ const Writer = struct { .alloc_comptime_mut, .elem_type, .indexable_ptr_elem_type, - .vector_elem_type, + .vec_arr_elem_type, .indexable_ptr_len, .anyframe_type, .bit_not, diff --git a/test/behavior/array.zig b/test/behavior/array.zig index a01e624a5d..17b8667238 100644 --- a/test/behavior/array.zig +++ b/test/behavior/array.zig @@ -1021,3 +1021,70 @@ test "runtime index of array of zero-bit values" { try std.testing.expect(result.index == 0); try std.testing.expect(result.value == {}); } + +test "@splat array" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + const S = struct { + fn doTheTest(comptime T: type, x: T) !void { + const arr: [10]T = @splat(x); + for (arr) |elem| { + try expectEqual(x, elem); + } + } + }; + + try S.doTheTest(u32, 123); + try comptime S.doTheTest(u32, 123); + + const Foo = struct { x: u8 }; + try S.doTheTest(Foo, .{ .x = 10 }); + try comptime S.doTheTest(Foo, .{ .x = 10 }); +} + +test "@splat array with sentinel" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + const S = struct { + fn doTheTest(comptime T: type, x: T, comptime s: T) !void { + const arr: [10:s]T = @splat(x); + for (arr) |elem| { + try expectEqual(x, elem); + } + const ptr: [*]const T = &arr; + try expectEqual(s, ptr[10]); // sentinel correct + } + }; + + try S.doTheTest(u32, 100, 42); + try comptime S.doTheTest(u32, 100, 42); + + try S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null); + try comptime S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null); +} + +test "@splat zero-length array" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + const S = struct { + fn doTheTest(comptime T: type, comptime s: T) !void { + var runtime_undef: T = undefined; + runtime_undef = undefined; + // The array should be comptime-known despite the `@splat` operand being runtime-known. + const arr: [0:s]T = @splat(runtime_undef); + const ptr: [*]const T = &arr; + comptime assert(ptr[0] == s); + } + }; + + try S.doTheTest(u32, 42); + try comptime S.doTheTest(u32, 42); + + try S.doTheTest(?*anyopaque, null); + try comptime S.doTheTest(?*anyopaque, null); +} diff --git a/test/cases/compile_errors/splat_bad_result_type.zig b/test/cases/compile_errors/splat_bad_result_type.zig new file mode 100644 index 0000000000..60d4b35d45 --- /dev/null +++ b/test/cases/compile_errors/splat_bad_result_type.zig @@ -0,0 +1,7 @@ +export fn f() void { + _ = @as(u32, @splat(5)); +} + +// error +// +// :2:18: error: expected array or vector type, found 'u32' diff --git a/test/cases/compile_errors/splat_result_type_non_vector.zig b/test/cases/compile_errors/splat_result_type_non_vector.zig deleted file mode 100644 index dbff8dc041..0000000000 --- a/test/cases/compile_errors/splat_result_type_non_vector.zig +++ /dev/null @@ -1,9 +0,0 @@ -export fn f() void { - _ = @as(u32, @splat(5)); -} - -// error -// backend=stage2 -// target=native -// -// :2:18: error: expected vector type, found 'u32'