diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 42b35281e0..8691c5bbad 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -431,34 +431,48 @@ pub fn zeroInit(comptime T: type, init: anytype) T { .Struct => |struct_info| { switch (@typeInfo(Init)) { .Struct => |init_info| { - var value = std.mem.zeroes(T); + if (init_info.is_tuple) { + if (init_info.fields.len > struct_info.fields.len) { + @compileError("Tuple initializer has more elments than there are fields in `" ++ @typeName(T) ++ "`"); + } + } else { + inline for (init_info.fields) |field| { + if (!@hasField(T, field.name)) { + @compileError("Encountered an initializer for `" ++ field.name ++ "`, but it is not a field of " ++ @typeName(T)); + } + } + } - inline for (struct_info.fields) |field| { - if (field.default_value) |default_value_ptr| { + var value: T = undefined; + + inline for (struct_info.fields) |field, i| { + if (field.is_comptime) { + continue; + } + + if (init_info.is_tuple and init_info.fields.len > i) { + @field(value, field.name) = @field(init, init_info.fields[i].name); + } else if (@hasField(@TypeOf(init), field.name)) { + switch (@typeInfo(field.type)) { + .Struct => { + @field(value, field.name) = zeroInit(field.type, @field(init, field.name)); + }, + else => { + @field(value, field.name) = @field(init, field.name); + }, + } + } else if (field.default_value) |default_value_ptr| { const default_value = @ptrCast(*align(1) const field.type, default_value_ptr).*; @field(value, field.name) = default_value; - } - } - - if (init_info.is_tuple) { - inline for (init_info.fields) |field, i| { - @field(value, struct_info.fields[i].name) = @field(init, field.name); - } - return value; - } - - inline for (init_info.fields) |field| { - if (!@hasField(T, field.name)) { - @compileError("Encountered an initializer for `" ++ field.name ++ "`, but it is not a field of " ++ @typeName(T)); - } - - switch (@typeInfo(field.type)) { - .Struct => { - @field(value, field.name) = zeroInit(field.type, @field(init, field.name)); - }, - else => { - @field(value, field.name) = @field(init, field.name); - }, + } else { + switch (@typeInfo(field.type)) { + .Struct => { + @field(value, field.name) = std.mem.zeroInit(field.type, .{}); + }, + else => { + @field(value, field.name) = std.mem.zeroes(@TypeOf(@field(value, field.name))); + }, + } } } @@ -538,6 +552,24 @@ test "zeroInit" { .foo = 69, .bar = 420, }, b); + + const Baz = struct { + foo: [:0]const u8 = "bar", + }; + + const baz1 = zeroInit(Baz, .{}); + try testing.expectEqual(Baz{}, baz1); + + const baz2 = zeroInit(Baz, .{ .foo = "zab" }); + try testing.expectEqualSlices(u8, "zab", baz2.foo); + + const NestedBaz = struct { + bbb: Baz, + }; + const nested_baz = zeroInit(NestedBaz, .{}); + try testing.expectEqual(NestedBaz{ + .bbb = Baz{}, + }, nested_baz); } /// Compares two slices of numbers lexicographically. O(n).