From c0aa4a1a42b3e0d312bd274799be67d60a1c0238 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 27 Sep 2021 19:48:42 -0700 Subject: [PATCH] stage2: implement basic unions * AIR instructions struct_field_ptr and related functions now are also emitted by the frontend for unions. Backends must inspect the type of the pointer operand to lower the instructions correctly. - These will be renamed to `agg_field_ptr` (short for "aggregate") in the future. * Introduce the new `set_union_tag` AIR instruction. * Introduce `Module.EnumNumbered` and associated `Type` methods. This is for enums which have no decls, but do have the possibility of overriding the integer tag type and tag values. * Sema: Implement support for union tag types in both the auto-generated and explicitly-provided cases, as well as explicitly provided enum tag values in union declarations. * LLVM backend: implement lowering union types, union field pointer instructions, and the new `set_union_tag` instruction. --- src/Air.zig | 14 +- src/Liveness.zig | 1 + src/Module.zig | 107 ++++++++++- src/Sema.zig | 320 +++++++++++++++++++++++++++------ src/codegen.zig | 9 + src/codegen/c.zig | 16 ++ src/codegen/llvm.zig | 80 ++++++++- src/print_air.zig | 1 + src/type.zig | 170 +++++++++--------- test/behavior/union.zig | 12 ++ test/behavior/union_stage1.zig | 7 - 11 files changed, 576 insertions(+), 161 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index 4341271f3a..40070dccfb 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -270,19 +270,26 @@ pub const Inst = struct { /// wrap from E to E!T /// Uses the `ty_op` field. wrap_errunion_err, - /// Given a pointer to a struct and a field index, returns a pointer to the field. + /// Given a pointer to a struct or union and a field index, returns a pointer to the field. /// Uses the `ty_pl` field, payload is `StructField`. + /// TODO rename to `agg_field_ptr`. struct_field_ptr, - /// Given a pointer to a struct, returns a pointer to the field. + /// Given a pointer to a struct or union, returns a pointer to the field. /// The field index is the number at the end of the name. /// Uses `ty_op` field. + /// TODO rename to `agg_field_ptr_index_X` struct_field_ptr_index_0, struct_field_ptr_index_1, struct_field_ptr_index_2, struct_field_ptr_index_3, - /// Given a byval struct and a field index, returns the field byval. + /// Given a byval struct or union and a field index, returns the field byval. /// Uses the `ty_pl` field, payload is `StructField`. + /// TODO rename to `agg_field_val` struct_field_val, + /// Given a pointer to a tagged union, set its tag to the provided value. + /// Result type is always void. + /// Uses the `bin_op` field. LHS is union pointer, RHS is new tag value. + set_union_tag, /// Given a slice value, return the length. /// Result type is always usize. /// Uses the `ty_op` field. @@ -643,6 +650,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .atomic_store_seq_cst, .memset, .memcpy, + .set_union_tag, => return Type.initTag(.void), .ptrtoint, diff --git a/src/Liveness.zig b/src/Liveness.zig index 42ab1ab351..9a7126d135 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -256,6 +256,7 @@ fn analyzeInst( .atomic_store_monotonic, .atomic_store_release, .atomic_store_seq_cst, + .set_union_tag, => { const o = inst_datas[inst].bin_op; return trackOperands(a, new_set, inst, main_tomb, .{ o.lhs, o.rhs, .none }); diff --git a/src/Module.zig b/src/Module.zig index dbece09255..83bbbb6366 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -859,6 +859,36 @@ pub const EnumSimple = struct { } }; +/// Represents the data that an enum declaration provides, when there are no +/// declarations. However an integer tag type is provided, and the enum tag values +/// are explicitly provided. +pub const EnumNumbered = struct { + /// The Decl that corresponds to the enum itself. + owner_decl: *Decl, + /// An integer type which is used for the numerical value of the enum. + /// Whether zig chooses this type or the user specifies it, it is stored here. + tag_ty: Type, + /// Set of field names in declaration order. + fields: NameMap, + /// Maps integer tag value to field index. + /// Entries are in declaration order, same as `fields`. + /// If this hash map is empty, it means the enum tags are auto-numbered. + values: ValueMap, + /// Offset from `owner_decl`, points to the enum decl AST node. + node_offset: i32, + + pub const NameMap = EnumFull.NameMap; + pub const ValueMap = EnumFull.ValueMap; + + pub fn srcLoc(self: EnumNumbered) SrcLoc { + return .{ + .file_scope = self.owner_decl.getFileScope(), + .parent_decl_node = self.owner_decl.src_node, + .lazy = .{ .node_offset = self.node_offset }, + }; + } +}; + /// Represents the data that an enum declaration provides, when there is /// at least one tag value explicitly specified, or at least one declaration. pub const EnumFull = struct { @@ -868,16 +898,17 @@ pub const EnumFull = struct { /// Whether zig chooses this type or the user specifies it, it is stored here. tag_ty: Type, /// Set of field names in declaration order. - fields: std.StringArrayHashMapUnmanaged(void), + fields: NameMap, /// Maps integer tag value to field index. /// Entries are in declaration order, same as `fields`. /// If this hash map is empty, it means the enum tags are auto-numbered. values: ValueMap, - /// Represents the declarations inside this struct. + /// Represents the declarations inside this enum. namespace: Scope.Namespace, /// Offset from `owner_decl`, points to the enum decl AST node. node_offset: i32, + pub const NameMap = std.StringArrayHashMapUnmanaged(void); pub const ValueMap = std.ArrayHashMapUnmanaged(Value, void, Value.ArrayHashContext, false); pub fn srcLoc(self: EnumFull) SrcLoc { @@ -933,6 +964,44 @@ pub const Union = struct { .lazy = .{ .node_offset = self.node_offset }, }; } + + pub fn haveFieldTypes(u: Union) bool { + return switch (u.status) { + .none, + .field_types_wip, + => false, + .have_field_types, + .layout_wip, + .have_layout, + => true, + }; + } + + pub fn onlyTagHasCodegenBits(u: Union) bool { + assert(u.haveFieldTypes()); + for (u.fields.values()) |field| { + if (field.ty.hasCodeGenBits()) return false; + } + return true; + } + + pub fn mostAlignedField(u: Union, target: Target) u32 { + assert(u.haveFieldTypes()); + var most_alignment: u64 = 0; + var most_index: usize = undefined; + for (u.fields.values()) |field, i| { + if (!field.ty.hasCodeGenBits()) continue; + const field_align = if (field.abi_align.tag() == .abi_align_default) + field.ty.abiAlignment(target) + else + field.abi_align.toUnsignedInt(); + if (field_align > most_alignment) { + most_alignment = field_align; + most_index = i; + } + } + return @intCast(u32, most_index); + } }; /// Some Fn struct memory is owned by the Decl's TypedValue.Managed arena allocator. @@ -1543,6 +1612,40 @@ pub const Scope = struct { }); } + pub fn addStructFieldPtr( + block: *Block, + struct_ptr: Air.Inst.Ref, + field_index: u32, + ptr_field_ty: Type, + ) !Air.Inst.Ref { + const ty = try block.sema.addType(ptr_field_ty); + const tag: Air.Inst.Tag = switch (field_index) { + 0 => .struct_field_ptr_index_0, + 1 => .struct_field_ptr_index_1, + 2 => .struct_field_ptr_index_2, + 3 => .struct_field_ptr_index_3, + else => { + return block.addInst(.{ + .tag = .struct_field_ptr, + .data = .{ .ty_pl = .{ + .ty = ty, + .payload = try block.sema.addExtra(Air.StructField{ + .struct_operand = struct_ptr, + .field_index = @intCast(u32, field_index), + }), + } }, + }); + }, + }; + return block.addInst(.{ + .tag = tag, + .data = .{ .ty_op = .{ + .ty = ty, + .operand = struct_ptr, + } }, + }); + } + pub fn addInst(block: *Block, inst: Air.Inst) error{OutOfMemory}!Air.Inst.Ref { return Air.indexToRef(try block.addInstAsIndex(inst)); } diff --git a/src/Sema.zig b/src/Sema.zig index 533252d682..f076389797 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1625,7 +1625,7 @@ fn zirAllocMut(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileEr if (block.is_comptime) { return sema.analyzeComptimeAlloc(block, var_type); } - try sema.validateVarType(block, ty_src, var_type); + try sema.validateVarType(block, ty_src, var_type, false); const ptr_type = try Type.ptr(sema.arena, .{ .pointee_type = var_type, .@"addrspace" = target_util.defaultAddressSpace(sema.mod.getTarget(), .local), @@ -1711,7 +1711,7 @@ fn zirResolveInferredAlloc(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Inde const peer_inst_list = inferred_alloc.data.stored_inst_list.items; const final_elem_ty = try sema.resolvePeerTypes(block, ty_src, peer_inst_list, .none); if (var_is_mut) { - try sema.validateVarType(block, ty_src, final_elem_ty); + try sema.validateVarType(block, ty_src, final_elem_ty, false); } // Change it to a normal alloc. const final_ptr_ty = try Type.ptr(sema.arena, .{ @@ -1730,19 +1730,82 @@ fn zirValidateStructInitPtr(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Ind const tracy = trace(@src()); defer tracy.end(); - const gpa = sema.gpa; - const mod = sema.mod; const validate_inst = sema.code.instructions.items(.data)[inst].pl_node; - const struct_init_src = validate_inst.src(); + const init_src = validate_inst.src(); const validate_extra = sema.code.extraData(Zir.Inst.Block, validate_inst.payload_index); const instrs = sema.code.extra[validate_extra.end..][0..validate_extra.data.body_len]; + const field_ptr_data = sema.code.instructions.items(.data)[instrs[0]].pl_node; + const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data; + const object_ptr = sema.resolveInst(field_ptr_extra.lhs); + const agg_ty = sema.typeOf(object_ptr).elemType(); + switch (agg_ty.zigTypeTag()) { + .Struct => return sema.validateStructInitPtr( + block, + agg_ty.castTag(.@"struct").?.data, + init_src, + instrs, + ), + .Union => return sema.validateUnionInitPtr( + block, + agg_ty.cast(Type.Payload.Union).?.data, + init_src, + instrs, + object_ptr, + ), + else => unreachable, + } +} - const struct_obj: *Module.Struct = s: { - const field_ptr_data = sema.code.instructions.items(.data)[instrs[0]].pl_node; - const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data; - const object_ptr = sema.resolveInst(field_ptr_extra.lhs); - break :s sema.typeOf(object_ptr).elemType().castTag(.@"struct").?.data; - }; +fn validateUnionInitPtr( + sema: *Sema, + block: *Scope.Block, + union_obj: *Module.Union, + init_src: LazySrcLoc, + instrs: []const Zir.Inst.Index, + union_ptr: Air.Inst.Ref, +) CompileError!void { + const mod = sema.mod; + + if (instrs.len != 1) { + // TODO add note for other field + // TODO add note for union declared here + return mod.fail(&block.base, init_src, "only one union field can be active at once", .{}); + } + + const field_ptr = instrs[0]; + const field_ptr_data = sema.code.instructions.items(.data)[field_ptr].pl_node; + const field_src: LazySrcLoc = .{ .node_offset_back2tok = field_ptr_data.src_node }; + const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data; + const field_name = sema.code.nullTerminatedString(field_ptr_extra.field_name_start); + const field_index_big = union_obj.fields.getIndex(field_name) orelse + return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name); + const field_index = @intCast(u32, field_index_big); + + // TODO here we need to go back and see if we need to convert the union + // to a comptime-known value. This will involve editing the AIR code we have + // generated so far - in particular deleting some runtime pointer bitcast + // instructions which are not actually needed if the initialization expression + // ends up being comptime-known. + + // Otherwise, we set the new union tag now. + const new_tag = try sema.addConstant( + union_obj.tag_ty, + try Value.Tag.enum_field_index.create(sema.arena, field_index), + ); + + try sema.requireRuntimeBlock(block, init_src); + _ = try block.addBinOp(.set_union_tag, union_ptr, new_tag); +} + +fn validateStructInitPtr( + sema: *Sema, + block: *Scope.Block, + struct_obj: *Module.Struct, + init_src: LazySrcLoc, + instrs: []const Zir.Inst.Index, +) CompileError!void { + const gpa = sema.gpa; + const mod = sema.mod; // Maps field index to field_ptr index of where it was already initialized. const found_fields = try gpa.alloc(Zir.Inst.Index, struct_obj.fields.count()); @@ -1781,9 +1844,9 @@ fn zirValidateStructInitPtr(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Ind const template = "missing struct field: {s}"; const args = .{field_name}; if (root_msg) |msg| { - try mod.errNote(&block.base, struct_init_src, msg, template, args); + try mod.errNote(&block.base, init_src, msg, template, args); } else { - root_msg = try mod.errMsg(&block.base, struct_init_src, template, args); + root_msg = try mod.errMsg(&block.base, init_src, template, args); } } if (root_msg) |msg| { @@ -8037,7 +8100,7 @@ fn checkAtomicOperandType( const max_atomic_bits = target_util.largestAtomicBits(target); const int_ty = switch (ty.zigTypeTag()) { .Int => ty, - .Enum => ty.enumTagType(&buffer), + .Enum => ty.intTagType(&buffer), .Float => { const bit_count = ty.floatBits(target); if (bit_count > max_atomic_bits) { @@ -8621,11 +8684,7 @@ fn zirVarExtended( return sema.failWithNeededComptime(block, init_src); } else Value.initTag(.unreachable_value); - if (!var_ty.isValidVarType(small.is_extern)) { - return sema.mod.fail(&block.base, mut_src, "variable of type '{}' must be const", .{ - var_ty, - }); - } + try sema.validateVarType(block, mut_src, var_ty, small.is_extern); if (lib_name != null) { // Look at the sema code for functions which has this logic, it just needs to @@ -8810,9 +8869,54 @@ fn requireIntegerType(sema: *Sema, block: *Scope.Block, src: LazySrcLoc, ty: Typ } } -fn validateVarType(sema: *Sema, block: *Scope.Block, src: LazySrcLoc, ty: Type) !void { - if (!ty.isValidVarType(false)) { - return sema.mod.fail(&block.base, src, "variable of type '{}' must be const or comptime", .{ty}); +/// Emit a compile error if type cannot be used for a runtime variable. +fn validateVarType( + sema: *Sema, + block: *Scope.Block, + src: LazySrcLoc, + var_ty: Type, + is_extern: bool, +) CompileError!void { + var ty = var_ty; + const ok: bool = while (true) switch (ty.zigTypeTag()) { + .Bool, + .Int, + .Float, + .ErrorSet, + .Enum, + .Frame, + .AnyFrame, + => break true, + + .BoundFn, + .ComptimeFloat, + .ComptimeInt, + .EnumLiteral, + .NoReturn, + .Type, + .Void, + .Undefined, + .Null, + => break false, + + .Opaque => break is_extern, + + .Optional => { + var buf: Type.Payload.ElemType = undefined; + const child_ty = ty.optionalChild(&buf); + return validateVarType(sema, block, src, child_ty, is_extern); + }, + .Pointer, .Array, .Vector => ty = ty.elemType(), + .ErrorUnion => ty = ty.errorUnionPayload(), + + .Fn => @panic("TODO fn validateVarType"), + .Struct, .Union => { + const resolved_ty = try sema.resolveTypeFields(block, src, ty); + break !resolved_ty.requiresComptime(); + }, + } else unreachable; // TODO should not need else unreachable + if (!ok) { + return sema.mod.fail(&block.base, src, "variable of type '{}' must be const or comptime", .{var_ty}); } } @@ -9393,8 +9497,9 @@ fn structFieldPtr( const struct_ty = try sema.resolveTypeFields(block, src, unresolved_struct_ty); const struct_obj = struct_ty.castTag(.@"struct").?.data; - const field_index = struct_obj.fields.getIndex(field_name) orelse + const field_index_big = struct_obj.fields.getIndex(field_name) orelse return sema.failWithBadFieldAccess(block, struct_obj, field_name_src, field_name); + const field_index = @intCast(u32, field_index_big); const field = struct_obj.fields.values()[field_index]; const ptr_field_ty = try Type.ptr(arena, .{ .pointee_type = field.ty, @@ -9413,31 +9518,7 @@ fn structFieldPtr( } try sema.requireRuntimeBlock(block, src); - const tag: Air.Inst.Tag = switch (field_index) { - 0 => .struct_field_ptr_index_0, - 1 => .struct_field_ptr_index_1, - 2 => .struct_field_ptr_index_2, - 3 => .struct_field_ptr_index_3, - else => { - return block.addInst(.{ - .tag = .struct_field_ptr, - .data = .{ .ty_pl = .{ - .ty = try sema.addType(ptr_field_ty), - .payload = try sema.addExtra(Air.StructField{ - .struct_operand = struct_ptr, - .field_index = @intCast(u32, field_index), - }), - } }, - }); - }, - }; - return block.addInst(.{ - .tag = tag, - .data = .{ .ty_op = .{ - .ty = try sema.addType(ptr_field_ty), - .operand = struct_ptr, - } }, - }); + return block.addStructFieldPtr(struct_ptr, field_index, ptr_field_ty); } fn structFieldVal( @@ -9487,7 +9568,6 @@ fn unionFieldPtr( field_name_src: LazySrcLoc, unresolved_union_ty: Type, ) CompileError!Air.Inst.Ref { - const mod = sema.mod; const arena = sema.arena; assert(unresolved_union_ty.zigTypeTag() == .Union); @@ -9495,8 +9575,9 @@ fn unionFieldPtr( const union_ty = try sema.resolveTypeFields(block, src, unresolved_union_ty); const union_obj = union_ty.cast(Type.Payload.Union).?.data; - const field_index = union_obj.fields.getIndex(field_name) orelse + const field_index_big = union_obj.fields.getIndex(field_name) orelse return sema.failWithBadUnionFieldAccess(block, union_obj, field_name_src, field_name); + const field_index = @intCast(u32, field_index_big); const field = union_obj.fields.values()[field_index]; const ptr_field_ty = try Type.ptr(arena, .{ @@ -9517,7 +9598,7 @@ fn unionFieldPtr( } try sema.requireRuntimeBlock(block, src); - return mod.fail(&block.base, src, "TODO implement runtime union field access", .{}); + return block.addStructFieldPtr(union_ptr, field_index, ptr_field_ty); } fn unionFieldVal( @@ -11160,6 +11241,28 @@ fn analyzeUnionFields( if (body.len != 0) { _ = try sema.analyzeBody(block, body); } + var int_tag_ty: Type = undefined; + var enum_field_names: ?*Module.EnumNumbered.NameMap = null; + var enum_value_map: ?*Module.EnumNumbered.ValueMap = null; + if (tag_type_ref != .none) { + const provided_ty = try sema.resolveType(block, src, tag_type_ref); + if (small.auto_enum_tag) { + // The provided type is an integer type and we must construct the enum tag type here. + int_tag_ty = provided_ty; + union_obj.tag_ty = try sema.generateUnionTagTypeNumbered(block, fields_len, provided_ty); + enum_field_names = &union_obj.tag_ty.castTag(.enum_numbered).?.data.fields; + enum_value_map = &union_obj.tag_ty.castTag(.enum_numbered).?.data.values; + } else { + // The provided type is the enum tag type. + union_obj.tag_ty = provided_ty; + } + } else { + // If auto_enum_tag is false, this is an untagged union. However, for semantic analysis + // purposes, we still auto-generate an enum tag type the same way. That the union is + // untagged is represented by the Type tag (union vs union_tagged). + union_obj.tag_ty = try sema.generateUnionTagTypeSimple(block, fields_len); + enum_field_names = &union_obj.tag_ty.castTag(.enum_simple).?.data.fields; + } const bits_per_field = 4; const fields_per_u32 = 32 / bits_per_field; @@ -11198,12 +11301,25 @@ fn analyzeUnionFields( break :blk align_ref; } else .none; - if (has_tag) { + const tag_ref: Zir.Inst.Ref = if (has_tag) blk: { + const tag_ref = @intToEnum(Zir.Inst.Ref, zir.extra[extra_index]); extra_index += 1; + break :blk tag_ref; + } else .none; + + if (enum_value_map) |map| { + const tag_src = src; // TODO better source location + const coerced = try sema.coerce(block, int_tag_ty, tag_ref, tag_src); + const val = try sema.resolveConstValue(block, tag_src, coerced); + map.putAssumeCapacityContext(val, {}, .{ .ty = int_tag_ty }); } // This string needs to outlive the ZIR code. const field_name = try decl_arena.allocator.dupe(u8, field_name_zir); + if (enum_field_names) |set| { + set.putAssumeCapacity(field_name, {}); + } + const field_ty: Type = if (field_type_ref == .none) Type.initTag(.void) else @@ -11225,11 +11341,84 @@ fn analyzeUnionFields( // But only resolve the source location if we need to emit a compile error. const abi_align_val = (try sema.resolveInstConst(block, src, align_ref)).val; gop.value_ptr.abi_align = try abi_align_val.copy(&decl_arena.allocator); + } else { + gop.value_ptr.abi_align = Value.initTag(.abi_align_default); } } +} - // TODO resolve the union tag_type_ref - _ = tag_type_ref; +fn generateUnionTagTypeNumbered( + sema: *Sema, + block: *Scope.Block, + fields_len: u32, + int_ty: Type, +) !Type { + const mod = sema.mod; + + var new_decl_arena = std.heap.ArenaAllocator.init(sema.gpa); + errdefer new_decl_arena.deinit(); + + const enum_obj = try new_decl_arena.allocator.create(Module.EnumNumbered); + const enum_ty_payload = try new_decl_arena.allocator.create(Type.Payload.EnumNumbered); + enum_ty_payload.* = .{ + .base = .{ .tag = .enum_numbered }, + .data = enum_obj, + }; + const enum_ty = Type.initPayload(&enum_ty_payload.base); + const enum_val = try Value.Tag.ty.create(&new_decl_arena.allocator, enum_ty); + // TODO better type name + const new_decl = try mod.createAnonymousDecl(&block.base, .{ + .ty = Type.initTag(.type), + .val = enum_val, + }); + new_decl.owns_tv = true; + errdefer sema.mod.deleteAnonDecl(&block.base, new_decl); + + enum_obj.* = .{ + .owner_decl = new_decl, + .tag_ty = int_ty, + .fields = .{}, + .values = .{}, + .node_offset = 0, + }; + // Here we pre-allocate the maps using the decl arena. + try enum_obj.fields.ensureTotalCapacity(&new_decl_arena.allocator, fields_len); + try enum_obj.values.ensureTotalCapacityContext(&new_decl_arena.allocator, fields_len, .{ .ty = int_ty }); + try new_decl.finalizeNewArena(&new_decl_arena); + return enum_ty; +} + +fn generateUnionTagTypeSimple(sema: *Sema, block: *Scope.Block, fields_len: u32) !Type { + const mod = sema.mod; + + var new_decl_arena = std.heap.ArenaAllocator.init(sema.gpa); + errdefer new_decl_arena.deinit(); + + const enum_obj = try new_decl_arena.allocator.create(Module.EnumSimple); + const enum_ty_payload = try new_decl_arena.allocator.create(Type.Payload.EnumSimple); + enum_ty_payload.* = .{ + .base = .{ .tag = .enum_simple }, + .data = enum_obj, + }; + const enum_ty = Type.initPayload(&enum_ty_payload.base); + const enum_val = try Value.Tag.ty.create(&new_decl_arena.allocator, enum_ty); + // TODO better type name + const new_decl = try mod.createAnonymousDecl(&block.base, .{ + .ty = Type.initTag(.type), + .val = enum_val, + }); + new_decl.owns_tv = true; + errdefer sema.mod.deleteAnonDecl(&block.base, new_decl); + + enum_obj.* = .{ + .owner_decl = new_decl, + .fields = .{}, + .node_offset = 0, + }; + // Here we pre-allocate the maps using the decl arena. + try enum_obj.fields.ensureTotalCapacity(&new_decl_arena.allocator, fields_len); + try new_decl.finalizeNewArena(&new_decl_arena); + return enum_ty; } fn getBuiltin( @@ -11367,11 +11556,28 @@ fn typeHasOnePossibleValue( } return Value.initTag(.empty_struct_value); }, + .enum_numbered => { + const resolved_ty = try sema.resolveTypeFields(block, src, ty); + const enum_obj = resolved_ty.castTag(.enum_numbered).?.data; + if (enum_obj.fields.count() == 1) { + if (enum_obj.values.count() == 0) { + return Value.initTag(.zero); // auto-numbered + } else { + return enum_obj.values.keys()[0]; + } + } else { + return null; + } + }, .enum_full => { const resolved_ty = try sema.resolveTypeFields(block, src, ty); - const enum_full = resolved_ty.castTag(.enum_full).?.data; - if (enum_full.fields.count() == 1) { - return enum_full.values.keys()[0]; + const enum_obj = resolved_ty.castTag(.enum_full).?.data; + if (enum_obj.fields.count() == 1) { + if (enum_obj.values.count() == 0) { + return Value.initTag(.zero); // auto-numbered + } else { + return enum_obj.values.keys()[0]; + } } else { return null; } diff --git a/src/codegen.zig b/src/codegen.zig index 102f8d4985..6a605edca9 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -889,6 +889,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .atomic_load => try self.airAtomicLoad(inst), .memcpy => try self.airMemcpy(inst), .memset => try self.airMemset(inst), + .set_union_tag => try self.airSetUnionTag(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -1543,6 +1544,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none }); } + fn airSetUnionTag(self: *Self, inst: Air.Inst.Index) !void { + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const result: MCValue = switch (arch) { + else => return self.fail("TODO implement airSetUnionTag for {}", .{self.target.cpu.arch}), + }; + return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none }); + } + fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool { if (!self.liveness.operandDies(inst, op_index)) return false; diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 16b13db292..fc0c86b8f1 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -955,6 +955,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO .atomic_load => try airAtomicLoad(f, inst), .memset => try airMemset(f, inst), .memcpy => try airMemcpy(f, inst), + .set_union_tag => try airSetUnionTag(f, inst), .int_to_float, .float_to_int, @@ -2080,6 +2081,21 @@ fn airMemcpy(f: *Function, inst: Air.Inst.Index) !CValue { return CValue.none; } +fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue { + const bin_op = f.air.instructions.items(.data)[inst].bin_op; + const union_ptr = try f.resolveInst(bin_op.lhs); + const new_tag = try f.resolveInst(bin_op.rhs); + const writer = f.object.writer(); + + try writer.writeAll("*"); + try f.writeCValue(writer, union_ptr); + try writer.writeAll(" = "); + try f.writeCValue(writer, new_tag); + try writer.writeAll(";\n"); + + return CValue.none; +} + fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 { return switch (order) { .Unordered => "memory_order_relaxed", diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index b15834c963..ab164b5d91 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -735,7 +735,7 @@ pub const DeclGen = struct { }, .Enum => { var buffer: Type.Payload.Bits = undefined; - const int_ty = t.enumTagType(&buffer); + const int_ty = t.intTagType(&buffer); const bit_count = int_ty.intInfo(self.module.getTarget()).bits; return self.context.intType(bit_count); }, @@ -812,6 +812,29 @@ pub const DeclGen = struct { .False, ); }, + .Union => { + const union_obj = t.castTag(.@"union").?.data; + assert(union_obj.haveFieldTypes()); + + const enum_tag_ty = union_obj.tag_ty; + const enum_tag_llvm_ty = try self.llvmType(enum_tag_ty); + + if (union_obj.onlyTagHasCodegenBits()) { + return enum_tag_llvm_ty; + } + + const target = self.module.getTarget(); + const most_aligned_field_index = union_obj.mostAlignedField(target); + const most_aligned_field = union_obj.fields.values()[most_aligned_field_index]; + // TODO handle when the most aligned field is different than the + // biggest sized field. + + const llvm_fields = [_]*const llvm.Type{ + try self.llvmType(most_aligned_field.ty), + enum_tag_llvm_ty, + }; + return self.context.structType(&llvm_fields, llvm_fields.len, .False); + }, .Fn => { const ret_ty = try self.llvmType(t.fnReturnType()); const params_len = t.fnParamLen(); @@ -840,7 +863,6 @@ pub const DeclGen = struct { .BoundFn => @panic("TODO remove BoundFn from the language"), - .Union, .Opaque, .Frame, .AnyFrame, @@ -1131,7 +1153,7 @@ pub const DeclGen = struct { var buffer: Type.Payload.Bits = undefined; const int_ty = switch (ty.zigTypeTag()) { .Int => ty, - .Enum => ty.enumTagType(&buffer), + .Enum => ty.intTagType(&buffer), .Float => { if (!is_rmw_xchg) return null; return dg.context.intType(@intCast(c_uint, ty.abiSize(target) * 8)); @@ -1281,6 +1303,7 @@ pub const FuncGen = struct { .atomic_load => try self.airAtomicLoad(inst), .memset => try self.airMemset(inst), .memcpy => try self.airMemcpy(inst), + .set_union_tag => try self.airSetUnionTag(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -1381,7 +1404,7 @@ pub const FuncGen = struct { const int_ty = switch (operand_ty.zigTypeTag()) { .Enum => blk: { var buffer: Type.Payload.Bits = undefined; - const int_ty = operand_ty.enumTagType(&buffer); + const int_ty = operand_ty.intTagType(&buffer); break :blk int_ty; }, .Int, .Bool, .Pointer, .ErrorSet => operand_ty, @@ -1660,8 +1683,9 @@ pub const FuncGen = struct { const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data; const struct_ptr = try self.resolveInst(struct_field.struct_operand); + const struct_ptr_ty = self.air.typeOf(struct_field.struct_operand); const field_index = @intCast(c_uint, struct_field.field_index); - return self.builder.buildStructGEP(struct_ptr, field_index, ""); + return self.fieldPtr(inst, struct_ptr, struct_ptr_ty, field_index); } fn airStructFieldPtrIndex(self: *FuncGen, inst: Air.Inst.Index, field_index: c_uint) !?*const llvm.Value { @@ -1670,7 +1694,8 @@ pub const FuncGen = struct { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const struct_ptr = try self.resolveInst(ty_op.operand); - return self.builder.buildStructGEP(struct_ptr, field_index, ""); + const struct_ptr_ty = self.air.typeOf(ty_op.operand); + return self.fieldPtr(inst, struct_ptr, struct_ptr_ty, field_index); } fn airStructFieldVal(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { @@ -2521,6 +2546,49 @@ pub const FuncGen = struct { return null; } + fn airSetUnionTag(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const union_ptr = try self.resolveInst(bin_op.lhs); + // TODO handle when onlyTagHasCodegenBits() == true + const new_tag = try self.resolveInst(bin_op.rhs); + const tag_field_ptr = self.builder.buildStructGEP(union_ptr, 1, ""); + + _ = self.builder.buildStore(new_tag, tag_field_ptr); + return null; + } + + fn fieldPtr( + self: *FuncGen, + inst: Air.Inst.Index, + struct_ptr: *const llvm.Value, + struct_ptr_ty: Type, + field_index: c_uint, + ) !?*const llvm.Value { + const struct_ty = struct_ptr_ty.childType(); + switch (struct_ty.zigTypeTag()) { + .Struct => return self.builder.buildStructGEP(struct_ptr, field_index, ""), + .Union => return self.unionFieldPtr(inst, struct_ptr, struct_ty, field_index), + else => unreachable, + } + } + + fn unionFieldPtr( + self: *FuncGen, + inst: Air.Inst.Index, + union_ptr: *const llvm.Value, + union_ty: Type, + field_index: c_uint, + ) !?*const llvm.Value { + const union_obj = union_ty.cast(Type.Payload.Union).?.data; + const field = &union_obj.fields.values()[field_index]; + const result_llvm_ty = try self.dg.llvmType(self.air.typeOfIndex(inst)); + if (!field.ty.hasCodeGenBits()) { + return null; + } + const union_field_ptr = self.builder.buildStructGEP(union_ptr, 0, ""); + return self.builder.buildBitCast(union_field_ptr, result_llvm_ty, ""); + } + fn getIntrinsic(self: *FuncGen, name: []const u8) *const llvm.Value { const id = llvm.lookupIntrinsicID(name.ptr, name.len); assert(id != 0); diff --git a/src/print_air.zig b/src/print_air.zig index fa384baae0..e735d03bd3 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -130,6 +130,7 @@ const Writer = struct { .ptr_ptr_elem_val, .shl, .shr, + .set_union_tag, => try w.writeBinOp(s, inst), .is_null, diff --git a/src/type.zig b/src/type.zig index 48c65c1008..bb798959f4 100644 --- a/src/type.zig +++ b/src/type.zig @@ -124,6 +124,7 @@ pub const Type = extern union { .enum_full, .enum_nonexhaustive, .enum_simple, + .enum_numbered, .atomic_order, .atomic_rmw_op, .calling_convention, @@ -874,6 +875,7 @@ pub const Type = extern union { .@"struct" => return self.copyPayloadShallow(allocator, Payload.Struct), .@"union", .union_tagged => return self.copyPayloadShallow(allocator, Payload.Union), .enum_simple => return self.copyPayloadShallow(allocator, Payload.EnumSimple), + .enum_numbered => return self.copyPayloadShallow(allocator, Payload.EnumNumbered), .enum_full, .enum_nonexhaustive => return self.copyPayloadShallow(allocator, Payload.EnumFull), .@"opaque" => return self.copyPayloadShallow(allocator, Payload.Opaque), } @@ -958,6 +960,10 @@ pub const Type = extern union { const enum_simple = ty.castTag(.enum_simple).?.data; return enum_simple.owner_decl.renderFullyQualifiedName(writer); }, + .enum_numbered => { + const enum_numbered = ty.castTag(.enum_numbered).?.data; + return enum_numbered.owner_decl.renderFullyQualifiedName(writer); + }, .@"opaque" => { // TODO use declaration name return writer.writeAll("opaque {}"); @@ -1268,6 +1274,7 @@ pub const Type = extern union { .@"union", .union_tagged, .enum_simple, + .enum_numbered, .enum_full, .enum_nonexhaustive, => false, // TODO some of these should be `true` depending on their child types @@ -1421,7 +1428,7 @@ pub const Type = extern union { const enum_simple = self.castTag(.enum_simple).?.data; return enum_simple.fields.count() >= 2; }, - .enum_nonexhaustive => { + .enum_numbered, .enum_nonexhaustive => { var buffer: Payload.Bits = undefined; const int_tag_ty = self.intTagType(&buffer); return int_tag_ty.hasCodeGenBits(); @@ -1682,7 +1689,7 @@ pub const Type = extern union { assert(biggest != 0); return biggest; }, - .enum_full, .enum_nonexhaustive, .enum_simple => { + .enum_full, .enum_nonexhaustive, .enum_simple, .enum_numbered => { var buffer: Payload.Bits = undefined; const int_tag_ty = self.intTagType(&buffer); return int_tag_ty.abiAlignment(target); @@ -1781,7 +1788,7 @@ pub const Type = extern union { } return size; }, - .enum_simple, .enum_full, .enum_nonexhaustive => { + .enum_simple, .enum_full, .enum_nonexhaustive, .enum_numbered => { var buffer: Payload.Bits = undefined; const int_tag_ty = self.intTagType(&buffer); return int_tag_ty.abiSize(target); @@ -1948,7 +1955,7 @@ pub const Type = extern union { .@"struct" => { @panic("TODO bitSize struct"); }, - .enum_simple, .enum_full, .enum_nonexhaustive => { + .enum_simple, .enum_full, .enum_nonexhaustive, .enum_numbered => { var buffer: Payload.Bits = undefined; const int_tag_ty = self.intTagType(&buffer); return int_tag_ty.bitSize(target); @@ -2094,23 +2101,6 @@ pub const Type = extern union { }; } - /// Asserts the type is an enum. - pub fn intTagType(self: Type, buffer: *Payload.Bits) Type { - switch (self.tag()) { - .enum_full, .enum_nonexhaustive => return self.cast(Payload.EnumFull).?.data.tag_ty, - .enum_simple => { - const enum_simple = self.castTag(.enum_simple).?.data; - const bits = std.math.log2_int_ceil(usize, enum_simple.fields.count()); - buffer.* = .{ - .base = .{ .tag = .int_unsigned }, - .data = bits, - }; - return Type.initPayload(&buffer.base); - }, - else => unreachable, - } - } - pub fn isSinglePointer(self: Type) bool { return switch (self.tag()) { .single_const_pointer, @@ -2363,48 +2353,6 @@ pub const Type = extern union { } } - /// Returns if type can be used for a runtime variable - pub fn isValidVarType(self: Type, is_extern: bool) bool { - var ty = self; - while (true) switch (ty.zigTypeTag()) { - .Bool, - .Int, - .Float, - .ErrorSet, - .Enum, - .Frame, - .AnyFrame, - => return true, - - .Opaque => return is_extern, - .BoundFn, - .ComptimeFloat, - .ComptimeInt, - .EnumLiteral, - .NoReturn, - .Type, - .Void, - .Undefined, - .Null, - => return false, - - .Optional => { - var buf: Payload.ElemType = undefined; - return ty.optionalChild(&buf).isValidVarType(is_extern); - }, - .Pointer, .Array, .Vector => ty = ty.elemType(), - .ErrorUnion => ty = ty.errorUnionPayload(), - - .Fn => @panic("TODO fn isValidVarType"), - .Struct => { - // TODO this is not always correct; introduce lazy value mechanism - // and here we need to force a resolve of "type requires comptime". - return true; - }, - .Union => @panic("TODO union isValidVarType"), - }; - } - pub fn childType(ty: Type) Type { return switch (ty.tag()) { .vector => ty.castTag(.vector).?.data.elem_type, @@ -2530,6 +2478,15 @@ pub const Type = extern union { } } + /// Returns the tag type of a union, if the type is a union and it has a tag type. + /// Otherwise, returns `null`. + pub fn unionTagType(ty: Type) ?Type { + return switch (ty.tag()) { + .union_tagged => ty.castTag(.union_tagged).?.data.tag_ty, + else => null, + }; + } + /// Asserts that the type is an error union. pub fn errorUnionPayload(self: Type) Type { return switch (self.tag()) { @@ -3000,6 +2957,7 @@ pub const Type = extern union { } }, .enum_nonexhaustive => ty = ty.castTag(.enum_nonexhaustive).?.data.tag_ty, + .enum_numbered => ty = ty.castTag(.enum_numbered).?.data.tag_ty, .@"union" => { return null; // TODO }, @@ -3114,31 +3072,21 @@ pub const Type = extern union { } } - /// Returns the integer tag type of the enum. - pub fn enumTagType(ty: Type, buffer: *Payload.Bits) Type { - switch (ty.tag()) { - .enum_full, .enum_nonexhaustive => { - const enum_full = ty.cast(Payload.EnumFull).?.data; - return enum_full.tag_ty; - }, + /// Asserts the type is an enum or a union. + /// TODO support unions + pub fn intTagType(self: Type, buffer: *Payload.Bits) Type { + switch (self.tag()) { + .enum_full, .enum_nonexhaustive => return self.cast(Payload.EnumFull).?.data.tag_ty, + .enum_numbered => return self.castTag(.enum_numbered).?.data.tag_ty, .enum_simple => { - const enum_simple = ty.castTag(.enum_simple).?.data; + const enum_simple = self.castTag(.enum_simple).?.data; + const bits = std.math.log2_int_ceil(usize, enum_simple.fields.count()); buffer.* = .{ .base = .{ .tag = .int_unsigned }, - .data = std.math.log2_int_ceil(usize, enum_simple.fields.count()), + .data = bits, }; return Type.initPayload(&buffer.base); }, - .atomic_order, - .atomic_rmw_op, - .calling_convention, - .float_mode, - .reduce_op, - .call_options, - .export_options, - .extern_options, - => @panic("TODO resolve std.builtin types"), - else => unreachable, } } @@ -3156,10 +3104,8 @@ pub const Type = extern union { const enum_full = ty.cast(Payload.EnumFull).?.data; return enum_full.fields.count(); }, - .enum_simple => { - const enum_simple = ty.castTag(.enum_simple).?.data; - return enum_simple.fields.count(); - }, + .enum_simple => return ty.castTag(.enum_simple).?.data.fields.count(), + .enum_numbered => return ty.castTag(.enum_numbered).?.data.fields.count(), .atomic_order, .atomic_rmw_op, .calling_convention, @@ -3185,6 +3131,10 @@ pub const Type = extern union { const enum_simple = ty.castTag(.enum_simple).?.data; return enum_simple.fields.keys()[field_index]; }, + .enum_numbered => { + const enum_numbered = ty.castTag(.enum_numbered).?.data; + return enum_numbered.fields.keys()[field_index]; + }, .atomic_order, .atomic_rmw_op, .calling_convention, @@ -3209,6 +3159,10 @@ pub const Type = extern union { const enum_simple = ty.castTag(.enum_simple).?.data; return enum_simple.fields.getIndex(field_name); }, + .enum_numbered => { + const enum_numbered = ty.castTag(.enum_numbered).?.data; + return enum_numbered.fields.getIndex(field_name); + }, .atomic_order, .atomic_rmw_op, .calling_convention, @@ -3252,6 +3206,15 @@ pub const Type = extern union { return enum_full.values.getIndexContext(enum_tag, .{ .ty = tag_ty }); } }, + .enum_numbered => { + const enum_obj = ty.castTag(.enum_numbered).?.data; + const tag_ty = enum_obj.tag_ty; + if (enum_obj.values.count() == 0) { + return S.fieldWithRange(tag_ty, enum_tag, enum_obj.fields.count()); + } else { + return enum_obj.values.getIndexContext(enum_tag, .{ .ty = tag_ty }); + } + }, .enum_simple => { const enum_simple = ty.castTag(.enum_simple).?.data; const fields_len = enum_simple.fields.count(); @@ -3303,6 +3266,7 @@ pub const Type = extern union { const enum_full = ty.cast(Payload.EnumFull).?.data; return enum_full.srcLoc(); }, + .enum_numbered => return ty.castTag(.enum_numbered).?.data.srcLoc(), .enum_simple => { const enum_simple = ty.castTag(.enum_simple).?.data; return enum_simple.srcLoc(); @@ -3340,6 +3304,7 @@ pub const Type = extern union { const enum_full = ty.cast(Payload.EnumFull).?.data; return enum_full.owner_decl; }, + .enum_numbered => return ty.castTag(.enum_numbered).?.data.owner_decl, .enum_simple => { const enum_simple = ty.castTag(.enum_simple).?.data; return enum_simple.owner_decl; @@ -3397,6 +3362,15 @@ pub const Type = extern union { return enum_full.values.containsContext(int, .{ .ty = tag_ty }); } }, + .enum_numbered => { + const enum_obj = ty.castTag(.enum_numbered).?.data; + const tag_ty = enum_obj.tag_ty; + if (enum_obj.values.count() == 0) { + return S.intInRange(tag_ty, int, enum_obj.fields.count()); + } else { + return enum_obj.values.containsContext(int, .{ .ty = tag_ty }); + } + }, .enum_simple => { const enum_simple = ty.castTag(.enum_simple).?.data; const fields_len = enum_simple.fields.count(); @@ -3534,6 +3508,7 @@ pub const Type = extern union { @"union", union_tagged, enum_simple, + enum_numbered, enum_full, enum_nonexhaustive, @@ -3642,6 +3617,7 @@ pub const Type = extern union { .@"union", .union_tagged => Payload.Union, .enum_full, .enum_nonexhaustive => Payload.EnumFull, .enum_simple => Payload.EnumSimple, + .enum_numbered => Payload.EnumNumbered, .empty_struct => Payload.ContainerScope, }; } @@ -3818,6 +3794,11 @@ pub const Type = extern union { base: Payload = .{ .tag = .enum_simple }, data: *Module.EnumSimple, }; + + pub const EnumNumbered = struct { + base: Payload = .{ .tag = .enum_numbered }, + data: *Module.EnumNumbered, + }; }; pub fn ptr(arena: *Allocator, d: Payload.Pointer.Data) !Type { @@ -3850,6 +3831,23 @@ pub const Type = extern union { }; return Type.initPayload(&type_payload.base); } + + pub fn smallestUnsignedInt(arena: *Allocator, max: u64) !Type { + const bits = bits: { + if (max == 0) break :bits 0; + const base = std.math.log2(max); + const upper = (@as(u64, 1) << base) - 1; + break :bits base + @boolToInt(upper < max); + }; + return switch (bits) { + 1 => initTag(.u1), + 8 => initTag(.u8), + 16 => initTag(.u16), + 32 => initTag(.u32), + 64 => initTag(.u64), + else => return Tag.int_unsigned.create(arena, bits), + }; + } }; pub const CType = enum { diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 14b5e374dd..6b8705e044 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -2,3 +2,15 @@ const std = @import("std"); const expect = std.testing.expect; const expectEqual = std.testing.expectEqual; const Tag = std.meta.Tag; + +const Foo = union { + float: f64, + int: i32, +}; + +test "basic unions" { + var foo = Foo{ .int = 1 }; + try expect(foo.int == 1); + foo = Foo{ .float = 12.34 }; + try expect(foo.float == 12.34); +} diff --git a/test/behavior/union_stage1.zig b/test/behavior/union_stage1.zig index 086bd981cd..5741858d51 100644 --- a/test/behavior/union_stage1.zig +++ b/test/behavior/union_stage1.zig @@ -39,13 +39,6 @@ const Foo = union { int: i32, }; -test "basic unions" { - var foo = Foo{ .int = 1 }; - try expect(foo.int == 1); - foo = Foo{ .float = 12.34 }; - try expect(foo.float == 12.34); -} - test "comptime union field access" { comptime { var foo = Foo{ .int = 0 };