From f276fd0f3728bf1a43b185e3e2d33d593309cb2f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 14 Nov 2017 23:53:53 -0500 Subject: [PATCH] basic union support See #144 --- src/all_types.hpp | 43 +++++++- src/analyze.cpp | 223 ++++++++++++++++++++++++++++++++++++++-- src/analyze.hpp | 1 + src/codegen.cpp | 101 +++++++++++++++++- src/ir.cpp | 171 ++++++++++++++++++++++++++++-- src/ir_print.cpp | 22 ++++ src/zig_llvm.cpp | 4 + src/zig_llvm.hpp | 1 + test/cases/union.zig | 13 +++ test/compile_errors.zig | 6 +- 10 files changed, 559 insertions(+), 26 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 0b4efd8415..ca6c214af8 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -73,6 +73,7 @@ enum ConstParentId { ConstParentIdNone, ConstParentIdStruct, ConstParentIdArray, + ConstParentIdUnion, }; struct ConstParent { @@ -87,6 +88,9 @@ struct ConstParent { ConstExprValue *struct_val; size_t field_index; } p_struct; + struct { + ConstExprValue *union_val; + } p_union; } data; }; @@ -100,6 +104,11 @@ struct ConstStructValue { ConstParent parent; }; +struct ConstUnionValue { + ConstExprValue *value; + ConstParent parent; +}; + enum ConstArraySpecial { ConstArraySpecialNone, ConstArraySpecialUndef, @@ -238,6 +247,7 @@ struct ConstExprValue { ErrorTableEntry *x_pure_err; ConstEnumValue x_enum; ConstStructValue x_struct; + ConstUnionValue x_union; ConstArrayValue x_array; ConstPtrValue x_ptr; ImportTableEntry *x_import; @@ -336,6 +346,12 @@ struct TypeEnumField { uint32_t gen_index; }; +struct TypeUnionField { + Buf *name; + TypeTableEntry *type_entry; + uint32_t gen_index; +}; + enum NodeType { NodeTypeRoot, NodeTypeFnProto, @@ -1026,9 +1042,9 @@ struct TypeTableEntryUnion { ContainerLayout layout; uint32_t src_field_count; uint32_t gen_field_count; - TypeStructField *fields; - uint64_t size_bytes; + TypeUnionField *fields; bool is_invalid; // true if any fields are invalid + ScopeDecls *decls_scope; // set this flag temporarily to detect infinite loops @@ -1039,6 +1055,10 @@ struct TypeTableEntryUnion { bool zero_bits_loop_flag; bool zero_bits_known; + uint32_t abi_alignment; // also figured out with zero_bits pass + + uint32_t size_bytes; + TypeTableEntry *most_aligned_union_member; }; struct FnGenParamInfo { @@ -1796,6 +1816,7 @@ enum IrInstructionId { IrInstructionIdFieldPtr, IrInstructionIdStructFieldPtr, IrInstructionIdEnumFieldPtr, + IrInstructionIdUnionFieldPtr, IrInstructionIdElemPtr, IrInstructionIdVarPtr, IrInstructionIdCall, @@ -1805,6 +1826,7 @@ enum IrInstructionId { IrInstructionIdContainerInitList, IrInstructionIdContainerInitFields, IrInstructionIdStructInit, + IrInstructionIdUnionInit, IrInstructionIdUnreachable, IrInstructionIdTypeOf, IrInstructionIdToPtrType, @@ -2060,6 +2082,14 @@ struct IrInstructionEnumFieldPtr { bool is_const; }; +struct IrInstructionUnionFieldPtr { + IrInstruction base; + + IrInstruction *union_ptr; + TypeUnionField *field; + bool is_const; +}; + struct IrInstructionElemPtr { IrInstruction base; @@ -2150,6 +2180,15 @@ struct IrInstructionStructInit { LLVMValueRef tmp_ptr; }; +struct IrInstructionUnionInit { + IrInstruction base; + + TypeTableEntry *union_type; + TypeUnionField *field; + IrInstruction *init_value; + LLVMValueRef tmp_ptr; +}; + struct IrInstructionUnreachable { IrInstruction base; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 343b1ecdb0..2f7eecaff4 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -992,18 +992,22 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi TypeTableEntryId type_id = container_to_type(kind); TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, scope); + unsigned dwarf_kind; switch (kind) { case ContainerKindStruct: entry->data.structure.decl_node = decl_node; entry->data.structure.layout = layout; + dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindEnum: entry->data.enumeration.decl_node = decl_node; entry->data.enumeration.layout = layout; + dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindUnion: entry->data.unionation.decl_node = decl_node; entry->data.unionation.layout = layout; + dwarf_kind = ZigLLVMTag_DW_union_type(); break; } @@ -1012,7 +1016,7 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi ImportTableEntry *import = get_scope_import(scope); entry->type_ref = LLVMStructCreateNamed(LLVMGetGlobalContext(), name); entry->di_type = ZigLLVMCreateReplaceableCompositeType(g->dbuilder, - ZigLLVMTag_DW_structure_type(), name, + dwarf_kind, name, ZigLLVMFileToScope(import->di_file), import->di_file, (unsigned)(line + 1)); buf_init_from_str(&entry->name, name); @@ -1285,7 +1289,7 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) { return; resolve_enum_zero_bits(g, enum_type); - if (enum_type->data.enumeration.is_invalid) + if (type_is_invalid(enum_type)) return; AstNode *decl_node = enum_type->data.enumeration.decl_node; @@ -1834,7 +1838,140 @@ static void resolve_struct_type(CodeGen *g, TypeTableEntry *struct_type) { } static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { - zig_panic("TODO"); + assert(union_type->id == TypeTableEntryIdUnion); + + if (union_type->data.unionation.complete) + return; + + resolve_union_zero_bits(g, union_type); + if (type_is_invalid(union_type)) + return; + + AstNode *decl_node = union_type->data.unionation.decl_node; + + if (union_type->data.unionation.embedded_in_current) { + if (!union_type->data.unionation.reported_infinite_err) { + union_type->data.unionation.reported_infinite_err = true; + add_node_error(g, decl_node, buf_sprintf("union '%s' contains itself", buf_ptr(&union_type->name))); + } + return; + } + + assert(!union_type->data.unionation.zero_bits_loop_flag); + assert(decl_node->type == NodeTypeContainerDecl); + assert(union_type->di_type); + + uint32_t field_count = union_type->data.unionation.src_field_count; + + assert(union_type->data.unionation.fields); + + uint32_t gen_field_count = union_type->data.unionation.gen_field_count; + ZigLLVMDIType **union_inner_di_types = allocate(gen_field_count); + + TypeTableEntry *most_aligned_union_member = nullptr; + uint64_t size_of_most_aligned_member_in_bits = 0; + uint64_t biggest_align_in_bits = 0; + uint64_t biggest_size_in_bits = 0; + + Scope *scope = &union_type->data.unionation.decls_scope->base; + ImportTableEntry *import = get_scope_import(scope); + + // set temporary flag + union_type->data.unionation.embedded_in_current = true; + + for (uint32_t i = 0; i < field_count; i += 1) { + AstNode *field_node = decl_node->data.container_decl.fields.at(i); + TypeUnionField *type_union_field = &union_type->data.unionation.fields[i]; + TypeTableEntry *field_type = type_union_field->type_entry; + + ensure_complete_type(g, field_type); + if (type_is_invalid(field_type)) { + union_type->data.unionation.is_invalid = true; + continue; + } + + if (!type_has_bits(field_type)) + continue; + + uint64_t store_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref); + uint64_t abi_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref); + + assert(store_size_in_bits > 0); + assert(abi_align_in_bits > 0); + + union_inner_di_types[type_union_field->gen_index] = ZigLLVMCreateDebugMemberType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), buf_ptr(type_union_field->name), + import->di_file, (unsigned)(field_node->line + 1), + store_size_in_bits, + abi_align_in_bits, + 0, + 0, field_type->di_type); + + biggest_size_in_bits = max(biggest_size_in_bits, store_size_in_bits); + + if (!most_aligned_union_member || abi_align_in_bits > biggest_align_in_bits) { + most_aligned_union_member = field_type; + biggest_align_in_bits = abi_align_in_bits; + size_of_most_aligned_member_in_bits = store_size_in_bits; + } + } + + // unset temporary flag + union_type->data.unionation.embedded_in_current = false; + union_type->data.unionation.complete = true; + union_type->data.unionation.size_bytes = biggest_size_in_bits / 8; + union_type->data.unionation.most_aligned_union_member = most_aligned_union_member; + + if (union_type->data.unionation.is_invalid) + return; + + if (union_type->zero_bits) { + union_type->type_ref = LLVMVoidType(); + + uint64_t debug_size_in_bits = 0; + uint64_t debug_align_in_bits = 0; + ZigLLVMDIType **di_root_members = nullptr; + size_t debug_member_count = 0; + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), + buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + debug_size_in_bits, + debug_align_in_bits, + 0, di_root_members, (int)debug_member_count, 0, ""); + + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); + union_type->di_type = replacement_di_type; + return; + } + + assert(most_aligned_union_member != nullptr); + + // create llvm type for union + uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits; + if (padding_in_bits > 0) { + TypeTableEntry *u8_type = get_int_type(g, false, 8); + TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8); + LLVMTypeRef union_element_types[] = { + most_aligned_union_member->type_ref, + padding_array->type_ref, + }; + LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false); + } else { + LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false); + } + + assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits); + assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits); + + // create debug type for union + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types, + gen_field_count, 0, ""); + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); + union_type->di_type = replacement_di_type; } static void resolve_enum_zero_bits(CodeGen *g, TypeTableEntry *enum_type) { @@ -1873,7 +2010,7 @@ static void resolve_enum_zero_bits(CodeGen *g, TypeTableEntry *enum_type) { type_enum_field->value = i; type_ensure_zero_bits_known(g, field_type); - if (field_type->id == TypeTableEntryIdInvalid) { + if (type_is_invalid(field_type)) { enum_type->data.enumeration.is_invalid = true; continue; } @@ -1980,7 +2117,66 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) { } static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) { - zig_panic("TODO resolve_union_zero_bits"); + assert(union_type->id == TypeTableEntryIdUnion); + + if (union_type->data.unionation.zero_bits_known) + return; + + if (union_type->data.unionation.zero_bits_loop_flag) { + union_type->data.unionation.zero_bits_known = true; + return; + } + + union_type->data.unionation.zero_bits_loop_flag = true; + + AstNode *decl_node = union_type->data.unionation.decl_node; + assert(decl_node->type == NodeTypeContainerDecl); + assert(union_type->di_type); + + assert(!union_type->data.unionation.fields); + uint32_t field_count = (uint32_t)decl_node->data.container_decl.fields.length; + union_type->data.unionation.src_field_count = field_count; + union_type->data.unionation.fields = allocate(field_count); + + uint32_t biggest_align_bytes = 0; + + Scope *scope = &union_type->data.unionation.decls_scope->base; + + uint32_t gen_field_index = 0; + for (uint32_t i = 0; i < field_count; i += 1) { + AstNode *field_node = decl_node->data.container_decl.fields.at(i); + TypeUnionField *type_union_field = &union_type->data.unionation.fields[i]; + type_union_field->name = field_node->data.struct_field.name; + TypeTableEntry *field_type = analyze_type_expr(g, scope, field_node->data.struct_field.type); + type_union_field->type_entry = field_type; + + type_ensure_zero_bits_known(g, field_type); + if (type_is_invalid(field_type)) { + union_type->data.unionation.is_invalid = true; + continue; + } + + if (!type_has_bits(field_type)) + continue; + + type_union_field->gen_index = gen_field_index; + gen_field_index += 1; + + uint32_t field_align_bytes = get_abi_alignment(g, field_type); + if (field_align_bytes > biggest_align_bytes) { + biggest_align_bytes = field_align_bytes; + } + } + + union_type->data.unionation.zero_bits_loop_flag = false; + union_type->data.unionation.gen_field_count = gen_field_index; + union_type->zero_bits = (gen_field_index == 0); + union_type->data.unionation.zero_bits_known = true; + + // also compute abi_alignment + if (!union_type->zero_bits) { + union_type->data.unionation.abi_alignment = biggest_align_bytes; + } } static void get_fully_qualified_decl_name_internal(Buf *buf, Scope *scope, uint8_t sep) { @@ -2851,6 +3047,18 @@ TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name) { return nullptr; } +TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name) { + assert(type_entry->id == TypeTableEntryIdUnion); + assert(type_entry->data.unionation.complete); + for (uint32_t i = 0; i < type_entry->data.unionation.src_field_count; i += 1) { + TypeUnionField *field = &type_entry->data.unionation.fields[i]; + if (buf_eql_buf(field->name, name)) { + return field; + } + } + return nullptr; +} + static bool is_container(TypeTableEntry *type_entry) { switch (type_entry->id) { case TypeTableEntryIdInvalid: @@ -4703,6 +4911,8 @@ ConstParent *get_const_val_parent(CodeGen *g, ConstExprValue *value) { return &value->data.x_array.s_none.parent; } else if (type_entry->id == TypeTableEntryIdStruct) { return &value->data.x_struct.parent; + } else if (type_entry->id == TypeTableEntryIdUnion) { + return &value->data.x_union.parent; } return nullptr; } @@ -4914,7 +5124,8 @@ uint32_t get_abi_alignment(CodeGen *g, TypeTableEntry *type_entry) { assert(type_entry->data.enumeration.abi_alignment != 0); return type_entry->data.enumeration.abi_alignment; } else if (type_entry->id == TypeTableEntryIdUnion) { - zig_panic("TODO"); + assert(type_entry->data.unionation.abi_alignment != 0); + return type_entry->data.unionation.abi_alignment; } else if (type_entry->id == TypeTableEntryIdOpaque) { return 1; } else { diff --git a/src/analyze.hpp b/src/analyze.hpp index 4f56640592..b2464af9a0 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -63,6 +63,7 @@ void resolve_container_type(CodeGen *g, TypeTableEntry *type_entry); TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name); ScopeDecls *get_container_scope(TypeTableEntry *type_entry); TypeEnumField *find_enum_type_field(TypeTableEntry *enum_type, Buf *name); +TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name); bool is_container_ref(TypeTableEntry *type_entry); void scan_decls(CodeGen *g, ScopeDecls *decls_scope, AstNode *node); void scan_import(CodeGen *g, ImportTableEntry *import); diff --git a/src/codegen.cpp b/src/codegen.cpp index 17c0ffc653..fc949f2ecd 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2393,6 +2393,27 @@ static LLVMValueRef ir_render_enum_field_ptr(CodeGen *g, IrExecutable *executabl return bitcasted_union_field_ptr; } +static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executable, + IrInstructionUnionFieldPtr *instruction) +{ + TypeTableEntry *union_ptr_type = instruction->union_ptr->value.type; + assert(union_ptr_type->id == TypeTableEntryIdPointer); + TypeTableEntry *union_type = union_ptr_type->data.pointer.child_type; + assert(union_type->id == TypeTableEntryIdUnion); + + TypeUnionField *field = instruction->field; + + if (!type_has_bits(field->type_entry)) + return nullptr; + + LLVMValueRef union_ptr = ir_llvm_value(g, instruction->union_ptr); + LLVMTypeRef field_type_ref = LLVMPointerType(field->type_entry->type_ref, 0); + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, ""); + LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); + + return bitcasted_union_field_ptr; +} + static size_t find_asm_index(CodeGen *g, AstNode *node, AsmToken *tok) { const char *ptr = buf_ptr(node->data.asm_expr.asm_template) + tok->start + 2; size_t len = tok->end - tok->start - 2; @@ -3365,6 +3386,25 @@ static LLVMValueRef ir_render_struct_init(CodeGen *g, IrExecutable *executable, return instruction->tmp_ptr; } +static LLVMValueRef ir_render_union_init(CodeGen *g, IrExecutable *executable, IrInstructionUnionInit *instruction) { + TypeUnionField *type_union_field = instruction->field; + + assert(type_has_bits(type_union_field->type_entry)); + + LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); + LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + + uint32_t field_align_bytes = get_abi_alignment(g, type_union_field->type_entry); + + TypeTableEntry *ptr_type = get_pointer_to_type_extra(g, type_union_field->type_entry, + false, false, field_align_bytes, + 0, 0); + + gen_assign_raw(g, field_ptr, ptr_type, value); + + return instruction->tmp_ptr; +} + static LLVMValueRef ir_render_container_init_list(CodeGen *g, IrExecutable *executable, IrInstructionContainerInitList *instruction) { @@ -3486,6 +3526,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_struct_field_ptr(g, executable, (IrInstructionStructFieldPtr *)instruction); case IrInstructionIdEnumFieldPtr: return ir_render_enum_field_ptr(g, executable, (IrInstructionEnumFieldPtr *)instruction); + case IrInstructionIdUnionFieldPtr: + return ir_render_union_field_ptr(g, executable, (IrInstructionUnionFieldPtr *)instruction); case IrInstructionIdAsm: return ir_render_asm(g, executable, (IrInstructionAsm *)instruction); case IrInstructionIdTestNonNull: @@ -3544,6 +3586,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_init_enum(g, executable, (IrInstructionInitEnum *)instruction); case IrInstructionIdStructInit: return ir_render_struct_init(g, executable, (IrInstructionStructInit *)instruction); + case IrInstructionIdUnionInit: + return ir_render_union_init(g, executable, (IrInstructionUnionInit *)instruction); case IrInstructionIdPtrCast: return ir_render_ptr_cast(g, executable, (IrInstructionPtrCast *)instruction); case IrInstructionIdBitCast: @@ -3595,6 +3639,7 @@ static void ir_render(CodeGen *g, FnTableEntry *fn_entry) { static LLVMValueRef gen_const_ptr_struct_recursive(CodeGen *g, ConstExprValue *struct_const_val, size_t field_index); static LLVMValueRef gen_const_ptr_array_recursive(CodeGen *g, ConstExprValue *array_const_val, size_t index); +static LLVMValueRef gen_const_ptr_union_recursive(CodeGen *g, ConstExprValue *array_const_val); static LLVMValueRef gen_parent_ptr(CodeGen *g, ConstExprValue *val, ConstParent *parent) { switch (parent->id) { @@ -3608,6 +3653,8 @@ static LLVMValueRef gen_parent_ptr(CodeGen *g, ConstExprValue *val, ConstParent case ConstParentIdArray: return gen_const_ptr_array_recursive(g, parent->data.p_array.array_val, parent->data.p_array.elem_index); + case ConstParentIdUnion: + return gen_const_ptr_union_recursive(g, parent->data.p_union.union_val); } zig_unreachable(); } @@ -3637,6 +3684,18 @@ static LLVMValueRef gen_const_ptr_struct_recursive(CodeGen *g, ConstExprValue *s return LLVMConstInBoundsGEP(base_ptr, indices, 2); } +static LLVMValueRef gen_const_ptr_union_recursive(CodeGen *g, ConstExprValue *union_const_val) { + ConstParent *parent = &union_const_val->data.x_union.parent; + LLVMValueRef base_ptr = gen_parent_ptr(g, union_const_val, parent); + + TypeTableEntry *u32 = g->builtin_types.entry_u32; + LLVMValueRef indices[] = { + LLVMConstNull(u32->type_ref), + LLVMConstInt(u32->type_ref, 0, false), + }; + return LLVMConstInBoundsGEP(base_ptr, indices, 2); +} + static LLVMValueRef pack_const_int(CodeGen *g, LLVMTypeRef big_int_type_ref, ConstExprValue *const_val) { switch (const_val->special) { case ConstValSpecialRuntime: @@ -3872,10 +3931,6 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { return LLVMConstNamedStruct(type_entry->type_ref, fields, type_entry->data.structure.gen_field_count); } } - case TypeTableEntryIdUnion: - { - zig_panic("TODO"); - } case TypeTableEntryIdArray: { uint64_t len = type_entry->data.array.len; @@ -3898,6 +3953,41 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { return LLVMConstArray(element_type_ref, values, (unsigned)len); } } + case TypeTableEntryIdUnion: + { + LLVMTypeRef union_type_ref = type_entry->type_ref; + ConstExprValue *payload_value = const_val->data.x_union.value; + assert(payload_value != nullptr); + + if (!type_has_bits(payload_value->type)) { + return LLVMGetUndef(union_type_ref); + } + + uint64_t field_type_bytes = LLVMStoreSizeOfType(g->target_data_ref, payload_value->type->type_ref); + uint64_t pad_bytes = type_entry->data.unionation.size_bytes - field_type_bytes; + + LLVMValueRef correctly_typed_value = gen_const_val(g, payload_value); + + bool make_unnamed_struct = is_llvm_value_unnamed_type(payload_value->type, correctly_typed_value) || + payload_value->type != type_entry->data.unionation.most_aligned_union_member; + + unsigned field_count; + LLVMValueRef fields[2]; + fields[0] = correctly_typed_value; + if (pad_bytes == 0) { + field_count = 1; + } else { + fields[0] = correctly_typed_value; + fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); + field_count = 2; + } + + if (make_unnamed_struct) { + return LLVMConstStruct(fields, field_count, false); + } else { + return LLVMConstNamedStruct(type_entry->type_ref, fields, field_count); + } + } case TypeTableEntryIdEnum: { LLVMTypeRef tag_type_ref = type_entry->data.enumeration.tag_type->type_ref; @@ -4376,6 +4466,9 @@ static void do_code_gen(CodeGen *g) { } else if (instruction->id == IrInstructionIdStructInit) { IrInstructionStructInit *struct_init_instruction = (IrInstructionStructInit *)instruction; slot = &struct_init_instruction->tmp_ptr; + } else if (instruction->id == IrInstructionIdUnionInit) { + IrInstructionUnionInit *union_init_instruction = (IrInstructionUnionInit *)instruction; + slot = &union_init_instruction->tmp_ptr; } else if (instruction->id == IrInstructionIdCall) { IrInstructionCall *call_instruction = (IrInstructionCall *)instruction; slot = &call_instruction->tmp_ptr; diff --git a/src/ir.cpp b/src/ir.cpp index fdaced6806..6df6b4f828 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -227,6 +227,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEnumFieldPtr *) return IrInstructionIdEnumFieldPtr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionFieldPtr *) { + return IrInstructionIdUnionFieldPtr; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionElemPtr *) { return IrInstructionIdElemPtr; } @@ -351,6 +355,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionStructInit *) { return IrInstructionIdStructInit; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInit *) { + return IrInstructionIdUnionInit; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionMinValue *) { return IrInstructionIdMinValue; } @@ -922,6 +930,27 @@ static IrInstruction *ir_build_enum_field_ptr_from(IrBuilder *irb, IrInstruction return new_instruction; } +static IrInstruction *ir_build_union_field_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_ptr, TypeUnionField *field) +{ + IrInstructionUnionFieldPtr *instruction = ir_build_instruction(irb, scope, source_node); + instruction->union_ptr = union_ptr; + instruction->field = field; + + ir_ref_instruction(union_ptr, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstruction *ir_build_union_field_ptr_from(IrBuilder *irb, IrInstruction *old_instruction, + IrInstruction *union_ptr, TypeUnionField *type_union_field) +{ + IrInstruction *new_instruction = ir_build_union_field_ptr(irb, old_instruction->scope, + old_instruction->source_node, union_ptr, type_union_field); + ir_link_new_instruction(new_instruction, old_instruction); + return new_instruction; +} + static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *source_node, FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args, bool is_comptime, bool is_inline) @@ -1112,6 +1141,28 @@ static IrInstruction *ir_build_struct_init_from(IrBuilder *irb, IrInstruction *o return new_instruction; } +static IrInstruction *ir_build_union_init(IrBuilder *irb, Scope *scope, AstNode *source_node, + TypeTableEntry *union_type, TypeUnionField *field, IrInstruction *init_value) +{ + IrInstructionUnionInit *union_init_instruction = ir_build_instruction(irb, scope, source_node); + union_init_instruction->union_type = union_type; + union_init_instruction->field = field; + union_init_instruction->init_value = init_value; + + ir_ref_instruction(init_value, irb->current_basic_block); + + return &union_init_instruction->base; +} + +static IrInstruction *ir_build_union_init_from(IrBuilder *irb, IrInstruction *old_instruction, + TypeTableEntry *union_type, TypeUnionField *field, IrInstruction *init_value) +{ + IrInstruction *new_instruction = ir_build_union_init(irb, old_instruction->scope, + old_instruction->source_node, union_type, field, init_value); + ir_link_new_instruction(new_instruction, old_instruction); + return new_instruction; +} + static IrInstruction *ir_build_unreachable(IrBuilder *irb, Scope *scope, AstNode *source_node) { IrInstructionUnreachable *unreachable_instruction = ir_build_instruction(irb, scope, source_node); @@ -2422,6 +2473,13 @@ static IrInstruction *ir_instruction_enumfieldptr_get_dep(IrInstructionEnumField } } +static IrInstruction *ir_instruction_unionfieldptr_get_dep(IrInstructionUnionFieldPtr *instruction, size_t index) { + switch (index) { + case 0: return instruction->union_ptr; + default: return nullptr; + } +} + static IrInstruction *ir_instruction_elemptr_get_dep(IrInstructionElemPtr *instruction, size_t index) { switch (index) { case 0: return instruction->array_ptr; @@ -2485,6 +2543,13 @@ static IrInstruction *ir_instruction_structinit_get_dep(IrInstructionStructInit return nullptr; } +static IrInstruction *ir_instruction_unioninit_get_dep(IrInstructionUnionInit *instruction, size_t index) { + switch (index) { + case 0: return instruction->init_value; + default: return nullptr; + } +} + static IrInstruction *ir_instruction_unreachable_get_dep(IrInstructionUnreachable *instruction, size_t index) { return nullptr; } @@ -3099,6 +3164,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t return ir_instruction_structfieldptr_get_dep((IrInstructionStructFieldPtr *) instruction, index); case IrInstructionIdEnumFieldPtr: return ir_instruction_enumfieldptr_get_dep((IrInstructionEnumFieldPtr *) instruction, index); + case IrInstructionIdUnionFieldPtr: + return ir_instruction_unionfieldptr_get_dep((IrInstructionUnionFieldPtr *) instruction, index); case IrInstructionIdElemPtr: return ir_instruction_elemptr_get_dep((IrInstructionElemPtr *) instruction, index); case IrInstructionIdVarPtr: @@ -3117,6 +3184,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t return ir_instruction_containerinitfields_get_dep((IrInstructionContainerInitFields *) instruction, index); case IrInstructionIdStructInit: return ir_instruction_structinit_get_dep((IrInstructionStructInit *) instruction, index); + case IrInstructionIdUnionInit: + return ir_instruction_unioninit_get_dep((IrInstructionUnionInit *) instruction, index); case IrInstructionIdUnreachable: return ir_instruction_unreachable_get_dep((IrInstructionUnreachable *) instruction, index); case IrInstructionIdTypeOf: @@ -11417,8 +11486,20 @@ static TypeTableEntry *ir_analyze_container_member_access_inner(IrAnalyze *ira, return ir_analyze_ref(ira, &field_ptr_instruction->base, bound_fn_value, true, false); } } + const char *prefix_name; + if (is_slice(bare_struct_type)) { + prefix_name = ""; + } else if (bare_struct_type->id == TypeTableEntryIdStruct) { + prefix_name = "struct "; + } else if (bare_struct_type->id == TypeTableEntryIdEnum) { + prefix_name = "enum "; + } else if (bare_struct_type->id == TypeTableEntryIdUnion) { + prefix_name = "union "; + } else { + prefix_name = ""; + } ir_add_error_node(ira, field_ptr_instruction->base.source_node, - buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&bare_struct_type->name))); + buf_sprintf("no member named '%s' in %s'%s'", buf_ptr(field_name), prefix_name, buf_ptr(&bare_struct_type->name))); return ira->codegen->builtin_types.entry_invalid; } @@ -11428,14 +11509,13 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field { TypeTableEntry *bare_type = container_ref_type(container_type); ensure_complete_type(ira->codegen, bare_type); + if (type_is_invalid(bare_type)) + return ira->codegen->builtin_types.entry_invalid; assert(container_ptr->value.type->id == TypeTableEntryIdPointer); bool is_const = container_ptr->value.type->data.pointer.is_const; bool is_volatile = container_ptr->value.type->data.pointer.is_volatile; if (bare_type->id == TypeTableEntryIdStruct) { - if (bare_type->data.structure.is_invalid) - return ira->codegen->builtin_types.entry_invalid; - TypeStructField *field = find_struct_type_field(bare_type, field_name); if (field) { bool is_packed = (bare_type->data.structure.layout == ContainerLayoutPacked); @@ -11476,9 +11556,6 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field field_ptr_instruction, container_ptr, container_type); } } else if (bare_type->id == TypeTableEntryIdEnum) { - if (bare_type->data.enumeration.is_invalid) - return ira->codegen->builtin_types.entry_invalid; - TypeEnumField *field = find_enum_type_field(bare_type, field_name); if (field) { ir_build_enum_field_ptr_from(&ira->new_irb, &field_ptr_instruction->base, container_ptr, field); @@ -11489,7 +11566,15 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field field_ptr_instruction, container_ptr, container_type); } } else if (bare_type->id == TypeTableEntryIdUnion) { - zig_panic("TODO"); + TypeUnionField *field = find_union_type_field(bare_type, field_name); + if (field) { + ir_build_union_field_ptr_from(&ira->new_irb, &field_ptr_instruction->base, container_ptr, field); + return get_pointer_to_type_extra(ira->codegen, field->type_entry, is_const, is_volatile, + get_abi_alignment(ira->codegen, field->type_entry), 0, 0); + } else { + return ir_analyze_container_member_access_inner(ira, bare_type, field_name, + field_ptr_instruction, container_ptr, container_type); + } } else { zig_unreachable(); } @@ -13033,9 +13118,70 @@ static TypeTableEntry *ir_analyze_instruction_ref(IrAnalyze *ira, IrInstructionR return ir_analyze_ref(ira, &ref_instruction->base, value, ref_instruction->is_const, ref_instruction->is_volatile); } +static TypeTableEntry *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrInstruction *instruction, + TypeTableEntry *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields) +{ + assert(container_type->id == TypeTableEntryIdUnion); + + ensure_complete_type(ira->codegen, container_type); + + if (instr_field_count != 1) { + ir_add_error(ira, instruction, + buf_sprintf("union initialization expects exactly one field")); + return ira->codegen->builtin_types.entry_invalid; + } + + IrInstructionContainerInitFieldsField *field = &fields[0]; + IrInstruction *field_value = field->value->other; + if (type_is_invalid(field_value->value.type)) + return ira->codegen->builtin_types.entry_invalid; + + TypeUnionField *type_field = find_union_type_field(container_type, field->name); + if (!type_field) { + ir_add_error_node(ira, field->source_node, + buf_sprintf("no member named '%s' in union '%s'", + buf_ptr(field->name), buf_ptr(&container_type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + + if (type_is_invalid(type_field->type_entry)) + return ira->codegen->builtin_types.entry_invalid; + + IrInstruction *casted_field_value = ir_implicit_cast(ira, field_value, type_field->type_entry); + if (casted_field_value == ira->codegen->invalid_instruction) + return ira->codegen->builtin_types.entry_invalid; + + bool is_comptime = ir_should_inline(ira->new_irb.exec, instruction->scope); + if (is_comptime || casted_field_value->value.special != ConstValSpecialRuntime) { + ConstExprValue *field_val = ir_resolve_const(ira, casted_field_value, UndefOk); + if (!field_val) + return ira->codegen->builtin_types.entry_invalid; + + ConstExprValue *out_val = ir_build_const_from(ira, instruction); + out_val->data.x_union.value = field_val; + + ConstParent *parent = get_const_val_parent(ira->codegen, field_val); + if (parent != nullptr) { + parent->id = ConstParentIdUnion; + parent->data.p_union.union_val = out_val; + } + + return container_type; + } + + IrInstruction *new_instruction = ir_build_union_init_from(&ira->new_irb, instruction, + container_type, type_field, casted_field_value); + + ir_add_alloca(ira, new_instruction, container_type); + return container_type; +} + static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstruction *instruction, TypeTableEntry *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields) { + if (container_type->id == TypeTableEntryIdUnion) { + return ir_analyze_container_init_fields_union(ira, instruction, container_type, instr_field_count, fields); + } if (container_type->id != TypeTableEntryIdStruct || is_slice(container_type)) { ir_add_error(ira, instruction, buf_sprintf("type '%s' does not support struct initialization syntax", @@ -13043,8 +13189,7 @@ static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstru return ira->codegen->builtin_types.entry_invalid; } - if (!type_is_complete(container_type)) - resolve_container_type(ira->codegen, container_type); + ensure_complete_type(ira->codegen, container_type); size_t actual_field_count = container_type->data.structure.src_field_count; @@ -13070,7 +13215,7 @@ static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstru TypeStructField *type_field = find_struct_type_field(container_type, field->name); if (!type_field) { ir_add_error_node(ira, field->source_node, - buf_sprintf("no member named '%s' in '%s'", + buf_sprintf("no member named '%s' in struct '%s'", buf_ptr(field->name), buf_ptr(&container_type->name))); return ira->codegen->builtin_types.entry_invalid; } @@ -15657,8 +15802,10 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi case IrInstructionIdIntToErr: case IrInstructionIdErrToInt: case IrInstructionIdStructInit: + case IrInstructionIdUnionInit: case IrInstructionIdStructFieldPtr: case IrInstructionIdEnumFieldPtr: + case IrInstructionIdUnionFieldPtr: case IrInstructionIdInitEnum: case IrInstructionIdMaybeWrap: case IrInstructionIdErrWrapCode: @@ -15968,6 +16115,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdContainerInitList: case IrInstructionIdContainerInitFields: case IrInstructionIdStructInit: + case IrInstructionIdUnionInit: case IrInstructionIdFieldPtr: case IrInstructionIdElemPtr: case IrInstructionIdVarPtr: @@ -15977,6 +16125,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdArrayLen: case IrInstructionIdStructFieldPtr: case IrInstructionIdEnumFieldPtr: + case IrInstructionIdUnionFieldPtr: case IrInstructionIdArrayType: case IrInstructionIdSliceType: case IrInstructionIdSizeOf: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 1c60d68628..55ad3ceb6c 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -290,6 +290,15 @@ static void ir_print_struct_init(IrPrint *irp, IrInstructionStructInit *instruct fprintf(irp->f, "} // struct init"); } +static void ir_print_union_init(IrPrint *irp, IrInstructionUnionInit *instruction) { + Buf *field_name = instruction->field->name; + + fprintf(irp->f, "%s {", buf_ptr(&instruction->union_type->name)); + fprintf(irp->f, ".%s = ", buf_ptr(field_name)); + ir_print_other_instruction(irp, instruction->init_value); + fprintf(irp->f, "} // union init"); +} + static void ir_print_unreachable(IrPrint *irp, IrInstructionUnreachable *instruction) { fprintf(irp->f, "unreachable"); } @@ -359,6 +368,13 @@ static void ir_print_enum_field_ptr(IrPrint *irp, IrInstructionEnumFieldPtr *ins fprintf(irp->f, ")"); } +static void ir_print_union_field_ptr(IrPrint *irp, IrInstructionUnionFieldPtr *instruction) { + fprintf(irp->f, "@UnionFieldPtr(&"); + ir_print_other_instruction(irp, instruction->union_ptr); + fprintf(irp->f, ".%s", buf_ptr(instruction->field->name)); + fprintf(irp->f, ")"); +} + static void ir_print_set_debug_safety(IrPrint *irp, IrInstructionSetDebugSafety *instruction) { fprintf(irp->f, "@setDebugSafety("); ir_print_other_instruction(irp, instruction->scope_value); @@ -1023,6 +1039,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdStructInit: ir_print_struct_init(irp, (IrInstructionStructInit *)instruction); break; + case IrInstructionIdUnionInit: + ir_print_union_init(irp, (IrInstructionUnionInit *)instruction); + break; case IrInstructionIdUnreachable: ir_print_unreachable(irp, (IrInstructionUnreachable *)instruction); break; @@ -1056,6 +1075,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdEnumFieldPtr: ir_print_enum_field_ptr(irp, (IrInstructionEnumFieldPtr *)instruction); break; + case IrInstructionIdUnionFieldPtr: + ir_print_union_field_ptr(irp, (IrInstructionUnionFieldPtr *)instruction); + break; case IrInstructionIdSetDebugSafety: ir_print_set_debug_safety(irp, (IrInstructionSetDebugSafety *)instruction); break; diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 086c1d6f96..658de77b31 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -403,6 +403,10 @@ unsigned ZigLLVMTag_DW_structure_type(void) { return dwarf::DW_TAG_structure_type; } +unsigned ZigLLVMTag_DW_union_type(void) { + return dwarf::DW_TAG_union_type; +} + ZigLLVMDIBuilder *ZigLLVMCreateDIBuilder(LLVMModuleRef module, bool allow_unresolved) { DIBuilder *di_builder = new DIBuilder(*unwrap(module), allow_unresolved); return reinterpret_cast(di_builder); diff --git a/src/zig_llvm.hpp b/src/zig_llvm.hpp index d7c9784e79..b72b6a889f 100644 --- a/src/zig_llvm.hpp +++ b/src/zig_llvm.hpp @@ -117,6 +117,7 @@ unsigned ZigLLVMEncoding_DW_ATE_signed_char(void); unsigned ZigLLVMLang_DW_LANG_C99(void); unsigned ZigLLVMTag_DW_variable(void); unsigned ZigLLVMTag_DW_structure_type(void); +unsigned ZigLLVMTag_DW_union_type(void); ZigLLVMDIBuilder *ZigLLVMCreateDIBuilder(LLVMModuleRef module, bool allow_unresolved); void ZigLLVMAddModuleDebugInfoFlag(LLVMModuleRef module); diff --git a/test/cases/union.zig b/test/cases/union.zig index 4b8ccb7245..74bda8db6a 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -31,3 +31,16 @@ test "unions embedded in aggregate types" { else => unreachable, } } + + +const Foo = union { + float: f64, + int: i32, +}; + +test "basic unions" { + var foo = Foo { .int = 1 }; + assert(foo.int == 1); + foo.float = 12.34; + assert(foo.float == 12.34); +} diff --git a/test/compile_errors.zig b/test/compile_errors.zig index fa90661158..9e15333750 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -389,8 +389,8 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\ const y = a.bar; \\} , - ".tmp_source.zig:4:6: error: no member named 'foo' in 'A'", - ".tmp_source.zig:5:16: error: no member named 'bar' in 'A'"); + ".tmp_source.zig:4:6: error: no member named 'foo' in struct 'A'", + ".tmp_source.zig:5:16: error: no member named 'bar' in struct 'A'"); cases.add("redefinition of struct", \\const A = struct { x : i32, }; @@ -454,7 +454,7 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\ .foo = 42, \\ }; \\} - , ".tmp_source.zig:10:9: error: no member named 'foo' in 'A'"); + , ".tmp_source.zig:10:9: error: no member named 'foo' in struct 'A'"); cases.add("invalid break expression", \\export fn f() {