From 4b9b9e725777a8f9e4ea9391beaeea34c834615f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 28 Dec 2021 01:52:19 -0700 Subject: [PATCH] stage2: LLVM backend: fix lowering of union constants Comment from this commit reproduced here: LLVM does not allow us to change the type of globals. So we must create a new global with the correct type, copy all its attributes, and then update all references to point to the new global, delete the original, and rename the new one to the old one's name. This is necessary because LLVM does not support const bitcasting a struct with padding bytes, which is needed to lower a const union value to LLVM, when a field other than the most-aligned is active. Instead, we must lower to an unnamed struct, and pointer cast at usage sites of the global. Such an unnamed struct is the cause of the global type mismatch, because we don't have the LLVM type until the *value* is created, whereas the global needs to be created based on the type alone, because lowering the value may reference the global as a pointer. --- src/codegen/llvm.zig | 89 ++++++++++++++++++++++++++------- src/codegen/llvm/bindings.zig | 18 +++++++ src/zig_llvm.cpp | 4 ++ src/zig_llvm.h | 24 +++++---- test/behavior/switch.zig | 19 +++++++ test/behavior/switch_stage1.zig | 18 ------- test/behavior/union.zig | 81 ++++++++++++++++++++++++++++++ test/behavior/union_stage1.zig | 81 ------------------------------ 8 files changed, 205 insertions(+), 129 deletions(-) diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 2ec1cb4708..a248fb1718 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -557,8 +557,8 @@ pub const DeclGen = struct { return self.object.llvm_module; } - fn genDecl(self: *DeclGen) !void { - const decl = self.decl; + fn genDecl(dg: *DeclGen) !void { + const decl = dg.decl; assert(decl.has_tv); log.debug("gen: {s} type: {}, value: {}", .{ decl.name, decl.ty, decl.val }); @@ -567,10 +567,10 @@ pub const DeclGen = struct { _ = func_payload; @panic("TODO llvm backend genDecl function pointer"); } else if (decl.val.castTag(.extern_fn)) |extern_fn| { - _ = try self.resolveLlvmFunction(extern_fn.data); + _ = try dg.resolveLlvmFunction(extern_fn.data); } else { - const target = self.module.getTarget(); - const global = try self.resolveGlobalDecl(decl); + const target = dg.module.getTarget(); + const global = try dg.resolveGlobalDecl(decl); global.setAlignment(decl.getAlignment(target)); assert(decl.has_tv); const init_val = if (decl.val.castTag(.variable)) |payload| init_val: { @@ -581,8 +581,35 @@ pub const DeclGen = struct { break :init_val decl.val; }; if (init_val.tag() != .unreachable_value) { - const llvm_init = try self.genTypedValue(.{ .ty = decl.ty, .val = init_val }); - global.setInitializer(llvm_init); + const llvm_init = try dg.genTypedValue(.{ .ty = decl.ty, .val = init_val }); + if (global.globalGetValueType() == llvm_init.typeOf()) { + global.setInitializer(llvm_init); + } else { + // LLVM does not allow us to change the type of globals. So we must + // create a new global with the correct type, copy all its attributes, + // and then update all references to point to the new global, + // delete the original, and rename the new one to the old one's name. + // This is necessary because LLVM does not support const bitcasting + // a struct with padding bytes, which is needed to lower a const union value + // to LLVM, when a field other than the most-aligned is active. Instead, + // we must lower to an unnamed struct, and pointer cast at usage sites + // of the global. Such an unnamed struct is the cause of the global type + // mismatch, because we don't have the LLVM type until the *value* is created, + // whereas the global needs to be created based on the type alone, because + // lowering the value may reference the global as a pointer. + const new_global = dg.object.llvm_module.addGlobalInAddressSpace( + llvm_init.typeOf(), + "", + dg.llvmAddressSpace(decl.@"addrspace"), + ); + new_global.setLinkage(global.getLinkage()); + new_global.setUnnamedAddr(global.getUnnamedAddress()); + new_global.setAlignment(global.getAlignment()); + new_global.setInitializer(llvm_init); + global.replaceAllUsesWith(new_global); + new_global.takeName(global); + global.deleteGlobal(); + } } } } @@ -1456,9 +1483,15 @@ pub const DeclGen = struct { const layout = tv.ty.unionGetLayout(target); if (layout.payload_size == 0) { - return genTypedValue(dg, .{ .ty = tv.ty.unionTagType().?, .val = tag_and_val.tag }); + return genTypedValue(dg, .{ + .ty = tv.ty.unionTagType().?, + .val = tag_and_val.tag, + }); } - const field_ty = tv.ty.unionFieldType(tag_and_val.tag); + const union_obj = tv.ty.cast(Type.Payload.Union).?.data; + const field_index = union_obj.tag_ty.enumTagFieldIndex(tag_and_val.tag).?; + assert(union_obj.haveFieldTypes()); + const field_ty = union_obj.fields.values()[field_index].ty; const payload = p: { if (!field_ty.hasCodeGenBits()) { const padding_len = @intCast(c_uint, layout.payload_size); @@ -1475,10 +1508,20 @@ pub const DeclGen = struct { }; break :p dg.context.constStruct(&fields, fields.len, .False); }; + + // In this case we must make an unnamed struct because LLVM does + // not support bitcasting our payload struct to the true union payload type. + // Instead we use an unnamed struct and every reference to the global + // must pointer cast to the expected type before accessing the union. + const need_unnamed = layout.most_aligned_field != field_index; + if (layout.tag_size == 0) { - const llvm_payload_ty = llvm_union_ty.structGetTypeAtIndex(0); - const fields: [1]*const llvm.Value = .{payload.constBitCast(llvm_payload_ty)}; - return llvm_union_ty.constNamedStruct(&fields, fields.len); + const fields: [1]*const llvm.Value = .{payload}; + if (need_unnamed) { + return dg.context.constStruct(&fields, fields.len, .False); + } else { + return llvm_union_ty.constNamedStruct(&fields, fields.len); + } } const llvm_tag_value = try genTypedValue(dg, .{ .ty = tv.ty.unionTagType().?, @@ -1486,13 +1529,15 @@ pub const DeclGen = struct { }); var fields: [2]*const llvm.Value = undefined; if (layout.tag_align >= layout.payload_align) { - fields[0] = llvm_tag_value; - fields[1] = payload.constBitCast(llvm_union_ty.structGetTypeAtIndex(1)); + fields = .{ llvm_tag_value, payload }; } else { - fields[0] = payload.constBitCast(llvm_union_ty.structGetTypeAtIndex(0)); - fields[1] = llvm_tag_value; + fields = .{ payload, llvm_tag_value }; + } + if (need_unnamed) { + return dg.context.constStruct(&fields, fields.len, .False); + } else { + return llvm_union_ty.constNamedStruct(&fields, fields.len); } - return llvm_union_ty.constNamedStruct(&fields, fields.len); }, .Vector => switch (tv.val.tag()) { .bytes => { @@ -1859,8 +1904,14 @@ pub const FuncGen = struct { global.setGlobalConstant(.True); global.setUnnamedAddr(.True); global.setAlignment(ty.abiAlignment(target)); - gop.value_ptr.* = global; - return global; + // Because of LLVM limitations for lowering certain types such as unions, + // the type of global constants might not match the type it is supposed to + // be, and so we must bitcast the pointer at the usage sites. + const wanted_llvm_ty = try self.dg.llvmType(ty); + const wanted_llvm_ptr_ty = wanted_llvm_ty.pointerType(0); + const casted_ptr = global.constBitCast(wanted_llvm_ptr_ty); + gop.value_ptr.* = casted_ptr; + return casted_ptr; } fn genBody(self: *FuncGen, body: []const Air.Inst.Index) Error!void { diff --git a/src/codegen/llvm/bindings.zig b/src/codegen/llvm/bindings.zig index 3c85524cc1..3dac5bdca4 100644 --- a/src/codegen/llvm/bindings.zig +++ b/src/codegen/llvm/bindings.zig @@ -184,6 +184,9 @@ pub const Value = opaque { pub const setValueName2 = LLVMSetValueName2; extern fn LLVMSetValueName2(Val: *const Value, Name: [*]const u8, NameLen: usize) void; + pub const takeName = ZigLLVMTakeName; + extern fn ZigLLVMTakeName(new_owner: *const Value, victim: *const Value) void; + pub const deleteFunction = LLVMDeleteFunction; extern fn LLVMDeleteFunction(Fn: *const Value) void; @@ -206,6 +209,21 @@ pub const Value = opaque { return LLVMIsPoison(Val).toBool(); } extern fn LLVMIsPoison(Val: *const Value) Bool; + + pub const replaceAllUsesWith = LLVMReplaceAllUsesWith; + extern fn LLVMReplaceAllUsesWith(OldVal: *const Value, NewVal: *const Value) void; + + pub const globalGetValueType = LLVMGlobalGetValueType; + extern fn LLVMGlobalGetValueType(Global: *const Value) *const Type; + + pub const getLinkage = LLVMGetLinkage; + extern fn LLVMGetLinkage(Global: *const Value) Linkage; + + pub const getUnnamedAddress = LLVMGetUnnamedAddress; + extern fn LLVMGetUnnamedAddress(Global: *const Value) Bool; + + pub const getAlignment = LLVMGetAlignment; + extern fn LLVMGetAlignment(V: *const Value) c_uint; }; pub const Type = opaque { diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 11cbf38368..aa07a7fed1 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -1319,6 +1319,10 @@ LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMVal return wrap(unwrap(B)->CreateFMulReduce(unwrap(Acc), unwrap(Val))); } +void ZigLLVMTakeName(LLVMValueRef new_owner, LLVMValueRef victim) { + unwrap(new_owner)->takeName(unwrap(victim)); +} + static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, ""); static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, ""); static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, ""); diff --git a/src/zig_llvm.h b/src/zig_llvm.h index d1e4fa2556..2b8156d51d 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -460,17 +460,19 @@ enum ZigLLVM_ObjectFormatType { ZigLLVM_XCOFF, }; -LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); -LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); -LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); + +ZIG_EXTERN_C void ZigLLVMTakeName(LLVMValueRef new_owner, LLVMValueRef victim); #define ZigLLVM_DIFlags_Zero 0U #define ZigLLVM_DIFlags_Private 1U diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index f5fc282b05..16bb890c9e 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -294,3 +294,22 @@ test "switch on union with some prongs capturing" { }; try expect(y == 11); } + +const Number = union(enum) { + One: u64, + Two: u8, + Three: f32, +}; + +const number = Number{ .Three = 1.23 }; + +fn returnsFalse() bool { + switch (number) { + Number.One => |x| return x > 1234, + Number.Two => |x| return x == 'a', + Number.Three => |x| return x > 12.34, + } +} +test "switch on const enum with var" { + try expect(!returnsFalse()); +} diff --git a/test/behavior/switch_stage1.zig b/test/behavior/switch_stage1.zig index 6a86eb9494..1b85d767d5 100644 --- a/test/behavior/switch_stage1.zig +++ b/test/behavior/switch_stage1.zig @@ -3,24 +3,6 @@ const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; -const Number = union(enum) { - One: u64, - Two: u8, - Three: f32, -}; - -const number = Number{ .Three = 1.23 }; - -fn returnsFalse() bool { - switch (number) { - Number.One => |x| return x > 1234, - Number.Two => |x| return x == 'a', - Number.Three => |x| return x > 12.34, - } -} -test "switch on const enum with var" { - try expect(!returnsFalse()); -} test "switch all prongs unreachable" { try testAllProngsUnreachable(); comptime try testAllProngsUnreachable(); diff --git a/test/behavior/union.zig b/test/behavior/union.zig index e296f6bbb8..033d2cf40b 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -71,3 +71,84 @@ test "0-sized extern union definition" { try expect(U.f == 1); } + +const Value = union(enum) { + Int: u64, + Array: [9]u8, +}; + +const Agg = struct { + val1: Value, + val2: Value, +}; + +const v1 = Value{ .Int = 1234 }; +const v2 = Value{ .Array = [_]u8{3} ** 9 }; + +const err = @as(anyerror!Agg, Agg{ + .val1 = v1, + .val2 = v2, +}); + +const array = [_]Value{ v1, v2, v1, v2 }; + +test "unions embedded in aggregate types" { + switch (array[1]) { + Value.Array => |arr| try expect(arr[4] == 3), + else => unreachable, + } + switch ((err catch unreachable).val1) { + Value.Int => |x| try expect(x == 1234), + else => unreachable, + } +} + +test "access a member of tagged union with conflicting enum tag name" { + const Bar = union(enum) { + A: A, + B: B, + + const A = u8; + const B = void; + }; + + comptime try expect(Bar.A == u8); +} + +test "constant tagged union with payload" { + var empty = TaggedUnionWithPayload{ .Empty = {} }; + var full = TaggedUnionWithPayload{ .Full = 13 }; + shouldBeEmpty(empty); + shouldBeNotEmpty(full); +} + +fn shouldBeEmpty(x: TaggedUnionWithPayload) void { + switch (x) { + TaggedUnionWithPayload.Empty => {}, + else => unreachable, + } +} + +fn shouldBeNotEmpty(x: TaggedUnionWithPayload) void { + switch (x) { + TaggedUnionWithPayload.Empty => unreachable, + else => {}, + } +} + +const TaggedUnionWithPayload = union(enum) { + Empty: void, + Full: i32, +}; + +test "union alignment" { + comptime { + try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf([9]u8)); + try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf(u64)); + } +} + +const AlignTestTaggedUnion = union(enum) { + A: [9]u8, + B: u64, +}; diff --git a/test/behavior/union_stage1.zig b/test/behavior/union_stage1.zig index 6a68737ecf..ac69915a39 100644 --- a/test/behavior/union_stage1.zig +++ b/test/behavior/union_stage1.zig @@ -3,37 +3,6 @@ const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; const Tag = std.meta.Tag; -const Value = union(enum) { - Int: u64, - Array: [9]u8, -}; - -const Agg = struct { - val1: Value, - val2: Value, -}; - -const v1 = Value{ .Int = 1234 }; -const v2 = Value{ .Array = [_]u8{3} ** 9 }; - -const err = @as(anyerror!Agg, Agg{ - .val1 = v1, - .val2 = v2, -}); - -const array = [_]Value{ v1, v2, v1, v2 }; - -test "unions embedded in aggregate types" { - switch (array[1]) { - Value.Array => |arr| try expect(arr[4] == 3), - else => unreachable, - } - switch ((err catch unreachable).val1) { - Value.Int => |x| try expect(x == 1234), - else => unreachable, - } -} - const Letter = enum { A, B, C }; const Payload = union(Letter) { A: i32, @@ -202,18 +171,6 @@ const PartialInstWithPayload = union(enum) { Compiled: i32, }; -test "access a member of tagged union with conflicting enum tag name" { - const Bar = union(enum) { - A: A, - B: B, - - const A = u8; - const B = void; - }; - - comptime try expect(Bar.A == u8); -} - test "tagged union initialization with runtime void" { try expect(testTaggedUnionInit({})); } @@ -775,41 +732,3 @@ test "tagged union as return value" { fn returnAnInt(x: i32) TaggedFoo { return TaggedFoo{ .One = x }; } - -test "constant tagged union with payload" { - var empty = TaggedUnionWithPayload{ .Empty = {} }; - var full = TaggedUnionWithPayload{ .Full = 13 }; - shouldBeEmpty(empty); - shouldBeNotEmpty(full); -} - -fn shouldBeEmpty(x: TaggedUnionWithPayload) void { - switch (x) { - TaggedUnionWithPayload.Empty => {}, - else => unreachable, - } -} - -fn shouldBeNotEmpty(x: TaggedUnionWithPayload) void { - switch (x) { - TaggedUnionWithPayload.Empty => unreachable, - else => {}, - } -} - -const TaggedUnionWithPayload = union(enum) { - Empty: void, - Full: i32, -}; - -test "union alignment" { - comptime { - try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf([9]u8)); - try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf(u64)); - } -} - -const AlignTestTaggedUnion = union(enum) { - A: [9]u8, - B: u64, -};