diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 978ae71b74..59c26cc887 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -2150,7 +2150,7 @@ pub fn byteSwapAllFields(comptime S: type, ptr: *S) void { } else { byteSwapAllFields(f.type, &@field(ptr, f.name)); }, - .array => byteSwapAllFields(f.type, &@field(ptr, f.name)), + .@"union", .array => byteSwapAllFields(f.type, &@field(ptr, f.name)), .@"enum" => { @field(ptr, f.name) = @enumFromInt(@byteSwap(@intFromEnum(@field(ptr, f.name)))); }, @@ -2164,10 +2164,25 @@ pub fn byteSwapAllFields(comptime S: type, ptr: *S) void { } } }, + .@"union" => |union_info| { + if (union_info.tag_type != null) { + @compileError("byteSwapAllFields expects an untagged union"); + } + + const first_size = @bitSizeOf(union_info.fields[0].type); + inline for (union_info.fields) |field| { + if (@bitSizeOf(field.type) != first_size) { + @compileError("Unable to byte-swap unions with varying field sizes"); + } + } + + const BackingInt = std.meta.Int(.unsigned, @bitSizeOf(S)); + ptr.* = @bitCast(@byteSwap(@as(BackingInt, @bitCast(ptr.*)))); + }, .array => { for (ptr) |*item| { switch (@typeInfo(@TypeOf(item.*))) { - .@"struct", .array => byteSwapAllFields(@TypeOf(item.*), item), + .@"struct", .@"union", .array => byteSwapAllFields(@TypeOf(item.*), item), .@"enum" => { item.* = @enumFromInt(@byteSwap(@intFromEnum(item.*))); }, @@ -2193,6 +2208,7 @@ test byteSwapAllFields { f3: [1]u8, f4: bool, f5: f32, + f6: extern union { f0: u16, f1: u16 }, }; const K = extern struct { f0: u8, @@ -2209,6 +2225,7 @@ test byteSwapAllFields { .f3 = .{0x12}, .f4 = true, .f5 = @as(f32, @bitCast(@as(u32, 0x4640e400))), + .f6 = .{ .f0 = 0x1234 }, }; var k = K{ .f0 = 0x12, @@ -2227,6 +2244,7 @@ test byteSwapAllFields { .f3 = .{0x12}, .f4 = true, .f5 = @as(f32, @bitCast(@as(u32, 0x00e44046))), + .f6 = .{ .f0 = 0x3412 }, }, s); try std.testing.expectEqual(K{ .f0 = 0x12, diff --git a/lib/std/testing.zig b/lib/std/testing.zig index 1c76759479..f52135f237 100644 --- a/lib/std/testing.zig +++ b/lib/std/testing.zig @@ -153,7 +153,18 @@ fn expectEqualInner(comptime T: type, expected: T, actual: T) !void { .@"union" => |union_info| { if (union_info.tag_type == null) { - @compileError("Unable to compare untagged union values for type " ++ @typeName(@TypeOf(actual))); + const first_size = @bitSizeOf(union_info.fields[0].type); + inline for (union_info.fields) |field| { + if (@bitSizeOf(field.type) != first_size) { + @compileError("Unable to compare untagged unions with varying field sizes for type " ++ @typeName(@TypeOf(actual))); + } + } + + const BackingInt = std.meta.Int(.unsigned, @bitSizeOf(T)); + return expectEqual( + @as(BackingInt, @bitCast(expected)), + @as(BackingInt, @bitCast(actual)), + ); } const Tag = std.meta.Tag(@TypeOf(expected));