diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 2ec1cb4708..a248fb1718 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -557,8 +557,8 @@ pub const DeclGen = struct { return self.object.llvm_module; } - fn genDecl(self: *DeclGen) !void { - const decl = self.decl; + fn genDecl(dg: *DeclGen) !void { + const decl = dg.decl; assert(decl.has_tv); log.debug("gen: {s} type: {}, value: {}", .{ decl.name, decl.ty, decl.val }); @@ -567,10 +567,10 @@ pub const DeclGen = struct { _ = func_payload; @panic("TODO llvm backend genDecl function pointer"); } else if (decl.val.castTag(.extern_fn)) |extern_fn| { - _ = try self.resolveLlvmFunction(extern_fn.data); + _ = try dg.resolveLlvmFunction(extern_fn.data); } else { - const target = self.module.getTarget(); - const global = try self.resolveGlobalDecl(decl); + const target = dg.module.getTarget(); + const global = try dg.resolveGlobalDecl(decl); global.setAlignment(decl.getAlignment(target)); assert(decl.has_tv); const init_val = if (decl.val.castTag(.variable)) |payload| init_val: { @@ -581,8 +581,35 @@ pub const DeclGen = struct { break :init_val decl.val; }; if (init_val.tag() != .unreachable_value) { - const llvm_init = try self.genTypedValue(.{ .ty = decl.ty, .val = init_val }); - global.setInitializer(llvm_init); + const llvm_init = try dg.genTypedValue(.{ .ty = decl.ty, .val = init_val }); + if (global.globalGetValueType() == llvm_init.typeOf()) { + global.setInitializer(llvm_init); + } else { + // LLVM does not allow us to change the type of globals. So we must + // create a new global with the correct type, copy all its attributes, + // and then update all references to point to the new global, + // delete the original, and rename the new one to the old one's name. + // This is necessary because LLVM does not support const bitcasting + // a struct with padding bytes, which is needed to lower a const union value + // to LLVM, when a field other than the most-aligned is active. Instead, + // we must lower to an unnamed struct, and pointer cast at usage sites + // of the global. Such an unnamed struct is the cause of the global type + // mismatch, because we don't have the LLVM type until the *value* is created, + // whereas the global needs to be created based on the type alone, because + // lowering the value may reference the global as a pointer. + const new_global = dg.object.llvm_module.addGlobalInAddressSpace( + llvm_init.typeOf(), + "", + dg.llvmAddressSpace(decl.@"addrspace"), + ); + new_global.setLinkage(global.getLinkage()); + new_global.setUnnamedAddr(global.getUnnamedAddress()); + new_global.setAlignment(global.getAlignment()); + new_global.setInitializer(llvm_init); + global.replaceAllUsesWith(new_global); + new_global.takeName(global); + global.deleteGlobal(); + } } } } @@ -1456,9 +1483,15 @@ pub const DeclGen = struct { const layout = tv.ty.unionGetLayout(target); if (layout.payload_size == 0) { - return genTypedValue(dg, .{ .ty = tv.ty.unionTagType().?, .val = tag_and_val.tag }); + return genTypedValue(dg, .{ + .ty = tv.ty.unionTagType().?, + .val = tag_and_val.tag, + }); } - const field_ty = tv.ty.unionFieldType(tag_and_val.tag); + const union_obj = tv.ty.cast(Type.Payload.Union).?.data; + const field_index = union_obj.tag_ty.enumTagFieldIndex(tag_and_val.tag).?; + assert(union_obj.haveFieldTypes()); + const field_ty = union_obj.fields.values()[field_index].ty; const payload = p: { if (!field_ty.hasCodeGenBits()) { const padding_len = @intCast(c_uint, layout.payload_size); @@ -1475,10 +1508,20 @@ pub const DeclGen = struct { }; break :p dg.context.constStruct(&fields, fields.len, .False); }; + + // In this case we must make an unnamed struct because LLVM does + // not support bitcasting our payload struct to the true union payload type. + // Instead we use an unnamed struct and every reference to the global + // must pointer cast to the expected type before accessing the union. + const need_unnamed = layout.most_aligned_field != field_index; + if (layout.tag_size == 0) { - const llvm_payload_ty = llvm_union_ty.structGetTypeAtIndex(0); - const fields: [1]*const llvm.Value = .{payload.constBitCast(llvm_payload_ty)}; - return llvm_union_ty.constNamedStruct(&fields, fields.len); + const fields: [1]*const llvm.Value = .{payload}; + if (need_unnamed) { + return dg.context.constStruct(&fields, fields.len, .False); + } else { + return llvm_union_ty.constNamedStruct(&fields, fields.len); + } } const llvm_tag_value = try genTypedValue(dg, .{ .ty = tv.ty.unionTagType().?, @@ -1486,13 +1529,15 @@ pub const DeclGen = struct { }); var fields: [2]*const llvm.Value = undefined; if (layout.tag_align >= layout.payload_align) { - fields[0] = llvm_tag_value; - fields[1] = payload.constBitCast(llvm_union_ty.structGetTypeAtIndex(1)); + fields = .{ llvm_tag_value, payload }; } else { - fields[0] = payload.constBitCast(llvm_union_ty.structGetTypeAtIndex(0)); - fields[1] = llvm_tag_value; + fields = .{ payload, llvm_tag_value }; + } + if (need_unnamed) { + return dg.context.constStruct(&fields, fields.len, .False); + } else { + return llvm_union_ty.constNamedStruct(&fields, fields.len); } - return llvm_union_ty.constNamedStruct(&fields, fields.len); }, .Vector => switch (tv.val.tag()) { .bytes => { @@ -1859,8 +1904,14 @@ pub const FuncGen = struct { global.setGlobalConstant(.True); global.setUnnamedAddr(.True); global.setAlignment(ty.abiAlignment(target)); - gop.value_ptr.* = global; - return global; + // Because of LLVM limitations for lowering certain types such as unions, + // the type of global constants might not match the type it is supposed to + // be, and so we must bitcast the pointer at the usage sites. + const wanted_llvm_ty = try self.dg.llvmType(ty); + const wanted_llvm_ptr_ty = wanted_llvm_ty.pointerType(0); + const casted_ptr = global.constBitCast(wanted_llvm_ptr_ty); + gop.value_ptr.* = casted_ptr; + return casted_ptr; } fn genBody(self: *FuncGen, body: []const Air.Inst.Index) Error!void { diff --git a/src/codegen/llvm/bindings.zig b/src/codegen/llvm/bindings.zig index 3c85524cc1..3dac5bdca4 100644 --- a/src/codegen/llvm/bindings.zig +++ b/src/codegen/llvm/bindings.zig @@ -184,6 +184,9 @@ pub const Value = opaque { pub const setValueName2 = LLVMSetValueName2; extern fn LLVMSetValueName2(Val: *const Value, Name: [*]const u8, NameLen: usize) void; + pub const takeName = ZigLLVMTakeName; + extern fn ZigLLVMTakeName(new_owner: *const Value, victim: *const Value) void; + pub const deleteFunction = LLVMDeleteFunction; extern fn LLVMDeleteFunction(Fn: *const Value) void; @@ -206,6 +209,21 @@ pub const Value = opaque { return LLVMIsPoison(Val).toBool(); } extern fn LLVMIsPoison(Val: *const Value) Bool; + + pub const replaceAllUsesWith = LLVMReplaceAllUsesWith; + extern fn LLVMReplaceAllUsesWith(OldVal: *const Value, NewVal: *const Value) void; + + pub const globalGetValueType = LLVMGlobalGetValueType; + extern fn LLVMGlobalGetValueType(Global: *const Value) *const Type; + + pub const getLinkage = LLVMGetLinkage; + extern fn LLVMGetLinkage(Global: *const Value) Linkage; + + pub const getUnnamedAddress = LLVMGetUnnamedAddress; + extern fn LLVMGetUnnamedAddress(Global: *const Value) Bool; + + pub const getAlignment = LLVMGetAlignment; + extern fn LLVMGetAlignment(V: *const Value) c_uint; }; pub const Type = opaque { diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 11cbf38368..aa07a7fed1 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -1319,6 +1319,10 @@ LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMVal return wrap(unwrap(B)->CreateFMulReduce(unwrap(Acc), unwrap(Val))); } +void ZigLLVMTakeName(LLVMValueRef new_owner, LLVMValueRef victim) { + unwrap(new_owner)->takeName(unwrap(victim)); +} + static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, ""); static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, ""); static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, ""); diff --git a/src/zig_llvm.h b/src/zig_llvm.h index d1e4fa2556..2b8156d51d 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -460,17 +460,19 @@ enum ZigLLVM_ObjectFormatType { ZigLLVM_XCOFF, }; -LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); -LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); -LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); -LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildAndReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildOrReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildXorReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val); + +ZIG_EXTERN_C void ZigLLVMTakeName(LLVMValueRef new_owner, LLVMValueRef victim); #define ZigLLVM_DIFlags_Zero 0U #define ZigLLVM_DIFlags_Private 1U diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index f5fc282b05..16bb890c9e 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -294,3 +294,22 @@ test "switch on union with some prongs capturing" { }; try expect(y == 11); } + +const Number = union(enum) { + One: u64, + Two: u8, + Three: f32, +}; + +const number = Number{ .Three = 1.23 }; + +fn returnsFalse() bool { + switch (number) { + Number.One => |x| return x > 1234, + Number.Two => |x| return x == 'a', + Number.Three => |x| return x > 12.34, + } +} +test "switch on const enum with var" { + try expect(!returnsFalse()); +} diff --git a/test/behavior/switch_stage1.zig b/test/behavior/switch_stage1.zig index 6a86eb9494..1b85d767d5 100644 --- a/test/behavior/switch_stage1.zig +++ b/test/behavior/switch_stage1.zig @@ -3,24 +3,6 @@ const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; -const Number = union(enum) { - One: u64, - Two: u8, - Three: f32, -}; - -const number = Number{ .Three = 1.23 }; - -fn returnsFalse() bool { - switch (number) { - Number.One => |x| return x > 1234, - Number.Two => |x| return x == 'a', - Number.Three => |x| return x > 12.34, - } -} -test "switch on const enum with var" { - try expect(!returnsFalse()); -} test "switch all prongs unreachable" { try testAllProngsUnreachable(); comptime try testAllProngsUnreachable(); diff --git a/test/behavior/union.zig b/test/behavior/union.zig index e296f6bbb8..033d2cf40b 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -71,3 +71,84 @@ test "0-sized extern union definition" { try expect(U.f == 1); } + +const Value = union(enum) { + Int: u64, + Array: [9]u8, +}; + +const Agg = struct { + val1: Value, + val2: Value, +}; + +const v1 = Value{ .Int = 1234 }; +const v2 = Value{ .Array = [_]u8{3} ** 9 }; + +const err = @as(anyerror!Agg, Agg{ + .val1 = v1, + .val2 = v2, +}); + +const array = [_]Value{ v1, v2, v1, v2 }; + +test "unions embedded in aggregate types" { + switch (array[1]) { + Value.Array => |arr| try expect(arr[4] == 3), + else => unreachable, + } + switch ((err catch unreachable).val1) { + Value.Int => |x| try expect(x == 1234), + else => unreachable, + } +} + +test "access a member of tagged union with conflicting enum tag name" { + const Bar = union(enum) { + A: A, + B: B, + + const A = u8; + const B = void; + }; + + comptime try expect(Bar.A == u8); +} + +test "constant tagged union with payload" { + var empty = TaggedUnionWithPayload{ .Empty = {} }; + var full = TaggedUnionWithPayload{ .Full = 13 }; + shouldBeEmpty(empty); + shouldBeNotEmpty(full); +} + +fn shouldBeEmpty(x: TaggedUnionWithPayload) void { + switch (x) { + TaggedUnionWithPayload.Empty => {}, + else => unreachable, + } +} + +fn shouldBeNotEmpty(x: TaggedUnionWithPayload) void { + switch (x) { + TaggedUnionWithPayload.Empty => unreachable, + else => {}, + } +} + +const TaggedUnionWithPayload = union(enum) { + Empty: void, + Full: i32, +}; + +test "union alignment" { + comptime { + try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf([9]u8)); + try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf(u64)); + } +} + +const AlignTestTaggedUnion = union(enum) { + A: [9]u8, + B: u64, +}; diff --git a/test/behavior/union_stage1.zig b/test/behavior/union_stage1.zig index 6a68737ecf..ac69915a39 100644 --- a/test/behavior/union_stage1.zig +++ b/test/behavior/union_stage1.zig @@ -3,37 +3,6 @@ const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; const Tag = std.meta.Tag; -const Value = union(enum) { - Int: u64, - Array: [9]u8, -}; - -const Agg = struct { - val1: Value, - val2: Value, -}; - -const v1 = Value{ .Int = 1234 }; -const v2 = Value{ .Array = [_]u8{3} ** 9 }; - -const err = @as(anyerror!Agg, Agg{ - .val1 = v1, - .val2 = v2, -}); - -const array = [_]Value{ v1, v2, v1, v2 }; - -test "unions embedded in aggregate types" { - switch (array[1]) { - Value.Array => |arr| try expect(arr[4] == 3), - else => unreachable, - } - switch ((err catch unreachable).val1) { - Value.Int => |x| try expect(x == 1234), - else => unreachable, - } -} - const Letter = enum { A, B, C }; const Payload = union(Letter) { A: i32, @@ -202,18 +171,6 @@ const PartialInstWithPayload = union(enum) { Compiled: i32, }; -test "access a member of tagged union with conflicting enum tag name" { - const Bar = union(enum) { - A: A, - B: B, - - const A = u8; - const B = void; - }; - - comptime try expect(Bar.A == u8); -} - test "tagged union initialization with runtime void" { try expect(testTaggedUnionInit({})); } @@ -775,41 +732,3 @@ test "tagged union as return value" { fn returnAnInt(x: i32) TaggedFoo { return TaggedFoo{ .One = x }; } - -test "constant tagged union with payload" { - var empty = TaggedUnionWithPayload{ .Empty = {} }; - var full = TaggedUnionWithPayload{ .Full = 13 }; - shouldBeEmpty(empty); - shouldBeNotEmpty(full); -} - -fn shouldBeEmpty(x: TaggedUnionWithPayload) void { - switch (x) { - TaggedUnionWithPayload.Empty => {}, - else => unreachable, - } -} - -fn shouldBeNotEmpty(x: TaggedUnionWithPayload) void { - switch (x) { - TaggedUnionWithPayload.Empty => unreachable, - else => {}, - } -} - -const TaggedUnionWithPayload = union(enum) { - Empty: void, - Full: i32, -}; - -test "union alignment" { - comptime { - try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf([9]u8)); - try expect(@alignOf(AlignTestTaggedUnion) >= @alignOf(u64)); - } -} - -const AlignTestTaggedUnion = union(enum) { - A: [9]u8, - B: u64, -};