diff --git a/src/Module.zig b/src/Module.zig index c847fadc17..c4c2b634f7 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -6607,6 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in return field_ty.abiAlignment(mod); } +/// Returns the index of the active field, given the current tag value pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 { const ip = &mod.intern_pool; if (enum_tag.toIntern() == .none) return null; diff --git a/src/Sema.zig b/src/Sema.zig index 27dd7221fc..d626de4d18 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -27258,7 +27258,7 @@ fn unionFieldVal( return sema.failWithOwnedErrorMsg(block, msg); } }, - .Packed, .Extern => { + .Packed, .Extern => |layout| { if (tag_matches) { return Air.internedToRef(un.val); } else { @@ -27267,7 +27267,7 @@ fn unionFieldVal( else union_ty.unionFieldType(un.tag.toValue(), mod).?; - if (try sema.bitCastVal(block, src, un.val.toValue(), old_ty, field_ty, 0)) |new_val| { + if (try sema.bitCastUnionFieldVal(block, src, un.val.toValue(), old_ty, field_ty, layout)) |new_val| { return Air.internedToRef(new_val.toIntern()); } } @@ -29788,13 +29788,19 @@ fn storePtrVal( error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{mut_kit.ty.fmt(mod)}), }; - operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) { - error.OutOfMemory => return error.OutOfMemory, - error.ReinterpretDeclRef => unreachable, - error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already - error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}), - }; - + if (reinterpret.write_packed) { + operand_val.writeToPackedMemory(operand_ty, mod, buffer[reinterpret.byte_offset..], 0) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + error.ReinterpretDeclRef => unreachable, + }; + } else { + operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + error.ReinterpretDeclRef => unreachable, + error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already + error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}), + }; + } const val = Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena) catch |err| switch (err) { error.OutOfMemory => return error.OutOfMemory, error.IllDefinedMemoryLayout => unreachable, @@ -29826,6 +29832,8 @@ const ComptimePtrMutationKit = struct { reinterpret: struct { val_ptr: *Value, byte_offset: usize, + /// If set, write the operand to packed memory + write_packed: bool = false, }, /// If the root decl could not be used as parent, this means `ty` is the type that /// caused that by not having a well-defined layout. @@ -30189,21 +30197,43 @@ fn beginComptimePtrMutation( ); }, .@"union" => { - // We need to set the active field of the union. - const union_tag_ty = base_child_ty.unionTagTypeHypothetical(mod); - const payload = &val_ptr.castTag(.@"union").?.data; - payload.tag = try mod.enumValueFieldIndex(union_tag_ty, field_index); + const layout = base_child_ty.containerLayout(mod); - return beginComptimePtrMutationInner( - sema, - block, - src, - parent.ty.structFieldType(field_index, mod), - &payload.val, - ptr_elem_ty, - parent.mut_decl, - ); + const tag_type = base_child_ty.unionTagTypeHypothetical(mod); + const hypothetical_tag = try mod.enumValueFieldIndex(tag_type, field_index); + if (layout == .Auto or (payload.tag != null and hypothetical_tag.eql(payload.tag.?, tag_type, mod))) { + // We need to set the active field of the union. + payload.tag = hypothetical_tag; + + const field_ty = parent.ty.structFieldType(field_index, mod); + return beginComptimePtrMutationInner( + sema, + block, + src, + field_ty, + &payload.val, + ptr_elem_ty, + parent.mut_decl, + ); + } else { + // Writing to a different field (a different or unknown tag is active) requires reinterpreting + // memory of the entire union, which requires knowing its abiSize. + try sema.resolveTypeLayout(parent.ty); + + // This union value no longer has a well-defined tag type. + // The reinterpretation will read it back out as .none. + payload.val = try payload.val.unintern(sema.arena, mod); + return ComptimePtrMutationKit{ + .mut_decl = parent.mut_decl, + .pointee = .{ .reinterpret = .{ + .val_ptr = val_ptr, + .byte_offset = 0, + .write_packed = layout == .Packed, + } }, + .ty = parent.ty, + }; + } }, .slice => switch (field_index) { Value.slice_ptr_index => return beginComptimePtrMutationInner( @@ -30704,6 +30734,7 @@ fn bitCastVal( // For types with well-defined memory layouts, we serialize them a byte buffer, // then deserialize to the new type. const abi_size = try sema.usizeCast(block, src, old_ty.abiSize(mod)); + const buffer = try sema.gpa.alloc(u8, abi_size); defer sema.gpa.free(buffer); val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) { @@ -30720,6 +30751,63 @@ fn bitCastVal( }; } +fn bitCastUnionFieldVal( + sema: *Sema, + block: *Block, + src: LazySrcLoc, + val: Value, + old_ty: Type, + field_ty: Type, + layout: std.builtin.Type.ContainerLayout, +) !?Value { + const mod = sema.mod; + if (old_ty.eql(field_ty, mod)) return val; + + const old_size = try sema.usizeCast(block, src, old_ty.abiSize(mod)); + const field_size = try sema.usizeCast(block, src, field_ty.abiSize(mod)); + const endian = mod.getTarget().cpu.arch.endian(); + + const buffer = try sema.gpa.alloc(u8, @max(old_size, field_size)); + defer sema.gpa.free(buffer); + + // Reading a larger value means we need to reinterpret from undefined bytes. + const offset = switch (layout) { + .Extern => offset: { + if (field_size > old_size) @memset(buffer[old_size..], 0xaa); + val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + error.ReinterpretDeclRef => return null, + error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already + error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{old_ty.fmt(mod)}), + }; + break :offset 0; + }, + .Packed => offset: { + if (field_size > old_size) { + const min_size = @max(old_size, 1); + switch (endian) { + .Little => @memset(buffer[min_size - 1 ..], 0xaa), + .Big => @memset(buffer[0 .. buffer.len - min_size + 1], 0xaa), + } + } + + val.writeToPackedMemory(old_ty, mod, buffer, 0) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + error.ReinterpretDeclRef => return null, + }; + + break :offset if (endian == .Big) buffer.len - field_size else 0; + }, + .Auto => unreachable, + }; + + return Value.readFromMemory(field_ty, mod, buffer[offset..], sema.arena) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + error.IllDefinedMemoryLayout => unreachable, + error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{field_ty.fmt(mod)}), + }; +} + fn coerceArrayPtrToSlice( sema: *Sema, block: *Block, diff --git a/src/TypedValue.zig b/src/TypedValue.zig index cf705cdf89..cef22543a3 100644 --- a/src/TypedValue.zig +++ b/src/TypedValue.zig @@ -84,22 +84,27 @@ pub fn print( if (level == 0) { return writer.writeAll(".{ ... }"); } - const union_val = val.castTag(.@"union").?.data; + const payload = val.castTag(.@"union").?.data; try writer.writeAll(".{ "); - if (union_val.tag.toIntern() != .none) { + if (payload.tag) |tag| { try print(.{ .ty = ip.indexToKey(ty.toIntern()).union_type.enum_tag_ty.toType(), - .val = union_val.tag, + .val = tag, }, writer, level - 1, mod); try writer.writeAll(" = "); - const field_ty = ty.unionFieldType(union_val.tag, mod).?; + const field_ty = ty.unionFieldType(tag, mod).?; try print(.{ .ty = field_ty, - .val = union_val.val, + .val = payload.val, }, writer, level - 1, mod); } else { - return writer.writeAll("(unknown tag)"); + try writer.writeAll("(unknown tag) = "); + const backing_ty = try ty.unionBackingType(mod); + try print(.{ + .ty = backing_ty, + .val = payload.val, + }, writer, level - 1, mod); } return writer.writeAll(" }"); @@ -421,7 +426,12 @@ pub fn print( .val = un.val.toValue(), }, writer, level - 1, mod); } else { - try writer.writeAll("(unknown tag)"); + try writer.writeAll("(unknown tag) = "); + const backing_ty = try ty.unionBackingType(mod); + try print(.{ + .ty = backing_ty, + .val = un.val.toValue(), + }, writer, level - 1, mod); } } else try writer.writeAll("..."); return writer.writeAll(" }"); diff --git a/src/type.zig b/src/type.zig index 78514fe3d5..87b29731f6 100644 --- a/src/type.zig +++ b/src/type.zig @@ -1954,6 +1954,16 @@ pub const Type = struct { return true; } + /// Returns the type used for backing storage of this union during comptime operations. + /// Asserts the type is either an extern or packed union. + pub fn unionBackingType(ty: Type, mod: *Module) !Type { + return switch (ty.containerLayout(mod)) { + .Extern => try mod.arrayType(.{ .len = ty.abiSize(mod), .child = .u8_type }), + .Packed => try mod.intType(.unsigned, @intCast(ty.bitSize(mod))), + .Auto => unreachable, + }; + } + pub fn unionGetLayout(ty: Type, mod: *Module) Module.UnionLayout { const ip = &mod.intern_pool; const union_type = ip.indexToKey(ty.toIntern()).union_type; diff --git a/src/value.zig b/src/value.zig index 48a2f0fca2..d37e039abf 100644 --- a/src/value.zig +++ b/src/value.zig @@ -327,11 +327,19 @@ pub const Value = struct { }, .@"union" => { const pl = val.castTag(.@"union").?.data; - return mod.intern(.{ .un = .{ - .ty = ty.toIntern(), - .tag = try pl.tag.intern(ty.unionTagTypeHypothetical(mod), mod), - .val = try pl.val.intern(ty.unionFieldType(pl.tag, mod).?, mod), - } }); + if (pl.tag) |pl_tag| { + return mod.intern(.{ .un = .{ + .ty = ty.toIntern(), + .tag = try pl_tag.intern(ty.unionTagTypeHypothetical(mod), mod), + .val = try pl.val.intern(ty.unionFieldType(pl_tag, mod).?, mod), + } }); + } else { + return mod.intern(.{ .un = .{ + .ty = ty.toIntern(), + .tag = .none, + .val = try pl.val.intern(try ty.unionBackingType(mod), mod), + } }); + } }, } } @@ -399,10 +407,7 @@ pub const Value = struct { .un => |un| Tag.@"union".create(arena, .{ // toValue asserts that the value cannot be .none which is valid on unions. - .tag = .{ - .ip_index = un.tag, - .legacy = undefined, - }, + .tag = if (un.tag == .none) null else un.tag.toValue(), .val = un.val.toValue(), }), @@ -709,21 +714,22 @@ pub const Value = struct { .Union => switch (ty.containerLayout(mod)) { .Auto => return error.IllDefinedMemoryLayout, // Sema is supposed to have emitted a compile error already .Extern => { - const union_obj = mod.typeToUnion(ty).?; if (val.unionTag(mod)) |union_tag| { + const union_obj = mod.typeToUnion(ty).?; const field_index = mod.unionTagFieldIndex(union_obj, union_tag).?; const field_type = union_obj.field_types.get(&mod.intern_pool)[field_index].toType(); const field_val = try val.fieldValue(mod, field_index); const byte_count = @as(usize, @intCast(field_type.abiSize(mod))); return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]); } else { - const union_size = ty.abiSize(mod); - const array_type = try mod.arrayType(.{ .len = union_size, .child = .u8_type }); - return writeToMemory(val.unionValue(mod), array_type, mod, buffer[0..@as(usize, @intCast(union_size))]); + const backing_ty = try ty.unionBackingType(mod); + const byte_count: usize = @intCast(backing_ty.abiSize(mod)); + return writeToMemory(val.unionValue(mod), backing_ty, mod, buffer[0..byte_count]); } }, .Packed => { - const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8; + const backing_ty = try ty.unionBackingType(mod); + const byte_count: usize = @intCast(backing_ty.abiSize(mod)); return writeToPackedMemory(val, ty, mod, buffer[0..byte_count], 0); }, }, @@ -842,9 +848,8 @@ pub const Value = struct { const field_val = try val.fieldValue(mod, field_index); return field_val.writeToPackedMemory(field_type, mod, buffer, bit_offset); } else { - const union_bits: u16 = @intCast(ty.bitSize(mod)); - const int_ty = try mod.intType(.unsigned, union_bits); - return val.unionValue(mod).writeToPackedMemory(int_ty, mod, buffer, bit_offset); + const backing_ty = try ty.unionBackingType(mod); + return val.unionValue(mod).writeToPackedMemory(backing_ty, mod, buffer, bit_offset); } }, } @@ -1146,10 +1151,8 @@ pub const Value = struct { .Union => switch (ty.containerLayout(mod)) { .Auto, .Extern => unreachable, // Handled by non-packed readFromMemory .Packed => { - const union_bits: u16 = @intCast(ty.bitSize(mod)); - assert(union_bits != 0); - const int_ty = try mod.intType(.unsigned, union_bits); - const val = (try readFromPackedMemory(int_ty, mod, buffer, bit_offset, arena)).toIntern(); + const backing_ty = try ty.unionBackingType(mod); + const val = (try readFromPackedMemory(backing_ty, mod, buffer, bit_offset, arena)).toIntern(); return (try mod.intern(.{ .un = .{ .ty = ty.toIntern(), .tag = .none, @@ -4017,7 +4020,7 @@ pub const Value = struct { data: Data, pub const Data = struct { - tag: Value, + tag: ?Value, val: Value, }; }; diff --git a/test/behavior/comptime_memory.zig b/test/behavior/comptime_memory.zig index 8a28d40743..d6b3ae3993 100644 --- a/test/behavior/comptime_memory.zig +++ b/test/behavior/comptime_memory.zig @@ -455,54 +455,3 @@ test "type pun null pointer-like optional" { // note that expectEqual hides the bug try testing.expect(@as(*const ?*i8, @ptrCast(&p)).* == null); } - -test "reinterpret extern union" { - { - const U = extern union { - a: u32, - b: u8 align(8), - }; - - comptime var u: U = undefined; - comptime @memset(std.mem.asBytes(&u), 42); - try comptime testing.expect(0x2a2a2a2a == u.a); - try comptime testing.expect(42 == u.b); - try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a); - try testing.expectEqual(42, u.b); - } -} - -test "reinterpret packed union" { - { - const U = packed union { - a: u32, - b: u8 align(8), - }; - - comptime var u: U = undefined; - comptime @memset(std.mem.asBytes(&u), 42); - try comptime testing.expect(0x2a2a2a2a == u.a); - try comptime testing.expect(0x2a == u.b); - try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a); - try testing.expectEqual(0x2a, u.b); - } - - { - const U = packed union { - a: u7, - b: u1, - }; - - const S = packed struct { - lsb: U, - msb: U, - }; - - comptime var s: S = undefined; - comptime @memset(std.mem.asBytes(&s), 0xaa); - try comptime testing.expectEqual(@as(u7, 0x2a), s.lsb.a); - try comptime testing.expectEqual(@as(u1, 0), s.lsb.b); - try comptime testing.expectEqual(@as(u7, 0x55), s.msb.a); - try comptime testing.expectEqual(@as(u1, 1), s.msb.b); - } -} diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 48d4ac9a68..b7f586f624 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const endian = builtin.cpu.arch.endian(); const expect = std.testing.expect; const assert = std.debug.assert; const expectEqual = std.testing.expectEqual; @@ -1660,15 +1661,220 @@ test "union with 128 bit integer" { } } -test "memset extern union at comptime" { +test "memset extern union" { const U = extern union { foo: u8, + bar: u32, }; - const u = comptime blk: { - var u: U = undefined; - @memset(std.mem.asBytes(&u), 0); - u.foo = 0; - break :blk u; + + const S = struct { + fn doTheTest() !void { + var u: U = undefined; + @memset(std.mem.asBytes(&u), 0); + try expectEqual(@as(u8, 0), u.foo); + try expectEqual(@as(u32, 0), u.bar); + } }; - try expect(u.foo == 0); + + try comptime S.doTheTest(); + try S.doTheTest(); +} + +test "memset packed union" { + const U = packed union { + a: u32, + b: u8, + }; + + const S = struct { + fn doTheTest() !void { + var u: U = undefined; + @memset(std.mem.asBytes(&u), 42); + try expectEqual(@as(u32, 0x2a2a2a2a), u.a); + try expectEqual(@as(u8, 0x2a), u.b); + } + }; + + try comptime S.doTheTest(); + + if (builtin.cpu.arch.isWasm()) return error.SkipZigTest; // TODO + try S.doTheTest(); +} + +fn littleToNativeEndian(comptime T: type, v: T) T { + return if (endian == .Little) v else @byteSwap(v); +} + +test "reinterpret extern union" { + const U = extern union { + foo: u8, + baz: u32 align(8), + bar: u32, + }; + + const S = struct { + fn doTheTest() !void { + { + // Undefined initialization + const u = blk: { + var u: U = undefined; + @memset(std.mem.asBytes(&u), 0); + u.bar = 0xbbbbbbbb; + u.foo = 0x2a; + break :blk u; + }; + + try expectEqual(@as(u8, 0x2a), u.foo); + try expectEqual(littleToNativeEndian(u32, 0xbbbbbb2a), u.bar); + try expectEqual(littleToNativeEndian(u32, 0xbbbbbb2a), u.baz); + } + + { + // Union initialization + var u: U = .{ + .foo = 0x2a, + }; + + { + const expected, const mask = switch (endian) { + .Little => .{ 0x2a, 0xff }, + .Big => .{ 0x2a000000, 0xff000000 }, + }; + + try expectEqual(@as(u8, 0x2a), u.foo); + try expectEqual(@as(u32, expected), u.bar & mask); + try expectEqual(@as(u32, expected), u.baz & mask); + } + + // Writing to a larger field + u.baz = 0xbbbbbbbb; + try expectEqual(@as(u8, 0xbb), u.foo); + try expectEqual(@as(u32, 0xbbbbbbbb), u.bar); + try expectEqual(@as(u32, 0xbbbbbbbb), u.baz); + + // Writing to the same field + u.baz = 0xcccccccc; + try expectEqual(@as(u8, 0xcc), u.foo); + try expectEqual(@as(u32, 0xcccccccc), u.bar); + try expectEqual(@as(u32, 0xcccccccc), u.baz); + + // Writing to a smaller field + u.foo = 0xdd; + try expectEqual(@as(u8, 0xdd), u.foo); + try expectEqual(littleToNativeEndian(u32, 0xccccccdd), u.bar); + try expectEqual(littleToNativeEndian(u32, 0xccccccdd), u.baz); + } + } + }; + + try comptime S.doTheTest(); + + if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; // TODO + try S.doTheTest(); +} + +test "reinterpret packed union" { + const U = packed union { + foo: u8, + bar: u29, + baz: u64, + qux: u12, + }; + + const S = struct { + fn doTheTest() !void { + { + const u = blk: { + var u: U = undefined; + @memset(std.mem.asBytes(&u), 0); + u.baz = 0xbbbbbbbb; + u.qux = 0xe2a; + break :blk u; + }; + + try expectEqual(@as(u8, 0x2a), u.foo); + try expectEqual(@as(u12, 0xe2a), u.qux); + + // https://github.com/ziglang/zig/issues/17360 + if (@inComptime()) { + try expectEqual(@as(u29, 0x1bbbbe2a), u.bar); + try expectEqual(@as(u64, 0xbbbbbe2a), u.baz); + } + } + + { + // Union initialization + var u: U = .{ + .qux = 0xe2a, + }; + try expectEqual(@as(u8, 0x2a), u.foo); + try expectEqual(@as(u12, 0xe2a), u.qux); + try expectEqual(@as(u29, 0xe2a), u.bar & 0xfff); + try expectEqual(@as(u64, 0xe2a), u.baz & 0xfff); + + // Writing to a larger field + u.baz = 0xbbbbbbbb; + try expectEqual(@as(u8, 0xbb), u.foo); + try expectEqual(@as(u12, 0xbbb), u.qux); + try expectEqual(@as(u29, 0x1bbbbbbb), u.bar); + try expectEqual(@as(u64, 0xbbbbbbbb), u.baz); + + // Writing to the same field + u.baz = 0xcccccccc; + try expectEqual(@as(u8, 0xcc), u.foo); + try expectEqual(@as(u12, 0xccc), u.qux); + try expectEqual(@as(u29, 0x0ccccccc), u.bar); + try expectEqual(@as(u64, 0xcccccccc), u.baz); + + // Writing to a smaller field + u.foo = 0xdd; + try expectEqual(@as(u8, 0xdd), u.foo); + try expectEqual(@as(u12, 0xcdd), u.qux); + try expectEqual(@as(u29, 0x0cccccdd), u.bar); + try expectEqual(@as(u64, 0xccccccdd), u.baz); + } + } + }; + + try comptime S.doTheTest(); + + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + if (builtin.cpu.arch.isPPC()) return error.SkipZigTest; // TODO + if (builtin.cpu.arch.isWasm()) return error.SkipZigTest; // TODO + try S.doTheTest(); +} + +test "reinterpret packed union inside packed struct" { + const U = packed union { + a: u7, + b: u1, + }; + + const V = packed struct { + lo: U, + hi: U, + }; + + const S = struct { + fn doTheTest() !void { + var v: V = undefined; + @memset(std.mem.asBytes(&v), 0x55); + try expectEqual(@as(u7, 0x55), v.lo.a); + try expectEqual(@as(u1, 1), v.lo.b); + try expectEqual(@as(u7, 0x2a), v.hi.a); + try expectEqual(@as(u1, 0), v.hi.b); + + v.lo.b = 0; + try expectEqual(@as(u7, 0x54), v.lo.a); + try expectEqual(@as(u1, 0), v.lo.b); + v.hi.b = 1; + try expectEqual(@as(u7, 0x2b), v.hi.a); + try expectEqual(@as(u1, 1), v.hi.b); + } + }; + + try comptime S.doTheTest(); + + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + try S.doTheTest(); }