From f276fd0f3728bf1a43b185e3e2d33d593309cb2f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 14 Nov 2017 23:53:53 -0500 Subject: [PATCH 01/34] basic union support See #144 --- src/all_types.hpp | 43 +++++++- src/analyze.cpp | 223 ++++++++++++++++++++++++++++++++++++++-- src/analyze.hpp | 1 + src/codegen.cpp | 101 +++++++++++++++++- src/ir.cpp | 171 ++++++++++++++++++++++++++++-- src/ir_print.cpp | 22 ++++ src/zig_llvm.cpp | 4 + src/zig_llvm.hpp | 1 + test/cases/union.zig | 13 +++ test/compile_errors.zig | 6 +- 10 files changed, 559 insertions(+), 26 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 0b4efd8415..ca6c214af8 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -73,6 +73,7 @@ enum ConstParentId { ConstParentIdNone, ConstParentIdStruct, ConstParentIdArray, + ConstParentIdUnion, }; struct ConstParent { @@ -87,6 +88,9 @@ struct ConstParent { ConstExprValue *struct_val; size_t field_index; } p_struct; + struct { + ConstExprValue *union_val; + } p_union; } data; }; @@ -100,6 +104,11 @@ struct ConstStructValue { ConstParent parent; }; +struct ConstUnionValue { + ConstExprValue *value; + ConstParent parent; +}; + enum ConstArraySpecial { ConstArraySpecialNone, ConstArraySpecialUndef, @@ -238,6 +247,7 @@ struct ConstExprValue { ErrorTableEntry *x_pure_err; ConstEnumValue x_enum; ConstStructValue x_struct; + ConstUnionValue x_union; ConstArrayValue x_array; ConstPtrValue x_ptr; ImportTableEntry *x_import; @@ -336,6 +346,12 @@ struct TypeEnumField { uint32_t gen_index; }; +struct TypeUnionField { + Buf *name; + TypeTableEntry *type_entry; + uint32_t gen_index; +}; + enum NodeType { NodeTypeRoot, NodeTypeFnProto, @@ -1026,9 +1042,9 @@ struct TypeTableEntryUnion { ContainerLayout layout; uint32_t src_field_count; uint32_t gen_field_count; - TypeStructField *fields; - uint64_t size_bytes; + TypeUnionField *fields; bool is_invalid; // true if any fields are invalid + ScopeDecls *decls_scope; // set this flag temporarily to detect infinite loops @@ -1039,6 +1055,10 @@ struct TypeTableEntryUnion { bool zero_bits_loop_flag; bool zero_bits_known; + uint32_t abi_alignment; // also figured out with zero_bits pass + + uint32_t size_bytes; + TypeTableEntry *most_aligned_union_member; }; struct FnGenParamInfo { @@ -1796,6 +1816,7 @@ enum IrInstructionId { IrInstructionIdFieldPtr, IrInstructionIdStructFieldPtr, IrInstructionIdEnumFieldPtr, + IrInstructionIdUnionFieldPtr, IrInstructionIdElemPtr, IrInstructionIdVarPtr, IrInstructionIdCall, @@ -1805,6 +1826,7 @@ enum IrInstructionId { IrInstructionIdContainerInitList, IrInstructionIdContainerInitFields, IrInstructionIdStructInit, + IrInstructionIdUnionInit, IrInstructionIdUnreachable, IrInstructionIdTypeOf, IrInstructionIdToPtrType, @@ -2060,6 +2082,14 @@ struct IrInstructionEnumFieldPtr { bool is_const; }; +struct IrInstructionUnionFieldPtr { + IrInstruction base; + + IrInstruction *union_ptr; + TypeUnionField *field; + bool is_const; +}; + struct IrInstructionElemPtr { IrInstruction base; @@ -2150,6 +2180,15 @@ struct IrInstructionStructInit { LLVMValueRef tmp_ptr; }; +struct IrInstructionUnionInit { + IrInstruction base; + + TypeTableEntry *union_type; + TypeUnionField *field; + IrInstruction *init_value; + LLVMValueRef tmp_ptr; +}; + struct IrInstructionUnreachable { IrInstruction base; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 343b1ecdb0..2f7eecaff4 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -992,18 +992,22 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi TypeTableEntryId type_id = container_to_type(kind); TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, scope); + unsigned dwarf_kind; switch (kind) { case ContainerKindStruct: entry->data.structure.decl_node = decl_node; entry->data.structure.layout = layout; + dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindEnum: entry->data.enumeration.decl_node = decl_node; entry->data.enumeration.layout = layout; + dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindUnion: entry->data.unionation.decl_node = decl_node; entry->data.unionation.layout = layout; + dwarf_kind = ZigLLVMTag_DW_union_type(); break; } @@ -1012,7 +1016,7 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi ImportTableEntry *import = get_scope_import(scope); entry->type_ref = LLVMStructCreateNamed(LLVMGetGlobalContext(), name); entry->di_type = ZigLLVMCreateReplaceableCompositeType(g->dbuilder, - ZigLLVMTag_DW_structure_type(), name, + dwarf_kind, name, ZigLLVMFileToScope(import->di_file), import->di_file, (unsigned)(line + 1)); buf_init_from_str(&entry->name, name); @@ -1285,7 +1289,7 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) { return; resolve_enum_zero_bits(g, enum_type); - if (enum_type->data.enumeration.is_invalid) + if (type_is_invalid(enum_type)) return; AstNode *decl_node = enum_type->data.enumeration.decl_node; @@ -1834,7 +1838,140 @@ static void resolve_struct_type(CodeGen *g, TypeTableEntry *struct_type) { } static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { - zig_panic("TODO"); + assert(union_type->id == TypeTableEntryIdUnion); + + if (union_type->data.unionation.complete) + return; + + resolve_union_zero_bits(g, union_type); + if (type_is_invalid(union_type)) + return; + + AstNode *decl_node = union_type->data.unionation.decl_node; + + if (union_type->data.unionation.embedded_in_current) { + if (!union_type->data.unionation.reported_infinite_err) { + union_type->data.unionation.reported_infinite_err = true; + add_node_error(g, decl_node, buf_sprintf("union '%s' contains itself", buf_ptr(&union_type->name))); + } + return; + } + + assert(!union_type->data.unionation.zero_bits_loop_flag); + assert(decl_node->type == NodeTypeContainerDecl); + assert(union_type->di_type); + + uint32_t field_count = union_type->data.unionation.src_field_count; + + assert(union_type->data.unionation.fields); + + uint32_t gen_field_count = union_type->data.unionation.gen_field_count; + ZigLLVMDIType **union_inner_di_types = allocate(gen_field_count); + + TypeTableEntry *most_aligned_union_member = nullptr; + uint64_t size_of_most_aligned_member_in_bits = 0; + uint64_t biggest_align_in_bits = 0; + uint64_t biggest_size_in_bits = 0; + + Scope *scope = &union_type->data.unionation.decls_scope->base; + ImportTableEntry *import = get_scope_import(scope); + + // set temporary flag + union_type->data.unionation.embedded_in_current = true; + + for (uint32_t i = 0; i < field_count; i += 1) { + AstNode *field_node = decl_node->data.container_decl.fields.at(i); + TypeUnionField *type_union_field = &union_type->data.unionation.fields[i]; + TypeTableEntry *field_type = type_union_field->type_entry; + + ensure_complete_type(g, field_type); + if (type_is_invalid(field_type)) { + union_type->data.unionation.is_invalid = true; + continue; + } + + if (!type_has_bits(field_type)) + continue; + + uint64_t store_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref); + uint64_t abi_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref); + + assert(store_size_in_bits > 0); + assert(abi_align_in_bits > 0); + + union_inner_di_types[type_union_field->gen_index] = ZigLLVMCreateDebugMemberType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), buf_ptr(type_union_field->name), + import->di_file, (unsigned)(field_node->line + 1), + store_size_in_bits, + abi_align_in_bits, + 0, + 0, field_type->di_type); + + biggest_size_in_bits = max(biggest_size_in_bits, store_size_in_bits); + + if (!most_aligned_union_member || abi_align_in_bits > biggest_align_in_bits) { + most_aligned_union_member = field_type; + biggest_align_in_bits = abi_align_in_bits; + size_of_most_aligned_member_in_bits = store_size_in_bits; + } + } + + // unset temporary flag + union_type->data.unionation.embedded_in_current = false; + union_type->data.unionation.complete = true; + union_type->data.unionation.size_bytes = biggest_size_in_bits / 8; + union_type->data.unionation.most_aligned_union_member = most_aligned_union_member; + + if (union_type->data.unionation.is_invalid) + return; + + if (union_type->zero_bits) { + union_type->type_ref = LLVMVoidType(); + + uint64_t debug_size_in_bits = 0; + uint64_t debug_align_in_bits = 0; + ZigLLVMDIType **di_root_members = nullptr; + size_t debug_member_count = 0; + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), + buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + debug_size_in_bits, + debug_align_in_bits, + 0, di_root_members, (int)debug_member_count, 0, ""); + + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); + union_type->di_type = replacement_di_type; + return; + } + + assert(most_aligned_union_member != nullptr); + + // create llvm type for union + uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits; + if (padding_in_bits > 0) { + TypeTableEntry *u8_type = get_int_type(g, false, 8); + TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8); + LLVMTypeRef union_element_types[] = { + most_aligned_union_member->type_ref, + padding_array->type_ref, + }; + LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false); + } else { + LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false); + } + + assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits); + assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits); + + // create debug type for union + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types, + gen_field_count, 0, ""); + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); + union_type->di_type = replacement_di_type; } static void resolve_enum_zero_bits(CodeGen *g, TypeTableEntry *enum_type) { @@ -1873,7 +2010,7 @@ static void resolve_enum_zero_bits(CodeGen *g, TypeTableEntry *enum_type) { type_enum_field->value = i; type_ensure_zero_bits_known(g, field_type); - if (field_type->id == TypeTableEntryIdInvalid) { + if (type_is_invalid(field_type)) { enum_type->data.enumeration.is_invalid = true; continue; } @@ -1980,7 +2117,66 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) { } static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) { - zig_panic("TODO resolve_union_zero_bits"); + assert(union_type->id == TypeTableEntryIdUnion); + + if (union_type->data.unionation.zero_bits_known) + return; + + if (union_type->data.unionation.zero_bits_loop_flag) { + union_type->data.unionation.zero_bits_known = true; + return; + } + + union_type->data.unionation.zero_bits_loop_flag = true; + + AstNode *decl_node = union_type->data.unionation.decl_node; + assert(decl_node->type == NodeTypeContainerDecl); + assert(union_type->di_type); + + assert(!union_type->data.unionation.fields); + uint32_t field_count = (uint32_t)decl_node->data.container_decl.fields.length; + union_type->data.unionation.src_field_count = field_count; + union_type->data.unionation.fields = allocate(field_count); + + uint32_t biggest_align_bytes = 0; + + Scope *scope = &union_type->data.unionation.decls_scope->base; + + uint32_t gen_field_index = 0; + for (uint32_t i = 0; i < field_count; i += 1) { + AstNode *field_node = decl_node->data.container_decl.fields.at(i); + TypeUnionField *type_union_field = &union_type->data.unionation.fields[i]; + type_union_field->name = field_node->data.struct_field.name; + TypeTableEntry *field_type = analyze_type_expr(g, scope, field_node->data.struct_field.type); + type_union_field->type_entry = field_type; + + type_ensure_zero_bits_known(g, field_type); + if (type_is_invalid(field_type)) { + union_type->data.unionation.is_invalid = true; + continue; + } + + if (!type_has_bits(field_type)) + continue; + + type_union_field->gen_index = gen_field_index; + gen_field_index += 1; + + uint32_t field_align_bytes = get_abi_alignment(g, field_type); + if (field_align_bytes > biggest_align_bytes) { + biggest_align_bytes = field_align_bytes; + } + } + + union_type->data.unionation.zero_bits_loop_flag = false; + union_type->data.unionation.gen_field_count = gen_field_index; + union_type->zero_bits = (gen_field_index == 0); + union_type->data.unionation.zero_bits_known = true; + + // also compute abi_alignment + if (!union_type->zero_bits) { + union_type->data.unionation.abi_alignment = biggest_align_bytes; + } } static void get_fully_qualified_decl_name_internal(Buf *buf, Scope *scope, uint8_t sep) { @@ -2851,6 +3047,18 @@ TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name) { return nullptr; } +TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name) { + assert(type_entry->id == TypeTableEntryIdUnion); + assert(type_entry->data.unionation.complete); + for (uint32_t i = 0; i < type_entry->data.unionation.src_field_count; i += 1) { + TypeUnionField *field = &type_entry->data.unionation.fields[i]; + if (buf_eql_buf(field->name, name)) { + return field; + } + } + return nullptr; +} + static bool is_container(TypeTableEntry *type_entry) { switch (type_entry->id) { case TypeTableEntryIdInvalid: @@ -4703,6 +4911,8 @@ ConstParent *get_const_val_parent(CodeGen *g, ConstExprValue *value) { return &value->data.x_array.s_none.parent; } else if (type_entry->id == TypeTableEntryIdStruct) { return &value->data.x_struct.parent; + } else if (type_entry->id == TypeTableEntryIdUnion) { + return &value->data.x_union.parent; } return nullptr; } @@ -4914,7 +5124,8 @@ uint32_t get_abi_alignment(CodeGen *g, TypeTableEntry *type_entry) { assert(type_entry->data.enumeration.abi_alignment != 0); return type_entry->data.enumeration.abi_alignment; } else if (type_entry->id == TypeTableEntryIdUnion) { - zig_panic("TODO"); + assert(type_entry->data.unionation.abi_alignment != 0); + return type_entry->data.unionation.abi_alignment; } else if (type_entry->id == TypeTableEntryIdOpaque) { return 1; } else { diff --git a/src/analyze.hpp b/src/analyze.hpp index 4f56640592..b2464af9a0 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -63,6 +63,7 @@ void resolve_container_type(CodeGen *g, TypeTableEntry *type_entry); TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name); ScopeDecls *get_container_scope(TypeTableEntry *type_entry); TypeEnumField *find_enum_type_field(TypeTableEntry *enum_type, Buf *name); +TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name); bool is_container_ref(TypeTableEntry *type_entry); void scan_decls(CodeGen *g, ScopeDecls *decls_scope, AstNode *node); void scan_import(CodeGen *g, ImportTableEntry *import); diff --git a/src/codegen.cpp b/src/codegen.cpp index 17c0ffc653..fc949f2ecd 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2393,6 +2393,27 @@ static LLVMValueRef ir_render_enum_field_ptr(CodeGen *g, IrExecutable *executabl return bitcasted_union_field_ptr; } +static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executable, + IrInstructionUnionFieldPtr *instruction) +{ + TypeTableEntry *union_ptr_type = instruction->union_ptr->value.type; + assert(union_ptr_type->id == TypeTableEntryIdPointer); + TypeTableEntry *union_type = union_ptr_type->data.pointer.child_type; + assert(union_type->id == TypeTableEntryIdUnion); + + TypeUnionField *field = instruction->field; + + if (!type_has_bits(field->type_entry)) + return nullptr; + + LLVMValueRef union_ptr = ir_llvm_value(g, instruction->union_ptr); + LLVMTypeRef field_type_ref = LLVMPointerType(field->type_entry->type_ref, 0); + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, ""); + LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); + + return bitcasted_union_field_ptr; +} + static size_t find_asm_index(CodeGen *g, AstNode *node, AsmToken *tok) { const char *ptr = buf_ptr(node->data.asm_expr.asm_template) + tok->start + 2; size_t len = tok->end - tok->start - 2; @@ -3365,6 +3386,25 @@ static LLVMValueRef ir_render_struct_init(CodeGen *g, IrExecutable *executable, return instruction->tmp_ptr; } +static LLVMValueRef ir_render_union_init(CodeGen *g, IrExecutable *executable, IrInstructionUnionInit *instruction) { + TypeUnionField *type_union_field = instruction->field; + + assert(type_has_bits(type_union_field->type_entry)); + + LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); + LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + + uint32_t field_align_bytes = get_abi_alignment(g, type_union_field->type_entry); + + TypeTableEntry *ptr_type = get_pointer_to_type_extra(g, type_union_field->type_entry, + false, false, field_align_bytes, + 0, 0); + + gen_assign_raw(g, field_ptr, ptr_type, value); + + return instruction->tmp_ptr; +} + static LLVMValueRef ir_render_container_init_list(CodeGen *g, IrExecutable *executable, IrInstructionContainerInitList *instruction) { @@ -3486,6 +3526,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_struct_field_ptr(g, executable, (IrInstructionStructFieldPtr *)instruction); case IrInstructionIdEnumFieldPtr: return ir_render_enum_field_ptr(g, executable, (IrInstructionEnumFieldPtr *)instruction); + case IrInstructionIdUnionFieldPtr: + return ir_render_union_field_ptr(g, executable, (IrInstructionUnionFieldPtr *)instruction); case IrInstructionIdAsm: return ir_render_asm(g, executable, (IrInstructionAsm *)instruction); case IrInstructionIdTestNonNull: @@ -3544,6 +3586,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_init_enum(g, executable, (IrInstructionInitEnum *)instruction); case IrInstructionIdStructInit: return ir_render_struct_init(g, executable, (IrInstructionStructInit *)instruction); + case IrInstructionIdUnionInit: + return ir_render_union_init(g, executable, (IrInstructionUnionInit *)instruction); case IrInstructionIdPtrCast: return ir_render_ptr_cast(g, executable, (IrInstructionPtrCast *)instruction); case IrInstructionIdBitCast: @@ -3595,6 +3639,7 @@ static void ir_render(CodeGen *g, FnTableEntry *fn_entry) { static LLVMValueRef gen_const_ptr_struct_recursive(CodeGen *g, ConstExprValue *struct_const_val, size_t field_index); static LLVMValueRef gen_const_ptr_array_recursive(CodeGen *g, ConstExprValue *array_const_val, size_t index); +static LLVMValueRef gen_const_ptr_union_recursive(CodeGen *g, ConstExprValue *array_const_val); static LLVMValueRef gen_parent_ptr(CodeGen *g, ConstExprValue *val, ConstParent *parent) { switch (parent->id) { @@ -3608,6 +3653,8 @@ static LLVMValueRef gen_parent_ptr(CodeGen *g, ConstExprValue *val, ConstParent case ConstParentIdArray: return gen_const_ptr_array_recursive(g, parent->data.p_array.array_val, parent->data.p_array.elem_index); + case ConstParentIdUnion: + return gen_const_ptr_union_recursive(g, parent->data.p_union.union_val); } zig_unreachable(); } @@ -3637,6 +3684,18 @@ static LLVMValueRef gen_const_ptr_struct_recursive(CodeGen *g, ConstExprValue *s return LLVMConstInBoundsGEP(base_ptr, indices, 2); } +static LLVMValueRef gen_const_ptr_union_recursive(CodeGen *g, ConstExprValue *union_const_val) { + ConstParent *parent = &union_const_val->data.x_union.parent; + LLVMValueRef base_ptr = gen_parent_ptr(g, union_const_val, parent); + + TypeTableEntry *u32 = g->builtin_types.entry_u32; + LLVMValueRef indices[] = { + LLVMConstNull(u32->type_ref), + LLVMConstInt(u32->type_ref, 0, false), + }; + return LLVMConstInBoundsGEP(base_ptr, indices, 2); +} + static LLVMValueRef pack_const_int(CodeGen *g, LLVMTypeRef big_int_type_ref, ConstExprValue *const_val) { switch (const_val->special) { case ConstValSpecialRuntime: @@ -3872,10 +3931,6 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { return LLVMConstNamedStruct(type_entry->type_ref, fields, type_entry->data.structure.gen_field_count); } } - case TypeTableEntryIdUnion: - { - zig_panic("TODO"); - } case TypeTableEntryIdArray: { uint64_t len = type_entry->data.array.len; @@ -3898,6 +3953,41 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { return LLVMConstArray(element_type_ref, values, (unsigned)len); } } + case TypeTableEntryIdUnion: + { + LLVMTypeRef union_type_ref = type_entry->type_ref; + ConstExprValue *payload_value = const_val->data.x_union.value; + assert(payload_value != nullptr); + + if (!type_has_bits(payload_value->type)) { + return LLVMGetUndef(union_type_ref); + } + + uint64_t field_type_bytes = LLVMStoreSizeOfType(g->target_data_ref, payload_value->type->type_ref); + uint64_t pad_bytes = type_entry->data.unionation.size_bytes - field_type_bytes; + + LLVMValueRef correctly_typed_value = gen_const_val(g, payload_value); + + bool make_unnamed_struct = is_llvm_value_unnamed_type(payload_value->type, correctly_typed_value) || + payload_value->type != type_entry->data.unionation.most_aligned_union_member; + + unsigned field_count; + LLVMValueRef fields[2]; + fields[0] = correctly_typed_value; + if (pad_bytes == 0) { + field_count = 1; + } else { + fields[0] = correctly_typed_value; + fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); + field_count = 2; + } + + if (make_unnamed_struct) { + return LLVMConstStruct(fields, field_count, false); + } else { + return LLVMConstNamedStruct(type_entry->type_ref, fields, field_count); + } + } case TypeTableEntryIdEnum: { LLVMTypeRef tag_type_ref = type_entry->data.enumeration.tag_type->type_ref; @@ -4376,6 +4466,9 @@ static void do_code_gen(CodeGen *g) { } else if (instruction->id == IrInstructionIdStructInit) { IrInstructionStructInit *struct_init_instruction = (IrInstructionStructInit *)instruction; slot = &struct_init_instruction->tmp_ptr; + } else if (instruction->id == IrInstructionIdUnionInit) { + IrInstructionUnionInit *union_init_instruction = (IrInstructionUnionInit *)instruction; + slot = &union_init_instruction->tmp_ptr; } else if (instruction->id == IrInstructionIdCall) { IrInstructionCall *call_instruction = (IrInstructionCall *)instruction; slot = &call_instruction->tmp_ptr; diff --git a/src/ir.cpp b/src/ir.cpp index fdaced6806..6df6b4f828 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -227,6 +227,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEnumFieldPtr *) return IrInstructionIdEnumFieldPtr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionFieldPtr *) { + return IrInstructionIdUnionFieldPtr; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionElemPtr *) { return IrInstructionIdElemPtr; } @@ -351,6 +355,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionStructInit *) { return IrInstructionIdStructInit; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInit *) { + return IrInstructionIdUnionInit; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionMinValue *) { return IrInstructionIdMinValue; } @@ -922,6 +930,27 @@ static IrInstruction *ir_build_enum_field_ptr_from(IrBuilder *irb, IrInstruction return new_instruction; } +static IrInstruction *ir_build_union_field_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_ptr, TypeUnionField *field) +{ + IrInstructionUnionFieldPtr *instruction = ir_build_instruction(irb, scope, source_node); + instruction->union_ptr = union_ptr; + instruction->field = field; + + ir_ref_instruction(union_ptr, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstruction *ir_build_union_field_ptr_from(IrBuilder *irb, IrInstruction *old_instruction, + IrInstruction *union_ptr, TypeUnionField *type_union_field) +{ + IrInstruction *new_instruction = ir_build_union_field_ptr(irb, old_instruction->scope, + old_instruction->source_node, union_ptr, type_union_field); + ir_link_new_instruction(new_instruction, old_instruction); + return new_instruction; +} + static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *source_node, FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args, bool is_comptime, bool is_inline) @@ -1112,6 +1141,28 @@ static IrInstruction *ir_build_struct_init_from(IrBuilder *irb, IrInstruction *o return new_instruction; } +static IrInstruction *ir_build_union_init(IrBuilder *irb, Scope *scope, AstNode *source_node, + TypeTableEntry *union_type, TypeUnionField *field, IrInstruction *init_value) +{ + IrInstructionUnionInit *union_init_instruction = ir_build_instruction(irb, scope, source_node); + union_init_instruction->union_type = union_type; + union_init_instruction->field = field; + union_init_instruction->init_value = init_value; + + ir_ref_instruction(init_value, irb->current_basic_block); + + return &union_init_instruction->base; +} + +static IrInstruction *ir_build_union_init_from(IrBuilder *irb, IrInstruction *old_instruction, + TypeTableEntry *union_type, TypeUnionField *field, IrInstruction *init_value) +{ + IrInstruction *new_instruction = ir_build_union_init(irb, old_instruction->scope, + old_instruction->source_node, union_type, field, init_value); + ir_link_new_instruction(new_instruction, old_instruction); + return new_instruction; +} + static IrInstruction *ir_build_unreachable(IrBuilder *irb, Scope *scope, AstNode *source_node) { IrInstructionUnreachable *unreachable_instruction = ir_build_instruction(irb, scope, source_node); @@ -2422,6 +2473,13 @@ static IrInstruction *ir_instruction_enumfieldptr_get_dep(IrInstructionEnumField } } +static IrInstruction *ir_instruction_unionfieldptr_get_dep(IrInstructionUnionFieldPtr *instruction, size_t index) { + switch (index) { + case 0: return instruction->union_ptr; + default: return nullptr; + } +} + static IrInstruction *ir_instruction_elemptr_get_dep(IrInstructionElemPtr *instruction, size_t index) { switch (index) { case 0: return instruction->array_ptr; @@ -2485,6 +2543,13 @@ static IrInstruction *ir_instruction_structinit_get_dep(IrInstructionStructInit return nullptr; } +static IrInstruction *ir_instruction_unioninit_get_dep(IrInstructionUnionInit *instruction, size_t index) { + switch (index) { + case 0: return instruction->init_value; + default: return nullptr; + } +} + static IrInstruction *ir_instruction_unreachable_get_dep(IrInstructionUnreachable *instruction, size_t index) { return nullptr; } @@ -3099,6 +3164,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t return ir_instruction_structfieldptr_get_dep((IrInstructionStructFieldPtr *) instruction, index); case IrInstructionIdEnumFieldPtr: return ir_instruction_enumfieldptr_get_dep((IrInstructionEnumFieldPtr *) instruction, index); + case IrInstructionIdUnionFieldPtr: + return ir_instruction_unionfieldptr_get_dep((IrInstructionUnionFieldPtr *) instruction, index); case IrInstructionIdElemPtr: return ir_instruction_elemptr_get_dep((IrInstructionElemPtr *) instruction, index); case IrInstructionIdVarPtr: @@ -3117,6 +3184,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t return ir_instruction_containerinitfields_get_dep((IrInstructionContainerInitFields *) instruction, index); case IrInstructionIdStructInit: return ir_instruction_structinit_get_dep((IrInstructionStructInit *) instruction, index); + case IrInstructionIdUnionInit: + return ir_instruction_unioninit_get_dep((IrInstructionUnionInit *) instruction, index); case IrInstructionIdUnreachable: return ir_instruction_unreachable_get_dep((IrInstructionUnreachable *) instruction, index); case IrInstructionIdTypeOf: @@ -11417,8 +11486,20 @@ static TypeTableEntry *ir_analyze_container_member_access_inner(IrAnalyze *ira, return ir_analyze_ref(ira, &field_ptr_instruction->base, bound_fn_value, true, false); } } + const char *prefix_name; + if (is_slice(bare_struct_type)) { + prefix_name = ""; + } else if (bare_struct_type->id == TypeTableEntryIdStruct) { + prefix_name = "struct "; + } else if (bare_struct_type->id == TypeTableEntryIdEnum) { + prefix_name = "enum "; + } else if (bare_struct_type->id == TypeTableEntryIdUnion) { + prefix_name = "union "; + } else { + prefix_name = ""; + } ir_add_error_node(ira, field_ptr_instruction->base.source_node, - buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&bare_struct_type->name))); + buf_sprintf("no member named '%s' in %s'%s'", buf_ptr(field_name), prefix_name, buf_ptr(&bare_struct_type->name))); return ira->codegen->builtin_types.entry_invalid; } @@ -11428,14 +11509,13 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field { TypeTableEntry *bare_type = container_ref_type(container_type); ensure_complete_type(ira->codegen, bare_type); + if (type_is_invalid(bare_type)) + return ira->codegen->builtin_types.entry_invalid; assert(container_ptr->value.type->id == TypeTableEntryIdPointer); bool is_const = container_ptr->value.type->data.pointer.is_const; bool is_volatile = container_ptr->value.type->data.pointer.is_volatile; if (bare_type->id == TypeTableEntryIdStruct) { - if (bare_type->data.structure.is_invalid) - return ira->codegen->builtin_types.entry_invalid; - TypeStructField *field = find_struct_type_field(bare_type, field_name); if (field) { bool is_packed = (bare_type->data.structure.layout == ContainerLayoutPacked); @@ -11476,9 +11556,6 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field field_ptr_instruction, container_ptr, container_type); } } else if (bare_type->id == TypeTableEntryIdEnum) { - if (bare_type->data.enumeration.is_invalid) - return ira->codegen->builtin_types.entry_invalid; - TypeEnumField *field = find_enum_type_field(bare_type, field_name); if (field) { ir_build_enum_field_ptr_from(&ira->new_irb, &field_ptr_instruction->base, container_ptr, field); @@ -11489,7 +11566,15 @@ static TypeTableEntry *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field field_ptr_instruction, container_ptr, container_type); } } else if (bare_type->id == TypeTableEntryIdUnion) { - zig_panic("TODO"); + TypeUnionField *field = find_union_type_field(bare_type, field_name); + if (field) { + ir_build_union_field_ptr_from(&ira->new_irb, &field_ptr_instruction->base, container_ptr, field); + return get_pointer_to_type_extra(ira->codegen, field->type_entry, is_const, is_volatile, + get_abi_alignment(ira->codegen, field->type_entry), 0, 0); + } else { + return ir_analyze_container_member_access_inner(ira, bare_type, field_name, + field_ptr_instruction, container_ptr, container_type); + } } else { zig_unreachable(); } @@ -13033,9 +13118,70 @@ static TypeTableEntry *ir_analyze_instruction_ref(IrAnalyze *ira, IrInstructionR return ir_analyze_ref(ira, &ref_instruction->base, value, ref_instruction->is_const, ref_instruction->is_volatile); } +static TypeTableEntry *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrInstruction *instruction, + TypeTableEntry *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields) +{ + assert(container_type->id == TypeTableEntryIdUnion); + + ensure_complete_type(ira->codegen, container_type); + + if (instr_field_count != 1) { + ir_add_error(ira, instruction, + buf_sprintf("union initialization expects exactly one field")); + return ira->codegen->builtin_types.entry_invalid; + } + + IrInstructionContainerInitFieldsField *field = &fields[0]; + IrInstruction *field_value = field->value->other; + if (type_is_invalid(field_value->value.type)) + return ira->codegen->builtin_types.entry_invalid; + + TypeUnionField *type_field = find_union_type_field(container_type, field->name); + if (!type_field) { + ir_add_error_node(ira, field->source_node, + buf_sprintf("no member named '%s' in union '%s'", + buf_ptr(field->name), buf_ptr(&container_type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + + if (type_is_invalid(type_field->type_entry)) + return ira->codegen->builtin_types.entry_invalid; + + IrInstruction *casted_field_value = ir_implicit_cast(ira, field_value, type_field->type_entry); + if (casted_field_value == ira->codegen->invalid_instruction) + return ira->codegen->builtin_types.entry_invalid; + + bool is_comptime = ir_should_inline(ira->new_irb.exec, instruction->scope); + if (is_comptime || casted_field_value->value.special != ConstValSpecialRuntime) { + ConstExprValue *field_val = ir_resolve_const(ira, casted_field_value, UndefOk); + if (!field_val) + return ira->codegen->builtin_types.entry_invalid; + + ConstExprValue *out_val = ir_build_const_from(ira, instruction); + out_val->data.x_union.value = field_val; + + ConstParent *parent = get_const_val_parent(ira->codegen, field_val); + if (parent != nullptr) { + parent->id = ConstParentIdUnion; + parent->data.p_union.union_val = out_val; + } + + return container_type; + } + + IrInstruction *new_instruction = ir_build_union_init_from(&ira->new_irb, instruction, + container_type, type_field, casted_field_value); + + ir_add_alloca(ira, new_instruction, container_type); + return container_type; +} + static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstruction *instruction, TypeTableEntry *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields) { + if (container_type->id == TypeTableEntryIdUnion) { + return ir_analyze_container_init_fields_union(ira, instruction, container_type, instr_field_count, fields); + } if (container_type->id != TypeTableEntryIdStruct || is_slice(container_type)) { ir_add_error(ira, instruction, buf_sprintf("type '%s' does not support struct initialization syntax", @@ -13043,8 +13189,7 @@ static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstru return ira->codegen->builtin_types.entry_invalid; } - if (!type_is_complete(container_type)) - resolve_container_type(ira->codegen, container_type); + ensure_complete_type(ira->codegen, container_type); size_t actual_field_count = container_type->data.structure.src_field_count; @@ -13070,7 +13215,7 @@ static TypeTableEntry *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstru TypeStructField *type_field = find_struct_type_field(container_type, field->name); if (!type_field) { ir_add_error_node(ira, field->source_node, - buf_sprintf("no member named '%s' in '%s'", + buf_sprintf("no member named '%s' in struct '%s'", buf_ptr(field->name), buf_ptr(&container_type->name))); return ira->codegen->builtin_types.entry_invalid; } @@ -15657,8 +15802,10 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi case IrInstructionIdIntToErr: case IrInstructionIdErrToInt: case IrInstructionIdStructInit: + case IrInstructionIdUnionInit: case IrInstructionIdStructFieldPtr: case IrInstructionIdEnumFieldPtr: + case IrInstructionIdUnionFieldPtr: case IrInstructionIdInitEnum: case IrInstructionIdMaybeWrap: case IrInstructionIdErrWrapCode: @@ -15968,6 +16115,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdContainerInitList: case IrInstructionIdContainerInitFields: case IrInstructionIdStructInit: + case IrInstructionIdUnionInit: case IrInstructionIdFieldPtr: case IrInstructionIdElemPtr: case IrInstructionIdVarPtr: @@ -15977,6 +16125,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdArrayLen: case IrInstructionIdStructFieldPtr: case IrInstructionIdEnumFieldPtr: + case IrInstructionIdUnionFieldPtr: case IrInstructionIdArrayType: case IrInstructionIdSliceType: case IrInstructionIdSizeOf: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 1c60d68628..55ad3ceb6c 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -290,6 +290,15 @@ static void ir_print_struct_init(IrPrint *irp, IrInstructionStructInit *instruct fprintf(irp->f, "} // struct init"); } +static void ir_print_union_init(IrPrint *irp, IrInstructionUnionInit *instruction) { + Buf *field_name = instruction->field->name; + + fprintf(irp->f, "%s {", buf_ptr(&instruction->union_type->name)); + fprintf(irp->f, ".%s = ", buf_ptr(field_name)); + ir_print_other_instruction(irp, instruction->init_value); + fprintf(irp->f, "} // union init"); +} + static void ir_print_unreachable(IrPrint *irp, IrInstructionUnreachable *instruction) { fprintf(irp->f, "unreachable"); } @@ -359,6 +368,13 @@ static void ir_print_enum_field_ptr(IrPrint *irp, IrInstructionEnumFieldPtr *ins fprintf(irp->f, ")"); } +static void ir_print_union_field_ptr(IrPrint *irp, IrInstructionUnionFieldPtr *instruction) { + fprintf(irp->f, "@UnionFieldPtr(&"); + ir_print_other_instruction(irp, instruction->union_ptr); + fprintf(irp->f, ".%s", buf_ptr(instruction->field->name)); + fprintf(irp->f, ")"); +} + static void ir_print_set_debug_safety(IrPrint *irp, IrInstructionSetDebugSafety *instruction) { fprintf(irp->f, "@setDebugSafety("); ir_print_other_instruction(irp, instruction->scope_value); @@ -1023,6 +1039,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdStructInit: ir_print_struct_init(irp, (IrInstructionStructInit *)instruction); break; + case IrInstructionIdUnionInit: + ir_print_union_init(irp, (IrInstructionUnionInit *)instruction); + break; case IrInstructionIdUnreachable: ir_print_unreachable(irp, (IrInstructionUnreachable *)instruction); break; @@ -1056,6 +1075,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdEnumFieldPtr: ir_print_enum_field_ptr(irp, (IrInstructionEnumFieldPtr *)instruction); break; + case IrInstructionIdUnionFieldPtr: + ir_print_union_field_ptr(irp, (IrInstructionUnionFieldPtr *)instruction); + break; case IrInstructionIdSetDebugSafety: ir_print_set_debug_safety(irp, (IrInstructionSetDebugSafety *)instruction); break; diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index 086c1d6f96..658de77b31 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -403,6 +403,10 @@ unsigned ZigLLVMTag_DW_structure_type(void) { return dwarf::DW_TAG_structure_type; } +unsigned ZigLLVMTag_DW_union_type(void) { + return dwarf::DW_TAG_union_type; +} + ZigLLVMDIBuilder *ZigLLVMCreateDIBuilder(LLVMModuleRef module, bool allow_unresolved) { DIBuilder *di_builder = new DIBuilder(*unwrap(module), allow_unresolved); return reinterpret_cast(di_builder); diff --git a/src/zig_llvm.hpp b/src/zig_llvm.hpp index d7c9784e79..b72b6a889f 100644 --- a/src/zig_llvm.hpp +++ b/src/zig_llvm.hpp @@ -117,6 +117,7 @@ unsigned ZigLLVMEncoding_DW_ATE_signed_char(void); unsigned ZigLLVMLang_DW_LANG_C99(void); unsigned ZigLLVMTag_DW_variable(void); unsigned ZigLLVMTag_DW_structure_type(void); +unsigned ZigLLVMTag_DW_union_type(void); ZigLLVMDIBuilder *ZigLLVMCreateDIBuilder(LLVMModuleRef module, bool allow_unresolved); void ZigLLVMAddModuleDebugInfoFlag(LLVMModuleRef module); diff --git a/test/cases/union.zig b/test/cases/union.zig index 4b8ccb7245..74bda8db6a 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -31,3 +31,16 @@ test "unions embedded in aggregate types" { else => unreachable, } } + + +const Foo = union { + float: f64, + int: i32, +}; + +test "basic unions" { + var foo = Foo { .int = 1 }; + assert(foo.int == 1); + foo.float = 12.34; + assert(foo.float == 12.34); +} diff --git a/test/compile_errors.zig b/test/compile_errors.zig index fa90661158..9e15333750 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -389,8 +389,8 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\ const y = a.bar; \\} , - ".tmp_source.zig:4:6: error: no member named 'foo' in 'A'", - ".tmp_source.zig:5:16: error: no member named 'bar' in 'A'"); + ".tmp_source.zig:4:6: error: no member named 'foo' in struct 'A'", + ".tmp_source.zig:5:16: error: no member named 'bar' in struct 'A'"); cases.add("redefinition of struct", \\const A = struct { x : i32, }; @@ -454,7 +454,7 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\ .foo = 42, \\ }; \\} - , ".tmp_source.zig:10:9: error: no member named 'foo' in 'A'"); + , ".tmp_source.zig:10:9: error: no member named 'foo' in struct 'A'"); cases.add("invalid break expression", \\export fn f() { From 018cbff438cedc19d0ad18021619ec7ede997307 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 15 Nov 2017 22:52:47 -0500 Subject: [PATCH 02/34] unions have a secret field for the type See #144 --- src/all_types.hpp | 12 +++- src/analyze.cpp | 140 +++++++++++++++++++++++++++++++++++++++---- src/codegen.cpp | 61 +++++++++++++------ test/cases/union.zig | 13 ++++ 4 files changed, 195 insertions(+), 31 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index ca6c214af8..a6a3f4e1e5 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1037,6 +1037,9 @@ struct TypeTableEntryEnumTag { LLVMValueRef name_table; }; +uint32_t type_ptr_hash(const TypeTableEntry *ptr); +bool type_ptr_eql(const TypeTableEntry *a, const TypeTableEntry *b); + struct TypeTableEntryUnion { AstNode *decl_node; ContainerLayout layout; @@ -1044,6 +1047,8 @@ struct TypeTableEntryUnion { uint32_t gen_field_count; TypeUnionField *fields; bool is_invalid; // true if any fields are invalid + TypeTableEntry *tag_type; + LLVMTypeRef union_type_ref; ScopeDecls *decls_scope; @@ -1057,8 +1062,13 @@ struct TypeTableEntryUnion { bool zero_bits_known; uint32_t abi_alignment; // also figured out with zero_bits pass - uint32_t size_bytes; + size_t gen_union_index; + size_t gen_tag_index; + + uint32_t union_size_bytes; TypeTableEntry *most_aligned_union_member; + + HashMap distinct_types = {}; }; struct FnGenParamInfo { diff --git a/src/analyze.cpp b/src/analyze.cpp index 2f7eecaff4..0f8414aaf8 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -992,26 +992,23 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi TypeTableEntryId type_id = container_to_type(kind); TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, scope); - unsigned dwarf_kind; switch (kind) { case ContainerKindStruct: entry->data.structure.decl_node = decl_node; entry->data.structure.layout = layout; - dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindEnum: entry->data.enumeration.decl_node = decl_node; entry->data.enumeration.layout = layout; - dwarf_kind = ZigLLVMTag_DW_structure_type(); break; case ContainerKindUnion: entry->data.unionation.decl_node = decl_node; entry->data.unionation.layout = layout; - dwarf_kind = ZigLLVMTag_DW_union_type(); break; } size_t line = decl_node ? decl_node->line : 0; + unsigned dwarf_kind = ZigLLVMTag_DW_structure_type(); ImportTableEntry *import = get_scope_import(scope); entry->type_ref = LLVMStructCreateNamed(LLVMGetGlobalContext(), name); @@ -1873,6 +1870,11 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { uint64_t biggest_align_in_bits = 0; uint64_t biggest_size_in_bits = 0; + bool auto_layout = (union_type->data.unionation.layout == ContainerLayoutAuto); + ZigLLVMDIEnumerator **di_enumerators = allocate(field_count); + auto distinct_types = &union_type->data.unionation.distinct_types; + distinct_types->init(4); + Scope *scope = &union_type->data.unionation.decls_scope->base; ImportTableEntry *import = get_scope_import(scope); @@ -1893,6 +1895,11 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { if (!type_has_bits(field_type)) continue; + size_t distinct_type_index = distinct_types->size(); + if (distinct_types->put_unique(field_type, distinct_type_index) == nullptr) { + di_enumerators[i] = ZigLLVMCreateDebugEnumerator(g->dbuilder, buf_ptr(&field_type->name), distinct_type_index); + } + uint64_t store_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref); uint64_t abi_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref); @@ -1919,7 +1926,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { // unset temporary flag union_type->data.unionation.embedded_in_current = false; union_type->data.unionation.complete = true; - union_type->data.unionation.size_bytes = biggest_size_in_bits / 8; + union_type->data.unionation.union_size_bytes = biggest_size_in_bits / 8; union_type->data.unionation.most_aligned_union_member = most_aligned_union_member; if (union_type->data.unionation.is_invalid) @@ -1947,8 +1954,42 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { assert(most_aligned_union_member != nullptr); - // create llvm type for union + bool want_safety = (distinct_types->size() > 1) && auto_layout; uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits; + + + if (!want_safety) { + if (padding_in_bits > 0) { + TypeTableEntry *u8_type = get_int_type(g, false, 8); + TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8); + LLVMTypeRef union_element_types[] = { + most_aligned_union_member->type_ref, + padding_array->type_ref, + }; + LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false); + } else { + LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false); + } + union_type->data.unionation.union_type_ref = union_type->type_ref; + union_type->data.unionation.gen_tag_index = SIZE_MAX; + union_type->data.unionation.gen_union_index = SIZE_MAX; + + assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits); + assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits); + + // create debug type for union + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types, + gen_field_count, 0, ""); + + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); + union_type->di_type = replacement_di_type; + return; + } + + LLVMTypeRef union_type_ref; if (padding_in_bits > 0) { TypeTableEntry *u8_type = get_int_type(g, false, 8); TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8); @@ -1956,20 +1997,87 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { most_aligned_union_member->type_ref, padding_array->type_ref, }; - LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false); + union_type_ref = LLVMStructType(union_element_types, 2, false); } else { - LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false); + union_type_ref = most_aligned_union_member->type_ref; + } + union_type->data.unionation.union_type_ref = union_type_ref; + + assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type_ref) >= biggest_align_in_bits); + assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type_ref) >= biggest_size_in_bits); + + // create llvm type for root struct + TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, distinct_types->size() - 1); + TypeTableEntry *tag_type_entry = tag_int_type; + union_type->data.unionation.tag_type = tag_type_entry; + uint64_t align_of_tag_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, tag_int_type->type_ref); + + if (align_of_tag_in_bits >= biggest_align_in_bits) { + union_type->data.unionation.gen_tag_index = 0; + union_type->data.unionation.gen_union_index = 1; + } else { + union_type->data.unionation.gen_union_index = 0; + union_type->data.unionation.gen_tag_index = 1; } - assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits); - assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits); + LLVMTypeRef root_struct_element_types[2]; + root_struct_element_types[union_type->data.unionation.gen_tag_index] = tag_type_entry->type_ref; + root_struct_element_types[union_type->data.unionation.gen_union_index] = union_type_ref; + LLVMStructSetBody(union_type->type_ref, root_struct_element_types, 2, false); + + + // create debug type for root struct + + // create debug type for tag + uint64_t tag_debug_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, tag_type_entry->type_ref); + uint64_t tag_debug_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, tag_type_entry->type_ref); + ZigLLVMDIType *tag_di_type = ZigLLVMCreateDebugEnumerationType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), "AnonEnum", + import->di_file, (unsigned)(decl_node->line + 1), + tag_debug_size_in_bits, tag_debug_align_in_bits, di_enumerators, distinct_types->size(), + tag_type_entry->di_type, ""); // create debug type for union - ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, - ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name), + ZigLLVMDIType *union_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), "AnonUnion", import->di_file, (unsigned)(decl_node->line + 1), biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types, gen_field_count, 0, ""); + + uint64_t union_offset_in_bits = 8*LLVMOffsetOfElement(g->target_data_ref, union_type->type_ref, + union_type->data.unionation.gen_union_index); + uint64_t tag_offset_in_bits = 8*LLVMOffsetOfElement(g->target_data_ref, union_type->type_ref, + union_type->data.unionation.gen_tag_index); + + ZigLLVMDIType *union_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), "union_field", + import->di_file, (unsigned)(decl_node->line + 1), + biggest_size_in_bits, + biggest_align_in_bits, + union_offset_in_bits, + 0, union_di_type); + ZigLLVMDIType *tag_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder, + ZigLLVMTypeToScope(union_type->di_type), "tag_field", + import->di_file, (unsigned)(decl_node->line + 1), + tag_debug_size_in_bits, + tag_debug_align_in_bits, + tag_offset_in_bits, + 0, tag_di_type); + + ZigLLVMDIType *di_root_members[2]; + di_root_members[union_type->data.unionation.gen_tag_index] = tag_member_di_type; + di_root_members[union_type->data.unionation.gen_union_index] = union_member_di_type; + + uint64_t debug_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref); + uint64_t debug_align_in_bits = 8*LLVMABISizeOfType(g->target_data_ref, union_type->type_ref); + ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugStructType(g->dbuilder, + ZigLLVMFileToScope(import->di_file), + buf_ptr(&union_type->name), + import->di_file, (unsigned)(decl_node->line + 1), + debug_size_in_bits, + debug_align_in_bits, + 0, nullptr, di_root_members, 2, 0, nullptr, ""); + ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type); union_type->di_type = replacement_di_type; } @@ -5140,3 +5248,11 @@ TypeTableEntry *get_align_amt_type(CodeGen *g) { } return g->align_amt_type; } + +uint32_t type_ptr_hash(const TypeTableEntry *ptr) { + return hash_ptr((void*)ptr); +} + +bool type_ptr_eql(const TypeTableEntry *a, const TypeTableEntry *b) { + return a == b; +} diff --git a/src/codegen.cpp b/src/codegen.cpp index fc949f2ecd..26a2590e44 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2408,9 +2408,15 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab LLVMValueRef union_ptr = ir_llvm_value(g, instruction->union_ptr); LLVMTypeRef field_type_ref = LLVMPointerType(field->type_entry->type_ref, 0); - LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, ""); - LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); + if (union_type->data.unionation.gen_tag_index == SIZE_MAX) { + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, ""); + LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); + return bitcasted_union_field_ptr; + } + + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, ""); + LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); return bitcasted_union_field_ptr; } @@ -3955,7 +3961,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { } case TypeTableEntryIdUnion: { - LLVMTypeRef union_type_ref = type_entry->type_ref; + LLVMTypeRef union_type_ref = type_entry->data.unionation.union_type_ref; ConstExprValue *payload_value = const_val->data.x_union.value; assert(payload_value != nullptr); @@ -3964,29 +3970,48 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { } uint64_t field_type_bytes = LLVMStoreSizeOfType(g->target_data_ref, payload_value->type->type_ref); - uint64_t pad_bytes = type_entry->data.unionation.size_bytes - field_type_bytes; - + uint64_t pad_bytes = type_entry->data.unionation.union_size_bytes - field_type_bytes; LLVMValueRef correctly_typed_value = gen_const_val(g, payload_value); - bool make_unnamed_struct = is_llvm_value_unnamed_type(payload_value->type, correctly_typed_value) || payload_value->type != type_entry->data.unionation.most_aligned_union_member; - unsigned field_count; - LLVMValueRef fields[2]; - fields[0] = correctly_typed_value; - if (pad_bytes == 0) { - field_count = 1; - } else { + LLVMValueRef union_value_ref; + { + unsigned field_count; + LLVMValueRef fields[2]; fields[0] = correctly_typed_value; - fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); - field_count = 2; + if (pad_bytes == 0) { + field_count = 1; + } else { + fields[0] = correctly_typed_value; + fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); + field_count = 2; + } + + if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) { + union_value_ref = LLVMConstStruct(fields, field_count, false); + } else { + union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count); + } } - if (make_unnamed_struct) { - return LLVMConstStruct(fields, field_count, false); - } else { - return LLVMConstNamedStruct(type_entry->type_ref, fields, field_count); + if (type_entry->data.unionation.gen_tag_index == SIZE_MAX) { + return union_value_ref; } + + size_t distinct_type_index = type_entry->data.unionation.distinct_types.get(const_val->data.x_union.value->type); + LLVMValueRef tag_value = LLVMConstInt(type_entry->data.unionation.tag_type->type_ref, distinct_type_index, false); + + LLVMValueRef fields[2]; + fields[type_entry->data.unionation.gen_union_index] = union_value_ref; + fields[type_entry->data.unionation.gen_tag_index] = tag_value; + + if (make_unnamed_struct) { + return LLVMConstStruct(fields, 2, false); + } else { + return LLVMConstNamedStruct(type_entry->type_ref, fields, 2); + } + } case TypeTableEntryIdEnum: { diff --git a/test/cases/union.zig b/test/cases/union.zig index 74bda8db6a..377374c157 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -44,3 +44,16 @@ test "basic unions" { foo.float = 12.34; assert(foo.float == 12.34); } + + +const FooExtern = extern union { + float: f64, + int: i32, +}; + +test "basic extern unions" { + var foo = FooExtern { .int = 1 }; + assert(foo.int == 1); + foo.float = 12.34; + assert(foo.float == 12.34); +} From f12d36641f67564d2103f75ed7a5445219197db5 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 10:06:58 -0500 Subject: [PATCH 03/34] union secret field is the tag index instead of distinct type index See #144 --- src/all_types.hpp | 6 +++--- src/analyze.cpp | 18 ++++++++---------- src/codegen.cpp | 5 ++--- src/ir.cpp | 3 ++- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index a6a3f4e1e5..86c9720f69 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -105,7 +105,8 @@ struct ConstStructValue { }; struct ConstUnionValue { - ConstExprValue *value; + uint64_t tag; + ConstExprValue *payload; ConstParent parent; }; @@ -349,6 +350,7 @@ struct TypeEnumField { struct TypeUnionField { Buf *name; TypeTableEntry *type_entry; + uint32_t value; uint32_t gen_index; }; @@ -1067,8 +1069,6 @@ struct TypeTableEntryUnion { uint32_t union_size_bytes; TypeTableEntry *most_aligned_union_member; - - HashMap distinct_types = {}; }; struct FnGenParamInfo { diff --git a/src/analyze.cpp b/src/analyze.cpp index 0f8414aaf8..ebad9fe0cb 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1872,8 +1872,6 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { bool auto_layout = (union_type->data.unionation.layout == ContainerLayoutAuto); ZigLLVMDIEnumerator **di_enumerators = allocate(field_count); - auto distinct_types = &union_type->data.unionation.distinct_types; - distinct_types->init(4); Scope *scope = &union_type->data.unionation.decls_scope->base; ImportTableEntry *import = get_scope_import(scope); @@ -1895,10 +1893,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { if (!type_has_bits(field_type)) continue; - size_t distinct_type_index = distinct_types->size(); - if (distinct_types->put_unique(field_type, distinct_type_index) == nullptr) { - di_enumerators[i] = ZigLLVMCreateDebugEnumerator(g->dbuilder, buf_ptr(&field_type->name), distinct_type_index); - } + di_enumerators[i] = ZigLLVMCreateDebugEnumerator(g->dbuilder, buf_ptr(type_union_field->name), i); uint64_t store_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref); uint64_t abi_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref); @@ -1954,7 +1949,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { assert(most_aligned_union_member != nullptr); - bool want_safety = (distinct_types->size() > 1) && auto_layout; + bool want_safety = auto_layout && (field_count >= 2); uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits; @@ -2007,7 +2002,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type_ref) >= biggest_size_in_bits); // create llvm type for root struct - TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, distinct_types->size() - 1); + TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, field_count - 1); TypeTableEntry *tag_type_entry = tag_int_type; union_type->data.unionation.tag_type = tag_type_entry; uint64_t align_of_tag_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, tag_int_type->type_ref); @@ -2034,7 +2029,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) { ZigLLVMDIType *tag_di_type = ZigLLVMCreateDebugEnumerationType(g->dbuilder, ZigLLVMTypeToScope(union_type->di_type), "AnonEnum", import->di_file, (unsigned)(decl_node->line + 1), - tag_debug_size_in_bits, tag_debug_align_in_bits, di_enumerators, distinct_types->size(), + tag_debug_size_in_bits, tag_debug_align_in_bits, di_enumerators, field_count, tag_type_entry->di_type, ""); // create debug type for union @@ -2257,6 +2252,7 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) { type_union_field->name = field_node->data.struct_field.name; TypeTableEntry *field_type = analyze_type_expr(g, scope, field_node->data.struct_field.type); type_union_field->type_entry = field_type; + type_union_field->value = i; type_ensure_zero_bits_known(g, field_type); if (type_is_invalid(field_type)) { @@ -2276,9 +2272,11 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) { } } + bool auto_layout = (union_type->data.unionation.layout == ContainerLayoutAuto); + union_type->data.unionation.zero_bits_loop_flag = false; union_type->data.unionation.gen_field_count = gen_field_index; - union_type->zero_bits = (gen_field_index == 0); + union_type->zero_bits = (gen_field_index == 0 && (field_count < 2 || !auto_layout)); union_type->data.unionation.zero_bits_known = true; // also compute abi_alignment diff --git a/src/codegen.cpp b/src/codegen.cpp index 26a2590e44..3777c3a87a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3962,7 +3962,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { case TypeTableEntryIdUnion: { LLVMTypeRef union_type_ref = type_entry->data.unionation.union_type_ref; - ConstExprValue *payload_value = const_val->data.x_union.value; + ConstExprValue *payload_value = const_val->data.x_union.payload; assert(payload_value != nullptr); if (!type_has_bits(payload_value->type)) { @@ -3999,8 +3999,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { return union_value_ref; } - size_t distinct_type_index = type_entry->data.unionation.distinct_types.get(const_val->data.x_union.value->type); - LLVMValueRef tag_value = LLVMConstInt(type_entry->data.unionation.tag_type->type_ref, distinct_type_index, false); + LLVMValueRef tag_value = LLVMConstInt(type_entry->data.unionation.tag_type->type_ref, const_val->data.x_union.tag, false); LLVMValueRef fields[2]; fields[type_entry->data.unionation.gen_union_index] = union_value_ref; diff --git a/src/ir.cpp b/src/ir.cpp index 6df6b4f828..fa59aa03f2 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -13158,7 +13158,8 @@ static TypeTableEntry *ir_analyze_container_init_fields_union(IrAnalyze *ira, Ir return ira->codegen->builtin_types.entry_invalid; ConstExprValue *out_val = ir_build_const_from(ira, instruction); - out_val->data.x_union.value = field_val; + out_val->data.x_union.payload = field_val; + out_val->data.x_union.tag = type_field->value; ConstParent *parent = get_const_val_parent(ira->codegen, field_val); if (parent != nullptr) { From e26ccd5166000f81a589c446d04102c21045bff6 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 21:15:15 -0500 Subject: [PATCH 04/34] debug safety for unions --- src/all_types.hpp | 1 + src/codegen.cpp | 37 ++++++++++++++++++++++++++----------- test/cases/union.zig | 2 +- test/debug_safety.zig | 20 ++++++++++++++++++++ 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 86c9720f69..2b09131bef 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1317,6 +1317,7 @@ enum PanicMsgId { PanicMsgIdUnwrapMaybeFail, PanicMsgIdInvalidErrorCode, PanicMsgIdIncorrectAlignment, + PanicMsgIdBadUnionField, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 3777c3a87a..eb56d26cae 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -810,6 +810,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("invalid error code"); case PanicMsgIdIncorrectAlignment: return buf_create_from_str("incorrect alignment"); + case PanicMsgIdBadUnionField: + return buf_create_from_str("access of inactive union field"); } zig_unreachable(); } @@ -2415,6 +2417,23 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab return bitcasted_union_field_ptr; } + if (ir_want_debug_safety(g, &instruction->base)) { + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_tag_index, ""); + LLVMValueRef tag_value = gen_load_untyped(g, tag_field_ptr, 0, false, ""); + LLVMValueRef expected_tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref, + field->value, false); + + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckOk"); + LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckFail"); + LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, tag_value, expected_tag_value, ""); + LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block); + + LLVMPositionBuilderAtEnd(g->builder, bad_block); + gen_debug_safety_crash(g, PanicMsgIdBadUnionField); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, ""); LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, ""); return bitcasted_union_field_ptr; @@ -3977,21 +3996,17 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) { LLVMValueRef union_value_ref; { - unsigned field_count; - LLVMValueRef fields[2]; - fields[0] = correctly_typed_value; if (pad_bytes == 0) { - field_count = 1; + union_value_ref = correctly_typed_value; } else { + LLVMValueRef fields[2]; fields[0] = correctly_typed_value; fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes)); - field_count = 2; - } - - if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) { - union_value_ref = LLVMConstStruct(fields, field_count, false); - } else { - union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count); + if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) { + union_value_ref = LLVMConstStruct(fields, 2, false); + } else { + union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, 2); + } } } diff --git a/test/cases/union.zig b/test/cases/union.zig index 377374c157..1abebb3b30 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -41,7 +41,7 @@ const Foo = union { test "basic unions" { var foo = Foo { .int = 1 }; assert(foo.int == 1); - foo.float = 12.34; + foo = Foo {.float = 12.34}; assert(foo.float == 12.34); } diff --git a/test/debug_safety.zig b/test/debug_safety.zig index 9e9ff98349..36f8d020c3 100644 --- a/test/debug_safety.zig +++ b/test/debug_safety.zig @@ -260,4 +260,24 @@ pub fn addCases(cases: &tests.CompareOutputContext) { \\ return int_slice[0]; \\} ); + + cases.addDebugSafety("bad union field access", + \\pub fn panic(message: []const u8) -> noreturn { + \\ @import("std").os.exit(126); + \\} + \\ + \\const Foo = union { + \\ float: f32, + \\ int: u32, + \\}; + \\ + \\pub fn main() -> %void { + \\ var f = Foo { .int = 42 }; + \\ bar(&f); + \\} + \\ + \\fn bar(f: &Foo) { + \\ f.float = 12.34; + \\} + ); } From 5d2ba056c801f46a07182a05c07887e06fd197fa Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 22:06:08 -0500 Subject: [PATCH 05/34] fix codegen for union init with runtime value see #144 --- src/codegen.cpp | 27 ++++++++++++++++++++++----- test/cases/union.zig | 18 ++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/codegen.cpp b/src/codegen.cpp index eb56d26cae..680f5a9e35 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3414,17 +3414,34 @@ static LLVMValueRef ir_render_struct_init(CodeGen *g, IrExecutable *executable, static LLVMValueRef ir_render_union_init(CodeGen *g, IrExecutable *executable, IrInstructionUnionInit *instruction) { TypeUnionField *type_union_field = instruction->field; - assert(type_has_bits(type_union_field->type_entry)); - - LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); - LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + if (!type_has_bits(type_union_field->type_entry)) + return nullptr; uint32_t field_align_bytes = get_abi_alignment(g, type_union_field->type_entry); - TypeTableEntry *ptr_type = get_pointer_to_type_extra(g, type_union_field->type_entry, false, false, field_align_bytes, 0, 0); + LLVMValueRef uncasted_union_ptr; + // Even if safety is off in this block, if the union type has the safety field, we have to populate it + // correctly. Otherwise safety code somewhere other than here could fail. + TypeTableEntry *union_type = instruction->union_type; + if (union_type->data.unionation.gen_tag_index != SIZE_MAX) { + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, + union_type->data.unionation.gen_tag_index, ""); + LLVMValueRef tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref, + type_union_field->value, false); + gen_store_untyped(g, tag_value, tag_field_ptr, 0, false); + + uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, + (unsigned)union_type->data.unionation.gen_union_index, ""); + } else { + uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); + } + + LLVMValueRef field_ptr = LLVMBuildBitCast(g->builder, uncasted_union_ptr, ptr_type->type_ref, ""); + LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + gen_assign_raw(g, field_ptr, ptr_type, value); return instruction->tmp_ptr; diff --git a/test/cases/union.zig b/test/cases/union.zig index 1abebb3b30..4044721582 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -45,6 +45,23 @@ test "basic unions" { assert(foo.float == 12.34); } +test "init union with runtime value" { + var foo: Foo = undefined; + + setFloat(&foo, 12.34); + assert(foo.float == 12.34); + + setInt(&foo, 42); + assert(foo.int == 42); +} + +fn setFloat(foo: &Foo, x: f64) { + *foo = Foo { .float = x }; +} + +fn setInt(foo: &Foo, x: i32) { + *foo = Foo { .int = x }; +} const FooExtern = extern union { float: f64, @@ -57,3 +74,4 @@ test "basic extern unions" { foo.float = 12.34; assert(foo.float == 12.34); } + From 1473eb9ae0196db729716aa0d29f5ce263412307 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 22:13:20 -0500 Subject: [PATCH 06/34] add documentation placeholders for unions --- doc/langref.html.in | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/langref.html.in b/doc/langref.html.in index 9aa142fc46..446011541b 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -75,6 +75,7 @@
  • Slices
  • struct
  • enum
  • +
  • union
  • switch
  • while
  • for
  • @@ -209,6 +210,7 @@
  • Invalid Error Code
  • Invalid Enum Cast
  • Incorrect Pointer Alignment
  • +
  • Wrong Union Field Access
  • Memory
  • @@ -2189,6 +2191,8 @@ Test 4/4 enum builtins...OK
  • @enumTagName
  • @memberCount
  • +

    union

    +

    TODO union documentation

    switch

    const assert = @import("std").debug.assert;
     const builtin = @import("builtin");
    @@ -5117,6 +5121,9 @@ comptime {
           

    Incorrect Pointer Alignment

    TODO

    +

    Wrong Union Field Access

    +

    TODO

    +

    Memory

    TODO: explain no default allocator in zig

    TODO: show how to use the allocator interface

    From a7d07d412c4995c7858cd65558edce64ccb91911 Mon Sep 17 00:00:00 2001 From: dimenus Date: Wed, 15 Nov 2017 14:46:49 -0600 Subject: [PATCH 07/34] Added DLL loading capability in windows to the std lib. --- std/os/index.zig | 2 ++ std/os/windows/index.zig | 6 ++++++ std/os/windows/util.zig | 23 +++++++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/std/os/index.zig b/std/os/index.zig index c3f33c4ccd..34de6f3bbf 100644 --- a/std/os/index.zig +++ b/std/os/index.zig @@ -31,6 +31,8 @@ pub const windowsWaitSingle = windows_util.windowsWaitSingle; pub const windowsWrite = windows_util.windowsWrite; pub const windowsIsCygwinPty = windows_util.windowsIsCygwinPty; pub const windowsOpen = windows_util.windowsOpen; +pub const windowsLoadDll = windows_util.windowsLoadDll; +pub const windowsUnloadDll = windows_util.windowsUnloadDll; pub const createWindowsEnvBlock = windows_util.createWindowsEnvBlock; pub const FileHandle = if (is_windows) windows.HANDLE else i32; diff --git a/std/os/windows/index.zig b/std/os/windows/index.zig index 913cc79801..0ce3794cdc 100644 --- a/std/os/windows/index.zig +++ b/std/os/windows/index.zig @@ -84,6 +84,11 @@ pub extern "kernel32" stdcallcc fn WriteFile(in_hFile: HANDLE, in_lpBuffer: &con in_nNumberOfBytesToWrite: DWORD, out_lpNumberOfBytesWritten: ?&DWORD, in_out_lpOverlapped: ?&OVERLAPPED) -> BOOL; +//TODO: call unicode versions instead of relying on ANSI code page +pub extern "kernel32" stdcallcc fn LoadLibraryA(lpLibFileName: LPCSTR) -> ?HMODULE; + +pub extern "kernel32" stdcallcc fn FreeLibrary(hModule: HMODULE) -> BOOL; + pub extern "user32" stdcallcc fn MessageBoxA(hWnd: ?HANDLE, lpText: ?LPCTSTR, lpCaption: ?LPCTSTR, uType: UINT) -> c_int; pub const PROV_RSA_FULL = 1; @@ -97,6 +102,7 @@ pub const FLOAT = f32; pub const HANDLE = &c_void; pub const HCRYPTPROV = ULONG_PTR; pub const HINSTANCE = &@OpaqueType(); +pub const HMODULE = &@OpaqueType(); pub const INT = c_int; pub const LPBYTE = &BYTE; pub const LPCH = &CHAR; diff --git a/std/os/windows/util.zig b/std/os/windows/util.zig index de2babe8d7..b3fc095d43 100644 --- a/std/os/windows/util.zig +++ b/std/os/windows/util.zig @@ -4,6 +4,7 @@ const windows = std.os.windows; const assert = std.debug.assert; const mem = std.mem; const BufMap = std.BufMap; +const cstr = std.cstr; error WaitAbandoned; error WaitTimeOut; @@ -149,3 +150,25 @@ pub fn createWindowsEnvBlock(allocator: &mem.Allocator, env_map: &const BufMap) result[i] = 0; return result; } + +error DllNotFound; +pub fn windowsLoadDll(allocator: &mem.Allocator, dll_path: []const u8) -> %windows.HMODULE { + const padded_buff = %return cstr.addNullByte(allocator, dll_path); + defer allocator.free(padded_buff); + return windows.LoadLibraryA(padded_buff.ptr) ?? error.DllNotFound; +} + +pub fn windowsUnloadDll(hModule: windows.HMODULE) { + assert(windows.FreeLibrary(hModule)!= 0); +} + + +test "InvalidDll" { + const DllName = "asdf.dll"; + const allocator = std.debug.global_allocator; + const handle = os.windowsLoadDll(allocator, DllName) %% |err| { + assert(err == error.DllNotFound); + return; + }; +} + From b50c676f760272b014e3c6320b13b979bfc864c7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 16 Nov 2017 23:54:33 -0500 Subject: [PATCH 08/34] add parse-c support for unions --- src/parsec.cpp | 26 ++++++++++++++++++-------- test/parsec.zig | 15 +++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/parsec.cpp b/src/parsec.cpp index adcae0946b..4ffee98cc7 100644 --- a/src/parsec.cpp +++ b/src/parsec.cpp @@ -680,11 +680,10 @@ static AstNode *trans_type(Context *c, const Type *ty, const SourceLocation &sou const ElaboratedType *elaborated_ty = static_cast(ty); switch (elaborated_ty->getKeyword()) { case ETK_Struct: - return trans_qual_type(c, elaborated_ty->getNamedType(), source_loc); case ETK_Enum: + case ETK_Union: return trans_qual_type(c, elaborated_ty->getNamedType(), source_loc); case ETK_Interface: - case ETK_Union: case ETK_Class: case ETK_Typename: case ETK_None: @@ -2946,15 +2945,24 @@ static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl) { const char *raw_name = decl_name(record_decl); - if (!record_decl->isStruct()) { - emit_warning(c, record_decl->getLocation(), "skipping record %s, not a struct", raw_name); + const char *container_kind_name; + ContainerKind container_kind; + if (record_decl->isUnion()) { + container_kind_name = "union"; + container_kind = ContainerKindUnion; + } else if (record_decl->isStruct()) { + container_kind_name = "struct"; + container_kind = ContainerKindStruct; + } else { + emit_warning(c, record_decl->getLocation(), "skipping record %s, not a struct or union", raw_name); c->decl_table.put(record_decl->getCanonicalDecl(), nullptr); return nullptr; } bool is_anonymous = record_decl->isAnonymousStructOrUnion() || raw_name[0] == 0; Buf *bare_name = is_anonymous ? nullptr : buf_create_from_str(raw_name); - Buf *full_type_name = (bare_name == nullptr) ? nullptr : buf_sprintf("struct_%s", buf_ptr(bare_name)); + Buf *full_type_name = (bare_name == nullptr) ? + nullptr : buf_sprintf("%s_%s", container_kind_name, buf_ptr(bare_name)); RecordDecl *record_def = record_decl->getDefinition(); if (record_def == nullptr) { @@ -2970,14 +2978,15 @@ static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl) { const FieldDecl *field_decl = *it; if (field_decl->isBitField()) { - emit_warning(c, field_decl->getLocation(), "struct %s demoted to opaque type - has bitfield", + emit_warning(c, field_decl->getLocation(), "%s %s demoted to opaque type - has bitfield", + container_kind_name, is_anonymous ? "(anon)" : buf_ptr(bare_name)); return demote_struct_to_opaque(c, record_decl, full_type_name, bare_name); } } AstNode *struct_node = trans_create_node(c, NodeTypeContainerDecl); - struct_node->data.container_decl.kind = ContainerKindStruct; + struct_node->data.container_decl.kind = container_kind; struct_node->data.container_decl.layout = ContainerLayoutExtern; // TODO handle attribute packed @@ -3004,7 +3013,8 @@ static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl) { if (field_node->data.struct_field.type == nullptr) { emit_warning(c, field_decl->getLocation(), - "struct %s demoted to opaque type - unresolved type", + "%s %s demoted to opaque type - unresolved type", + container_kind_name, is_anonymous ? "(anon)" : buf_ptr(bare_name)); return demote_struct_to_opaque(c, record_decl, full_type_name, bare_name); diff --git a/test/parsec.zig b/test/parsec.zig index f9e90cb705..1fe55290bb 100644 --- a/test/parsec.zig +++ b/test/parsec.zig @@ -857,6 +857,21 @@ pub fn addCases(cases: &tests.ParseCContext) { \\ (*(??x)) = 1; \\} ); + + cases.add("simple union", + \\union Foo { + \\ int x; + \\ double y; + \\}; + , + \\pub const union_Foo = extern union { + \\ x: c_int, + \\ y: f64, + \\}; + , + \\pub const Foo = union_Foo; + ); + } From 339d48ac1558dcd1977574372becd21f7fc4a075 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 17 Nov 2017 12:11:03 -0500 Subject: [PATCH 09/34] parse-c: support address of operator --- src/parsec.cpp | 12 +++++++++--- test/parsec.zig | 13 +++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/parsec.cpp b/src/parsec.cpp index 4ffee98cc7..79ba2ab990 100644 --- a/src/parsec.cpp +++ b/src/parsec.cpp @@ -1586,12 +1586,18 @@ static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *bloc emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_PreDec"); return nullptr; case UO_AddrOf: - emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_AddrOf"); - return nullptr; + { + AstNode *value_node = trans_expr(c, result_used, block, stmt->getSubExpr(), TransLValue); + if (value_node == nullptr) + return value_node; + return trans_create_node_addr_of(c, false, false, value_node); + } case UO_Deref: { - bool is_fn_ptr = qual_type_is_fn_ptr(c, stmt->getSubExpr()->getType()); AstNode *value_node = trans_expr(c, result_used, block, stmt->getSubExpr(), TransRValue); + if (value_node == nullptr) + return nullptr; + bool is_fn_ptr = qual_type_is_fn_ptr(c, stmt->getSubExpr()->getType()); if (is_fn_ptr) return value_node; AstNode *unwrapped = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, value_node); diff --git a/test/parsec.zig b/test/parsec.zig index 1fe55290bb..f830131262 100644 --- a/test/parsec.zig +++ b/test/parsec.zig @@ -872,6 +872,19 @@ pub fn addCases(cases: &tests.ParseCContext) { \\pub const Foo = union_Foo; ); + cases.add("address of operator", + \\int foo(void) { + \\ int x = 1234; + \\ int *ptr = &x; + \\ return *ptr; + \\} + , + \\pub fn foo() -> c_int { + \\ var x: c_int = 1234; + \\ var ptr: ?&c_int = &x; + \\ return *(??ptr); + \\} + ); } From a44283b0b2e585d7e15d7c8e6574411b75c12a0a Mon Sep 17 00:00:00 2001 From: Josh Wolfe Date: Fri, 17 Nov 2017 23:42:21 -0700 Subject: [PATCH 10/34] rework std.base64 api * rename decode to decodeExactUnsafe. * add decodeExact, which checks for invalid chars and padding. * add decodeWithIgnore, which also allows ignoring chars. * alphabets are supplied to the decoders with their char-to-index mapping already built, which enables it to be done at comptime. * all decode/encode apis except decodeWithIgnore require dest to be the exactly correct length. This is calculated by a calc function corresponding to each api. These apis no longer return the dest parameter. * for decodeWithIgnore, an exact size cannot be known a priori. Instead, a calc function gives an upperbound, and a runtime error is returned in case of overflow. decodeWithIgnore returns the number of bytes written to dest. closes #611 --- doc/langref.html.in | 11 +- example/mix_o_files/base64.zig | 4 +- std/base64.zig | 484 ++++++++++++++++++++++++++------- std/os/index.zig | 6 +- 4 files changed, 394 insertions(+), 111 deletions(-) diff --git a/doc/langref.html.in b/doc/langref.html.in index 446011541b..d3ca772aa6 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -5412,10 +5412,13 @@ const c = @cImport({ export fn decode_base_64(dest_ptr: &u8, dest_len: usize, source_ptr: &const u8, source_len: usize) -> usize { - const src = source_ptr[0...source_len]; - const dest = dest_ptr[0...dest_len]; - return base64.decode(dest, src).len; -}
    + const src = source_ptr[0..source_len]; + const dest = dest_ptr[0..dest_len]; + const decoded_size = base64.calcDecodedSizeExactUnsafe(src, base64.standard_pad_char); + base64.decodeExactUnsafe(dest[0..decoded_size], src, base64.standard_alphabet_unsafe); + return decoded_size; +} +

    test.c

    // This header is generated by zig from base64.zig
     #include "base64.h"
    diff --git a/example/mix_o_files/base64.zig b/example/mix_o_files/base64.zig
    index 10438ad077..a7cdc9d439 100644
    --- a/example/mix_o_files/base64.zig
    +++ b/example/mix_o_files/base64.zig
    @@ -3,5 +3,7 @@ const base64 = @import("std").base64;
     export fn decode_base_64(dest_ptr: &u8, dest_len: usize, source_ptr: &const u8, source_len: usize) -> usize {
         const src = source_ptr[0..source_len];
         const dest = dest_ptr[0..dest_len];
    -    return base64.decode(dest, src).len;
    +    const decoded_size = base64.calcDecodedSizeExactUnsafe(src, base64.standard_pad_char);
    +    base64.decodeExactUnsafe(dest[0..decoded_size], src, base64.standard_alphabet_unsafe);
    +    return decoded_size;
     }
    diff --git a/std/base64.zig b/std/base64.zig
    index 5a57e7777f..84b442212a 100644
    --- a/std/base64.zig
    +++ b/std/base64.zig
    @@ -1,100 +1,332 @@
     const assert = @import("debug.zig").assert;
     const mem = @import("mem.zig");
     
    -pub const standard_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
    +pub const standard_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    +pub const standard_pad_char = '=';
     
    -pub fn encode(dest: []u8, source: []const u8) -> []u8 {
    -    return encodeWithAlphabet(dest, source, standard_alphabet);
    +/// ceil(source_len * 4/3)
    +pub fn calcEncodedSize(source_len: usize) -> usize {
    +    return @divTrunc(source_len + 2, 3) * 4;
     }
     
    -/// invalid characters in source are allowed, but they cause the value of dest to be undefined.
    -pub fn decode(dest: []u8, source: []const u8) -> []u8 {
    -    return decodeWithAlphabet(dest, source, standard_alphabet);
    -}
    -
    -pub fn encodeWithAlphabet(dest: []u8, source: []const u8, alphabet: []const u8) -> []u8 {
    -    assert(alphabet.len == 65);
    -    assert(dest.len >= calcEncodedSize(source.len));
    +/// dest.len must be what you get from ::calcEncodedSize.
    +/// It is assumed that alphabet_chars and pad_char are all unique characters.
    +pub fn encode(dest: []u8, source: []const u8, alphabet_chars: []const u8, pad_char: u8) {
    +    assert(alphabet_chars.len == 64);
    +    assert(dest.len == calcEncodedSize(source.len));
     
         var i: usize = 0;
         var out_index: usize = 0;
         while (i + 2 < source.len) : (i += 3) {
    -        dest[out_index] = alphabet[(source[i] >> 2) & 0x3f];
    +        dest[out_index] = alphabet_chars[(source[i] >> 2) & 0x3f];
             out_index += 1;
     
    -        dest[out_index] = alphabet[((source[i] & 0x3) << 4) |
    +        dest[out_index] = alphabet_chars[((source[i] & 0x3) << 4) |
                               ((source[i + 1] & 0xf0) >> 4)];
             out_index += 1;
     
    -        dest[out_index] = alphabet[((source[i + 1] & 0xf) << 2) |
    +        dest[out_index] = alphabet_chars[((source[i + 1] & 0xf) << 2) |
                               ((source[i + 2] & 0xc0) >> 6)];
             out_index += 1;
     
    -        dest[out_index] = alphabet[source[i + 2] & 0x3f];
    +        dest[out_index] = alphabet_chars[source[i + 2] & 0x3f];
             out_index += 1;
         }
     
         if (i < source.len) {
    -        dest[out_index] = alphabet[(source[i] >> 2) & 0x3f];
    +        dest[out_index] = alphabet_chars[(source[i] >> 2) & 0x3f];
             out_index += 1;
     
             if (i + 1 == source.len) {
    -            dest[out_index] = alphabet[(source[i] & 0x3) << 4];
    +            dest[out_index] = alphabet_chars[(source[i] & 0x3) << 4];
                 out_index += 1;
     
    -            dest[out_index] = alphabet[64];
    +            dest[out_index] = pad_char;
                 out_index += 1;
             } else {
    -            dest[out_index] = alphabet[((source[i] & 0x3) << 4) |
    +            dest[out_index] = alphabet_chars[((source[i] & 0x3) << 4) |
                                   ((source[i + 1] & 0xf0) >> 4)];
                 out_index += 1;
     
    -            dest[out_index] = alphabet[(source[i + 1] & 0xf) << 2];
    +            dest[out_index] = alphabet_chars[(source[i + 1] & 0xf) << 2];
                 out_index += 1;
             }
     
    -        dest[out_index] = alphabet[64];
    +        dest[out_index] = pad_char;
             out_index += 1;
         }
    -
    -    return dest[0..out_index];
     }
     
    -/// invalid characters in source are allowed, but they cause the value of dest to be undefined.
    -pub fn decodeWithAlphabet(dest: []u8, source: []const u8, alphabet: []const u8) -> []u8 {
    -    assert(alphabet.len == 65);
    +pub const standard_alphabet = Base64Alphabet.init(standard_alphabet_chars, standard_pad_char);
     
    -    var ascii6 = []u8{64} ** 256;
    -    for (alphabet) |c, i| {
    -        ascii6[c] = u8(i);
    +/// For use with ::decodeExact.
    +pub const Base64Alphabet = struct {
    +    /// e.g. 'A' => 0.
    +    /// undefined for any value not in the 64 alphabet chars.
    +    char_to_index: [256]u8,
    +    /// true only for the 64 chars in the alphabet, not the pad char.
    +    char_in_alphabet: [256]bool,
    +    pad_char: u8,
    +
    +    pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64Alphabet {
    +        assert(alphabet_chars.len == 64);
    +
    +        var result = Base64Alphabet{
    +            .char_to_index = undefined,
    +            .char_in_alphabet = []bool{false} ** 256,
    +            .pad_char = pad_char,
    +        };
    +
    +        for (alphabet_chars) |c, i| {
    +            assert(!result.char_in_alphabet[c]);
    +            assert(c != pad_char);
    +
    +            result.char_to_index[c] = u8(i);
    +            result.char_in_alphabet[c] = true;
    +        }
    +
    +        return result;
    +    }
    +};
    +
    +error InvalidPadding;
    +/// For use with ::decodeExact.
    +/// If the encoded buffer is detected to be invalid, returns error.InvalidPadding.
    +pub fn calcDecodedSizeExact(encoded: []const u8, pad_char: u8) -> %usize {
    +    if (encoded.len % 4 != 0) return error.InvalidPadding;
    +    return calcDecodedSizeExactUnsafe(encoded, pad_char);
    +}
    +
    +error InvalidCharacter;
    +/// dest.len must be what you get from ::calcDecodedSizeExact.
    +/// invalid characters result in error.InvalidCharacter.
    +/// invalid padding results in error.InvalidPadding.
    +pub fn decodeExact(dest: []u8, source: []const u8, alphabet: &const Base64Alphabet) -> %void {
    +    assert(dest.len == %%calcDecodedSizeExact(source, alphabet.pad_char));
    +    assert(source.len % 4 == 0);
    +
    +    var src_cursor: usize = 0;
    +    var dest_cursor: usize = 0;
    +
    +    while (src_cursor < source.len) : (src_cursor += 4) {
    +        if (!alphabet.char_in_alphabet[source[src_cursor + 0]]) return error.InvalidCharacter;
    +        if (!alphabet.char_in_alphabet[source[src_cursor + 1]]) return error.InvalidCharacter;
    +        if (src_cursor < source.len - 4 or source[src_cursor + 3] != alphabet.pad_char) {
    +            // common case
    +            if (!alphabet.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter;
    +            if (!alphabet.char_in_alphabet[source[src_cursor + 3]]) return error.InvalidCharacter;
    +            dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 |
    +                                    alphabet.char_to_index[source[src_cursor + 1]] >> 4;
    +            dest[dest_cursor + 1] = alphabet.char_to_index[source[src_cursor + 1]] << 4 |
    +                                    alphabet.char_to_index[source[src_cursor + 2]] >> 2;
    +            dest[dest_cursor + 2] = alphabet.char_to_index[source[src_cursor + 2]] << 6 |
    +                                    alphabet.char_to_index[source[src_cursor + 3]];
    +            dest_cursor += 3;
    +        } else if (source[src_cursor + 2] != alphabet.pad_char) {
    +            // one pad char
    +            if (!alphabet.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter;
    +            dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 |
    +                                    alphabet.char_to_index[source[src_cursor + 1]] >> 4;
    +            dest[dest_cursor + 1] = alphabet.char_to_index[source[src_cursor + 1]] << 4 |
    +                                    alphabet.char_to_index[source[src_cursor + 2]] >> 2;
    +            if (alphabet.char_to_index[source[src_cursor + 2]] << 6 != 0) return error.InvalidPadding;
    +            dest_cursor += 2;
    +        } else {
    +            // two pad chars
    +            dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 |
    +                                    alphabet.char_to_index[source[src_cursor + 1]] >> 4;
    +            if (alphabet.char_to_index[source[src_cursor + 1]] << 4 != 0) return error.InvalidPadding;
    +            dest_cursor += 1;
    +        }
         }
     
    -    return decodeWithAscii6BitMap(dest, source, ascii6[0..], alphabet[64]);
    +    assert(src_cursor == source.len);
    +    assert(dest_cursor == dest.len);
     }
     
    -pub fn decodeWithAscii6BitMap(dest: []u8, source: []const u8, ascii6: []const u8, pad_char: u8) -> []u8 {
    -    assert(ascii6.len == 256);
    -    assert(dest.len >= calcExactDecodedSizeWithPadChar(source, pad_char));
    +/// For use with ::decodeWithIgnore.
    +pub const Base64AlphabetWithIgnore = struct {
    +    alphabet: Base64Alphabet,
    +    char_is_ignored: [256]bool,
    +    pub fn init(alphabet_chars: []const u8, pad_char: u8, ignore_chars: []const u8) -> Base64AlphabetWithIgnore {
    +        var result = Base64AlphabetWithIgnore {
    +            .alphabet = Base64Alphabet.init(alphabet_chars, pad_char),
    +            .char_is_ignored = []bool{false} ** 256,
    +        };
    +
    +        for (ignore_chars) |c| {
    +            assert(!result.alphabet.char_in_alphabet[c]);
    +            assert(!result.char_is_ignored[c]);
    +            assert(result.alphabet.pad_char != c);
    +            result.char_is_ignored[c] = true;
    +        }
    +
    +        return result;
    +    }
    +};
    +
    +/// For use with ::decodeWithIgnore.
    +/// If no characters end up being ignored, this will be the exact decoded size.
    +pub fn calcDecodedSizeUpperBound(encoded_len: usize) -> %usize {
    +    return @divTrunc(encoded_len, 4) * 3;
    +}
    +
    +error OutputTooSmall;
    +/// Invalid characters that are not ignored results in error.InvalidCharacter.
    +/// Invalid padding results in error.InvalidPadding.
    +/// Decoding more data than can fit in dest results in error.OutputTooSmall. See also ::calcDecodedSizeUpperBound.
    +/// Returns the number of bytes writen to dest.
    +pub fn decodeWithIgnore(dest: []u8, source: []const u8, alphabet_with_ignore: &const Base64AlphabetWithIgnore) -> %usize {
    +    const alphabet = &const alphabet_with_ignore.alphabet;
    +
    +    var src_cursor: usize = 0;
    +    var dest_cursor: usize = 0;
    +
    +    while (true) {
    +        // get the next 4 chars, if available
    +        var next_4_chars: [4]u8 = undefined;
    +        var available_chars: usize = 0;
    +        var pad_char_count: usize = 0;
    +        while (available_chars < 4 and src_cursor < source.len) {
    +            var c = source[src_cursor];
    +            src_cursor += 1;
    +
    +            if (alphabet.char_in_alphabet[c]) {
    +                // normal char
    +                next_4_chars[available_chars] = c;
    +                available_chars += 1;
    +            } else if (alphabet_with_ignore.char_is_ignored[c]) {
    +                // we're told to skip this one
    +                continue;
    +            } else if (c == alphabet.pad_char) {
    +                // the padding has begun. count the pad chars.
    +                pad_char_count += 1;
    +                while (src_cursor < source.len) {
    +                    c = source[src_cursor];
    +                    src_cursor += 1;
    +                    if (c == alphabet.pad_char) {
    +                        pad_char_count += 1;
    +                        if (pad_char_count > 2) return error.InvalidCharacter;
    +                    } else if (alphabet_with_ignore.char_is_ignored[c]) {
    +                        // we can even ignore chars during the padding
    +                        continue;
    +                    } else return error.InvalidCharacter;
    +                }
    +                break;
    +            } else return error.InvalidCharacter;
    +        }
    +
    +        switch (available_chars) {
    +            4 => {
    +                // common case
    +                if (dest_cursor + 3 > dest.len) return error.OutputTooSmall;
    +                assert(pad_char_count == 0);
    +                dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 |
    +                                        alphabet.char_to_index[next_4_chars[1]] >> 4;
    +                dest[dest_cursor + 1] = alphabet.char_to_index[next_4_chars[1]] << 4 |
    +                                        alphabet.char_to_index[next_4_chars[2]] >> 2;
    +                dest[dest_cursor + 2] = alphabet.char_to_index[next_4_chars[2]] << 6 |
    +                                        alphabet.char_to_index[next_4_chars[3]];
    +                dest_cursor += 3;
    +                continue;
    +            },
    +            3 => {
    +                if (dest_cursor + 2 > dest.len) return error.OutputTooSmall;
    +                if (pad_char_count != 1) return error.InvalidPadding;
    +                dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 |
    +                                        alphabet.char_to_index[next_4_chars[1]] >> 4;
    +                dest[dest_cursor + 1] = alphabet.char_to_index[next_4_chars[1]] << 4 |
    +                                        alphabet.char_to_index[next_4_chars[2]] >> 2;
    +                if (alphabet.char_to_index[next_4_chars[2]] << 6 != 0) return error.InvalidPadding;
    +                dest_cursor += 2;
    +                break;
    +            },
    +            2 => {
    +                if (dest_cursor + 1 > dest.len) return error.OutputTooSmall;
    +                if (pad_char_count != 2) return error.InvalidPadding;
    +                dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 |
    +                                        alphabet.char_to_index[next_4_chars[1]] >> 4;
    +                if (alphabet.char_to_index[next_4_chars[1]] << 4 != 0) return error.InvalidPadding;
    +                dest_cursor += 1;
    +                break;
    +            },
    +            1 => {
    +                return error.InvalidPadding;
    +            },
    +            0 => {
    +                if (pad_char_count != 0) return error.InvalidPadding;
    +                break;
    +            },
    +            else => unreachable,
    +        }
    +    }
    +
    +    assert(src_cursor == source.len);
    +
    +    return dest_cursor;
    +}
    +
    +pub const standard_alphabet_unsafe = Base64AlphabetUnsafe.init(standard_alphabet_chars, standard_pad_char);
    +
    +/// For use with ::decodeExactUnsafe.
    +pub const Base64AlphabetUnsafe = struct {
    +    /// e.g. 'A' => 0.
    +    /// undefined for any value not in the 64 alphabet chars.
    +    char_to_index: [256]u8,
    +    pad_char: u8,
    +
    +    pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64AlphabetUnsafe {
    +        assert(alphabet_chars.len == 64);
    +        var result = Base64AlphabetUnsafe {
    +            .char_to_index = undefined,
    +            .pad_char = pad_char,
    +        };
    +        for (alphabet_chars) |c, i| {
    +            assert(c != pad_char);
    +            result.char_to_index[c] = u8(i);
    +        }
    +        return result;
    +    }
    +};
    +
    +/// For use with ::decodeExactUnsafe.
    +/// The encoded buffer must be valid.
    +pub fn calcDecodedSizeExactUnsafe(encoded: []const u8, pad_char: u8) -> usize {
    +    if (encoded.len == 0) return 0;
    +    var result = @divExact(encoded.len, 4) * 3;
    +    if (encoded[encoded.len - 1] == pad_char) {
    +        result -= 1;
    +        if (encoded[encoded.len - 2] == pad_char) {
    +            result -= 1;
    +        }
    +    }
    +    return result;
    +}
    +
    +/// dest.len must be what you get from ::calcDecodedSizeExactUnsafe.
    +/// invalid characters or padding will result in undefined values.
    +pub fn decodeExactUnsafe(dest: []u8, source: []const u8, alphabet: &const Base64AlphabetUnsafe) {
    +    assert(dest.len == calcDecodedSizeExactUnsafe(source, alphabet.pad_char));
     
         var src_index: usize = 0;
         var dest_index: usize = 0;
         var in_buf_len: usize = source.len;
     
    -    while (in_buf_len > 0 and source[in_buf_len - 1] == pad_char) {
    +    while (in_buf_len > 0 and source[in_buf_len - 1] == alphabet.pad_char) {
             in_buf_len -= 1;
         }
     
         while (in_buf_len > 4) {
    -        dest[dest_index] = ascii6[source[src_index + 0]] << 2 |
    -                   ascii6[source[src_index + 1]] >> 4;
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 0]] << 2 |
    +                           alphabet.char_to_index[source[src_index + 1]] >> 4;
             dest_index += 1;
     
    -        dest[dest_index] = ascii6[source[src_index + 1]] << 4 |
    -                   ascii6[source[src_index + 2]] >> 2;
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 1]] << 4 |
    +                           alphabet.char_to_index[source[src_index + 2]] >> 2;
             dest_index += 1;
     
    -        dest[dest_index] = ascii6[source[src_index + 2]] << 6 |
    -                   ascii6[source[src_index + 3]];
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 2]] << 6 |
    +                           alphabet.char_to_index[source[src_index + 3]];
             dest_index += 1;
     
             src_index += 4;
    @@ -102,85 +334,131 @@ pub fn decodeWithAscii6BitMap(dest: []u8, source: []const u8, ascii6: []const u8
         }
     
         if (in_buf_len > 1) {
    -        dest[dest_index] = ascii6[source[src_index + 0]] << 2 |
    -                   ascii6[source[src_index + 1]] >> 4;
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 0]] << 2 |
    +                           alphabet.char_to_index[source[src_index + 1]] >> 4;
             dest_index += 1;
         }
         if (in_buf_len > 2) {
    -        dest[dest_index] = ascii6[source[src_index + 1]] << 4 |
    -                   ascii6[source[src_index + 2]] >> 2;
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 1]] << 4 |
    +                           alphabet.char_to_index[source[src_index + 2]] >> 2;
             dest_index += 1;
         }
         if (in_buf_len > 3) {
    -        dest[dest_index] = ascii6[source[src_index + 2]] << 6 |
    -                   ascii6[source[src_index + 3]];
    +        dest[dest_index] = alphabet.char_to_index[source[src_index + 2]] << 6 |
    +                           alphabet.char_to_index[source[src_index + 3]];
             dest_index += 1;
         }
    -
    -    return dest[0..dest_index];
    -}
    -
    -pub fn calcEncodedSize(source_len: usize) -> usize {
    -    return (((source_len * 4) / 3 + 3) / 4) * 4;
    -}
    -
    -/// Computes the upper bound of the decoded size based only on the encoded length.
    -/// To compute the exact decoded size, see ::calcExactDecodedSize
    -pub fn calcMaxDecodedSize(encoded_len: usize) -> usize {
    -    return @divExact(encoded_len * 3,  4);
    -}
    -
    -/// Computes the number of decoded bytes there will be. This function must
    -/// be given the encoded buffer because there might be padding
    -/// bytes at the end ('=' in the standard alphabet)
    -pub fn calcExactDecodedSize(encoded: []const u8) -> usize {
    -    return calcExactDecodedSizeWithAlphabet(encoded, standard_alphabet);
    -}
    -
    -pub fn calcExactDecodedSizeWithAlphabet(encoded: []const u8, alphabet: []const u8) -> usize {
    -    assert(alphabet.len == 65);
    -    return calcExactDecodedSizeWithPadChar(encoded, alphabet[64]);
    -}
    -
    -pub fn calcExactDecodedSizeWithPadChar(encoded: []const u8, pad_char: u8) -> usize {
    -    var buf_len = encoded.len;
    -
    -    while (buf_len > 0 and encoded[buf_len - 1] == pad_char) {
    -        buf_len -= 1;
    -    }
    -
    -    return (buf_len * 3) / 4;
     }
     
     test "base64" {
    -    testBase64();
    -    comptime testBase64();
    +    @setEvalBranchQuota(5000);
    +    %%testBase64();
    +    comptime %%testBase64();
     }
     
    -fn testBase64() {
    -    testBase64Case("", "");
    -    testBase64Case("f", "Zg==");
    -    testBase64Case("fo", "Zm8=");
    -    testBase64Case("foo", "Zm9v");
    -    testBase64Case("foob", "Zm9vYg==");
    -    testBase64Case("fooba", "Zm9vYmE=");
    -    testBase64Case("foobar", "Zm9vYmFy");
    +fn testBase64() -> %void {
    +    %return testAllApis("",       "");
    +    %return testAllApis("f",      "Zg==");
    +    %return testAllApis("fo",     "Zm8=");
    +    %return testAllApis("foo",    "Zm9v");
    +    %return testAllApis("foob",   "Zm9vYg==");
    +    %return testAllApis("fooba",  "Zm9vYmE=");
    +    %return testAllApis("foobar", "Zm9vYmFy");
    +
    +    %return testDecodeIgnoreSpace("",       " ");
    +    %return testDecodeIgnoreSpace("f",      "Z g= =");
    +    %return testDecodeIgnoreSpace("fo",     "    Zm8=");
    +    %return testDecodeIgnoreSpace("foo",    "Zm9v    ");
    +    %return testDecodeIgnoreSpace("foob",   "Zm9vYg = = ");
    +    %return testDecodeIgnoreSpace("fooba",  "Zm9v YmE=");
    +    %return testDecodeIgnoreSpace("foobar", " Z m 9 v Y m F y ");
    +
    +    // test getting some api errors
    +    %return testError("A",    error.InvalidPadding);
    +    %return testError("AA",   error.InvalidPadding);
    +    %return testError("AAA",  error.InvalidPadding);
    +    %return testError("A..A", error.InvalidCharacter);
    +    %return testError("AA=A", error.InvalidCharacter);
    +    %return testError("AA/=", error.InvalidPadding);
    +    %return testError("A/==", error.InvalidPadding);
    +    %return testError("A===", error.InvalidCharacter);
    +    %return testError("====", error.InvalidCharacter);
    +
    +    %return testOutputTooSmallError("AA==");
    +    %return testOutputTooSmallError("AAA=");
    +    %return testOutputTooSmallError("AAAA");
    +    %return testOutputTooSmallError("AAAAAA==");
     }
     
    -fn testBase64Case(expected_decoded: []const u8, expected_encoded: []const u8) {
    -    const calculated_decoded_len = calcExactDecodedSize(expected_encoded);
    -    assert(calculated_decoded_len == expected_decoded.len);
    +fn testAllApis(expected_decoded: []const u8, expected_encoded: []const u8) -> %void {
    +    // encode
    +    {
    +        var buffer: [0x100]u8 = undefined;
    +        var encoded = buffer[0..calcEncodedSize(expected_decoded.len)];
    +        encode(encoded, expected_decoded, standard_alphabet_chars, standard_pad_char);
    +        assert(mem.eql(u8, encoded, expected_encoded));
    +    }
     
    -    const calculated_encoded_len = calcEncodedSize(expected_decoded.len);
    -    assert(calculated_encoded_len == expected_encoded.len);
    +    // decodeExact
    +    {
    +        var buffer: [0x100]u8 = undefined;
    +        var decoded = buffer[0..%return calcDecodedSizeExact(expected_encoded, standard_pad_char)];
    +        %return decodeExact(decoded, expected_encoded, standard_alphabet);
    +        assert(mem.eql(u8, decoded, expected_decoded));
    +    }
     
    -    var buf: [100]u8 = undefined;
    +    // decodeWithIgnore
    +    {
    +        const standard_alphabet_ignore_nothing = Base64AlphabetWithIgnore.init(
    +            standard_alphabet_chars, standard_pad_char, "");
    +        var buffer: [0x100]u8 = undefined;
    +        var decoded = buffer[0..%return calcDecodedSizeUpperBound(expected_encoded.len)];
    +        var written = %return decodeWithIgnore(decoded, expected_encoded, standard_alphabet_ignore_nothing);
    +        assert(written <= decoded.len);
    +        assert(mem.eql(u8, decoded[0..written], expected_decoded));
    +    }
     
    -    const actual_decoded = decode(buf[0..], expected_encoded);
    -    assert(actual_decoded.len == expected_decoded.len);
    -    assert(mem.eql(u8, expected_decoded, actual_decoded));
    -
    -    const actual_encoded = encode(buf[0..], expected_decoded);
    -    assert(actual_encoded.len == expected_encoded.len);
    -    assert(mem.eql(u8, expected_encoded, actual_encoded));
    +    // decodeExactUnsafe
    +    {
    +        var buffer: [0x100]u8 = undefined;
    +        var decoded = buffer[0..calcDecodedSizeExactUnsafe(expected_encoded, standard_pad_char)];
    +        decodeExactUnsafe(decoded, expected_encoded, standard_alphabet_unsafe);
    +        assert(mem.eql(u8, decoded, expected_decoded));
    +    }
    +}
    +
    +fn testDecodeIgnoreSpace(expected_decoded: []const u8, encoded: []const u8) -> %void {
    +    const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init(
    +        standard_alphabet_chars, standard_pad_char, " ");
    +    var buffer: [0x100]u8 = undefined;
    +    var decoded = buffer[0..%return calcDecodedSizeUpperBound(encoded.len)];
    +    var written = %return decodeWithIgnore(decoded, encoded, standard_alphabet_ignore_space);
    +    assert(mem.eql(u8, decoded[0..written], expected_decoded));
    +}
    +
    +error ExpectedError;
    +fn testError(encoded: []const u8, expected_err: error) -> %void {
    +    const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init(
    +        standard_alphabet_chars, standard_pad_char, " ");
    +    var buffer: [0x100]u8 = undefined;
    +    if (calcDecodedSizeExact(encoded, standard_pad_char)) |decoded_size| {
    +        var decoded = buffer[0..decoded_size];
    +        if (decodeExact(decoded, encoded, standard_alphabet)) |_| {
    +            return error.ExpectedError;
    +        } else |err| if (err != expected_err) return err;
    +    } else |err| if (err != expected_err) return err;
    +
    +    if (decodeWithIgnore(buffer[0..], encoded, standard_alphabet_ignore_space)) |_| {
    +        return error.ExpectedError;
    +    } else |err| if (err != expected_err) return err;
    +}
    +
    +fn testOutputTooSmallError(encoded: []const u8) -> %void {
    +    const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init(
    +        standard_alphabet_chars, standard_pad_char, " ");
    +    var buffer: [0x100]u8 = undefined;
    +    var decoded = buffer[0..calcDecodedSizeExactUnsafe(encoded, standard_pad_char) - 1];
    +    if (decodeWithIgnore(decoded, encoded, standard_alphabet_ignore_space)) |_| {
    +        return error.ExpectedError;
    +    } else |err| if (err != error.OutputTooSmall) return err;
     }
    diff --git a/std/os/index.zig b/std/os/index.zig
    index 34de6f3bbf..872564224c 100644
    --- a/std/os/index.zig
    +++ b/std/os/index.zig
    @@ -622,7 +622,7 @@ pub fn symLinkPosix(allocator: &Allocator, existing_path: []const u8, new_path:
     }
     
     // here we replace the standard +/ with -_ so that it can be used in a file name
    -const b64_fs_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=";
    +const b64_fs_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
     
     pub fn atomicSymLink(allocator: &Allocator, existing_path: []const u8, new_path: []const u8) -> %void {
         if (symLink(allocator, existing_path, new_path)) {
    @@ -639,7 +639,7 @@ pub fn atomicSymLink(allocator: &Allocator, existing_path: []const u8, new_path:
         mem.copy(u8, tmp_path[0..], new_path);
         while (true) {
             %return getRandomBytes(rand_buf[0..]);
    -        _ = base64.encodeWithAlphabet(tmp_path[new_path.len..], rand_buf, b64_fs_alphabet);
    +        base64.encode(tmp_path[new_path.len..], rand_buf, b64_fs_alphabet_chars, base64.standard_pad_char);
             if (symLink(allocator, existing_path, tmp_path)) {
                 return rename(allocator, tmp_path, new_path);
             } else |err| {
    @@ -721,7 +721,7 @@ pub fn copyFileMode(allocator: &Allocator, source_path: []const u8, dest_path: [
         defer allocator.free(tmp_path);
         mem.copy(u8, tmp_path[0..], dest_path);
         %return getRandomBytes(rand_buf[0..]);
    -    _ = base64.encodeWithAlphabet(tmp_path[dest_path.len..], rand_buf, b64_fs_alphabet);
    +    base64.encode(tmp_path[dest_path.len..], rand_buf, b64_fs_alphabet_chars, base64.standard_pad_char);
     
         var out_file = %return io.File.openWriteMode(tmp_path, mode, allocator);
         defer out_file.close();
    
    From afbbdb2c67127985cadae7244348665ece8b2f25 Mon Sep 17 00:00:00 2001
    From: Josh Wolfe 
    Date: Mon, 20 Nov 2017 21:36:18 -0700
    Subject: [PATCH 11/34] move base64 functions into structs
    
    ---
     doc/langref.html.in            |   5 +-
     example/mix_o_files/base64.zig |   5 +-
     std/base64.zig                 | 585 +++++++++++++++++----------------
     std/os/index.zig               |  12 +-
     4 files changed, 316 insertions(+), 291 deletions(-)
    
    diff --git a/doc/langref.html.in b/doc/langref.html.in
    index d3ca772aa6..e8f76e230b 100644
    --- a/doc/langref.html.in
    +++ b/doc/langref.html.in
    @@ -5414,8 +5414,9 @@ export fn decode_base_64(dest_ptr: &u8, dest_len: usize,
     {
         const src = source_ptr[0..source_len];
         const dest = dest_ptr[0..dest_len];
    -    const decoded_size = base64.calcDecodedSizeExactUnsafe(src, base64.standard_pad_char);
    -    base64.decodeExactUnsafe(dest[0..decoded_size], src, base64.standard_alphabet_unsafe);
    +    const base64_decoder = base64.standard_decoder_unsafe;
    +    const decoded_size = base64_decoder.calcSize(src);
    +    base64_decoder.decode(dest[0..decoded_size], src);
         return decoded_size;
     }
     
    diff --git a/example/mix_o_files/base64.zig b/example/mix_o_files/base64.zig index a7cdc9d439..49c9bc6012 100644 --- a/example/mix_o_files/base64.zig +++ b/example/mix_o_files/base64.zig @@ -3,7 +3,8 @@ const base64 = @import("std").base64; export fn decode_base_64(dest_ptr: &u8, dest_len: usize, source_ptr: &const u8, source_len: usize) -> usize { const src = source_ptr[0..source_len]; const dest = dest_ptr[0..dest_len]; - const decoded_size = base64.calcDecodedSizeExactUnsafe(src, base64.standard_pad_char); - base64.decodeExactUnsafe(dest[0..decoded_size], src, base64.standard_alphabet_unsafe); + const base64_decoder = base64.standard_decoder_unsafe; + const decoded_size = base64_decoder.calcSize(src); + base64_decoder.decode(dest[0..decoded_size], src); return decoded_size; } diff --git a/std/base64.zig b/std/base64.zig index 84b442212a..25e438c4fb 100644 --- a/std/base64.zig +++ b/std/base64.zig @@ -3,64 +3,85 @@ const mem = @import("mem.zig"); pub const standard_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; pub const standard_pad_char = '='; +pub const standard_encoder = Base64Encoder.init(standard_alphabet_chars, standard_pad_char); -/// ceil(source_len * 4/3) -pub fn calcEncodedSize(source_len: usize) -> usize { - return @divTrunc(source_len + 2, 3) * 4; -} +pub const Base64Encoder = struct { + alphabet_chars: []const u8, + pad_char: u8, -/// dest.len must be what you get from ::calcEncodedSize. -/// It is assumed that alphabet_chars and pad_char are all unique characters. -pub fn encode(dest: []u8, source: []const u8, alphabet_chars: []const u8, pad_char: u8) { - assert(alphabet_chars.len == 64); - assert(dest.len == calcEncodedSize(source.len)); + /// a bunch of assertions, then simply pass the data right through. + pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64Encoder { + assert(alphabet_chars.len == 64); + var char_in_alphabet = []bool{false} ** 256; + for (alphabet_chars) |c| { + assert(!char_in_alphabet[c]); + assert(c != pad_char); + char_in_alphabet[c] = true; + } - var i: usize = 0; - var out_index: usize = 0; - while (i + 2 < source.len) : (i += 3) { - dest[out_index] = alphabet_chars[(source[i] >> 2) & 0x3f]; - out_index += 1; - - dest[out_index] = alphabet_chars[((source[i] & 0x3) << 4) | - ((source[i + 1] & 0xf0) >> 4)]; - out_index += 1; - - dest[out_index] = alphabet_chars[((source[i + 1] & 0xf) << 2) | - ((source[i + 2] & 0xc0) >> 6)]; - out_index += 1; - - dest[out_index] = alphabet_chars[source[i + 2] & 0x3f]; - out_index += 1; + return Base64Encoder{ + .alphabet_chars = alphabet_chars, + .pad_char = pad_char, + }; } - if (i < source.len) { - dest[out_index] = alphabet_chars[(source[i] >> 2) & 0x3f]; - out_index += 1; + /// ceil(source_len * 4/3) + pub fn calcSize(source_len: usize) -> usize { + return @divTrunc(source_len + 2, 3) * 4; + } - if (i + 1 == source.len) { - dest[out_index] = alphabet_chars[(source[i] & 0x3) << 4]; + /// dest.len must be what you get from ::calcSize. + pub fn encode(encoder: &const Base64Encoder, dest: []u8, source: []const u8) { + assert(dest.len == Base64Encoder.calcSize(source.len)); + + var i: usize = 0; + var out_index: usize = 0; + while (i + 2 < source.len) : (i += 3) { + dest[out_index] = encoder.alphabet_chars[(source[i] >> 2) & 0x3f]; out_index += 1; - dest[out_index] = pad_char; - out_index += 1; - } else { - dest[out_index] = alphabet_chars[((source[i] & 0x3) << 4) | + dest[out_index] = encoder.alphabet_chars[((source[i] & 0x3) << 4) | ((source[i + 1] & 0xf0) >> 4)]; out_index += 1; - dest[out_index] = alphabet_chars[(source[i + 1] & 0xf) << 2]; + dest[out_index] = encoder.alphabet_chars[((source[i + 1] & 0xf) << 2) | + ((source[i + 2] & 0xc0) >> 6)]; + out_index += 1; + + dest[out_index] = encoder.alphabet_chars[source[i + 2] & 0x3f]; out_index += 1; } - dest[out_index] = pad_char; - out_index += 1; + if (i < source.len) { + dest[out_index] = encoder.alphabet_chars[(source[i] >> 2) & 0x3f]; + out_index += 1; + + if (i + 1 == source.len) { + dest[out_index] = encoder.alphabet_chars[(source[i] & 0x3) << 4]; + out_index += 1; + + dest[out_index] = encoder.pad_char; + out_index += 1; + } else { + dest[out_index] = encoder.alphabet_chars[((source[i] & 0x3) << 4) | + ((source[i + 1] & 0xf0) >> 4)]; + out_index += 1; + + dest[out_index] = encoder.alphabet_chars[(source[i + 1] & 0xf) << 2]; + out_index += 1; + } + + dest[out_index] = encoder.pad_char; + out_index += 1; + } } -} +}; -pub const standard_alphabet = Base64Alphabet.init(standard_alphabet_chars, standard_pad_char); +pub const standard_decoder = Base64Decoder.init(standard_alphabet_chars, standard_pad_char); +error InvalidPadding; +error InvalidCharacter; -/// For use with ::decodeExact. -pub const Base64Alphabet = struct { +pub const Base64Decoder = struct { /// e.g. 'A' => 0. /// undefined for any value not in the 64 alphabet chars. char_to_index: [256]u8, @@ -68,10 +89,10 @@ pub const Base64Alphabet = struct { char_in_alphabet: [256]bool, pad_char: u8, - pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64Alphabet { + pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64Decoder { assert(alphabet_chars.len == 64); - var result = Base64Alphabet{ + var result = Base64Decoder{ .char_to_index = undefined, .char_in_alphabet = []bool{false} ** 256, .pad_char = pad_char, @@ -87,197 +108,193 @@ pub const Base64Alphabet = struct { return result; } -}; -error InvalidPadding; -/// For use with ::decodeExact. -/// If the encoded buffer is detected to be invalid, returns error.InvalidPadding. -pub fn calcDecodedSizeExact(encoded: []const u8, pad_char: u8) -> %usize { - if (encoded.len % 4 != 0) return error.InvalidPadding; - return calcDecodedSizeExactUnsafe(encoded, pad_char); -} - -error InvalidCharacter; -/// dest.len must be what you get from ::calcDecodedSizeExact. -/// invalid characters result in error.InvalidCharacter. -/// invalid padding results in error.InvalidPadding. -pub fn decodeExact(dest: []u8, source: []const u8, alphabet: &const Base64Alphabet) -> %void { - assert(dest.len == %%calcDecodedSizeExact(source, alphabet.pad_char)); - assert(source.len % 4 == 0); - - var src_cursor: usize = 0; - var dest_cursor: usize = 0; - - while (src_cursor < source.len) : (src_cursor += 4) { - if (!alphabet.char_in_alphabet[source[src_cursor + 0]]) return error.InvalidCharacter; - if (!alphabet.char_in_alphabet[source[src_cursor + 1]]) return error.InvalidCharacter; - if (src_cursor < source.len - 4 or source[src_cursor + 3] != alphabet.pad_char) { - // common case - if (!alphabet.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; - if (!alphabet.char_in_alphabet[source[src_cursor + 3]]) return error.InvalidCharacter; - dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 | - alphabet.char_to_index[source[src_cursor + 1]] >> 4; - dest[dest_cursor + 1] = alphabet.char_to_index[source[src_cursor + 1]] << 4 | - alphabet.char_to_index[source[src_cursor + 2]] >> 2; - dest[dest_cursor + 2] = alphabet.char_to_index[source[src_cursor + 2]] << 6 | - alphabet.char_to_index[source[src_cursor + 3]]; - dest_cursor += 3; - } else if (source[src_cursor + 2] != alphabet.pad_char) { - // one pad char - if (!alphabet.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; - dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 | - alphabet.char_to_index[source[src_cursor + 1]] >> 4; - dest[dest_cursor + 1] = alphabet.char_to_index[source[src_cursor + 1]] << 4 | - alphabet.char_to_index[source[src_cursor + 2]] >> 2; - if (alphabet.char_to_index[source[src_cursor + 2]] << 6 != 0) return error.InvalidPadding; - dest_cursor += 2; - } else { - // two pad chars - dest[dest_cursor + 0] = alphabet.char_to_index[source[src_cursor + 0]] << 2 | - alphabet.char_to_index[source[src_cursor + 1]] >> 4; - if (alphabet.char_to_index[source[src_cursor + 1]] << 4 != 0) return error.InvalidPadding; - dest_cursor += 1; - } + /// If the encoded buffer is detected to be invalid, returns error.InvalidPadding. + pub fn calcSize(decoder: &const Base64Decoder, source: []const u8) -> %usize { + if (source.len % 4 != 0) return error.InvalidPadding; + return calcDecodedSizeExactUnsafe(source, decoder.pad_char); } - assert(src_cursor == source.len); - assert(dest_cursor == dest.len); -} + /// dest.len must be what you get from ::calcSize. + /// invalid characters result in error.InvalidCharacter. + /// invalid padding results in error.InvalidPadding. + pub fn decode(decoder: &const Base64Decoder, dest: []u8, source: []const u8) -> %void { + assert(dest.len == %%decoder.calcSize(source)); + assert(source.len % 4 == 0); -/// For use with ::decodeWithIgnore. -pub const Base64AlphabetWithIgnore = struct { - alphabet: Base64Alphabet, + var src_cursor: usize = 0; + var dest_cursor: usize = 0; + + while (src_cursor < source.len) : (src_cursor += 4) { + if (!decoder.char_in_alphabet[source[src_cursor + 0]]) return error.InvalidCharacter; + if (!decoder.char_in_alphabet[source[src_cursor + 1]]) return error.InvalidCharacter; + if (src_cursor < source.len - 4 or source[src_cursor + 3] != decoder.pad_char) { + // common case + if (!decoder.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; + if (!decoder.char_in_alphabet[source[src_cursor + 3]]) return error.InvalidCharacter; + dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | + decoder.char_to_index[source[src_cursor + 1]] >> 4; + dest[dest_cursor + 1] = decoder.char_to_index[source[src_cursor + 1]] << 4 | + decoder.char_to_index[source[src_cursor + 2]] >> 2; + dest[dest_cursor + 2] = decoder.char_to_index[source[src_cursor + 2]] << 6 | + decoder.char_to_index[source[src_cursor + 3]]; + dest_cursor += 3; + } else if (source[src_cursor + 2] != decoder.pad_char) { + // one pad char + if (!decoder.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; + dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | + decoder.char_to_index[source[src_cursor + 1]] >> 4; + dest[dest_cursor + 1] = decoder.char_to_index[source[src_cursor + 1]] << 4 | + decoder.char_to_index[source[src_cursor + 2]] >> 2; + if (decoder.char_to_index[source[src_cursor + 2]] << 6 != 0) return error.InvalidPadding; + dest_cursor += 2; + } else { + // two pad chars + dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | + decoder.char_to_index[source[src_cursor + 1]] >> 4; + if (decoder.char_to_index[source[src_cursor + 1]] << 4 != 0) return error.InvalidPadding; + dest_cursor += 1; + } + } + + assert(src_cursor == source.len); + assert(dest_cursor == dest.len); + } +}; + +error OutputTooSmall; + +pub const Base64DecoderWithIgnore = struct { + decoder: Base64Decoder, char_is_ignored: [256]bool, - pub fn init(alphabet_chars: []const u8, pad_char: u8, ignore_chars: []const u8) -> Base64AlphabetWithIgnore { - var result = Base64AlphabetWithIgnore { - .alphabet = Base64Alphabet.init(alphabet_chars, pad_char), + pub fn init(alphabet_chars: []const u8, pad_char: u8, ignore_chars: []const u8) -> Base64DecoderWithIgnore { + var result = Base64DecoderWithIgnore { + .decoder = Base64Decoder.init(alphabet_chars, pad_char), .char_is_ignored = []bool{false} ** 256, }; for (ignore_chars) |c| { - assert(!result.alphabet.char_in_alphabet[c]); + assert(!result.decoder.char_in_alphabet[c]); assert(!result.char_is_ignored[c]); - assert(result.alphabet.pad_char != c); + assert(result.decoder.pad_char != c); result.char_is_ignored[c] = true; } return result; } -}; -/// For use with ::decodeWithIgnore. -/// If no characters end up being ignored, this will be the exact decoded size. -pub fn calcDecodedSizeUpperBound(encoded_len: usize) -> %usize { - return @divTrunc(encoded_len, 4) * 3; -} - -error OutputTooSmall; -/// Invalid characters that are not ignored results in error.InvalidCharacter. -/// Invalid padding results in error.InvalidPadding. -/// Decoding more data than can fit in dest results in error.OutputTooSmall. See also ::calcDecodedSizeUpperBound. -/// Returns the number of bytes writen to dest. -pub fn decodeWithIgnore(dest: []u8, source: []const u8, alphabet_with_ignore: &const Base64AlphabetWithIgnore) -> %usize { - const alphabet = &const alphabet_with_ignore.alphabet; - - var src_cursor: usize = 0; - var dest_cursor: usize = 0; - - while (true) { - // get the next 4 chars, if available - var next_4_chars: [4]u8 = undefined; - var available_chars: usize = 0; - var pad_char_count: usize = 0; - while (available_chars < 4 and src_cursor < source.len) { - var c = source[src_cursor]; - src_cursor += 1; - - if (alphabet.char_in_alphabet[c]) { - // normal char - next_4_chars[available_chars] = c; - available_chars += 1; - } else if (alphabet_with_ignore.char_is_ignored[c]) { - // we're told to skip this one - continue; - } else if (c == alphabet.pad_char) { - // the padding has begun. count the pad chars. - pad_char_count += 1; - while (src_cursor < source.len) { - c = source[src_cursor]; - src_cursor += 1; - if (c == alphabet.pad_char) { - pad_char_count += 1; - if (pad_char_count > 2) return error.InvalidCharacter; - } else if (alphabet_with_ignore.char_is_ignored[c]) { - // we can even ignore chars during the padding - continue; - } else return error.InvalidCharacter; - } - break; - } else return error.InvalidCharacter; - } - - switch (available_chars) { - 4 => { - // common case - if (dest_cursor + 3 > dest.len) return error.OutputTooSmall; - assert(pad_char_count == 0); - dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 | - alphabet.char_to_index[next_4_chars[1]] >> 4; - dest[dest_cursor + 1] = alphabet.char_to_index[next_4_chars[1]] << 4 | - alphabet.char_to_index[next_4_chars[2]] >> 2; - dest[dest_cursor + 2] = alphabet.char_to_index[next_4_chars[2]] << 6 | - alphabet.char_to_index[next_4_chars[3]]; - dest_cursor += 3; - continue; - }, - 3 => { - if (dest_cursor + 2 > dest.len) return error.OutputTooSmall; - if (pad_char_count != 1) return error.InvalidPadding; - dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 | - alphabet.char_to_index[next_4_chars[1]] >> 4; - dest[dest_cursor + 1] = alphabet.char_to_index[next_4_chars[1]] << 4 | - alphabet.char_to_index[next_4_chars[2]] >> 2; - if (alphabet.char_to_index[next_4_chars[2]] << 6 != 0) return error.InvalidPadding; - dest_cursor += 2; - break; - }, - 2 => { - if (dest_cursor + 1 > dest.len) return error.OutputTooSmall; - if (pad_char_count != 2) return error.InvalidPadding; - dest[dest_cursor + 0] = alphabet.char_to_index[next_4_chars[0]] << 2 | - alphabet.char_to_index[next_4_chars[1]] >> 4; - if (alphabet.char_to_index[next_4_chars[1]] << 4 != 0) return error.InvalidPadding; - dest_cursor += 1; - break; - }, - 1 => { - return error.InvalidPadding; - }, - 0 => { - if (pad_char_count != 0) return error.InvalidPadding; - break; - }, - else => unreachable, - } + /// If no characters end up being ignored or padding, this will be the exact decoded size. + pub fn calcSizeUpperBound(encoded_len: usize) -> %usize { + return @divTrunc(encoded_len, 4) * 3; } - assert(src_cursor == source.len); + /// Invalid characters that are not ignored result in error.InvalidCharacter. + /// Invalid padding results in error.InvalidPadding. + /// Decoding more data than can fit in dest results in error.OutputTooSmall. See also ::calcSizeUpperBound. + /// Returns the number of bytes writen to dest. + pub fn decode(decoder_with_ignore: &const Base64DecoderWithIgnore, dest: []u8, source: []const u8) -> %usize { + const decoder = &const decoder_with_ignore.decoder; - return dest_cursor; -} + var src_cursor: usize = 0; + var dest_cursor: usize = 0; -pub const standard_alphabet_unsafe = Base64AlphabetUnsafe.init(standard_alphabet_chars, standard_pad_char); + while (true) { + // get the next 4 chars, if available + var next_4_chars: [4]u8 = undefined; + var available_chars: usize = 0; + var pad_char_count: usize = 0; + while (available_chars < 4 and src_cursor < source.len) { + var c = source[src_cursor]; + src_cursor += 1; -/// For use with ::decodeExactUnsafe. -pub const Base64AlphabetUnsafe = struct { + if (decoder.char_in_alphabet[c]) { + // normal char + next_4_chars[available_chars] = c; + available_chars += 1; + } else if (decoder_with_ignore.char_is_ignored[c]) { + // we're told to skip this one + continue; + } else if (c == decoder.pad_char) { + // the padding has begun. count the pad chars. + pad_char_count += 1; + while (src_cursor < source.len) { + c = source[src_cursor]; + src_cursor += 1; + if (c == decoder.pad_char) { + pad_char_count += 1; + if (pad_char_count > 2) return error.InvalidCharacter; + } else if (decoder_with_ignore.char_is_ignored[c]) { + // we can even ignore chars during the padding + continue; + } else return error.InvalidCharacter; + } + break; + } else return error.InvalidCharacter; + } + + switch (available_chars) { + 4 => { + // common case + if (dest_cursor + 3 > dest.len) return error.OutputTooSmall; + assert(pad_char_count == 0); + dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | + decoder.char_to_index[next_4_chars[1]] >> 4; + dest[dest_cursor + 1] = decoder.char_to_index[next_4_chars[1]] << 4 | + decoder.char_to_index[next_4_chars[2]] >> 2; + dest[dest_cursor + 2] = decoder.char_to_index[next_4_chars[2]] << 6 | + decoder.char_to_index[next_4_chars[3]]; + dest_cursor += 3; + continue; + }, + 3 => { + if (dest_cursor + 2 > dest.len) return error.OutputTooSmall; + if (pad_char_count != 1) return error.InvalidPadding; + dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | + decoder.char_to_index[next_4_chars[1]] >> 4; + dest[dest_cursor + 1] = decoder.char_to_index[next_4_chars[1]] << 4 | + decoder.char_to_index[next_4_chars[2]] >> 2; + if (decoder.char_to_index[next_4_chars[2]] << 6 != 0) return error.InvalidPadding; + dest_cursor += 2; + break; + }, + 2 => { + if (dest_cursor + 1 > dest.len) return error.OutputTooSmall; + if (pad_char_count != 2) return error.InvalidPadding; + dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | + decoder.char_to_index[next_4_chars[1]] >> 4; + if (decoder.char_to_index[next_4_chars[1]] << 4 != 0) return error.InvalidPadding; + dest_cursor += 1; + break; + }, + 1 => { + return error.InvalidPadding; + }, + 0 => { + if (pad_char_count != 0) return error.InvalidPadding; + break; + }, + else => unreachable, + } + } + + assert(src_cursor == source.len); + + return dest_cursor; + } +}; + + +pub const standard_decoder_unsafe = Base64DecoderUnsafe.init(standard_alphabet_chars, standard_pad_char); + +pub const Base64DecoderUnsafe = struct { /// e.g. 'A' => 0. /// undefined for any value not in the 64 alphabet chars. char_to_index: [256]u8, pad_char: u8, - pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64AlphabetUnsafe { + pub fn init(alphabet_chars: []const u8, pad_char: u8) -> Base64DecoderUnsafe { assert(alphabet_chars.len == 64); - var result = Base64AlphabetUnsafe { + var result = Base64DecoderUnsafe { .char_to_index = undefined, .pad_char = pad_char, }; @@ -287,68 +304,72 @@ pub const Base64AlphabetUnsafe = struct { } return result; } + + /// The source buffer must be valid. + pub fn calcSize(decoder: &const Base64DecoderUnsafe, source: []const u8) -> usize { + return calcDecodedSizeExactUnsafe(source, decoder.pad_char); + } + + /// dest.len must be what you get from ::calcDecodedSizeExactUnsafe. + /// invalid characters or padding will result in undefined values. + pub fn decode(decoder: &const Base64DecoderUnsafe, dest: []u8, source: []const u8) { + assert(dest.len == decoder.calcSize(source)); + + var src_index: usize = 0; + var dest_index: usize = 0; + var in_buf_len: usize = source.len; + + while (in_buf_len > 0 and source[in_buf_len - 1] == decoder.pad_char) { + in_buf_len -= 1; + } + + while (in_buf_len > 4) { + dest[dest_index] = decoder.char_to_index[source[src_index + 0]] << 2 | + decoder.char_to_index[source[src_index + 1]] >> 4; + dest_index += 1; + + dest[dest_index] = decoder.char_to_index[source[src_index + 1]] << 4 | + decoder.char_to_index[source[src_index + 2]] >> 2; + dest_index += 1; + + dest[dest_index] = decoder.char_to_index[source[src_index + 2]] << 6 | + decoder.char_to_index[source[src_index + 3]]; + dest_index += 1; + + src_index += 4; + in_buf_len -= 4; + } + + if (in_buf_len > 1) { + dest[dest_index] = decoder.char_to_index[source[src_index + 0]] << 2 | + decoder.char_to_index[source[src_index + 1]] >> 4; + dest_index += 1; + } + if (in_buf_len > 2) { + dest[dest_index] = decoder.char_to_index[source[src_index + 1]] << 4 | + decoder.char_to_index[source[src_index + 2]] >> 2; + dest_index += 1; + } + if (in_buf_len > 3) { + dest[dest_index] = decoder.char_to_index[source[src_index + 2]] << 6 | + decoder.char_to_index[source[src_index + 3]]; + dest_index += 1; + } + } }; -/// For use with ::decodeExactUnsafe. -/// The encoded buffer must be valid. -pub fn calcDecodedSizeExactUnsafe(encoded: []const u8, pad_char: u8) -> usize { - if (encoded.len == 0) return 0; - var result = @divExact(encoded.len, 4) * 3; - if (encoded[encoded.len - 1] == pad_char) { +fn calcDecodedSizeExactUnsafe(source: []const u8, pad_char: u8) -> usize { + if (source.len == 0) return 0; + var result = @divExact(source.len, 4) * 3; + if (source[source.len - 1] == pad_char) { result -= 1; - if (encoded[encoded.len - 2] == pad_char) { + if (source[source.len - 2] == pad_char) { result -= 1; } } return result; } -/// dest.len must be what you get from ::calcDecodedSizeExactUnsafe. -/// invalid characters or padding will result in undefined values. -pub fn decodeExactUnsafe(dest: []u8, source: []const u8, alphabet: &const Base64AlphabetUnsafe) { - assert(dest.len == calcDecodedSizeExactUnsafe(source, alphabet.pad_char)); - - var src_index: usize = 0; - var dest_index: usize = 0; - var in_buf_len: usize = source.len; - - while (in_buf_len > 0 and source[in_buf_len - 1] == alphabet.pad_char) { - in_buf_len -= 1; - } - - while (in_buf_len > 4) { - dest[dest_index] = alphabet.char_to_index[source[src_index + 0]] << 2 | - alphabet.char_to_index[source[src_index + 1]] >> 4; - dest_index += 1; - - dest[dest_index] = alphabet.char_to_index[source[src_index + 1]] << 4 | - alphabet.char_to_index[source[src_index + 2]] >> 2; - dest_index += 1; - - dest[dest_index] = alphabet.char_to_index[source[src_index + 2]] << 6 | - alphabet.char_to_index[source[src_index + 3]]; - dest_index += 1; - - src_index += 4; - in_buf_len -= 4; - } - - if (in_buf_len > 1) { - dest[dest_index] = alphabet.char_to_index[source[src_index + 0]] << 2 | - alphabet.char_to_index[source[src_index + 1]] >> 4; - dest_index += 1; - } - if (in_buf_len > 2) { - dest[dest_index] = alphabet.char_to_index[source[src_index + 1]] << 4 | - alphabet.char_to_index[source[src_index + 2]] >> 2; - dest_index += 1; - } - if (in_buf_len > 3) { - dest[dest_index] = alphabet.char_to_index[source[src_index + 2]] << 6 | - alphabet.char_to_index[source[src_index + 3]]; - dest_index += 1; - } -} test "base64" { @setEvalBranchQuota(5000); @@ -391,74 +412,74 @@ fn testBase64() -> %void { } fn testAllApis(expected_decoded: []const u8, expected_encoded: []const u8) -> %void { - // encode + // Base64Encoder { var buffer: [0x100]u8 = undefined; - var encoded = buffer[0..calcEncodedSize(expected_decoded.len)]; - encode(encoded, expected_decoded, standard_alphabet_chars, standard_pad_char); + var encoded = buffer[0..Base64Encoder.calcSize(expected_decoded.len)]; + standard_encoder.encode(encoded, expected_decoded); assert(mem.eql(u8, encoded, expected_encoded)); } - // decodeExact + // Base64Decoder { var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..%return calcDecodedSizeExact(expected_encoded, standard_pad_char)]; - %return decodeExact(decoded, expected_encoded, standard_alphabet); + var decoded = buffer[0..%return standard_decoder.calcSize(expected_encoded)]; + %return standard_decoder.decode(decoded, expected_encoded); assert(mem.eql(u8, decoded, expected_decoded)); } - // decodeWithIgnore + // Base64DecoderWithIgnore { - const standard_alphabet_ignore_nothing = Base64AlphabetWithIgnore.init( + const standard_decoder_ignore_nothing = Base64DecoderWithIgnore.init( standard_alphabet_chars, standard_pad_char, ""); var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..%return calcDecodedSizeUpperBound(expected_encoded.len)]; - var written = %return decodeWithIgnore(decoded, expected_encoded, standard_alphabet_ignore_nothing); + var decoded = buffer[0..%return Base64DecoderWithIgnore.calcSizeUpperBound(expected_encoded.len)]; + var written = %return standard_decoder_ignore_nothing.decode(decoded, expected_encoded); assert(written <= decoded.len); assert(mem.eql(u8, decoded[0..written], expected_decoded)); } - // decodeExactUnsafe + // Base64DecoderUnsafe { var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..calcDecodedSizeExactUnsafe(expected_encoded, standard_pad_char)]; - decodeExactUnsafe(decoded, expected_encoded, standard_alphabet_unsafe); + var decoded = buffer[0..standard_decoder_unsafe.calcSize(expected_encoded)]; + standard_decoder_unsafe.decode(decoded, expected_encoded); assert(mem.eql(u8, decoded, expected_decoded)); } } fn testDecodeIgnoreSpace(expected_decoded: []const u8, encoded: []const u8) -> %void { - const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init( + const standard_decoder_ignore_space = Base64DecoderWithIgnore.init( standard_alphabet_chars, standard_pad_char, " "); var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..%return calcDecodedSizeUpperBound(encoded.len)]; - var written = %return decodeWithIgnore(decoded, encoded, standard_alphabet_ignore_space); + var decoded = buffer[0..%return Base64DecoderWithIgnore.calcSizeUpperBound(encoded.len)]; + var written = %return standard_decoder_ignore_space.decode(decoded, encoded); assert(mem.eql(u8, decoded[0..written], expected_decoded)); } error ExpectedError; fn testError(encoded: []const u8, expected_err: error) -> %void { - const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init( + const standard_decoder_ignore_space = Base64DecoderWithIgnore.init( standard_alphabet_chars, standard_pad_char, " "); var buffer: [0x100]u8 = undefined; - if (calcDecodedSizeExact(encoded, standard_pad_char)) |decoded_size| { + if (standard_decoder.calcSize(encoded)) |decoded_size| { var decoded = buffer[0..decoded_size]; - if (decodeExact(decoded, encoded, standard_alphabet)) |_| { + if (standard_decoder.decode(decoded, encoded)) |_| { return error.ExpectedError; } else |err| if (err != expected_err) return err; } else |err| if (err != expected_err) return err; - if (decodeWithIgnore(buffer[0..], encoded, standard_alphabet_ignore_space)) |_| { + if (standard_decoder_ignore_space.decode(buffer[0..], encoded)) |_| { return error.ExpectedError; } else |err| if (err != expected_err) return err; } fn testOutputTooSmallError(encoded: []const u8) -> %void { - const standard_alphabet_ignore_space = Base64AlphabetWithIgnore.init( + const standard_decoder_ignore_space = Base64DecoderWithIgnore.init( standard_alphabet_chars, standard_pad_char, " "); var buffer: [0x100]u8 = undefined; var decoded = buffer[0..calcDecodedSizeExactUnsafe(encoded, standard_pad_char) - 1]; - if (decodeWithIgnore(decoded, encoded, standard_alphabet_ignore_space)) |_| { + if (standard_decoder_ignore_space.decode(decoded, encoded)) |_| { return error.ExpectedError; } else |err| if (err != error.OutputTooSmall) return err; } diff --git a/std/os/index.zig b/std/os/index.zig index 872564224c..e6a5fc4d15 100644 --- a/std/os/index.zig +++ b/std/os/index.zig @@ -622,7 +622,9 @@ pub fn symLinkPosix(allocator: &Allocator, existing_path: []const u8, new_path: } // here we replace the standard +/ with -_ so that it can be used in a file name -const b64_fs_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; +const b64_fs_encoder = base64.Base64Encoder.init( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", + base64.standard_pad_char); pub fn atomicSymLink(allocator: &Allocator, existing_path: []const u8, new_path: []const u8) -> %void { if (symLink(allocator, existing_path, new_path)) { @@ -634,12 +636,12 @@ pub fn atomicSymLink(allocator: &Allocator, existing_path: []const u8, new_path: } var rand_buf: [12]u8 = undefined; - const tmp_path = %return allocator.alloc(u8, new_path.len + base64.calcEncodedSize(rand_buf.len)); + const tmp_path = %return allocator.alloc(u8, new_path.len + base64.Base64Encoder.calcSize(rand_buf.len)); defer allocator.free(tmp_path); mem.copy(u8, tmp_path[0..], new_path); while (true) { %return getRandomBytes(rand_buf[0..]); - base64.encode(tmp_path[new_path.len..], rand_buf, b64_fs_alphabet_chars, base64.standard_pad_char); + b64_fs_encoder.encode(tmp_path[new_path.len..], rand_buf); if (symLink(allocator, existing_path, tmp_path)) { return rename(allocator, tmp_path, new_path); } else |err| { @@ -717,11 +719,11 @@ pub fn copyFile(allocator: &Allocator, source_path: []const u8, dest_path: []con /// Guaranteed to be atomic. pub fn copyFileMode(allocator: &Allocator, source_path: []const u8, dest_path: []const u8, mode: usize) -> %void { var rand_buf: [12]u8 = undefined; - const tmp_path = %return allocator.alloc(u8, dest_path.len + base64.calcEncodedSize(rand_buf.len)); + const tmp_path = %return allocator.alloc(u8, dest_path.len + base64.Base64Encoder.calcSize(rand_buf.len)); defer allocator.free(tmp_path); mem.copy(u8, tmp_path[0..], dest_path); %return getRandomBytes(rand_buf[0..]); - base64.encode(tmp_path[dest_path.len..], rand_buf, b64_fs_alphabet_chars, base64.standard_pad_char); + b64_fs_encoder.encode(tmp_path[dest_path.len..], rand_buf); var out_file = %return io.File.openWriteMode(tmp_path, mode, allocator); defer out_file.close(); From 5a25505668bac9aed0ad8f3b23fe81c6aff29b71 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 24 Nov 2017 14:56:05 -0500 Subject: [PATCH 12/34] rename "parsec" to "translate-c" --- CMakeLists.txt | 2 +- build.zig | 2 +- ci/travis_osx_script | 2 +- src-self-hosted/main.zig | 2 +- src/analyze.cpp | 4 +-- src/codegen.cpp | 4 +-- src/codegen.hpp | 2 +- src/ir.cpp | 2 +- src/main.cpp | 22 ++++++------- src/{parsec.cpp => translate_c.cpp} | 2 +- src/{parsec.hpp => translate_c.hpp} | 0 test/tests.zig | 48 ++++++++++++++-------------- test/{parsec.zig => translate_c.zig} | 2 +- 13 files changed, 47 insertions(+), 47 deletions(-) rename src/{parsec.cpp => translate_c.cpp} (99%) rename src/{parsec.hpp => translate_c.hpp} (100%) rename test/{parsec.zig => translate_c.zig} (99%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8906798c69..72c480cd40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -339,7 +339,7 @@ set(ZIG_SOURCES "${CMAKE_SOURCE_DIR}/src/target.cpp" "${CMAKE_SOURCE_DIR}/src/tokenizer.cpp" "${CMAKE_SOURCE_DIR}/src/util.cpp" - "${CMAKE_SOURCE_DIR}/src/parsec.cpp" + "${CMAKE_SOURCE_DIR}/src/translate_c.cpp" "${CMAKE_SOURCE_DIR}/src/zig_llvm.cpp" ) diff --git a/build.zig b/build.zig index 79fefd5c7a..1c1d4f832e 100644 --- a/build.zig +++ b/build.zig @@ -58,5 +58,5 @@ pub fn build(b: &Builder) { test_step.dependOn(tests.addCompileErrorTests(b, test_filter)); test_step.dependOn(tests.addAssembleAndLinkTests(b, test_filter)); test_step.dependOn(tests.addDebugSafetyTests(b, test_filter)); - test_step.dependOn(tests.addParseCTests(b, test_filter)); + test_step.dependOn(tests.addTranslateCTests(b, test_filter)); } diff --git a/ci/travis_osx_script b/ci/travis_osx_script index 1f345c9c50..a55132395b 100755 --- a/ci/travis_osx_script +++ b/ci/travis_osx_script @@ -22,4 +22,4 @@ make install ./zig build --build-file ../build.zig test-compile-errors --verbose ./zig build --build-file ../build.zig test-asm-link --verbose ./zig build --build-file ../build.zig test-debug-safety --verbose -./zig build --build-file ../build.zig test-parsec --verbose +./zig build --build-file ../build.zig test-translate-c --verbose diff --git a/src-self-hosted/main.zig b/src-self-hosted/main.zig index 0716ad433b..71180b2001 100644 --- a/src-self-hosted/main.zig +++ b/src-self-hosted/main.zig @@ -208,7 +208,7 @@ fn printUsage(outstream: &io.OutStream) -> %void { \\ build-exe [source] create executable from source or object files \\ build-lib [source] create library from source or object files \\ build-obj [source] create object from source or assembly - \\ parsec [source] convert c code to zig code + \\ translate-c [source] convert c code to zig code \\ targets list available compilation targets \\ test [source] create and run a test build \\ version print version number and exit diff --git a/src/analyze.cpp b/src/analyze.cpp index ebad9fe0cb..7a9df874d4 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -28,7 +28,7 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type); ErrorMsg *add_node_error(CodeGen *g, AstNode *node, Buf *msg) { if (node->owner->c_import_node != nullptr) { - // if this happens, then parsec generated code that + // if this happens, then translate_c generated code that // failed semantic analysis, which isn't supposed to happen ErrorMsg *err = add_node_error(g, node->owner->c_import_node, buf_sprintf("compiler bug: @cImport generated invalid zig code")); @@ -48,7 +48,7 @@ ErrorMsg *add_node_error(CodeGen *g, AstNode *node, Buf *msg) { ErrorMsg *add_error_note(CodeGen *g, ErrorMsg *parent_msg, AstNode *node, Buf *msg) { if (node->owner->c_import_node != nullptr) { - // if this happens, then parsec generated code that + // if this happens, then translate_c generated code that // failed semantic analysis, which isn't supposed to happen Buf *note_path = buf_create_from_str("?.c"); diff --git a/src/codegen.cpp b/src/codegen.cpp index 680f5a9e35..24d24a91e5 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -15,7 +15,7 @@ #include "ir.hpp" #include "link.hpp" #include "os.hpp" -#include "parsec.hpp" +#include "translate_c.hpp" #include "target.hpp" #include "zig_llvm.hpp" @@ -5353,7 +5353,7 @@ static void init(CodeGen *g) { define_builtin_compile_vars(g); } -void codegen_parsec(CodeGen *g, Buf *full_path) { +void codegen_translate_c(CodeGen *g, Buf *full_path) { find_libc_include_path(g); Buf *src_basename = buf_alloc(); diff --git a/src/codegen.hpp b/src/codegen.hpp index b71a7fa651..b29cadee55 100644 --- a/src/codegen.hpp +++ b/src/codegen.hpp @@ -56,7 +56,7 @@ PackageTableEntry *codegen_create_package(CodeGen *g, const char *root_src_dir, void codegen_add_assembly(CodeGen *g, Buf *path); void codegen_add_object(CodeGen *g, Buf *object_path); -void codegen_parsec(CodeGen *g, Buf *path); +void codegen_translate_c(CodeGen *g, Buf *path); #endif diff --git a/src/ir.cpp b/src/ir.cpp index fa59aa03f2..c81de7fa7a 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -11,7 +11,7 @@ #include "ir.hpp" #include "ir_print.hpp" #include "os.hpp" -#include "parsec.hpp" +#include "translate_c.hpp" #include "range_set.hpp" #include "softfloat.hpp" diff --git a/src/main.cpp b/src/main.cpp index 13da71f9e8..60d2750bde 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -23,7 +23,7 @@ static int usage(const char *arg0) { " build-exe [source] create executable from source or object files\n" " build-lib [source] create library from source or object files\n" " build-obj [source] create object from source or assembly\n" - " parsec [source] convert c code to zig code\n" + " translate-c [source] convert c code to zig code\n" " targets list available compilation targets\n" " test [source] create and run a test build\n" " version print version number and exit\n" @@ -229,7 +229,7 @@ enum Cmd { CmdTest, CmdVersion, CmdZen, - CmdParseC, + CmdTranslateC, CmdTargets, }; @@ -632,8 +632,8 @@ int main(int argc, char **argv) { cmd = CmdVersion; } else if (strcmp(arg, "zen") == 0) { cmd = CmdZen; - } else if (strcmp(arg, "parsec") == 0) { - cmd = CmdParseC; + } else if (strcmp(arg, "translate-c") == 0) { + cmd = CmdTranslateC; } else if (strcmp(arg, "test") == 0) { cmd = CmdTest; out_type = OutTypeExe; @@ -646,7 +646,7 @@ int main(int argc, char **argv) { } else { switch (cmd) { case CmdBuild: - case CmdParseC: + case CmdTranslateC: case CmdTest: if (!in_file) { in_file = arg; @@ -703,13 +703,13 @@ int main(int argc, char **argv) { switch (cmd) { case CmdBuild: - case CmdParseC: + case CmdTranslateC: case CmdTest: { if (cmd == CmdBuild && !in_file && objects.length == 0 && asm_files.length == 0) { fprintf(stderr, "Expected source file argument or at least one --object or --assembly argument.\n"); return usage(arg0); - } else if ((cmd == CmdParseC || cmd == CmdTest) && !in_file) { + } else if ((cmd == CmdTranslateC || cmd == CmdTest) && !in_file) { fprintf(stderr, "Expected source file argument.\n"); return usage(arg0); } else if (cmd == CmdBuild && out_type == OutTypeObj && objects.length != 0) { @@ -719,7 +719,7 @@ int main(int argc, char **argv) { assert(cmd != CmdBuild || out_type != OutTypeUnknown); - bool need_name = (cmd == CmdBuild || cmd == CmdParseC); + bool need_name = (cmd == CmdBuild || cmd == CmdTranslateC); Buf *in_file_buf = nullptr; @@ -742,7 +742,7 @@ int main(int argc, char **argv) { return usage(arg0); } - Buf *zig_root_source_file = (cmd == CmdParseC) ? nullptr : in_file_buf; + Buf *zig_root_source_file = (cmd == CmdTranslateC) ? nullptr : in_file_buf; Buf *full_cache_dir = buf_alloc(); os_path_resolve(buf_create_from_str("."), @@ -841,8 +841,8 @@ int main(int argc, char **argv) { if (timing_info) codegen_print_timing_report(g, stdout); return EXIT_SUCCESS; - } else if (cmd == CmdParseC) { - codegen_parsec(g, in_file_buf); + } else if (cmd == CmdTranslateC) { + codegen_translate_c(g, in_file_buf); ast_render(g, stdout, g->root_import->root, 4); if (timing_info) codegen_print_timing_report(g, stdout); diff --git a/src/parsec.cpp b/src/translate_c.cpp similarity index 99% rename from src/parsec.cpp rename to src/translate_c.cpp index 79ba2ab990..a76432fa57 100644 --- a/src/parsec.cpp +++ b/src/translate_c.cpp @@ -11,7 +11,7 @@ #include "error.hpp" #include "ir.hpp" #include "os.hpp" -#include "parsec.hpp" +#include "translate_c.hpp" #include "parser.hpp" diff --git a/src/parsec.hpp b/src/translate_c.hpp similarity index 100% rename from src/parsec.hpp rename to src/translate_c.hpp diff --git a/test/tests.zig b/test/tests.zig index 20b57c7573..73d9646552 100644 --- a/test/tests.zig +++ b/test/tests.zig @@ -18,7 +18,7 @@ const build_examples = @import("build_examples.zig"); const compile_errors = @import("compile_errors.zig"); const assemble_and_link = @import("assemble_and_link.zig"); const debug_safety = @import("debug_safety.zig"); -const parsec = @import("parsec.zig"); +const translate_c = @import("translate_c.zig"); const TestTarget = struct { os: builtin.Os, @@ -123,16 +123,16 @@ pub fn addAssembleAndLinkTests(b: &build.Builder, test_filter: ?[]const u8) -> & return cases.step; } -pub fn addParseCTests(b: &build.Builder, test_filter: ?[]const u8) -> &build.Step { - const cases = %%b.allocator.create(ParseCContext); - *cases = ParseCContext { +pub fn addTranslateCTests(b: &build.Builder, test_filter: ?[]const u8) -> &build.Step { + const cases = %%b.allocator.create(TranslateCContext); + *cases = TranslateCContext { .b = b, - .step = b.step("test-parsec", "Run the C header file parsing tests"), + .step = b.step("test-translate-c", "Run the C header file parsing tests"), .test_index = 0, .test_filter = test_filter, }; - parsec.addCases(cases); + translate_c.addCases(cases); return cases.step; } @@ -770,7 +770,7 @@ pub const BuildExamplesContext = struct { } }; -pub const ParseCContext = struct { +pub const TranslateCContext = struct { b: &build.Builder, step: &build.Step, test_index: usize, @@ -799,17 +799,17 @@ pub const ParseCContext = struct { } }; - const ParseCCmpOutputStep = struct { + const TranslateCCmpOutputStep = struct { step: build.Step, - context: &ParseCContext, + context: &TranslateCContext, name: []const u8, test_index: usize, case: &const TestCase, - pub fn create(context: &ParseCContext, name: []const u8, case: &const TestCase) -> &ParseCCmpOutputStep { + pub fn create(context: &TranslateCContext, name: []const u8, case: &const TestCase) -> &TranslateCCmpOutputStep { const allocator = context.b.allocator; - const ptr = %%allocator.create(ParseCCmpOutputStep); - *ptr = ParseCCmpOutputStep { + const ptr = %%allocator.create(TranslateCCmpOutputStep); + *ptr = TranslateCCmpOutputStep { .step = build.Step.init("ParseCCmpOutput", allocator, make), .context = context, .name = name, @@ -821,7 +821,7 @@ pub const ParseCContext = struct { } fn make(step: &build.Step) -> %void { - const self = @fieldParentPtr(ParseCCmpOutputStep, "step", step); + const self = @fieldParentPtr(TranslateCCmpOutputStep, "step", step); const b = self.context.b; const root_src = %%os.path.join(b.allocator, b.cache_root, self.case.sources.items[0].filename); @@ -829,7 +829,7 @@ pub const ParseCContext = struct { var zig_args = ArrayList([]const u8).init(b.allocator); %%zig_args.append(b.zig_exe); - %%zig_args.append("parsec"); + %%zig_args.append("translate-c"); %%zig_args.append(b.pathFromRoot(root_src)); warn("Test {}/{} {}...", self.test_index+1, self.context.test_index, self.name); @@ -882,7 +882,7 @@ pub const ParseCContext = struct { if (stderr.len != 0 and !self.case.allow_warnings) { warn( - \\====== parsec emitted warnings: ============ + \\====== translate-c emitted warnings: ======= \\{} \\============================================ \\ @@ -914,7 +914,7 @@ pub const ParseCContext = struct { warn("\n"); } - pub fn create(self: &ParseCContext, allow_warnings: bool, filename: []const u8, name: []const u8, + pub fn create(self: &TranslateCContext, allow_warnings: bool, filename: []const u8, name: []const u8, source: []const u8, expected_lines: ...) -> &TestCase { const tc = %%self.b.allocator.create(TestCase); @@ -932,37 +932,37 @@ pub const ParseCContext = struct { return tc; } - pub fn add(self: &ParseCContext, name: []const u8, source: []const u8, expected_lines: ...) { + pub fn add(self: &TranslateCContext, name: []const u8, source: []const u8, expected_lines: ...) { const tc = self.create(false, "source.h", name, source, expected_lines); self.addCase(tc); } - pub fn addC(self: &ParseCContext, name: []const u8, source: []const u8, expected_lines: ...) { + pub fn addC(self: &TranslateCContext, name: []const u8, source: []const u8, expected_lines: ...) { const tc = self.create(false, "source.c", name, source, expected_lines); self.addCase(tc); } - pub fn addAllowWarnings(self: &ParseCContext, name: []const u8, source: []const u8, expected_lines: ...) { + pub fn addAllowWarnings(self: &TranslateCContext, name: []const u8, source: []const u8, expected_lines: ...) { const tc = self.create(true, "source.h", name, source, expected_lines); self.addCase(tc); } - pub fn addCase(self: &ParseCContext, case: &const TestCase) { + pub fn addCase(self: &TranslateCContext, case: &const TestCase) { const b = self.b; - const annotated_case_name = %%fmt.allocPrint(self.b.allocator, "parsec {}", case.name); + const annotated_case_name = %%fmt.allocPrint(self.b.allocator, "translate-c {}", case.name); if (self.test_filter) |filter| { if (mem.indexOf(u8, annotated_case_name, filter) == null) return; } - const parsec_and_cmp = ParseCCmpOutputStep.create(self, annotated_case_name, case); - self.step.dependOn(&parsec_and_cmp.step); + const translate_c_and_cmp = TranslateCCmpOutputStep.create(self, annotated_case_name, case); + self.step.dependOn(&translate_c_and_cmp.step); for (case.sources.toSliceConst()) |src_file| { const expanded_src_path = %%os.path.join(b.allocator, b.cache_root, src_file.filename); const write_src = b.addWriteFile(expanded_src_path, src_file.source); - parsec_and_cmp.step.dependOn(&write_src.step); + translate_c_and_cmp.step.dependOn(&write_src.step); } } }; diff --git a/test/parsec.zig b/test/translate_c.zig similarity index 99% rename from test/parsec.zig rename to test/translate_c.zig index f830131262..6394950fd0 100644 --- a/test/parsec.zig +++ b/test/translate_c.zig @@ -1,6 +1,6 @@ const tests = @import("tests.zig"); -pub fn addCases(cases: &tests.ParseCContext) { +pub fn addCases(cases: &tests.TranslateCContext) { cases.addAllowWarnings("simple data types", \\#include \\int foo(char a, unsigned char b, signed char c); From 741504862c7094dc91d3bf4c4da58e9236a7633a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 24 Nov 2017 15:06:12 -0500 Subject: [PATCH 13/34] update homepage docs --- doc/home.html.in | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/doc/home.html.in b/doc/home.html.in index 002ac0b70a..3b2ef3407b 100644 --- a/doc/home.html.in +++ b/doc/home.html.in @@ -70,10 +70,14 @@
  • Mersenne Twister Random Number Generator
  • Hello World

    -
    const io = @import("std").io;
    +    
    const std = @import("std");
     
     pub fn main() -> %void {
    -    %return io.stdout.printf("Hello, world!\n");
    +    // If this program is run without stdout attached, exit with an error.
    +    var stdout_file = %return std.io.getStdOut();
    +    // If this program encounters pipe failure when printing to stdout, exit
    +    // with an error.
    +    %return stdout_file.write("Hello, world!\n");
     }

    Build this with:

    zig build-exe hello.zig
    From 68312afcdf5a75ddc95dc21cb2f610780c0b69a6 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 24 Nov 2017 16:36:39 -0500 Subject: [PATCH 14/34] translate-c: support pre increment and decrement operators --- src/translate_c.cpp | 128 ++++++++++++++++++++++++++++++------------- test/translate_c.zig | 44 +++++++++++++++ 2 files changed, 133 insertions(+), 39 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index a76432fa57..9233b7d1a0 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -1524,47 +1524,93 @@ static AstNode *trans_create_post_crement(Context *c, bool result_used, AstNode trans_expr(c, true, block, op_expr, TransLValue), assign_op, trans_create_node_unsigned(c, 1)); - } else { - // worst case - // c: expr++ - // zig: { - // zig: const _ref = &expr; - // zig: const _tmp = *_ref; - // zig: *_ref += 1; - // zig: _tmp - // zig: } - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + } + // worst case + // c: expr++ + // zig: { + // zig: const _ref = &expr; + // zig: const _tmp = *_ref; + // zig: *_ref += 1; + // zig: _tmp + // zig: } + AstNode *child_block = trans_create_node(c, NodeTypeBlock); - // const _ref = &expr; - AstNode *expr = trans_expr(c, true, child_block, op_expr, TransLValue); - if (expr == nullptr) return nullptr; - AstNode *addr_of_expr = trans_create_node_addr_of(c, false, false, expr); - // TODO: avoid name collisions with generated variable names - Buf* ref_var_name = buf_create_from_str("_ref"); - AstNode *ref_var_decl = trans_create_node_var_decl_local(c, true, ref_var_name, nullptr, addr_of_expr); - child_block->data.block.statements.append(ref_var_decl); + // const _ref = &expr; + AstNode *expr = trans_expr(c, true, child_block, op_expr, TransLValue); + if (expr == nullptr) return nullptr; + AstNode *addr_of_expr = trans_create_node_addr_of(c, false, false, expr); + // TODO: avoid name collisions with generated variable names + Buf* ref_var_name = buf_create_from_str("_ref"); + AstNode *ref_var_decl = trans_create_node_var_decl_local(c, true, ref_var_name, nullptr, addr_of_expr); + child_block->data.block.statements.append(ref_var_decl); - // const _tmp = *_ref; - Buf* tmp_var_name = buf_create_from_str("_tmp"); - AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, - trans_create_node_prefix_op(c, PrefixOpDereference, - trans_create_node_symbol(c, ref_var_name))); - child_block->data.block.statements.append(tmp_var_decl); + // const _tmp = *_ref; + Buf* tmp_var_name = buf_create_from_str("_tmp"); + AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, + trans_create_node_prefix_op(c, PrefixOpDereference, + trans_create_node_symbol(c, ref_var_name))); + child_block->data.block.statements.append(tmp_var_decl); - // *_ref += 1; - AstNode *assign_statement = trans_create_node_bin_op(c, - trans_create_node_prefix_op(c, PrefixOpDereference, - trans_create_node_symbol(c, ref_var_name)), + // *_ref += 1; + AstNode *assign_statement = trans_create_node_bin_op(c, + trans_create_node_prefix_op(c, PrefixOpDereference, + trans_create_node_symbol(c, ref_var_name)), + assign_op, + trans_create_node_unsigned(c, 1)); + child_block->data.block.statements.append(assign_statement); + + // _tmp + child_block->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); + child_block->data.block.last_statement_is_result_expression = true; + + return child_block; +} + +static AstNode *trans_create_pre_crement(Context *c, bool result_used, AstNode *block, UnaryOperator *stmt, BinOpType assign_op) { + Expr *op_expr = stmt->getSubExpr(); + + if (!result_used) { + // common case + // c: ++expr + // zig: expr += 1 + return trans_create_node_bin_op(c, + trans_expr(c, true, block, op_expr, TransLValue), assign_op, trans_create_node_unsigned(c, 1)); - child_block->data.block.statements.append(assign_statement); - - // _tmp - child_block->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); - child_block->data.block.last_statement_is_result_expression = true; - - return child_block; } + // worst case + // c: ++expr + // zig: { + // zig: const _ref = &expr; + // zig: *_ref += 1; + // zig: *_ref + // zig: } + AstNode *child_block = trans_create_node(c, NodeTypeBlock); + + // const _ref = &expr; + AstNode *expr = trans_expr(c, true, child_block, op_expr, TransLValue); + if (expr == nullptr) return nullptr; + AstNode *addr_of_expr = trans_create_node_addr_of(c, false, false, expr); + // TODO: avoid name collisions with generated variable names + Buf* ref_var_name = buf_create_from_str("_ref"); + AstNode *ref_var_decl = trans_create_node_var_decl_local(c, true, ref_var_name, nullptr, addr_of_expr); + child_block->data.block.statements.append(ref_var_decl); + + // *_ref += 1; + AstNode *assign_statement = trans_create_node_bin_op(c, + trans_create_node_prefix_op(c, PrefixOpDereference, + trans_create_node_symbol(c, ref_var_name)), + assign_op, + trans_create_node_unsigned(c, 1)); + child_block->data.block.statements.append(assign_statement); + + // *_ref + AstNode *deref_expr = trans_create_node_prefix_op(c, PrefixOpDereference, + trans_create_node_symbol(c, ref_var_name)); + child_block->data.block.statements.append(deref_expr); + child_block->data.block.last_statement_is_result_expression = true; + + return child_block; } static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *block, UnaryOperator *stmt) { @@ -1580,11 +1626,15 @@ static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *bloc else return trans_create_post_crement(c, result_used, block, stmt, BinOpTypeAssignMinus); case UO_PreInc: - emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_PreInc"); - return nullptr; + if (qual_type_has_wrapping_overflow(c, stmt->getType())) + return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignPlusWrap); + else + return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignPlus); case UO_PreDec: - emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_PreDec"); - return nullptr; + if (qual_type_has_wrapping_overflow(c, stmt->getType())) + return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignMinusWrap); + else + return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignMinus); case UO_AddrOf: { AstNode *value_node = trans_expr(c, result_used, block, stmt->getSubExpr(), TransLValue); diff --git a/test/translate_c.zig b/test/translate_c.zig index 6394950fd0..4e84576c64 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -805,6 +805,50 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\} ); + cases.addC("pre increment/decrement", + \\void foo(void) { + \\ int i = 0; + \\ unsigned u = 0; + \\ ++i; + \\ --i; + \\ ++u; + \\ --u; + \\ i = ++i; + \\ i = --i; + \\ u = ++u; + \\ u = --u; + \\} + , + \\export fn foo() { + \\ var i: c_int = 0; + \\ var u: c_uint = c_uint(0); + \\ i += 1; + \\ i -= 1; + \\ u +%= 1; + \\ u -%= 1; + \\ i = { + \\ const _ref = &i; + \\ (*_ref) += 1; + \\ *_ref + \\ }; + \\ i = { + \\ const _ref = &i; + \\ (*_ref) -= 1; + \\ *_ref + \\ }; + \\ u = { + \\ const _ref = &u; + \\ (*_ref) +%= 1; + \\ *_ref + \\ }; + \\ u = { + \\ const _ref = &u; + \\ (*_ref) -%= 1; + \\ *_ref + \\ }; + \\} + ); + cases.addC("do loop", \\void foo(void) { \\ int a = 2; From 40480c7cdc84697219e8b2434772709ad948ed4d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 24 Nov 2017 19:26:05 -0500 Subject: [PATCH 15/34] translate-c supports string literals --- src/translate_c.cpp | 75 ++++++++++++++++++++++++++++++++++++++++---- test/translate_c.zig | 10 ++++++ 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 9233b7d1a0..bb1c81f429 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -324,6 +324,10 @@ static AstNode *add_global_var(Context *c, Buf *var_name, AstNode *value_node) { return node; } +static Buf *string_ref_to_buf(StringRef string_ref) { + return buf_create_from_mem((const char *)string_ref.bytes_begin(), string_ref.size()); +} + static const char *decl_name(const Decl *decl) { const NamedDecl *named_decl = static_cast(decl); return (const char *)named_decl->getName().bytes_begin(); @@ -339,6 +343,44 @@ static AstNode *trans_create_node_apint(Context *c, const llvm::APSInt &aps_int) static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &source_loc); +static QualType get_expr_qual_type(Context *c, const Expr *expr) { + // String literals in C are `char *` but they should really be `const char *`. + if (expr->getStmtClass() == Stmt::ImplicitCastExprClass) { + const ImplicitCastExpr *cast_expr = static_cast(expr); + if (cast_expr->getCastKind() == CK_ArrayToPointerDecay) { + const Expr *sub_expr = cast_expr->getSubExpr(); + if (sub_expr->getStmtClass() == Stmt::StringLiteralClass) { + QualType array_qt = sub_expr->getType(); + const ArrayType *array_type = static_cast(array_qt.getTypePtr()); + QualType pointee_qt = array_type->getElementType(); + pointee_qt.addConst(); + return c->ctx->getPointerType(pointee_qt); + } + } + } + return expr->getType(); +} + +static AstNode *get_expr_type(Context *c, const Expr *expr) { + return trans_qual_type(c, get_expr_qual_type(c, expr), expr->getLocStart()); +} + +static bool expr_types_equal(Context *c, const Expr *expr1, const Expr *expr2) { + QualType t1 = get_expr_qual_type(c, expr1); + QualType t2 = get_expr_qual_type(c, expr2); + + if (t1.isConstQualified() != t2.isConstQualified()) { + return false; + } + if (t1.isVolatileQualified() != t2.isVolatileQualified()) { + return false; + } + if (t1.isRestrictQualified() != t2.isRestrictQualified()) { + return false; + } + return t1.getTypePtr() == t2.getTypePtr(); +} + static bool is_c_void_type(AstNode *node) { return (node->type == NodeTypeSymbol && buf_eql_str(node->data.symbol_expr.symbol, "c_void")); } @@ -1335,7 +1377,11 @@ static AstNode *trans_implicit_cast_expr(Context *c, AstNode *block, ImplicitCas if (target_node == nullptr) return nullptr; - AstNode *dest_type_node = trans_qual_type(c, stmt->getType(), stmt->getLocStart()); + if (expr_types_equal(c, stmt, stmt->getSubExpr())) { + return target_node; + } + + AstNode *dest_type_node = get_expr_type(c, stmt); AstNode *node = trans_create_node_builtin_fn_call_str(c, "ptrCast"); node->data.fn_call_expr.params.append(dest_type_node); @@ -2126,6 +2172,24 @@ static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { return while_node; } +static AstNode *trans_string_literal(Context *c, AstNode *block, StringLiteral *stmt) { + switch (stmt->getKind()) { + case StringLiteral::Ascii: + case StringLiteral::UTF8: + return trans_create_node_str_lit_c(c, string_ref_to_buf(stmt->getString())); + case StringLiteral::UTF16: + emit_warning(c, stmt->getLocStart(), "TODO support UTF16 string literals"); + return nullptr; + case StringLiteral::UTF32: + emit_warning(c, stmt->getLocStart(), "TODO support UTF32 string literals"); + return nullptr; + case StringLiteral::Wide: + emit_warning(c, stmt->getLocStart(), "TODO support wide string literals"); + return nullptr; + } + zig_unreachable(); +} + static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *stmt, TransLRValue lrvalue) { Stmt::StmtClass sc = stmt->getStmtClass(); switch (sc) { @@ -2167,6 +2231,8 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s return trans_unary_expr_or_type_trait_expr(c, block, (UnaryExprOrTypeTraitExpr *)stmt); case Stmt::DoStmtClass: return trans_do_loop(c, block, (DoStmt *)stmt); + case Stmt::StringLiteralClass: + return trans_string_literal(c, block, (StringLiteral *)stmt); case Stmt::CaseStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); return nullptr; @@ -2488,9 +2554,6 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s case Stmt::StmtExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C StmtExprClass"); return nullptr; - case Stmt::StringLiteralClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C StringLiteralClass"); - return nullptr; case Stmt::SubstNonTypeTemplateParmExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SubstNonTypeTemplateParmExprClass"); return nullptr; @@ -3530,7 +3593,7 @@ int parse_h_file(ImportTableEntry *import, ZigList *errors, const ch break; } StringRef msg_str_ref = it->getMessage(); - Buf *msg = buf_create_from_str((const char *)msg_str_ref.bytes_begin()); + Buf *msg = string_ref_to_buf(msg_str_ref); FullSourceLoc fsl = it->getLocation(); if (fsl.hasManager()) { FileID file_id = fsl.getFileID(); @@ -3543,7 +3606,7 @@ int parse_h_file(ImportTableEntry *import, ZigList *errors, const ch if (filename.empty()) { path = buf_alloc(); } else { - path = buf_create_from_mem((const char *)filename.bytes_begin(), filename.size()); + path = string_ref_to_buf(filename); } ErrorMsg *err_msg = err_msg_create_with_offset(path, line, column, offset, source, msg); diff --git a/test/translate_c.zig b/test/translate_c.zig index 4e84576c64..08626f6e90 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -929,6 +929,16 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return *(??ptr); \\} ); + + cases.add("string literal", + \\const char *foo(void) { + \\ return "bar"; + \\} + , + \\pub fn foo() -> ?&const u8 { + \\ return c"bar"; + \\} + ); } From cd36baf530d0c312884f485ed619438c1de91ef1 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 24 Nov 2017 22:04:24 -0500 Subject: [PATCH 16/34] fix assertion failed when invalid type encountered --- src/analyze.cpp | 2 +- src/ir.cpp | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/analyze.cpp b/src/analyze.cpp index 7a9df874d4..1c223c63f7 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -338,7 +338,7 @@ TypeTableEntry *get_smallest_unsigned_int_type(CodeGen *g, uint64_t x) { TypeTableEntry *get_pointer_to_type_extra(CodeGen *g, TypeTableEntry *child_type, bool is_const, bool is_volatile, uint32_t byte_alignment, uint32_t bit_offset, uint32_t unaligned_bit_count) { - assert(child_type->id != TypeTableEntryIdInvalid); + assert(!type_is_invalid(child_type)); TypeId type_id = {}; TypeTableEntry **parent_pointer = nullptr; diff --git a/src/ir.cpp b/src/ir.cpp index c81de7fa7a..35e6b3f8c6 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -7966,6 +7966,9 @@ static IrInstruction *ir_get_const_ptr(IrAnalyze *ira, IrInstruction *instructio ConstExprValue *const_val = &const_instr->value; const_val->type = pointee_type; type_ensure_zero_bits_known(ira->codegen, type_entry); + if (type_is_invalid(type_entry)) { + return ira->codegen->invalid_instruction; + } const_val->data.x_type = get_pointer_to_type_extra(ira->codegen, type_entry, ptr_is_const, ptr_is_volatile, get_abi_alignment(ira->codegen, type_entry), 0, 0); return const_instr; From 18eb3c5f90900627b0582b7e341f9db9177fcc09 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 00:25:47 -0500 Subject: [PATCH 17/34] translate-c supports returning void --- src/translate_c.cpp | 3 +-- test/translate_c.zig | 10 ++++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index bb1c81f429..9f80d4e898 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -937,8 +937,7 @@ static AstNode *trans_compound_stmt(Context *c, AstNode *parent, CompoundStmt *s static AstNode *trans_return_stmt(Context *c, AstNode *block, ReturnStmt *stmt) { Expr *value_expr = stmt->getRetValue(); if (value_expr == nullptr) { - emit_warning(c, stmt->getLocStart(), "TODO handle C return void"); - return nullptr; + return trans_create_node(c, NodeTypeReturnExpr); } else { AstNode *return_node = trans_create_node(c, NodeTypeReturnExpr); return_node->data.return_expr.expr = trans_expr(c, true, block, value_expr, TransRValue); diff --git a/test/translate_c.zig b/test/translate_c.zig index 08626f6e90..feb44f2cf0 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -939,6 +939,16 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return c"bar"; \\} ); + + cases.add("return void", + \\void foo(void) { + \\ return; + \\} + , + \\pub fn foo() { + \\ return; + \\} + ); } From bf20b260ce1b76b3b3d2c8a1bc1819eb3871bb00 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 00:57:48 -0500 Subject: [PATCH 18/34] translate-c supports for loops --- src/translate_c.cpp | 53 +++++++++++++++++++++++++++++++++++++++----- test/translate_c.zig | 13 +++++++++++ 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 9f80d4e898..6c55aa6afc 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -2119,9 +2119,6 @@ static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, AstNode *block, } static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { - stmt->getBody(); - stmt->getCond(); - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); @@ -2171,6 +2168,51 @@ static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { return while_node; } +static AstNode *trans_for_loop(Context *c, AstNode *block, ForStmt *stmt) { + AstNode *loop_block_node; + AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + Stmt *init_stmt = stmt->getInit(); + if (init_stmt == nullptr) { + loop_block_node = while_node; + } else { + loop_block_node = trans_create_node(c, NodeTypeBlock); + + AstNode *vars_node = trans_stmt(c, false, loop_block_node, init_stmt, TransRValue); + if (vars_node == nullptr) + return nullptr; + if (vars_node != skip_add_to_block_node) + loop_block_node->data.block.statements.append(vars_node); + + loop_block_node->data.block.statements.append(while_node); + } + + Stmt *cond_stmt = stmt->getCond(); + if (cond_stmt == nullptr) { + AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); + true_node->data.bool_literal.value = true; + while_node->data.while_expr.condition = true_node; + } else { + while_node->data.while_expr.condition = trans_stmt(c, false, loop_block_node, cond_stmt, TransRValue); + if (while_node->data.while_expr.condition == nullptr) + return nullptr; + } + + Stmt *inc_stmt = stmt->getInc(); + if (inc_stmt != nullptr) { + AstNode *inc_node = trans_stmt(c, false, loop_block_node, inc_stmt, TransRValue); + if (inc_node == nullptr) + return nullptr; + while_node->data.while_expr.continue_expr = inc_node; + } + + AstNode *child_statement = trans_stmt(c, false, loop_block_node, stmt->getBody(), TransRValue); + if (child_statement == nullptr) + return nullptr; + while_node->data.while_expr.body = child_statement; + + return loop_block_node; +} + static AstNode *trans_string_literal(Context *c, AstNode *block, StringLiteral *stmt) { switch (stmt->getKind()) { case StringLiteral::Ascii: @@ -2230,6 +2272,8 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s return trans_unary_expr_or_type_trait_expr(c, block, (UnaryExprOrTypeTraitExpr *)stmt); case Stmt::DoStmtClass: return trans_do_loop(c, block, (DoStmt *)stmt); + case Stmt::ForStmtClass: + return trans_for_loop(c, block, (ForStmt *)stmt); case Stmt::StringLiteralClass: return trans_string_literal(c, block, (StringLiteral *)stmt); case Stmt::CaseStmtClass: @@ -2568,9 +2612,6 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s case Stmt::VAArgExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C VAArgExprClass"); return nullptr; - case Stmt::ForStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C ForStmtClass"); - return nullptr; case Stmt::GotoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C GotoStmtClass"); return nullptr; diff --git a/test/translate_c.zig b/test/translate_c.zig index feb44f2cf0..80ccdabc74 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -949,6 +949,19 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return; \\} ); + + cases.add("for loop", + \\void foo(void) { + \\ for (int i = 0; i < 10; i += 1) { } + \\} + , + \\pub fn foo() { + \\ { + \\ var i: c_int = 0; + \\ while (i < 10) : (i += 1) {}; + \\ }; + \\} + ); } From b390929826062f81b9f796d4cf46a72aa23d291a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 11:56:17 -0500 Subject: [PATCH 19/34] translate-c supports break and continue --- src/translate_c.cpp | 18 ++++++++++++------ test/translate_c.zig | 28 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 6c55aa6afc..00309e8bc1 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -2231,6 +2231,14 @@ static AstNode *trans_string_literal(Context *c, AstNode *block, StringLiteral * zig_unreachable(); } +static AstNode *trans_break_stmt(Context *c, AstNode *block, BreakStmt *stmt) { + return trans_create_node(c, NodeTypeBreak); +} + +static AstNode *trans_continue_stmt(Context *c, AstNode *block, ContinueStmt *stmt) { + return trans_create_node(c, NodeTypeContinue); +} + static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *stmt, TransLRValue lrvalue) { Stmt::StmtClass sc = stmt->getStmtClass(); switch (sc) { @@ -2276,6 +2284,10 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s return trans_for_loop(c, block, (ForStmt *)stmt); case Stmt::StringLiteralClass: return trans_string_literal(c, block, (StringLiteral *)stmt); + case Stmt::BreakStmtClass: + return trans_break_stmt(c, block, (BreakStmt *)stmt); + case Stmt::ContinueStmtClass: + return trans_continue_stmt(c, block, (ContinueStmt *)stmt); case Stmt::CaseStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); return nullptr; @@ -2297,9 +2309,6 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s case Stmt::AttributedStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C AttributedStmtClass"); return nullptr; - case Stmt::BreakStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C BreakStmtClass"); - return nullptr; case Stmt::CXXCatchStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXCatchStmtClass"); return nullptr; @@ -2312,9 +2321,6 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s case Stmt::CapturedStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CapturedStmtClass"); return nullptr; - case Stmt::ContinueStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C ContinueStmtClass"); - return nullptr; case Stmt::CoreturnStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CoreturnStmtClass"); return nullptr; diff --git a/test/translate_c.zig b/test/translate_c.zig index 80ccdabc74..152928cbe9 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -962,6 +962,34 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ }; \\} ); + + cases.add("break statement", + \\void foo(void) { + \\ for (;;) { + \\ break; + \\ } + \\} + , + \\pub fn foo() { + \\ while (true) { + \\ break; + \\ }; + \\} + ); + + cases.add("continue statement", + \\void foo(void) { + \\ for (;;) { + \\ continue; + \\ } + \\} + , + \\pub fn foo() { + \\ while (true) { + \\ continue; + \\ }; + \\} + ); } From 48ebb65cc7a976599c0a2b9e6647ea057727cf21 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 16:34:08 -0500 Subject: [PATCH 20/34] add an assert to catch corrupted memory --- src/ir.cpp | 54 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/src/ir.cpp b/src/ir.cpp index 35e6b3f8c6..a58374dc8c 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -10351,30 +10351,40 @@ static IrInstruction *ir_get_var_ptr(IrAnalyze *ira, IrInstruction *instruction, bool is_const = (var->value->type->id == TypeTableEntryIdMetaType) ? is_const_ptr : var->src_is_const; bool is_volatile = (var->value->type->id == TypeTableEntryIdMetaType) ? is_volatile_ptr : false; - if (mem_slot && mem_slot->special != ConstValSpecialRuntime) { - ConstPtrMut ptr_mut; - if (comptime_var_mem) { - ptr_mut = ConstPtrMutComptimeVar; - } else if (var->gen_is_const) { - ptr_mut = ConstPtrMutComptimeConst; - } else { - assert(!comptime_var_mem); - ptr_mut = ConstPtrMutRuntimeVar; + if (mem_slot != nullptr) { + switch (mem_slot->special) { + case ConstValSpecialRuntime: + goto no_mem_slot; + case ConstValSpecialStatic: // fallthrough + case ConstValSpecialUndef: { + ConstPtrMut ptr_mut; + if (comptime_var_mem) { + ptr_mut = ConstPtrMutComptimeVar; + } else if (var->gen_is_const) { + ptr_mut = ConstPtrMutComptimeConst; + } else { + assert(!comptime_var_mem); + ptr_mut = ConstPtrMutRuntimeVar; + } + return ir_get_const_ptr(ira, instruction, mem_slot, var->value->type, + ptr_mut, is_const, is_volatile, var->align_bytes); + } } - return ir_get_const_ptr(ira, instruction, mem_slot, var->value->type, - ptr_mut, is_const, is_volatile, var->align_bytes); - } else { - IrInstruction *var_ptr_instruction = ir_build_var_ptr(&ira->new_irb, - instruction->scope, instruction->source_node, var, is_const, is_volatile); - var_ptr_instruction->value.type = get_pointer_to_type_extra(ira->codegen, var->value->type, - var->src_is_const, is_volatile, var->align_bytes, 0, 0); - type_ensure_zero_bits_known(ira->codegen, var->value->type); - - bool in_fn_scope = (scope_fn_entry(var->parent_scope) != nullptr); - var_ptr_instruction->value.data.rh_ptr = in_fn_scope ? RuntimeHintPtrStack : RuntimeHintPtrNonStack; - - return var_ptr_instruction; + zig_unreachable(); } + +no_mem_slot: + + IrInstruction *var_ptr_instruction = ir_build_var_ptr(&ira->new_irb, + instruction->scope, instruction->source_node, var, is_const, is_volatile); + var_ptr_instruction->value.type = get_pointer_to_type_extra(ira->codegen, var->value->type, + var->src_is_const, is_volatile, var->align_bytes, 0, 0); + type_ensure_zero_bits_known(ira->codegen, var->value->type); + + bool in_fn_scope = (scope_fn_entry(var->parent_scope) != nullptr); + var_ptr_instruction->value.data.rh_ptr = in_fn_scope ? RuntimeHintPtrStack : RuntimeHintPtrNonStack; + + return var_ptr_instruction; } static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *call_instruction, From a2afcae9ff1f2eed5c6a19a2bd63af288b9171ad Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 18:16:33 -0500 Subject: [PATCH 21/34] fix crash when constant inside comptime function has compile error closes #625 --- src/all_types.hpp | 3 +++ src/ir.cpp | 27 +++++++++++++++++++++++---- test/compile_errors.zig | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 2b09131bef..2fccf08e88 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -36,6 +36,7 @@ struct IrInstructionCast; struct IrBasicBlock; struct ScopeDecls; struct ZigWindowsSDK; +struct Tld; struct IrGotoItem { AstNode *source_node; @@ -59,7 +60,9 @@ struct IrExecutable { Buf *c_import_buf; AstNode *source_node; IrExecutable *parent_exec; + IrExecutable *source_exec; Scope *begin_scope; + ZigList tld_list; }; enum OutType { diff --git a/src/ir.cpp b/src/ir.cpp index a58374dc8c..f632a261f6 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -6332,6 +6332,9 @@ static IrInstruction *ir_gen_container_decl(IrBuilder *irb, Scope *parent_scope, } irb->codegen->resolve_queue.append(&tld_container->base); + // Add this to the list to mark as invalid if analyzing this exec fails. + irb->exec->tld_list.append(&tld_container->base); + return ir_build_const_type(irb, parent_scope, node, container_type); } @@ -6554,6 +6557,20 @@ static bool ir_goto_pass2(IrBuilder *irb) { return true; } +static void invalidate_exec(IrExecutable *exec) { + if (exec->invalid) + return; + + exec->invalid = true; + + for (size_t i = 0; i < exec->tld_list.length; i += 1) { + exec->tld_list.items[i]->resolution = TldResolutionInvalid; + } + + if (exec->source_exec != nullptr) + invalidate_exec(exec->source_exec); +} + bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_executable) { assert(node->owner); @@ -6577,7 +6594,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec } if (!ir_goto_pass2(irb)) { - irb->exec->invalid = true; + invalidate_exec(ir_executable); return false; } @@ -6603,7 +6620,7 @@ static void add_call_stack_errors(CodeGen *codegen, IrExecutable *exec, ErrorMsg } static ErrorMsg *exec_add_error_node(CodeGen *codegen, IrExecutable *exec, AstNode *source_node, Buf *msg) { - exec->invalid = true; + invalidate_exec(exec); ErrorMsg *err_msg = add_node_error(codegen, source_node, msg); if (exec->parent_exec) { add_call_stack_errors(codegen, exec, err_msg, 10); @@ -8056,6 +8073,7 @@ IrInstruction *ir_eval_const_value(CodeGen *codegen, Scope *scope, AstNode *node IrExecutable analyzed_executable = {0}; analyzed_executable.source_node = source_node; analyzed_executable.parent_exec = parent_exec; + analyzed_executable.source_exec = &ir_executable; analyzed_executable.name = exec_name; analyzed_executable.is_inline = true; analyzed_executable.fn_entry = fn_entry; @@ -10490,10 +10508,11 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type, ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota, fn_entry, nullptr, call_instruction->base.source_node, nullptr, ira->new_irb.exec); - if (type_is_invalid(result->value.type)) - return ira->codegen->builtin_types.entry_invalid; ira->codegen->memoized_fn_eval_table.put(exec_scope, result); + + if (type_is_invalid(result->value.type)) + return ira->codegen->builtin_types.entry_invalid; } ConstExprValue *out_val = ir_build_const_from(ira, &call_instruction->base); diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 9e15333750..3ef4a63e5f 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -2343,4 +2343,23 @@ pub fn addCases(cases: &tests.CompileErrorContext) { \\pub extern fn foo(format: &const u8, ...); , ".tmp_source.zig:2:9: error: expected type '&const u8', found '[5]u8'"); + + cases.add("constant inside comptime function has compile error", + \\const ContextAllocator = MemoryPool(usize); + \\ + \\pub fn MemoryPool(comptime T: type) -> type { + \\ const free_list_t = @compileError("aoeu"); + \\ + \\ struct { + \\ free_list: free_list_t, + \\ } + \\} + \\ + \\export fn entry() { + \\ var allocator: ContextAllocator = undefined; + \\} + , + ".tmp_source.zig:4:25: error: aoeu", + ".tmp_source.zig:1:36: note: called from here", + ".tmp_source.zig:12:20: note: referenced here"); } From df0e875856023e38af8b88dd94ef97f347d0a5e3 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 20:34:05 -0500 Subject: [PATCH 22/34] translate-c: introduce the concept of scopes in preparation to implement switch and solve variable name collisions --- src/translate_c.cpp | 635 +++++++++++++++++++++++++++---------------- test/translate_c.zig | 51 ++++ 2 files changed, 452 insertions(+), 234 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 00309e8bc1..4b86bc2dea 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -23,6 +23,8 @@ using namespace clang; +struct TransScope; + struct MacroSymbol { Buf *name; Buf *value; @@ -52,12 +54,59 @@ struct Context { ASTContext *ctx; HashMap ptr_params; + TransScope *child_scope; // TODO refactor out }; +enum ResultUsed { + ResultUsedNo, + ResultUsedYes, +}; + +enum TransLRValue { + TransLValue, + TransRValue, +}; + +enum TransScopeId { + TransScopeIdSwitch, + TransScopeIdVar, + TransScopeIdBlock, +}; + +struct TransScope { + TransScopeId id; + TransScope *parent; +}; + +struct TransScopeSwitch { + TransScope base; + AstNode *switch_node; +}; + +struct TransScopeVar { + TransScope base; + Buf *c_name; + Buf *zig_name; +}; + +struct TransScopeBlock { + TransScope base; + AstNode *node; +}; + +static AstNode *const skip_add_to_block_node = (AstNode *) 0x2; + +static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope); +static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name); +//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); + static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl); static AstNode *resolve_enum_decl(Context *c, const EnumDecl *enum_decl); static AstNode *resolve_typedef_decl(Context *c, const TypedefNameDecl *typedef_decl); +static AstNode *trans_stmt(Context *c, ResultUsed result_used, TransScope *scope, const Stmt *stmt, TransLRValue lrval); +static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &source_loc); + ATTRIBUTE_PRINTF(3, 4) static void emit_warning(Context *c, const SourceLocation &sl, const char *format, ...) { @@ -165,8 +214,8 @@ static AstNode *trans_create_node_bin_op(Context *c, AstNode *lhs_node, BinOpTyp return node; } -static AstNode *maybe_suppress_result(Context *c, bool result_used, AstNode *node) { - if (result_used) return node; +static AstNode *maybe_suppress_result(Context *c, ResultUsed result_used, AstNode *node) { + if (result_used == ResultUsedYes) return node; return trans_create_node_bin_op(c, trans_create_node_symbol_str(c, "_"), BinOpTypeAssign, @@ -341,8 +390,6 @@ static AstNode *trans_create_node_apint(Context *c, const llvm::APSInt &aps_int) } -static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &source_loc); - static QualType get_expr_qual_type(Context *c, const Expr *expr) { // String literals in C are `char *` but they should really be `const char *`. if (expr->getStmtClass() == Stmt::ImplicitCastExprClass) { @@ -573,16 +620,8 @@ static bool qual_type_has_wrapping_overflow(Context *c, QualType qt) { } } -enum TransLRValue { - TransLValue, - TransRValue, -}; - -static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *stmt, TransLRValue lrval); -static AstNode *const skip_add_to_block_node = (AstNode *) 0x2; - -static AstNode *trans_expr(Context *c, bool result_used, AstNode *block, Expr *expr, TransLRValue lrval) { - return trans_stmt(c, result_used, block, expr, lrval); +static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope, const Expr *expr, TransLRValue lrval) { + return trans_stmt(c, result_used, scope, expr, lrval); } static AstNode *trans_type(Context *c, const Type *ty, const SourceLocation &source_loc) { @@ -922,32 +961,33 @@ static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &s return trans_type(c, qt.getTypePtr(), source_loc); } -static AstNode *trans_compound_stmt(Context *c, AstNode *parent, CompoundStmt *stmt) { - AstNode *child_block = trans_create_node(c, NodeTypeBlock); - for (CompoundStmt::body_iterator it = stmt->body_begin(), end_it = stmt->body_end(); it != end_it; ++it) { - AstNode *child_node = trans_stmt(c, false, child_block, *it, TransRValue); +static AstNode *trans_compound_stmt(Context *c, TransScope *parent_scope, const CompoundStmt *stmt) { + TransScopeBlock *child_scope_block = trans_scope_block_create(c, parent_scope); + for (CompoundStmt::const_body_iterator it = stmt->body_begin(), end_it = stmt->body_end(); it != end_it; ++it) { + AstNode *child_node = trans_stmt(c, ResultUsedNo, &child_scope_block->base, *it, TransRValue); if (child_node == nullptr) return nullptr; if (child_node != skip_add_to_block_node) - child_block->data.block.statements.append(child_node); + child_scope_block->node->data.block.statements.append(child_node); } - return child_block; + c->child_scope = &child_scope_block->base; + return child_scope_block->node; } -static AstNode *trans_return_stmt(Context *c, AstNode *block, ReturnStmt *stmt) { - Expr *value_expr = stmt->getRetValue(); +static AstNode *trans_return_stmt(Context *c, TransScope *scope, const ReturnStmt *stmt) { + const Expr *value_expr = stmt->getRetValue(); if (value_expr == nullptr) { return trans_create_node(c, NodeTypeReturnExpr); } else { AstNode *return_node = trans_create_node(c, NodeTypeReturnExpr); - return_node->data.return_expr.expr = trans_expr(c, true, block, value_expr, TransRValue); + return_node->data.return_expr.expr = trans_expr(c, ResultUsedYes, scope, value_expr, TransRValue); if (return_node->data.return_expr.expr == nullptr) return nullptr; return return_node; } } -static AstNode *trans_integer_literal(Context *c, IntegerLiteral *stmt) { +static AstNode *trans_integer_literal(Context *c, const IntegerLiteral *stmt) { llvm::APSInt result; if (!stmt->EvaluateAsInt(result, *c->ctx)) { emit_warning(c, stmt->getLocStart(), "invalid integer literal"); @@ -956,54 +996,56 @@ static AstNode *trans_integer_literal(Context *c, IntegerLiteral *stmt) { return trans_create_node_apint(c, result); } -static AstNode *trans_conditional_operator(Context *c, bool result_used, AstNode *block, ConditionalOperator *stmt) { +static AstNode *trans_conditional_operator(Context *c, ResultUsed result_used, TransScope *scope, + const ConditionalOperator *stmt) +{ AstNode *node = trans_create_node(c, NodeTypeIfBoolExpr); Expr *cond_expr = stmt->getCond(); Expr *true_expr = stmt->getTrueExpr(); Expr *false_expr = stmt->getFalseExpr(); - node->data.if_bool_expr.condition = trans_expr(c, true, block, cond_expr, TransRValue); + node->data.if_bool_expr.condition = trans_expr(c, ResultUsedYes, scope, cond_expr, TransRValue); if (node->data.if_bool_expr.condition == nullptr) return nullptr; - node->data.if_bool_expr.then_block = trans_expr(c, result_used, block, true_expr, TransRValue); + node->data.if_bool_expr.then_block = trans_expr(c, result_used, scope, true_expr, TransRValue); if (node->data.if_bool_expr.then_block == nullptr) return nullptr; - node->data.if_bool_expr.else_node = trans_expr(c, result_used, block, false_expr, TransRValue); + node->data.if_bool_expr.else_node = trans_expr(c, result_used, scope, false_expr, TransRValue); if (node->data.if_bool_expr.else_node == nullptr) return nullptr; return maybe_suppress_result(c, result_used, node); } -static AstNode *trans_create_bin_op(Context *c, AstNode *block, Expr *lhs, BinOpType bin_op, Expr *rhs) { +static AstNode *trans_create_bin_op(Context *c, TransScope *scope, Expr *lhs, BinOpType bin_op, Expr *rhs) { AstNode *node = trans_create_node(c, NodeTypeBinOpExpr); node->data.bin_op_expr.bin_op = bin_op; - node->data.bin_op_expr.op1 = trans_expr(c, true, block, lhs, TransRValue); + node->data.bin_op_expr.op1 = trans_expr(c, ResultUsedYes, scope, lhs, TransRValue); if (node->data.bin_op_expr.op1 == nullptr) return nullptr; - node->data.bin_op_expr.op2 = trans_expr(c, true, block, rhs, TransRValue); + node->data.bin_op_expr.op2 = trans_expr(c, ResultUsedYes, scope, rhs, TransRValue); if (node->data.bin_op_expr.op2 == nullptr) return nullptr; return node; } -static AstNode *trans_create_assign(Context *c, bool result_used, AstNode *block, Expr *lhs, Expr *rhs) { - if (!result_used) { +static AstNode *trans_create_assign(Context *c, ResultUsed result_used, TransScope *scope, Expr *lhs, Expr *rhs) { + if (result_used == ResultUsedNo) { // common case AstNode *node = trans_create_node(c, NodeTypeBinOpExpr); node->data.bin_op_expr.bin_op = BinOpTypeAssign; - node->data.bin_op_expr.op1 = trans_expr(c, true, block, lhs, TransLValue); + node->data.bin_op_expr.op1 = trans_expr(c, ResultUsedYes, scope, lhs, TransLValue); if (node->data.bin_op_expr.op1 == nullptr) return nullptr; - node->data.bin_op_expr.op2 = trans_expr(c, true, block, rhs, TransRValue); + node->data.bin_op_expr.op2 = trans_expr(c, ResultUsedYes, scope, rhs, TransRValue); if (node->data.bin_op_expr.op2 == nullptr) return nullptr; @@ -1017,47 +1059,49 @@ static AstNode *trans_create_assign(Context *c, bool result_used, AstNode *block // zig: _tmp // zig: } - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, scope); // const _tmp = rhs; - AstNode *rhs_node = trans_expr(c, true, child_block, rhs, TransRValue); + AstNode *rhs_node = trans_expr(c, ResultUsedYes, &child_scope->base, rhs, TransRValue); if (rhs_node == nullptr) return nullptr; // TODO: avoid name collisions with generated variable names Buf* tmp_var_name = buf_create_from_str("_tmp"); AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, rhs_node); - child_block->data.block.statements.append(tmp_var_decl); + child_scope->node->data.block.statements.append(tmp_var_decl); // lhs = _tmp; - AstNode *lhs_node = trans_expr(c, true, child_block, lhs, TransLValue); + AstNode *lhs_node = trans_expr(c, ResultUsedYes, &child_scope->base, lhs, TransLValue); if (lhs_node == nullptr) return nullptr; - child_block->data.block.statements.append( + child_scope->node->data.block.statements.append( trans_create_node_bin_op(c, lhs_node, BinOpTypeAssign, trans_create_node_symbol(c, tmp_var_name))); // _tmp - child_block->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); - child_block->data.block.last_statement_is_result_expression = true; + child_scope->node->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); + child_scope->node->data.block.last_statement_is_result_expression = true; - return child_block; + return child_scope->node; } } -static AstNode *trans_create_shift_op(Context *c, AstNode *block, QualType result_type, Expr *lhs_expr, BinOpType bin_op, Expr *rhs_expr) { +static AstNode *trans_create_shift_op(Context *c, TransScope *scope, QualType result_type, + Expr *lhs_expr, BinOpType bin_op, Expr *rhs_expr) +{ const SourceLocation &rhs_location = rhs_expr->getLocStart(); AstNode *rhs_type = qual_type_to_log2_int_ref(c, result_type, rhs_location); // lhs >> u5(rh) - AstNode *lhs = trans_expr(c, true, block, lhs_expr, TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, scope, lhs_expr, TransLValue); if (lhs == nullptr) return nullptr; - AstNode *rhs = trans_expr(c, true, block, rhs_expr, TransRValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, scope, rhs_expr, TransRValue); if (rhs == nullptr) return nullptr; AstNode *coerced_rhs = trans_create_node_fn_call_1(c, rhs_type, rhs); return trans_create_node_bin_op(c, lhs, bin_op, coerced_rhs); } -static AstNode *trans_binary_operator(Context *c, bool result_used, AstNode *block, BinaryOperator *stmt) { +static AstNode *trans_binary_operator(Context *c, ResultUsed result_used, TransScope *scope, const BinaryOperator *stmt) { switch (stmt->getOpcode()) { case BO_PtrMemD: emit_warning(c, stmt->getLocStart(), "TODO handle more C binary operators: BO_PtrMemD"); @@ -1066,20 +1110,20 @@ static AstNode *trans_binary_operator(Context *c, bool result_used, AstNode *blo emit_warning(c, stmt->getLocStart(), "TODO handle more C binary operators: BO_PtrMemI"); return nullptr; case BO_Mul: - return trans_create_bin_op(c, block, stmt->getLHS(), + return trans_create_bin_op(c, scope, stmt->getLHS(), qual_type_has_wrapping_overflow(c, stmt->getType()) ? BinOpTypeMultWrap : BinOpTypeMult, stmt->getRHS()); case BO_Div: if (qual_type_has_wrapping_overflow(c, stmt->getType())) { // unsigned/float division uses the operator - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeDiv, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeDiv, stmt->getRHS()); } else { // signed integer division uses @divTrunc AstNode *fn_call = trans_create_node_builtin_fn_call_str(c, "divTrunc"); - AstNode *lhs = trans_expr(c, true, block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, scope, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; fn_call->data.fn_call_expr.params.append(lhs); - AstNode *rhs = trans_expr(c, true, block, stmt->getRHS(), TransLValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, scope, stmt->getRHS(), TransLValue); if (rhs == nullptr) return nullptr; fn_call->data.fn_call_expr.params.append(rhs); return fn_call; @@ -1087,66 +1131,70 @@ static AstNode *trans_binary_operator(Context *c, bool result_used, AstNode *blo case BO_Rem: if (qual_type_has_wrapping_overflow(c, stmt->getType())) { // unsigned/float division uses the operator - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeMod, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeMod, stmt->getRHS()); } else { // signed integer division uses @rem AstNode *fn_call = trans_create_node_builtin_fn_call_str(c, "rem"); - AstNode *lhs = trans_expr(c, true, block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, scope, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; fn_call->data.fn_call_expr.params.append(lhs); - AstNode *rhs = trans_expr(c, true, block, stmt->getRHS(), TransLValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, scope, stmt->getRHS(), TransLValue); if (rhs == nullptr) return nullptr; fn_call->data.fn_call_expr.params.append(rhs); return fn_call; } case BO_Add: - return trans_create_bin_op(c, block, stmt->getLHS(), + return trans_create_bin_op(c, scope, stmt->getLHS(), qual_type_has_wrapping_overflow(c, stmt->getType()) ? BinOpTypeAddWrap : BinOpTypeAdd, stmt->getRHS()); case BO_Sub: - return trans_create_bin_op(c, block, stmt->getLHS(), + return trans_create_bin_op(c, scope, stmt->getLHS(), qual_type_has_wrapping_overflow(c, stmt->getType()) ? BinOpTypeSubWrap : BinOpTypeSub, stmt->getRHS()); case BO_Shl: - return trans_create_shift_op(c, block, stmt->getType(), stmt->getLHS(), BinOpTypeBitShiftLeft, stmt->getRHS()); + return trans_create_shift_op(c, scope, stmt->getType(), stmt->getLHS(), BinOpTypeBitShiftLeft, stmt->getRHS()); case BO_Shr: - return trans_create_shift_op(c, block, stmt->getType(), stmt->getLHS(), BinOpTypeBitShiftRight, stmt->getRHS()); + return trans_create_shift_op(c, scope, stmt->getType(), stmt->getLHS(), BinOpTypeBitShiftRight, stmt->getRHS()); case BO_LT: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpLessThan, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpLessThan, stmt->getRHS()); case BO_GT: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpGreaterThan, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpGreaterThan, stmt->getRHS()); case BO_LE: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpLessOrEq, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpLessOrEq, stmt->getRHS()); case BO_GE: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpGreaterOrEq, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpGreaterOrEq, stmt->getRHS()); case BO_EQ: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpEq, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpEq, stmt->getRHS()); case BO_NE: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeCmpNotEq, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeCmpNotEq, stmt->getRHS()); case BO_And: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeBinAnd, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeBinAnd, stmt->getRHS()); case BO_Xor: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeBinXor, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeBinXor, stmt->getRHS()); case BO_Or: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeBinOr, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeBinOr, stmt->getRHS()); case BO_LAnd: - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeBoolAnd, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeBoolAnd, stmt->getRHS()); case BO_LOr: // TODO: int vs bool - return trans_create_bin_op(c, block, stmt->getLHS(), BinOpTypeBoolOr, stmt->getRHS()); + return trans_create_bin_op(c, scope, stmt->getLHS(), BinOpTypeBoolOr, stmt->getRHS()); case BO_Assign: - return trans_create_assign(c, result_used, block, stmt->getLHS(), stmt->getRHS()); + return trans_create_assign(c, result_used, scope, stmt->getLHS(), stmt->getRHS()); case BO_Comma: { - block = trans_create_node(c, NodeTypeBlock); - AstNode *lhs = trans_expr(c, false, block, stmt->getLHS(), TransRValue); - if (lhs == nullptr) return nullptr; - block->data.block.statements.append(maybe_suppress_result(c, false, lhs)); - AstNode *rhs = trans_expr(c, result_used, block, stmt->getRHS(), TransRValue); - if (rhs == nullptr) return nullptr; - block->data.block.statements.append(maybe_suppress_result(c, result_used, rhs)); - block->data.block.last_statement_is_result_expression = true; - return block; + TransScopeBlock *scope_block = trans_scope_block_create(c, scope); + AstNode *lhs = trans_expr(c, ResultUsedNo, &scope_block->base, stmt->getLHS(), TransRValue); + if (lhs == nullptr) + return nullptr; + scope_block->node->data.block.statements.append(maybe_suppress_result(c, ResultUsedNo, lhs)); + + AstNode *rhs = trans_expr(c, result_used, &scope_block->base, stmt->getRHS(), TransRValue); + if (rhs == nullptr) + return nullptr; + scope_block->node->data.block.statements.append(maybe_suppress_result(c, result_used, rhs)); + + scope_block->node->data.block.last_statement_is_result_expression = true; + return scope_block->node; } case BO_MulAssign: case BO_DivAssign: @@ -1164,18 +1212,20 @@ static AstNode *trans_binary_operator(Context *c, bool result_used, AstNode *blo zig_unreachable(); } -static AstNode *trans_create_compound_assign_shift(Context *c, bool result_used, AstNode *block, CompoundAssignOperator *stmt, BinOpType assign_op, BinOpType bin_op) { +static AstNode *trans_create_compound_assign_shift(Context *c, ResultUsed result_used, TransScope *scope, + const CompoundAssignOperator *stmt, BinOpType assign_op, BinOpType bin_op) +{ const SourceLocation &rhs_location = stmt->getRHS()->getLocStart(); AstNode *rhs_type = qual_type_to_log2_int_ref(c, stmt->getComputationLHSType(), rhs_location); bool use_intermediate_casts = stmt->getComputationLHSType().getTypePtr() != stmt->getComputationResultType().getTypePtr(); - if (!use_intermediate_casts && !result_used) { + if (!use_intermediate_casts && result_used == ResultUsedNo) { // simple common case, where the C and Zig are identical: // lhs >>= rhs - AstNode *lhs = trans_expr(c, true, block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, scope, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; - AstNode *rhs = trans_expr(c, true, block, stmt->getRHS(), TransRValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, scope, stmt->getRHS(), TransRValue); if (rhs == nullptr) return nullptr; AstNode *coerced_rhs = trans_create_node_fn_call_1(c, rhs_type, rhs); @@ -1190,20 +1240,20 @@ static AstNode *trans_create_compound_assign_shift(Context *c, bool result_used, // zig: } // where u5 is the appropriate type - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, scope); // const _ref = &lhs; - AstNode *lhs = trans_expr(c, true, child_block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, &child_scope->base, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; AstNode *addr_of_lhs = trans_create_node_addr_of(c, false, false, lhs); // TODO: avoid name collisions with generated variable names Buf* tmp_var_name = buf_create_from_str("_ref"); AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, addr_of_lhs); - child_block->data.block.statements.append(tmp_var_decl); + child_scope->node->data.block.statements.append(tmp_var_decl); // *_ref = result_type(operation_type(*_ref) >> u5(rhs)); - AstNode *rhs = trans_expr(c, true, child_block, stmt->getRHS(), TransRValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, &child_scope->base, stmt->getRHS(), TransRValue); if (rhs == nullptr) return nullptr; AstNode *coerced_rhs = trans_create_node_fn_call_1(c, rhs_type, rhs); @@ -1220,27 +1270,29 @@ static AstNode *trans_create_compound_assign_shift(Context *c, bool result_used, trans_create_node_symbol(c, tmp_var_name))), bin_op, coerced_rhs))); - child_block->data.block.statements.append(assign_statement); + child_scope->node->data.block.statements.append(assign_statement); - if (result_used) { + if (result_used == ResultUsedYes) { // *_ref - child_block->data.block.statements.append( + child_scope->node->data.block.statements.append( trans_create_node_prefix_op(c, PrefixOpDereference, trans_create_node_symbol(c, tmp_var_name))); - child_block->data.block.last_statement_is_result_expression = true; + child_scope->node->data.block.last_statement_is_result_expression = true; } - return child_block; + return child_scope->node; } } -static AstNode *trans_create_compound_assign(Context *c, bool result_used, AstNode *block, CompoundAssignOperator *stmt, BinOpType assign_op, BinOpType bin_op) { - if (!result_used) { +static AstNode *trans_create_compound_assign(Context *c, ResultUsed result_used, TransScope *scope, + const CompoundAssignOperator *stmt, BinOpType assign_op, BinOpType bin_op) +{ + if (result_used == ResultUsedNo) { // simple common case, where the C and Zig are identical: // lhs += rhs - AstNode *lhs = trans_expr(c, true, block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, scope, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; - AstNode *rhs = trans_expr(c, true, block, stmt->getRHS(), TransRValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, scope, stmt->getRHS(), TransRValue); if (rhs == nullptr) return nullptr; return trans_create_node_bin_op(c, lhs, assign_op, rhs); } else { @@ -1252,20 +1304,20 @@ static AstNode *trans_create_compound_assign(Context *c, bool result_used, AstNo // zig: *_ref // zig: } - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, scope); // const _ref = &lhs; - AstNode *lhs = trans_expr(c, true, child_block, stmt->getLHS(), TransLValue); + AstNode *lhs = trans_expr(c, ResultUsedYes, &child_scope->base, stmt->getLHS(), TransLValue); if (lhs == nullptr) return nullptr; AstNode *addr_of_lhs = trans_create_node_addr_of(c, false, false, lhs); // TODO: avoid name collisions with generated variable names Buf* tmp_var_name = buf_create_from_str("_ref"); AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, addr_of_lhs); - child_block->data.block.statements.append(tmp_var_decl); + child_scope->node->data.block.statements.append(tmp_var_decl); // *_ref = *_ref + rhs; - AstNode *rhs = trans_expr(c, true, child_block, stmt->getRHS(), TransRValue); + AstNode *rhs = trans_expr(c, ResultUsedYes, &child_scope->base, stmt->getRHS(), TransRValue); if (rhs == nullptr) return nullptr; AstNode *assign_statement = trans_create_node_bin_op(c, @@ -1277,26 +1329,28 @@ static AstNode *trans_create_compound_assign(Context *c, bool result_used, AstNo trans_create_node_symbol(c, tmp_var_name)), bin_op, rhs)); - child_block->data.block.statements.append(assign_statement); + child_scope->node->data.block.statements.append(assign_statement); // *_ref - child_block->data.block.statements.append( + child_scope->node->data.block.statements.append( trans_create_node_prefix_op(c, PrefixOpDereference, trans_create_node_symbol(c, tmp_var_name))); - child_block->data.block.last_statement_is_result_expression = true; + child_scope->node->data.block.last_statement_is_result_expression = true; - return child_block; + return child_scope->node; } } -static AstNode *trans_compound_assign_operator(Context *c, bool result_used, AstNode *block, CompoundAssignOperator *stmt) { +static AstNode *trans_compound_assign_operator(Context *c, ResultUsed result_used, TransScope *scope, + const CompoundAssignOperator *stmt) +{ switch (stmt->getOpcode()) { case BO_MulAssign: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignTimesWrap, BinOpTypeMultWrap); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignTimesWrap, BinOpTypeMultWrap); else - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignTimes, BinOpTypeMult); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignTimes, BinOpTypeMult); case BO_DivAssign: emit_warning(c, stmt->getLocStart(), "TODO handle more C compound assign operators: BO_DivAssign"); return nullptr; @@ -1305,24 +1359,24 @@ static AstNode *trans_compound_assign_operator(Context *c, bool result_used, Ast return nullptr; case BO_AddAssign: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignPlusWrap, BinOpTypeAddWrap); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignPlusWrap, BinOpTypeAddWrap); else - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignPlus, BinOpTypeAdd); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignPlus, BinOpTypeAdd); case BO_SubAssign: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignMinusWrap, BinOpTypeSubWrap); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignMinusWrap, BinOpTypeSubWrap); else - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignMinus, BinOpTypeSub); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignMinus, BinOpTypeSub); case BO_ShlAssign: - return trans_create_compound_assign_shift(c, result_used, block, stmt, BinOpTypeAssignBitShiftLeft, BinOpTypeBitShiftLeft); + return trans_create_compound_assign_shift(c, result_used, scope, stmt, BinOpTypeAssignBitShiftLeft, BinOpTypeBitShiftLeft); case BO_ShrAssign: - return trans_create_compound_assign_shift(c, result_used, block, stmt, BinOpTypeAssignBitShiftRight, BinOpTypeBitShiftRight); + return trans_create_compound_assign_shift(c, result_used, scope, stmt, BinOpTypeAssignBitShiftRight, BinOpTypeBitShiftRight); case BO_AndAssign: - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignBitAnd, BinOpTypeBinAnd); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignBitAnd, BinOpTypeBinAnd); case BO_XorAssign: - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignBitXor, BinOpTypeBinXor); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignBitXor, BinOpTypeBinXor); case BO_OrAssign: - return trans_create_compound_assign(c, result_used, block, stmt, BinOpTypeAssignBitOr, BinOpTypeBinOr); + return trans_create_compound_assign(c, result_used, scope, stmt, BinOpTypeAssignBitOr, BinOpTypeBinOr); case BO_PtrMemD: case BO_PtrMemI: case BO_Assign: @@ -1351,13 +1405,13 @@ static AstNode *trans_compound_assign_operator(Context *c, bool result_used, Ast zig_unreachable(); } -static AstNode *trans_implicit_cast_expr(Context *c, AstNode *block, ImplicitCastExpr *stmt) { +static AstNode *trans_implicit_cast_expr(Context *c, TransScope *scope, const ImplicitCastExpr *stmt) { switch (stmt->getCastKind()) { case CK_LValueToRValue: - return trans_expr(c, true, block, stmt->getSubExpr(), TransRValue); + return trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue); case CK_IntegralCast: { - AstNode *target_node = trans_expr(c, true, block, stmt->getSubExpr(), TransRValue); + AstNode *target_node = trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue); if (target_node == nullptr) return nullptr; return trans_c_cast(c, stmt->getExprLoc(), stmt->getType(), target_node); @@ -1365,14 +1419,14 @@ static AstNode *trans_implicit_cast_expr(Context *c, AstNode *block, ImplicitCas case CK_FunctionToPointerDecay: case CK_ArrayToPointerDecay: { - AstNode *target_node = trans_expr(c, true, block, stmt->getSubExpr(), TransRValue); + AstNode *target_node = trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue); if (target_node == nullptr) return nullptr; return target_node; } case CK_BitCast: { - AstNode *target_node = trans_expr(c, true, block, stmt->getSubExpr(), TransRValue); + AstNode *target_node = trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue); if (target_node == nullptr) return nullptr; @@ -1549,8 +1603,8 @@ static AstNode *trans_implicit_cast_expr(Context *c, AstNode *block, ImplicitCas zig_unreachable(); } -static AstNode *trans_decl_ref_expr(Context *c, DeclRefExpr *stmt, TransLRValue lrval) { - ValueDecl *value_decl = stmt->getDecl(); +static AstNode *trans_decl_ref_expr(Context *c, const DeclRefExpr *stmt, TransLRValue lrval) { + const ValueDecl *value_decl = stmt->getDecl(); Buf *symbol_name = buf_create_from_str(decl_name(value_decl)); if (lrval == TransLValue) { c->ptr_params.put(symbol_name, true); @@ -1558,15 +1612,17 @@ static AstNode *trans_decl_ref_expr(Context *c, DeclRefExpr *stmt, TransLRValue return trans_create_node_symbol(c, symbol_name); } -static AstNode *trans_create_post_crement(Context *c, bool result_used, AstNode *block, UnaryOperator *stmt, BinOpType assign_op) { +static AstNode *trans_create_post_crement(Context *c, ResultUsed result_used, TransScope *scope, + const UnaryOperator *stmt, BinOpType assign_op) +{ Expr *op_expr = stmt->getSubExpr(); - if (!result_used) { + if (result_used == ResultUsedNo) { // common case // c: expr++ // zig: expr += 1 return trans_create_node_bin_op(c, - trans_expr(c, true, block, op_expr, TransLValue), + trans_expr(c, ResultUsedYes, scope, op_expr, TransLValue), assign_op, trans_create_node_unsigned(c, 1)); } @@ -1578,23 +1634,23 @@ static AstNode *trans_create_post_crement(Context *c, bool result_used, AstNode // zig: *_ref += 1; // zig: _tmp // zig: } - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, scope); // const _ref = &expr; - AstNode *expr = trans_expr(c, true, child_block, op_expr, TransLValue); + AstNode *expr = trans_expr(c, ResultUsedYes, &child_scope->base, op_expr, TransLValue); if (expr == nullptr) return nullptr; AstNode *addr_of_expr = trans_create_node_addr_of(c, false, false, expr); // TODO: avoid name collisions with generated variable names Buf* ref_var_name = buf_create_from_str("_ref"); AstNode *ref_var_decl = trans_create_node_var_decl_local(c, true, ref_var_name, nullptr, addr_of_expr); - child_block->data.block.statements.append(ref_var_decl); + child_scope->node->data.block.statements.append(ref_var_decl); // const _tmp = *_ref; Buf* tmp_var_name = buf_create_from_str("_tmp"); AstNode *tmp_var_decl = trans_create_node_var_decl_local(c, true, tmp_var_name, nullptr, trans_create_node_prefix_op(c, PrefixOpDereference, trans_create_node_symbol(c, ref_var_name))); - child_block->data.block.statements.append(tmp_var_decl); + child_scope->node->data.block.statements.append(tmp_var_decl); // *_ref += 1; AstNode *assign_statement = trans_create_node_bin_op(c, @@ -1602,24 +1658,26 @@ static AstNode *trans_create_post_crement(Context *c, bool result_used, AstNode trans_create_node_symbol(c, ref_var_name)), assign_op, trans_create_node_unsigned(c, 1)); - child_block->data.block.statements.append(assign_statement); + child_scope->node->data.block.statements.append(assign_statement); // _tmp - child_block->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); - child_block->data.block.last_statement_is_result_expression = true; + child_scope->node->data.block.statements.append(trans_create_node_symbol(c, tmp_var_name)); + child_scope->node->data.block.last_statement_is_result_expression = true; - return child_block; + return child_scope->node; } -static AstNode *trans_create_pre_crement(Context *c, bool result_used, AstNode *block, UnaryOperator *stmt, BinOpType assign_op) { +static AstNode *trans_create_pre_crement(Context *c, ResultUsed result_used, TransScope *scope, + const UnaryOperator *stmt, BinOpType assign_op) +{ Expr *op_expr = stmt->getSubExpr(); - if (!result_used) { + if (result_used == ResultUsedNo) { // common case // c: ++expr // zig: expr += 1 return trans_create_node_bin_op(c, - trans_expr(c, true, block, op_expr, TransLValue), + trans_expr(c, ResultUsedYes, scope, op_expr, TransLValue), assign_op, trans_create_node_unsigned(c, 1)); } @@ -1630,16 +1688,16 @@ static AstNode *trans_create_pre_crement(Context *c, bool result_used, AstNode * // zig: *_ref += 1; // zig: *_ref // zig: } - AstNode *child_block = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, scope); // const _ref = &expr; - AstNode *expr = trans_expr(c, true, child_block, op_expr, TransLValue); + AstNode *expr = trans_expr(c, ResultUsedYes, &child_scope->base, op_expr, TransLValue); if (expr == nullptr) return nullptr; AstNode *addr_of_expr = trans_create_node_addr_of(c, false, false, expr); // TODO: avoid name collisions with generated variable names Buf* ref_var_name = buf_create_from_str("_ref"); AstNode *ref_var_decl = trans_create_node_var_decl_local(c, true, ref_var_name, nullptr, addr_of_expr); - child_block->data.block.statements.append(ref_var_decl); + child_scope->node->data.block.statements.append(ref_var_decl); // *_ref += 1; AstNode *assign_statement = trans_create_node_bin_op(c, @@ -1647,49 +1705,49 @@ static AstNode *trans_create_pre_crement(Context *c, bool result_used, AstNode * trans_create_node_symbol(c, ref_var_name)), assign_op, trans_create_node_unsigned(c, 1)); - child_block->data.block.statements.append(assign_statement); + child_scope->node->data.block.statements.append(assign_statement); // *_ref AstNode *deref_expr = trans_create_node_prefix_op(c, PrefixOpDereference, trans_create_node_symbol(c, ref_var_name)); - child_block->data.block.statements.append(deref_expr); - child_block->data.block.last_statement_is_result_expression = true; + child_scope->node->data.block.statements.append(deref_expr); + child_scope->node->data.block.last_statement_is_result_expression = true; - return child_block; + return child_scope->node; } -static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *block, UnaryOperator *stmt) { +static AstNode *trans_unary_operator(Context *c, ResultUsed result_used, TransScope *scope, const UnaryOperator *stmt) { switch (stmt->getOpcode()) { case UO_PostInc: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_post_crement(c, result_used, block, stmt, BinOpTypeAssignPlusWrap); + return trans_create_post_crement(c, result_used, scope, stmt, BinOpTypeAssignPlusWrap); else - return trans_create_post_crement(c, result_used, block, stmt, BinOpTypeAssignPlus); + return trans_create_post_crement(c, result_used, scope, stmt, BinOpTypeAssignPlus); case UO_PostDec: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_post_crement(c, result_used, block, stmt, BinOpTypeAssignMinusWrap); + return trans_create_post_crement(c, result_used, scope, stmt, BinOpTypeAssignMinusWrap); else - return trans_create_post_crement(c, result_used, block, stmt, BinOpTypeAssignMinus); + return trans_create_post_crement(c, result_used, scope, stmt, BinOpTypeAssignMinus); case UO_PreInc: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignPlusWrap); + return trans_create_pre_crement(c, result_used, scope, stmt, BinOpTypeAssignPlusWrap); else - return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignPlus); + return trans_create_pre_crement(c, result_used, scope, stmt, BinOpTypeAssignPlus); case UO_PreDec: if (qual_type_has_wrapping_overflow(c, stmt->getType())) - return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignMinusWrap); + return trans_create_pre_crement(c, result_used, scope, stmt, BinOpTypeAssignMinusWrap); else - return trans_create_pre_crement(c, result_used, block, stmt, BinOpTypeAssignMinus); + return trans_create_pre_crement(c, result_used, scope, stmt, BinOpTypeAssignMinus); case UO_AddrOf: { - AstNode *value_node = trans_expr(c, result_used, block, stmt->getSubExpr(), TransLValue); + AstNode *value_node = trans_expr(c, result_used, scope, stmt->getSubExpr(), TransLValue); if (value_node == nullptr) return value_node; return trans_create_node_addr_of(c, false, false, value_node); } case UO_Deref: { - AstNode *value_node = trans_expr(c, result_used, block, stmt->getSubExpr(), TransRValue); + AstNode *value_node = trans_expr(c, result_used, scope, stmt->getSubExpr(), TransRValue); if (value_node == nullptr) return nullptr; bool is_fn_ptr = qual_type_is_fn_ptr(c, stmt->getSubExpr()->getType()); @@ -1708,7 +1766,7 @@ static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *bloc AstNode *node = trans_create_node(c, NodeTypePrefixOpExpr); node->data.prefix_op_expr.prefix_op = PrefixOpNegation; - node->data.prefix_op_expr.primary_expr = trans_expr(c, true, block, op_expr, TransRValue); + node->data.prefix_op_expr.primary_expr = trans_expr(c, ResultUsedYes, scope, op_expr, TransRValue); if (node->data.prefix_op_expr.primary_expr == nullptr) return nullptr; @@ -1718,7 +1776,7 @@ static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *bloc AstNode *node = trans_create_node(c, NodeTypeBinOpExpr); node->data.bin_op_expr.op1 = trans_create_node_unsigned(c, 0); - node->data.bin_op_expr.op2 = trans_expr(c, true, block, op_expr, TransRValue); + node->data.bin_op_expr.op2 = trans_expr(c, ResultUsedYes, scope, op_expr, TransRValue); if (node->data.bin_op_expr.op2 == nullptr) return nullptr; @@ -1751,7 +1809,7 @@ static AstNode *trans_unary_operator(Context *c, bool result_used, AstNode *bloc zig_unreachable(); } -static AstNode *trans_local_declaration(Context *c, AstNode *block, DeclStmt *stmt) { +static AstNode *trans_local_declaration(Context *c, TransScope *scope, const DeclStmt *stmt) { for (auto iter = stmt->decl_begin(); iter != stmt->decl_end(); iter++) { Decl *decl = *iter; switch (decl->getKind()) { @@ -1760,7 +1818,7 @@ static AstNode *trans_local_declaration(Context *c, AstNode *block, DeclStmt *st QualType qual_type = var_decl->getTypeSourceInfo()->getType(); AstNode *init_node = nullptr; if (var_decl->hasInit()) { - init_node = trans_expr(c, true, block, var_decl->getInit(), TransRValue); + init_node = trans_expr(c, ResultUsedYes, scope, var_decl->getInit(), TransRValue); if (init_node == nullptr) return nullptr; @@ -1773,7 +1831,10 @@ static AstNode *trans_local_declaration(Context *c, AstNode *block, DeclStmt *st AstNode *node = trans_create_node_var_decl_local(c, qual_type.isConstQualified(), symbol_name, type_node, init_node); - block->data.block.statements.append(node); + + assert(scope->id == TransScopeIdBlock); + TransScopeBlock *scope_block = (TransScopeBlock *)scope; + scope_block->node->data.block.statements.append(node); continue; } case Decl::AccessSpec: @@ -2000,37 +2061,37 @@ static AstNode *trans_local_declaration(Context *c, AstNode *block, DeclStmt *st return skip_add_to_block_node; } -static AstNode *trans_while_loop(Context *c, AstNode *block, WhileStmt *stmt) { +static AstNode *trans_while_loop(Context *c, TransScope *scope, const WhileStmt *stmt) { AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); - while_node->data.while_expr.condition = trans_expr(c, true, block, stmt->getCond(), TransRValue); + while_node->data.while_expr.condition = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); if (while_node->data.while_expr.condition == nullptr) return nullptr; - while_node->data.while_expr.body = trans_stmt(c, false, block, stmt->getBody(), TransRValue); + while_node->data.while_expr.body = trans_stmt(c, ResultUsedNo, scope, stmt->getBody(), TransRValue); if (while_node->data.while_expr.body == nullptr) return nullptr; return while_node; } -static AstNode *trans_if_statement(Context *c, AstNode *block, IfStmt *stmt) { +static AstNode *trans_if_statement(Context *c, TransScope *scope, const IfStmt *stmt) { // if (c) t // if (c) t else e AstNode *if_node = trans_create_node(c, NodeTypeIfBoolExpr); // TODO: condition != 0 - AstNode *condition_node = trans_expr(c, true, block, stmt->getCond(), TransRValue); + AstNode *condition_node = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); if (condition_node == nullptr) return nullptr; if_node->data.if_bool_expr.condition = condition_node; - if_node->data.if_bool_expr.then_block = trans_stmt(c, false, block, stmt->getThen(), TransRValue); + if_node->data.if_bool_expr.then_block = trans_stmt(c, ResultUsedNo, scope, stmt->getThen(), TransRValue); if (if_node->data.if_bool_expr.then_block == nullptr) return nullptr; if (stmt->getElse() != nullptr) { - if_node->data.if_bool_expr.else_node = trans_stmt(c, false, block, stmt->getElse(), TransRValue); + if_node->data.if_bool_expr.else_node = trans_stmt(c, ResultUsedNo, scope, stmt->getElse(), TransRValue); if (if_node->data.if_bool_expr.else_node == nullptr) return nullptr; } @@ -2038,10 +2099,10 @@ static AstNode *trans_if_statement(Context *c, AstNode *block, IfStmt *stmt) { return if_node; } -static AstNode *trans_call_expr(Context *c, bool result_used, AstNode *block, CallExpr *stmt) { +static AstNode *trans_call_expr(Context *c, ResultUsed result_used, TransScope *scope, const CallExpr *stmt) { AstNode *node = trans_create_node(c, NodeTypeFnCallExpr); - AstNode *callee_raw_node = trans_expr(c, true, block, stmt->getCallee(), TransRValue); + AstNode *callee_raw_node = trans_expr(c, ResultUsedYes, scope, stmt->getCallee(), TransRValue); if (callee_raw_node == nullptr) return nullptr; @@ -2055,9 +2116,9 @@ static AstNode *trans_call_expr(Context *c, bool result_used, AstNode *block, Ca node->data.fn_call_expr.fn_ref_expr = callee_node; unsigned num_args = stmt->getNumArgs(); - Expr **args = stmt->getArgs(); + const Expr * const* args = stmt->getArgs(); for (unsigned i = 0; i < num_args; i += 1) { - AstNode *arg_node = trans_expr(c, true, block, args[i], TransRValue); + AstNode *arg_node = trans_expr(c, ResultUsedYes, scope, args[i], TransRValue); if (arg_node == nullptr) return nullptr; @@ -2067,8 +2128,8 @@ static AstNode *trans_call_expr(Context *c, bool result_used, AstNode *block, Ca return node; } -static AstNode *trans_member_expr(Context *c, AstNode *block, MemberExpr *stmt) { - AstNode *container_node = trans_expr(c, true, block, stmt->getBase(), TransRValue); +static AstNode *trans_member_expr(Context *c, TransScope *scope, const MemberExpr *stmt) { + AstNode *container_node = trans_expr(c, ResultUsedYes, scope, stmt->getBase(), TransRValue); if (container_node == nullptr) return nullptr; @@ -2082,12 +2143,12 @@ static AstNode *trans_member_expr(Context *c, AstNode *block, MemberExpr *stmt) return node; } -static AstNode *trans_array_subscript_expr(Context *c, AstNode *block, ArraySubscriptExpr *stmt) { - AstNode *container_node = trans_expr(c, true, block, stmt->getBase(), TransRValue); +static AstNode *trans_array_subscript_expr(Context *c, TransScope *scope, const ArraySubscriptExpr *stmt) { + AstNode *container_node = trans_expr(c, ResultUsedYes, scope, stmt->getBase(), TransRValue); if (container_node == nullptr) return nullptr; - AstNode *idx_node = trans_expr(c, true, block, stmt->getIdx(), TransRValue); + AstNode *idx_node = trans_expr(c, ResultUsedYes, scope, stmt->getIdx(), TransRValue); if (idx_node == nullptr) return nullptr; @@ -2098,17 +2159,19 @@ static AstNode *trans_array_subscript_expr(Context *c, AstNode *block, ArraySubs return node; } -static AstNode *trans_c_style_cast_expr(Context *c, bool result_used, AstNode *block, - CStyleCastExpr *stmt, TransLRValue lrvalue) +static AstNode *trans_c_style_cast_expr(Context *c, ResultUsed result_used, TransScope *scope, + const CStyleCastExpr *stmt, TransLRValue lrvalue) { - AstNode *sub_expr_node = trans_expr(c, result_used, block, stmt->getSubExpr(), lrvalue); + AstNode *sub_expr_node = trans_expr(c, result_used, scope, stmt->getSubExpr(), lrvalue); if (sub_expr_node == nullptr) return nullptr; return trans_c_cast(c, stmt->getLocStart(), stmt->getType(), sub_expr_node); } -static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, AstNode *block, UnaryExprOrTypeTraitExpr *stmt) { +static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, TransScope *scope, + const UnaryExprOrTypeTraitExpr *stmt) +{ AstNode *type_node = trans_qual_type(c, stmt->getTypeOfArgument(), stmt->getLocStart()); if (type_node == nullptr) return nullptr; @@ -2118,7 +2181,7 @@ static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, AstNode *block, return node; } -static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { +static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt *stmt) { AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); @@ -2126,6 +2189,7 @@ static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { while_node->data.while_expr.condition = true_node; AstNode *body_node; + TransScope *child_scope; if (stmt->getBody()->getStmtClass() == Stmt::CompoundStmtClass) { // there's already a block in C, so we'll append our condition to it. // c: do { @@ -2137,9 +2201,10 @@ static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { // zig: b; // zig: if (!cond) break; // zig: } - body_node = trans_stmt(c, false, block, stmt->getBody(), TransRValue); + body_node = trans_stmt(c, ResultUsedNo, parent_scope, stmt->getBody(), TransRValue); if (body_node == nullptr) return nullptr; assert(body_node->type == NodeTypeBlock); + child_scope = c->child_scope; } else { // the C statement is without a block, so we need to create a block to contain it. // c: do @@ -2149,63 +2214,103 @@ static AstNode *trans_do_loop(Context *c, AstNode *block, DoStmt *stmt) { // zig: a; // zig: if (!cond) break; // zig: } - body_node = trans_create_node(c, NodeTypeBlock); - AstNode *child_statement = trans_stmt(c, false, body_node, stmt->getBody(), TransRValue); + TransScopeBlock *child_block_scope = trans_scope_block_create(c, parent_scope); + body_node = child_block_scope->node; + child_scope = &child_block_scope->base; + AstNode *child_statement = trans_stmt(c, ResultUsedNo, child_scope, stmt->getBody(), TransRValue); if (child_statement == nullptr) return nullptr; - body_node->data.block.statements.append(child_statement); + child_block_scope->node->data.block.statements.append(child_statement); } // if (!cond) break; - AstNode *condition_node = trans_expr(c, true, body_node, stmt->getCond(), TransRValue); + AstNode *condition_node = trans_expr(c, ResultUsedYes, child_scope, stmt->getCond(), TransRValue); if (condition_node == nullptr) return nullptr; AstNode *terminator_node = trans_create_node(c, NodeTypeIfBoolExpr); terminator_node->data.if_bool_expr.condition = trans_create_node_prefix_op(c, PrefixOpBoolNot, condition_node); terminator_node->data.if_bool_expr.then_block = trans_create_node(c, NodeTypeBreak); - body_node->data.block.statements.append(terminator_node); + + assert(child_scope->id == TransScopeIdBlock); + TransScopeBlock *child_block_scope = (TransScopeBlock *)child_scope; + + child_block_scope->node->data.block.statements.append(terminator_node); while_node->data.while_expr.body = body_node; return while_node; } -static AstNode *trans_for_loop(Context *c, AstNode *block, ForStmt *stmt) { +//static AstNode *trans_switch_stmt(Context *c, TransScope *scope, const SwitchStmt *stmt) { +// AstNode *switch_block_node = trans_create_node(c, NodeTypeBlock); +// AstNode *switch_node = trans_create_node(c, NodeTypeSwitchExpr); +// const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); +// if (var_decl_stmt != nullptr) { +// AstNode *vars_node = trans_stmt(c, ResultUsedNo, switch_block_node, var_decl_stmt, TransRValue); +// if (vars_node == nullptr) +// return nullptr; +// if (vars_node != skip_add_to_block_node) +// switch_block_node->data.block.statements.append(vars_node); +// } +// switch_block_node->data.block.statements.append(switch_node); +// +// const Expr *cond_expr = stmt->getCond(); +// assert(cond_expr != nullptr); +// +// AstNode *expr_node = trans_expr(c, ResultUsedYes, switch_block_node, cond_expr, TransRValue); +// if (expr_node == nullptr) +// return nullptr; +// switch_node->data.switch_expr.expr = expr_node; +// +// AstNode *body_node = trans_stmt(c, ResultUsedNo, switch_block_node, stmt->getBody(), TransRValue); +// if (body_node == nullptr) +// return nullptr; +// if (body_node != skip_add_to_block_node) +// switch_block_node->data.block.statements.append(body_node); +// +// return switch_block_node; +//} + +static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForStmt *stmt) { AstNode *loop_block_node; + TransScope *condition_scope; AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); - Stmt *init_stmt = stmt->getInit(); + const Stmt *init_stmt = stmt->getInit(); if (init_stmt == nullptr) { loop_block_node = while_node; + condition_scope = parent_scope; } else { - loop_block_node = trans_create_node(c, NodeTypeBlock); + TransScopeBlock *child_scope = trans_scope_block_create(c, parent_scope); + loop_block_node = child_scope->node; + condition_scope = &child_scope->base; - AstNode *vars_node = trans_stmt(c, false, loop_block_node, init_stmt, TransRValue); + AstNode *vars_node = trans_stmt(c, ResultUsedNo, &child_scope->base, init_stmt, TransRValue); if (vars_node == nullptr) return nullptr; if (vars_node != skip_add_to_block_node) - loop_block_node->data.block.statements.append(vars_node); + child_scope->node->data.block.statements.append(vars_node); - loop_block_node->data.block.statements.append(while_node); + child_scope->node->data.block.statements.append(while_node); } - Stmt *cond_stmt = stmt->getCond(); + const Stmt *cond_stmt = stmt->getCond(); if (cond_stmt == nullptr) { AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); true_node->data.bool_literal.value = true; while_node->data.while_expr.condition = true_node; } else { - while_node->data.while_expr.condition = trans_stmt(c, false, loop_block_node, cond_stmt, TransRValue); + while_node->data.while_expr.condition = trans_stmt(c, ResultUsedNo, condition_scope, cond_stmt, TransRValue); if (while_node->data.while_expr.condition == nullptr) return nullptr; } - Stmt *inc_stmt = stmt->getInc(); + const Stmt *inc_stmt = stmt->getInc(); if (inc_stmt != nullptr) { - AstNode *inc_node = trans_stmt(c, false, loop_block_node, inc_stmt, TransRValue); + AstNode *inc_node = trans_stmt(c, ResultUsedNo, condition_scope, inc_stmt, TransRValue); if (inc_node == nullptr) return nullptr; while_node->data.while_expr.continue_expr = inc_node; } - AstNode *child_statement = trans_stmt(c, false, loop_block_node, stmt->getBody(), TransRValue); + AstNode *child_statement = trans_stmt(c, ResultUsedNo, condition_scope, stmt->getBody(), TransRValue); if (child_statement == nullptr) return nullptr; while_node->data.while_expr.body = child_statement; @@ -2213,7 +2318,7 @@ static AstNode *trans_for_loop(Context *c, AstNode *block, ForStmt *stmt) { return loop_block_node; } -static AstNode *trans_string_literal(Context *c, AstNode *block, StringLiteral *stmt) { +static AstNode *trans_string_literal(Context *c, TransScope *scope, const StringLiteral *stmt) { switch (stmt->getKind()) { case StringLiteral::Ascii: case StringLiteral::UTF8: @@ -2231,72 +2336,75 @@ static AstNode *trans_string_literal(Context *c, AstNode *block, StringLiteral * zig_unreachable(); } -static AstNode *trans_break_stmt(Context *c, AstNode *block, BreakStmt *stmt) { +static AstNode *trans_break_stmt(Context *c, TransScope *scope, const BreakStmt *stmt) { return trans_create_node(c, NodeTypeBreak); } -static AstNode *trans_continue_stmt(Context *c, AstNode *block, ContinueStmt *stmt) { +static AstNode *trans_continue_stmt(Context *c, TransScope *scope, const ContinueStmt *stmt) { return trans_create_node(c, NodeTypeContinue); } -static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *stmt, TransLRValue lrvalue) { +static AstNode *trans_stmt(Context *c, ResultUsed result_used, TransScope *scope, const Stmt *stmt, TransLRValue lrvalue) { + c->child_scope = scope; // TODO refactor out Stmt::StmtClass sc = stmt->getStmtClass(); switch (sc) { case Stmt::ReturnStmtClass: - return trans_return_stmt(c, block, (ReturnStmt *)stmt); + return trans_return_stmt(c, scope, (const ReturnStmt *)stmt); case Stmt::CompoundStmtClass: - return trans_compound_stmt(c, block, (CompoundStmt *)stmt); + return trans_compound_stmt(c, scope, (const CompoundStmt *)stmt); case Stmt::IntegerLiteralClass: - return trans_integer_literal(c, (IntegerLiteral *)stmt); + return trans_integer_literal(c, (const IntegerLiteral *)stmt); case Stmt::ConditionalOperatorClass: - return trans_conditional_operator(c, result_used, block, (ConditionalOperator *)stmt); + return trans_conditional_operator(c, result_used, scope, (const ConditionalOperator *)stmt); case Stmt::BinaryOperatorClass: - return trans_binary_operator(c, result_used, block, (BinaryOperator *)stmt); + return trans_binary_operator(c, result_used, scope, (const BinaryOperator *)stmt); case Stmt::CompoundAssignOperatorClass: - return trans_compound_assign_operator(c, result_used, block, (CompoundAssignOperator *)stmt); + return trans_compound_assign_operator(c, result_used, scope, (const CompoundAssignOperator *)stmt); case Stmt::ImplicitCastExprClass: - return trans_implicit_cast_expr(c, block, (ImplicitCastExpr *)stmt); + return trans_implicit_cast_expr(c, scope, (const ImplicitCastExpr *)stmt); case Stmt::DeclRefExprClass: - return trans_decl_ref_expr(c, (DeclRefExpr *)stmt, lrvalue); + return trans_decl_ref_expr(c, (const DeclRefExpr *)stmt, lrvalue); case Stmt::UnaryOperatorClass: - return trans_unary_operator(c, result_used, block, (UnaryOperator *)stmt); + return trans_unary_operator(c, result_used, scope, (const UnaryOperator *)stmt); case Stmt::DeclStmtClass: - return trans_local_declaration(c, block, (DeclStmt *)stmt); + return trans_local_declaration(c, scope, (const DeclStmt *)stmt); case Stmt::WhileStmtClass: - return trans_while_loop(c, block, (WhileStmt *)stmt); + return trans_while_loop(c, scope, (const WhileStmt *)stmt); case Stmt::IfStmtClass: - return trans_if_statement(c, block, (IfStmt *)stmt); + return trans_if_statement(c, scope, (const IfStmt *)stmt); case Stmt::CallExprClass: - return trans_call_expr(c, result_used, block, (CallExpr *)stmt); + return trans_call_expr(c, result_used, scope, (const CallExpr *)stmt); case Stmt::NullStmtClass: return skip_add_to_block_node; case Stmt::MemberExprClass: - return trans_member_expr(c, block, (MemberExpr *)stmt); + return trans_member_expr(c, scope, (const MemberExpr *)stmt); case Stmt::ArraySubscriptExprClass: - return trans_array_subscript_expr(c, block, (ArraySubscriptExpr *)stmt); + return trans_array_subscript_expr(c, scope, (const ArraySubscriptExpr *)stmt); case Stmt::CStyleCastExprClass: - return trans_c_style_cast_expr(c, result_used, block, (CStyleCastExpr *)stmt, lrvalue); + return trans_c_style_cast_expr(c, result_used, scope, (const CStyleCastExpr *)stmt, lrvalue); case Stmt::UnaryExprOrTypeTraitExprClass: - return trans_unary_expr_or_type_trait_expr(c, block, (UnaryExprOrTypeTraitExpr *)stmt); + return trans_unary_expr_or_type_trait_expr(c, scope, (const UnaryExprOrTypeTraitExpr *)stmt); case Stmt::DoStmtClass: - return trans_do_loop(c, block, (DoStmt *)stmt); + return trans_do_loop(c, scope, (const DoStmt *)stmt); case Stmt::ForStmtClass: - return trans_for_loop(c, block, (ForStmt *)stmt); + return trans_for_loop(c, scope, (const ForStmt *)stmt); case Stmt::StringLiteralClass: - return trans_string_literal(c, block, (StringLiteral *)stmt); + return trans_string_literal(c, scope, (const StringLiteral *)stmt); case Stmt::BreakStmtClass: - return trans_break_stmt(c, block, (BreakStmt *)stmt); + return trans_break_stmt(c, scope, (const BreakStmt *)stmt); case Stmt::ContinueStmtClass: - return trans_continue_stmt(c, block, (ContinueStmt *)stmt); + return trans_continue_stmt(c, scope, (const ContinueStmt *)stmt); +// case Stmt::SwitchStmtClass: +// return trans_switch_stmt(c, scope, (const SwitchStmt *)stmt); + case Stmt::SwitchStmtClass: + emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass"); + return nullptr; case Stmt::CaseStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); return nullptr; case Stmt::DefaultStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass"); return nullptr; - case Stmt::SwitchStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass"); - return nullptr; case Stmt::NoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass"); return nullptr; @@ -2584,7 +2692,7 @@ static AstNode *trans_stmt(Context *c, bool result_used, AstNode *block, Stmt *s emit_warning(c, stmt->getLocStart(), "TODO handle C PackExpansionExprClass"); return nullptr; case Stmt::ParenExprClass: - return trans_expr(c, result_used, block, ((ParenExpr*)stmt)->getSubExpr(), lrvalue); + return trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue); case Stmt::ParenListExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ParenListExprClass"); return nullptr; @@ -2838,10 +2946,13 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { return; } + TransScope *scope = nullptr; + for (size_t i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { AstNode *param_node = proto_node->data.fn_proto.params.at(i); const ParmVarDecl *param = fn_decl->getParamDecl(i); const char *name = decl_name(param); + Buf *proto_param_name; if (strlen(name) != 0) { proto_param_name = buf_create_from_str(name); @@ -2851,7 +2962,11 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { proto_param_name = buf_sprintf("arg%" ZIG_PRI_usize "", i); } } - param_node->data.param_decl.name = proto_param_name; + + TransScopeVar *scope_var = trans_scope_var_create(c, scope, proto_param_name); + scope = &scope_var->base; + + param_node->data.param_decl.name = scope_var->zig_name; } if (!fn_decl->hasBody()) { @@ -2863,7 +2978,7 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { // actual function definition with body c->ptr_params.clear(); Stmt *body = fn_decl->getBody(); - AstNode *actual_body_node = trans_stmt(c, false, nullptr, body, TransRValue); + AstNode *actual_body_node = trans_stmt(c, ResultUsedNo, scope, body, TransRValue); assert(actual_body_node != skip_add_to_block_node); if (actual_body_node == nullptr) { emit_warning(c, fn_decl->getLocation(), "unable to translate function"); @@ -3301,14 +3416,66 @@ static bool decl_visitor(void *context, const Decl *decl) { return true; } -static bool name_exists(Context *c, Buf *name) { +static bool name_exists_global(Context *c, Buf *name) { return get_global(c, name) != nullptr; } +static bool name_exists_scope(Context *c, Buf *name, TransScope *scope) { + while (scope != nullptr) { + if (scope->id == TransScopeIdVar) { + TransScopeVar *var_scope = (TransScopeVar *)scope; + if (buf_eql_buf(name, var_scope->zig_name)) { + return true; + } + } + scope = scope->parent; + } + return name_exists_global(c, name); +} + +static Buf *get_unique_name(Context *c, Buf *name, TransScope *scope) { + Buf *proposed_name = name; + int count = 0; + while (name_exists_scope(c, proposed_name, scope)) { + if (proposed_name == name) { + proposed_name = buf_alloc(); + } + buf_resize(proposed_name, 0); + buf_appendf(proposed_name, "%s_%d", buf_ptr(name), count); + count += 1; + } + return proposed_name; +} + +static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope) { + TransScopeBlock *result = allocate(1); + result->base.id = TransScopeIdBlock; + result->base.parent = parent_scope; + result->node = trans_create_node(c, NodeTypeBlock); + return result; +} + +static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name) { + TransScopeVar *result = allocate(1); + result->base.id = TransScopeIdVar; + result->base.parent = parent_scope; + result->c_name = wanted_name; + result->zig_name = get_unique_name(c, wanted_name, parent_scope); + return result; +} + +//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) { +// TransScopeSwitch *result = allocate(1); +// result->base.id = TransScopeIdSwitch; +// result->base.parent = parent_scope; +// result->switch_node = trans_create_node(c, NodeTypeSwitchExpr); +// return result; +//} + static void render_aliases(Context *c) { for (size_t i = 0; i < c->aliases.length; i += 1) { Alias *alias = &c->aliases.at(i); - if (name_exists(c, alias->new_name)) + if (name_exists_global(c, alias->new_name)) continue; add_global_var(c, alias->new_name, trans_create_node_symbol(c, alias->canon_name)); @@ -3438,7 +3605,7 @@ static void process_symbol_macros(Context *c) { // Check if this macro aliases another top level declaration AstNode *existing_node = get_global(c, ms.value); - if (!existing_node || name_exists(c, ms.name)) + if (!existing_node || name_exists_global(c, ms.name)) continue; // If a macro aliases a global variable which is a function pointer, we conclude that @@ -3487,7 +3654,7 @@ static void process_preprocessor_entities(Context *c, ASTUnit &unit) { continue; } Buf *name = buf_create_from_str(raw_name); - if (name_exists(c, name)) { + if (name_exists_global(c, name)) { continue; } diff --git a/test/translate_c.zig b/test/translate_c.zig index 152928cbe9..b957ac4b05 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -963,6 +963,16 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\} ); + cases.add("empty for loop", + \\void foo(void) { + \\ for (;;) { } + \\} + , + \\pub fn foo() { + \\ while (true) {}; + \\} + ); + cases.add("break statement", \\void foo(void) { \\ for (;;) { @@ -990,6 +1000,47 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ }; \\} ); + + //cases.add("switch statement", + // \\int foo(int x) { + // \\ switch (x) { + // \\ case 1: + // \\ x += 1; + // \\ case 2: + // \\ break; + // \\ case 3: + // \\ case 4: + // \\ return x + 1; + // \\ default: + // \\ return 10; + // \\ } + // \\ return x + 13; + // \\} + //, + // \\fn foo(_x: i32) -> i32 { + // \\ var x = _x; + // \\ switch (x) { + // \\ 1 => goto switch_case_1; + // \\ 2 => goto switch_case_2; + // \\ 3 => goto switch_case_3; + // \\ 4 => goto switch_case_4; + // \\ else => goto switch_default; + // \\ } + // \\switch_case_1: + // \\ x += 1; + // \\ goto switch_case_2; + // \\switch_case_2: + // \\ goto switch_end; + // \\switch_case_3: + // \\ goto switch_case_4; + // \\switch_case_4: + // \\ return x += 1; + // \\switch_default: + // \\ return 10; + // \\switch_end: + // \\ return x + 13; + // \\} + //); } From 687e3592919940bc3c77a22adbee065fb5c1bc7e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 25 Nov 2017 22:16:50 -0500 Subject: [PATCH 23/34] translate-c: avoid global state and introduce var decl scopes in preparation to implement switch and solve variable name collisions --- src/translate_c.cpp | 767 +++++++++++++++++++++++++------------------- 1 file changed, 433 insertions(+), 334 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 4b86bc2dea..9e6676063d 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -23,8 +23,6 @@ using namespace clang; -struct TransScope; - struct MacroSymbol { Buf *name; Buf *value; @@ -54,7 +52,6 @@ struct Context { ASTContext *ctx; HashMap ptr_params; - TransScope *child_scope; // TODO refactor out }; enum ResultUsed { @@ -71,6 +68,7 @@ enum TransScopeId { TransScopeIdSwitch, TransScopeIdVar, TransScopeIdBlock, + TransScopeIdRoot, }; struct TransScope { @@ -94,17 +92,27 @@ struct TransScopeBlock { AstNode *node; }; -static AstNode *const skip_add_to_block_node = (AstNode *) 0x2; +struct TransScopeRoot { + TransScope base; +}; +static TransScopeRoot *trans_scope_root_create(Context *c); static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope); static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name); //static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); +static TransScopeBlock *trans_scope_block_find(TransScope *scope); + static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl); static AstNode *resolve_enum_decl(Context *c, const EnumDecl *enum_decl); static AstNode *resolve_typedef_decl(Context *c, const TypedefNameDecl *typedef_decl); -static AstNode *trans_stmt(Context *c, ResultUsed result_used, TransScope *scope, const Stmt *stmt, TransLRValue lrval); +static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt, + ResultUsed result_used, TransLRValue lrval, + AstNode **out_node, TransScope **out_child_scope, + TransScope **out_node_scope); +static TransScope *trans_stmt(Context *c, TransScope *scope, const Stmt *stmt, AstNode **out_node); +static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope, const Expr *expr, TransLRValue lrval); static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &source_loc); @@ -620,10 +628,6 @@ static bool qual_type_has_wrapping_overflow(Context *c, QualType qt) { } } -static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope, const Expr *expr, TransLRValue lrval) { - return trans_stmt(c, result_used, scope, expr, lrval); -} - static AstNode *trans_type(Context *c, const Type *ty, const SourceLocation &source_loc) { switch (ty->getTypeClass()) { case Type::Builtin: @@ -961,16 +965,22 @@ static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &s return trans_type(c, qt.getTypePtr(), source_loc); } -static AstNode *trans_compound_stmt(Context *c, TransScope *parent_scope, const CompoundStmt *stmt) { - TransScopeBlock *child_scope_block = trans_scope_block_create(c, parent_scope); +static AstNode *trans_compound_stmt(Context *c, TransScope *scope, const CompoundStmt *stmt, + TransScope **out_node_scope) +{ + TransScopeBlock *child_scope_block = trans_scope_block_create(c, scope); + scope = &child_scope_block->base; for (CompoundStmt::const_body_iterator it = stmt->body_begin(), end_it = stmt->body_end(); it != end_it; ++it) { - AstNode *child_node = trans_stmt(c, ResultUsedNo, &child_scope_block->base, *it, TransRValue); - if (child_node == nullptr) + AstNode *child_node; + scope = trans_stmt(c, scope, *it, &child_node); + if (scope == nullptr) return nullptr; - if (child_node != skip_add_to_block_node) + if (child_node != nullptr) child_scope_block->node->data.block.statements.append(child_node); } - c->child_scope = &child_scope_block->base; + if (out_node_scope != nullptr) { + *out_node_scope = &child_scope_block->base; + } return child_scope_block->node; } @@ -1809,7 +1819,15 @@ static AstNode *trans_unary_operator(Context *c, ResultUsed result_used, TransSc zig_unreachable(); } -static AstNode *trans_local_declaration(Context *c, TransScope *scope, const DeclStmt *stmt) { +static int trans_local_declaration(Context *c, TransScope *scope, const DeclStmt *stmt, + AstNode **out_node, TransScope **out_scope) +{ + // declarations are added via the scope + *out_node = nullptr; + + TransScopeBlock *scope_block = trans_scope_block_find(scope); + assert(scope_block != nullptr); + for (auto iter = stmt->decl_begin(); iter != stmt->decl_end(); iter++) { Decl *decl = *iter; switch (decl->getKind()) { @@ -1820,245 +1838,246 @@ static AstNode *trans_local_declaration(Context *c, TransScope *scope, const Dec if (var_decl->hasInit()) { init_node = trans_expr(c, ResultUsedYes, scope, var_decl->getInit(), TransRValue); if (init_node == nullptr) - return nullptr; + return ErrorUnexpected; } AstNode *type_node = trans_qual_type(c, qual_type, stmt->getLocStart()); if (type_node == nullptr) - return nullptr; + return ErrorUnexpected; - Buf *symbol_name = buf_create_from_str(decl_name(var_decl)); + Buf *c_symbol_name = buf_create_from_str(decl_name(var_decl)); + + TransScopeVar *var_scope = trans_scope_var_create(c, scope, c_symbol_name); + scope = &var_scope->base; AstNode *node = trans_create_node_var_decl_local(c, qual_type.isConstQualified(), - symbol_name, type_node, init_node); + var_scope->zig_name, type_node, init_node); - assert(scope->id == TransScopeIdBlock); - TransScopeBlock *scope_block = (TransScopeBlock *)scope; scope_block->node->data.block.statements.append(node); continue; } case Decl::AccessSpec: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind AccessSpec"); - return nullptr; + return ErrorUnexpected; case Decl::Block: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Block"); - return nullptr; + return ErrorUnexpected; case Decl::Captured: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Captured"); - return nullptr; + return ErrorUnexpected; case Decl::ClassScopeFunctionSpecialization: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ClassScopeFunctionSpecialization"); - return nullptr; + return ErrorUnexpected; case Decl::Empty: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Empty"); - return nullptr; + return ErrorUnexpected; case Decl::Export: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Export"); - return nullptr; + return ErrorUnexpected; case Decl::ExternCContext: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ExternCContext"); - return nullptr; + return ErrorUnexpected; case Decl::FileScopeAsm: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind FileScopeAsm"); - return nullptr; + return ErrorUnexpected; case Decl::Friend: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Friend"); - return nullptr; + return ErrorUnexpected; case Decl::FriendTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind FriendTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::Import: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Import"); - return nullptr; + return ErrorUnexpected; case Decl::LinkageSpec: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind LinkageSpec"); - return nullptr; + return ErrorUnexpected; case Decl::Label: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Label"); - return nullptr; + return ErrorUnexpected; case Decl::Namespace: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Namespace"); - return nullptr; + return ErrorUnexpected; case Decl::NamespaceAlias: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind NamespaceAlias"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCCompatibleAlias: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCCompatibleAlias"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCCategory: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCCategory"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCCategoryImpl: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCCategoryImpl"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCImplementation: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCImplementation"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCInterface: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCInterface"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCProtocol: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCProtocol"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCMethod: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCMethod"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCProperty: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCProperty"); - return nullptr; + return ErrorUnexpected; case Decl::BuiltinTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind BuiltinTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::ClassTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ClassTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::FunctionTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind FunctionTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::TypeAliasTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind TypeAliasTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::VarTemplate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind VarTemplate"); - return nullptr; + return ErrorUnexpected; case Decl::TemplateTemplateParm: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind TemplateTemplateParm"); - return nullptr; + return ErrorUnexpected; case Decl::Enum: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Enum"); - return nullptr; + return ErrorUnexpected; case Decl::Record: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Record"); - return nullptr; + return ErrorUnexpected; case Decl::CXXRecord: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXRecord"); - return nullptr; + return ErrorUnexpected; case Decl::ClassTemplateSpecialization: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ClassTemplateSpecialization"); - return nullptr; + return ErrorUnexpected; case Decl::ClassTemplatePartialSpecialization: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ClassTemplatePartialSpecialization"); - return nullptr; + return ErrorUnexpected; case Decl::TemplateTypeParm: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind TemplateTypeParm"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCTypeParam: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCTypeParam"); - return nullptr; + return ErrorUnexpected; case Decl::TypeAlias: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind TypeAlias"); - return nullptr; + return ErrorUnexpected; case Decl::Typedef: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Typedef"); - return nullptr; + return ErrorUnexpected; case Decl::UnresolvedUsingTypename: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind UnresolvedUsingTypename"); - return nullptr; + return ErrorUnexpected; case Decl::Using: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Using"); - return nullptr; + return ErrorUnexpected; case Decl::UsingDirective: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind UsingDirective"); - return nullptr; + return ErrorUnexpected; case Decl::UsingPack: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind UsingPack"); - return nullptr; + return ErrorUnexpected; case Decl::UsingShadow: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind UsingShadow"); - return nullptr; + return ErrorUnexpected; case Decl::ConstructorUsingShadow: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ConstructorUsingShadow"); - return nullptr; + return ErrorUnexpected; case Decl::Binding: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Binding"); - return nullptr; + return ErrorUnexpected; case Decl::Field: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Field"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCAtDefsField: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCAtDefsField"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCIvar: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCIvar"); - return nullptr; + return ErrorUnexpected; case Decl::Function: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Function"); - return nullptr; + return ErrorUnexpected; case Decl::CXXDeductionGuide: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXDeductionGuide"); - return nullptr; + return ErrorUnexpected; case Decl::CXXMethod: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXMethod"); - return nullptr; + return ErrorUnexpected; case Decl::CXXConstructor: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXConstructor"); - return nullptr; + return ErrorUnexpected; case Decl::CXXConversion: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXConversion"); - return nullptr; + return ErrorUnexpected; case Decl::CXXDestructor: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind CXXDestructor"); - return nullptr; + return ErrorUnexpected; case Decl::MSProperty: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind MSProperty"); - return nullptr; + return ErrorUnexpected; case Decl::NonTypeTemplateParm: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind NonTypeTemplateParm"); - return nullptr; + return ErrorUnexpected; case Decl::Decomposition: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind Decomposition"); - return nullptr; + return ErrorUnexpected; case Decl::ImplicitParam: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ImplicitParam"); - return nullptr; + return ErrorUnexpected; case Decl::OMPCapturedExpr: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind OMPCapturedExpr"); - return nullptr; + return ErrorUnexpected; case Decl::ParmVar: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ParmVar"); - return nullptr; + return ErrorUnexpected; case Decl::VarTemplateSpecialization: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind VarTemplateSpecialization"); - return nullptr; + return ErrorUnexpected; case Decl::VarTemplatePartialSpecialization: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind VarTemplatePartialSpecialization"); - return nullptr; + return ErrorUnexpected; case Decl::EnumConstant: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind EnumConstant"); - return nullptr; + return ErrorUnexpected; case Decl::IndirectField: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind IndirectField"); - return nullptr; + return ErrorUnexpected; case Decl::OMPDeclareReduction: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind OMPDeclareReduction"); - return nullptr; + return ErrorUnexpected; case Decl::UnresolvedUsingValue: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind UnresolvedUsingValue"); - return nullptr; + return ErrorUnexpected; case Decl::OMPThreadPrivate: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind OMPThreadPrivate"); - return nullptr; + return ErrorUnexpected; case Decl::ObjCPropertyImpl: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind ObjCPropertyImpl"); - return nullptr; + return ErrorUnexpected; case Decl::PragmaComment: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind PragmaComment"); - return nullptr; + return ErrorUnexpected; case Decl::PragmaDetectMismatch: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind PragmaDetectMismatch"); - return nullptr; + return ErrorUnexpected; case Decl::StaticAssert: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind StaticAssert"); - return nullptr; + return ErrorUnexpected; case Decl::TranslationUnit: emit_warning(c, stmt->getLocStart(), "TODO handle decl kind TranslationUnit"); - return nullptr; + return ErrorUnexpected; } zig_unreachable(); } - // declarations were already added - return skip_add_to_block_node; + *out_scope = scope; + return ErrorNone; } static AstNode *trans_while_loop(Context *c, TransScope *scope, const WhileStmt *stmt) { @@ -2068,8 +2087,8 @@ static AstNode *trans_while_loop(Context *c, TransScope *scope, const WhileStmt if (while_node->data.while_expr.condition == nullptr) return nullptr; - while_node->data.while_expr.body = trans_stmt(c, ResultUsedNo, scope, stmt->getBody(), TransRValue); - if (while_node->data.while_expr.body == nullptr) + TransScope *body_scope = trans_stmt(c, scope, stmt->getBody(), &while_node->data.while_expr.body); + if (body_scope == nullptr) return nullptr; return while_node; @@ -2086,13 +2105,13 @@ static AstNode *trans_if_statement(Context *c, TransScope *scope, const IfStmt * return nullptr; if_node->data.if_bool_expr.condition = condition_node; - if_node->data.if_bool_expr.then_block = trans_stmt(c, ResultUsedNo, scope, stmt->getThen(), TransRValue); - if (if_node->data.if_bool_expr.then_block == nullptr) + TransScope *then_scope = trans_stmt(c, scope, stmt->getThen(), &if_node->data.if_bool_expr.then_block); + if (then_scope == nullptr) return nullptr; if (stmt->getElse() != nullptr) { - if_node->data.if_bool_expr.else_node = trans_stmt(c, ResultUsedNo, scope, stmt->getElse(), TransRValue); - if (if_node->data.if_bool_expr.else_node == nullptr) + TransScope *else_scope = trans_stmt(c, scope, stmt->getElse(), &if_node->data.if_bool_expr.else_node); + if (else_scope == nullptr) return nullptr; } @@ -2201,10 +2220,14 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: b; // zig: if (!cond) break; // zig: } - body_node = trans_stmt(c, ResultUsedNo, parent_scope, stmt->getBody(), TransRValue); - if (body_node == nullptr) return nullptr; + + // We call the low level function so that we can set child_scope to the scope of the generated block. + if (trans_stmt_extra(c, parent_scope, stmt->getBody(), ResultUsedNo, TransRValue, &body_node, + nullptr, &child_scope)) + { + return nullptr; + } assert(body_node->type == NodeTypeBlock); - child_scope = c->child_scope; } else { // the C statement is without a block, so we need to create a block to contain it. // c: do @@ -2216,10 +2239,10 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: } TransScopeBlock *child_block_scope = trans_scope_block_create(c, parent_scope); body_node = child_block_scope->node; - child_scope = &child_block_scope->base; - AstNode *child_statement = trans_stmt(c, ResultUsedNo, child_scope, stmt->getBody(), TransRValue); - if (child_statement == nullptr) return nullptr; - child_block_scope->node->data.block.statements.append(child_statement); + AstNode *child_statement; + child_scope = trans_stmt(c, &child_block_scope->base, stmt->getBody(), &child_statement); + if (child_scope == nullptr) return nullptr; + body_node->data.block.statements.append(child_statement); } // if (!cond) break; @@ -2229,10 +2252,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt terminator_node->data.if_bool_expr.condition = trans_create_node_prefix_op(c, PrefixOpBoolNot, condition_node); terminator_node->data.if_bool_expr.then_block = trans_create_node(c, NodeTypeBreak); - assert(child_scope->id == TransScopeIdBlock); - TransScopeBlock *child_block_scope = (TransScopeBlock *)child_scope; - - child_block_scope->node->data.block.statements.append(terminator_node); + body_node->data.block.statements.append(terminator_node); while_node->data.while_expr.body = body_node; @@ -2244,10 +2264,10 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // AstNode *switch_node = trans_create_node(c, NodeTypeSwitchExpr); // const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); // if (var_decl_stmt != nullptr) { -// AstNode *vars_node = trans_stmt(c, ResultUsedNo, switch_block_node, var_decl_stmt, TransRValue); +// AstNode *vars_node = trans_stmt(c, switch_block_node, var_decl_stmt); // if (vars_node == nullptr) // return nullptr; -// if (vars_node != skip_add_to_block_node) +// if (vars_node != nullptr) // switch_block_node->data.block.statements.append(vars_node); // } // switch_block_node->data.block.statements.append(switch_node); @@ -2260,10 +2280,10 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // return nullptr; // switch_node->data.switch_expr.expr = expr_node; // -// AstNode *body_node = trans_stmt(c, ResultUsedNo, switch_block_node, stmt->getBody(), TransRValue); +// AstNode *body_node = trans_stmt(c, switch_block_node, stmt->getBody()); // if (body_node == nullptr) // return nullptr; -// if (body_node != skip_add_to_block_node) +// if (body_node != nullptr) // switch_block_node->data.block.statements.append(body_node); // // return switch_block_node; @@ -2271,21 +2291,22 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForStmt *stmt) { AstNode *loop_block_node; - TransScope *condition_scope; + TransScope *inner_scope; AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); const Stmt *init_stmt = stmt->getInit(); if (init_stmt == nullptr) { loop_block_node = while_node; - condition_scope = parent_scope; + inner_scope = parent_scope; } else { TransScopeBlock *child_scope = trans_scope_block_create(c, parent_scope); loop_block_node = child_scope->node; - condition_scope = &child_scope->base; + inner_scope = &child_scope->base; - AstNode *vars_node = trans_stmt(c, ResultUsedNo, &child_scope->base, init_stmt, TransRValue); - if (vars_node == nullptr) + AstNode *vars_node; + inner_scope = trans_stmt(c, &child_scope->base, init_stmt, &vars_node); + if (inner_scope == nullptr) return nullptr; - if (vars_node != skip_add_to_block_node) + if (vars_node != nullptr) child_scope->node->data.block.statements.append(vars_node); child_scope->node->data.block.statements.append(while_node); @@ -2297,21 +2318,23 @@ static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForSt true_node->data.bool_literal.value = true; while_node->data.while_expr.condition = true_node; } else { - while_node->data.while_expr.condition = trans_stmt(c, ResultUsedNo, condition_scope, cond_stmt, TransRValue); - if (while_node->data.while_expr.condition == nullptr) + TransScope *cond_scope = trans_stmt(c, inner_scope, cond_stmt, &while_node->data.while_expr.condition); + if (cond_scope == nullptr) return nullptr; } const Stmt *inc_stmt = stmt->getInc(); if (inc_stmt != nullptr) { - AstNode *inc_node = trans_stmt(c, ResultUsedNo, condition_scope, inc_stmt, TransRValue); - if (inc_node == nullptr) + AstNode *inc_node; + TransScope *inc_scope = trans_stmt(c, inner_scope, inc_stmt, &inc_node); + if (inc_scope == nullptr) return nullptr; while_node->data.while_expr.continue_expr = inc_node; } - AstNode *child_statement = trans_stmt(c, ResultUsedNo, condition_scope, stmt->getBody(), TransRValue); - if (child_statement == nullptr) + AstNode *child_statement; + TransScope *body_scope = trans_stmt(c, inner_scope, stmt->getBody(), &child_statement); + if (body_scope == nullptr) return nullptr; while_node->data.while_expr.body = child_statement; @@ -2344,578 +2367,636 @@ static AstNode *trans_continue_stmt(Context *c, TransScope *scope, const Continu return trans_create_node(c, NodeTypeContinue); } -static AstNode *trans_stmt(Context *c, ResultUsed result_used, TransScope *scope, const Stmt *stmt, TransLRValue lrvalue) { - c->child_scope = scope; // TODO refactor out +static int wrap_stmt(AstNode **out_node, TransScope **out_scope, TransScope *in_scope, AstNode *result_node) { + if (result_node == nullptr) + return ErrorUnexpected; + *out_node = result_node; + if (out_scope != nullptr) { + *out_scope = in_scope; + } + return ErrorNone; +} + +static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt, + ResultUsed result_used, TransLRValue lrvalue, + AstNode **out_node, TransScope **out_child_scope, + TransScope **out_node_scope) +{ Stmt::StmtClass sc = stmt->getStmtClass(); switch (sc) { case Stmt::ReturnStmtClass: - return trans_return_stmt(c, scope, (const ReturnStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_return_stmt(c, scope, (const ReturnStmt *)stmt)); case Stmt::CompoundStmtClass: - return trans_compound_stmt(c, scope, (const CompoundStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_compound_stmt(c, scope, (const CompoundStmt *)stmt, out_node_scope)); case Stmt::IntegerLiteralClass: - return trans_integer_literal(c, (const IntegerLiteral *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_integer_literal(c, (const IntegerLiteral *)stmt)); case Stmt::ConditionalOperatorClass: - return trans_conditional_operator(c, result_used, scope, (const ConditionalOperator *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_conditional_operator(c, result_used, scope, (const ConditionalOperator *)stmt)); case Stmt::BinaryOperatorClass: - return trans_binary_operator(c, result_used, scope, (const BinaryOperator *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_binary_operator(c, result_used, scope, (const BinaryOperator *)stmt)); case Stmt::CompoundAssignOperatorClass: - return trans_compound_assign_operator(c, result_used, scope, (const CompoundAssignOperator *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_compound_assign_operator(c, result_used, scope, (const CompoundAssignOperator *)stmt)); case Stmt::ImplicitCastExprClass: - return trans_implicit_cast_expr(c, scope, (const ImplicitCastExpr *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_implicit_cast_expr(c, scope, (const ImplicitCastExpr *)stmt)); case Stmt::DeclRefExprClass: - return trans_decl_ref_expr(c, (const DeclRefExpr *)stmt, lrvalue); + return wrap_stmt(out_node, out_child_scope, scope, + trans_decl_ref_expr(c, (const DeclRefExpr *)stmt, lrvalue)); case Stmt::UnaryOperatorClass: - return trans_unary_operator(c, result_used, scope, (const UnaryOperator *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_unary_operator(c, result_used, scope, (const UnaryOperator *)stmt)); case Stmt::DeclStmtClass: - return trans_local_declaration(c, scope, (const DeclStmt *)stmt); + return trans_local_declaration(c, scope, (const DeclStmt *)stmt, out_node, out_child_scope); case Stmt::WhileStmtClass: - return trans_while_loop(c, scope, (const WhileStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_while_loop(c, scope, (const WhileStmt *)stmt)); case Stmt::IfStmtClass: - return trans_if_statement(c, scope, (const IfStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_if_statement(c, scope, (const IfStmt *)stmt)); case Stmt::CallExprClass: - return trans_call_expr(c, result_used, scope, (const CallExpr *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_call_expr(c, result_used, scope, (const CallExpr *)stmt)); case Stmt::NullStmtClass: - return skip_add_to_block_node; + *out_node = nullptr; + *out_child_scope = scope; + return ErrorNone; case Stmt::MemberExprClass: - return trans_member_expr(c, scope, (const MemberExpr *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_member_expr(c, scope, (const MemberExpr *)stmt)); case Stmt::ArraySubscriptExprClass: - return trans_array_subscript_expr(c, scope, (const ArraySubscriptExpr *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_array_subscript_expr(c, scope, (const ArraySubscriptExpr *)stmt)); case Stmt::CStyleCastExprClass: - return trans_c_style_cast_expr(c, result_used, scope, (const CStyleCastExpr *)stmt, lrvalue); + return wrap_stmt(out_node, out_child_scope, scope, + trans_c_style_cast_expr(c, result_used, scope, (const CStyleCastExpr *)stmt, lrvalue)); case Stmt::UnaryExprOrTypeTraitExprClass: - return trans_unary_expr_or_type_trait_expr(c, scope, (const UnaryExprOrTypeTraitExpr *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_unary_expr_or_type_trait_expr(c, scope, (const UnaryExprOrTypeTraitExpr *)stmt)); case Stmt::DoStmtClass: - return trans_do_loop(c, scope, (const DoStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_do_loop(c, scope, (const DoStmt *)stmt)); case Stmt::ForStmtClass: - return trans_for_loop(c, scope, (const ForStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_for_loop(c, scope, (const ForStmt *)stmt)); case Stmt::StringLiteralClass: - return trans_string_literal(c, scope, (const StringLiteral *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_string_literal(c, scope, (const StringLiteral *)stmt)); case Stmt::BreakStmtClass: - return trans_break_stmt(c, scope, (const BreakStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_break_stmt(c, scope, (const BreakStmt *)stmt)); case Stmt::ContinueStmtClass: - return trans_continue_stmt(c, scope, (const ContinueStmt *)stmt); + return wrap_stmt(out_node, out_child_scope, scope, + trans_continue_stmt(c, scope, (const ContinueStmt *)stmt)); + case Stmt::ParenExprClass: + return wrap_stmt(out_node, out_child_scope, scope, + trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue)); // case Stmt::SwitchStmtClass: -// return trans_switch_stmt(c, scope, (const SwitchStmt *)stmt); +// return wrap_stmt(out_node, out_child_scope, scope, +// trans_switch_stmt(c, scope, (const SwitchStmt *)stmt)); case Stmt::SwitchStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CaseStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::DefaultStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::NoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::GCCAsmStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C GCCAsmStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::MSAsmStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C MSAsmStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::AttributedStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C AttributedStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXCatchStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXCatchStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXForRangeStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXForRangeStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXTryStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXTryStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CapturedStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CapturedStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CoreturnStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CoreturnStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CoroutineBodyStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CoroutineBodyStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::BinaryConditionalOperatorClass: emit_warning(c, stmt->getLocStart(), "TODO handle C BinaryConditionalOperatorClass"); - return nullptr; + return ErrorUnexpected; case Stmt::AddrLabelExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C AddrLabelExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ArrayInitIndexExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ArrayInitIndexExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ArrayInitLoopExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ArrayInitLoopExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ArrayTypeTraitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ArrayTypeTraitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::AsTypeExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C AsTypeExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::AtomicExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C AtomicExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::BlockExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C BlockExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXBindTemporaryExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXBindTemporaryExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXBoolLiteralExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXBoolLiteralExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXConstructExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXConstructExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXTemporaryObjectExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXTemporaryObjectExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXDefaultArgExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXDefaultArgExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXDefaultInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXDefaultInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXDeleteExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXDeleteExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXDependentScopeMemberExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXDependentScopeMemberExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXFoldExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXFoldExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXInheritedCtorInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXInheritedCtorInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXNewExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXNewExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXNoexceptExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXNoexceptExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXNullPtrLiteralExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXNullPtrLiteralExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXPseudoDestructorExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXPseudoDestructorExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXScalarValueInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXScalarValueInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXStdInitializerListExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXStdInitializerListExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXThisExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXThisExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXThrowExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXThrowExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXTypeidExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXTypeidExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXUnresolvedConstructExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXUnresolvedConstructExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXUuidofExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXUuidofExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CUDAKernelCallExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CUDAKernelCallExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXMemberCallExprClass: - (void)result_used; emit_warning(c, stmt->getLocStart(), "TODO handle C CXXMemberCallExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXOperatorCallExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXOperatorCallExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::UserDefinedLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C UserDefinedLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXFunctionalCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXFunctionalCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXConstCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXConstCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXDynamicCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXDynamicCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXReinterpretCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXReinterpretCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CXXStaticCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CXXStaticCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCBridgedCastExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCBridgedCastExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CharacterLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CharacterLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ChooseExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ChooseExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CompoundLiteralExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CompoundLiteralExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ConvertVectorExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ConvertVectorExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CoawaitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CoawaitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::CoyieldExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C CoyieldExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::DependentCoawaitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DependentCoawaitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::DependentScopeDeclRefExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DependentScopeDeclRefExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::DesignatedInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DesignatedInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::DesignatedInitUpdateExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C DesignatedInitUpdateExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ExprWithCleanupsClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ExprWithCleanupsClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ExpressionTraitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ExpressionTraitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ExtVectorElementExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ExtVectorElementExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::FloatingLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C FloatingLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::FunctionParmPackExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C FunctionParmPackExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::GNUNullExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C GNUNullExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::GenericSelectionExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C GenericSelectionExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ImaginaryLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ImaginaryLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ImplicitValueInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ImplicitValueInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::InitListExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C InitListExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::LambdaExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C LambdaExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::MSPropertyRefExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C MSPropertyRefExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::MSPropertySubscriptExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C MSPropertySubscriptExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::MaterializeTemporaryExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C MaterializeTemporaryExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::NoInitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C NoInitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPArraySectionExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPArraySectionExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCArrayLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCArrayLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAvailabilityCheckExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAvailabilityCheckExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCBoolLiteralExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCBoolLiteralExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCBoxedExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCBoxedExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCDictionaryLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCDictionaryLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCEncodeExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCEncodeExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCIndirectCopyRestoreExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCIndirectCopyRestoreExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCIsaExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCIsaExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCIvarRefExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCIvarRefExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCMessageExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCMessageExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCPropertyRefExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCPropertyRefExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCProtocolExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCProtocolExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCSelectorExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCSelectorExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCStringLiteralClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCStringLiteralClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCSubscriptRefExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCSubscriptRefExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OffsetOfExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OffsetOfExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OpaqueValueExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OpaqueValueExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::UnresolvedLookupExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C UnresolvedLookupExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::UnresolvedMemberExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C UnresolvedMemberExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::PackExpansionExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C PackExpansionExprClass"); - return nullptr; - case Stmt::ParenExprClass: - return trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue); + return ErrorUnexpected; case Stmt::ParenListExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ParenListExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::PredefinedExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C PredefinedExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::PseudoObjectExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C PseudoObjectExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ShuffleVectorExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ShuffleVectorExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SizeOfPackExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SizeOfPackExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::StmtExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C StmtExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SubstNonTypeTemplateParmExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SubstNonTypeTemplateParmExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SubstNonTypeTemplateParmPackExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SubstNonTypeTemplateParmPackExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::TypeTraitExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C TypeTraitExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::TypoExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C TypoExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::VAArgExprClass: emit_warning(c, stmt->getLocStart(), "TODO handle C VAArgExprClass"); - return nullptr; + return ErrorUnexpected; case Stmt::GotoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C GotoStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::IndirectGotoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C IndirectGotoStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::LabelStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C LabelStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::MSDependentExistsStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C MSDependentExistsStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPAtomicDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPAtomicDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPBarrierDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPBarrierDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPCancelDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPCancelDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPCancellationPointDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPCancellationPointDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPCriticalDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPCriticalDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPFlushDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPFlushDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPDistributeDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPDistributeDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPDistributeParallelForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPDistributeParallelForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPDistributeParallelForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPDistributeParallelForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPDistributeSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPDistributeSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPParallelForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPParallelForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPParallelForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPParallelForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetParallelForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetParallelForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetTeamsDistributeDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetTeamsDistributeDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetTeamsDistributeParallelForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetTeamsDistributeParallelForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetTeamsDistributeParallelForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetTeamsDistributeSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskLoopDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskLoopDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskLoopSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskLoopSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTeamsDistributeDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTeamsDistributeDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTeamsDistributeParallelForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTeamsDistributeParallelForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTeamsDistributeParallelForSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTeamsDistributeSimdDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTeamsDistributeSimdDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPMasterDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPMasterDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPOrderedDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPOrderedDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPParallelDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPParallelDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPParallelSectionsDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPParallelSectionsDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPSectionDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPSectionDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPSectionsDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPSectionsDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPSingleDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPSingleDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetDataDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetDataDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetEnterDataDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetEnterDataDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetExitDataDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetExitDataDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetParallelDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetParallelDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetParallelForDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetParallelForDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetTeamsDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetTeamsDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTargetUpdateDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTargetUpdateDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskgroupDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskgroupDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskwaitDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskwaitDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTaskyieldDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTaskyieldDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::OMPTeamsDirectiveClass: emit_warning(c, stmt->getLocStart(), "TODO handle C OMPTeamsDirectiveClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAtCatchStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAtCatchStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAtFinallyStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAtFinallyStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAtSynchronizedStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAtSynchronizedStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAtThrowStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAtThrowStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAtTryStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAtTryStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCAutoreleasePoolStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCAutoreleasePoolStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::ObjCForCollectionStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C ObjCForCollectionStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SEHExceptStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SEHExceptStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SEHFinallyStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SEHFinallyStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SEHLeaveStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SEHLeaveStmtClass"); - return nullptr; + return ErrorUnexpected; case Stmt::SEHTryStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C SEHTryStmtClass"); - return nullptr; + return ErrorUnexpected; } zig_unreachable(); } +// Returns null if there was an error +static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope, const Expr *expr, + TransLRValue lrval) +{ + AstNode *result_node; + if (trans_stmt_extra(c, scope, expr, result_used, lrval, &result_node, nullptr, nullptr)) { + return nullptr; + } + return result_node; +} + +// Statements have no result and no concept of L or R value. +// Returns child scope, or null if there was an error +static TransScope *trans_stmt(Context *c, TransScope *scope, const Stmt *stmt, AstNode **out_node) { + TransScope *child_scope; + if (trans_stmt_extra(c, scope, stmt, ResultUsedNo, TransRValue, out_node, &child_scope, nullptr)) { + return nullptr; + } + return child_scope; +} + static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { Buf *fn_name = buf_create_from_str(decl_name(fn_decl)); @@ -2946,7 +3027,8 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { return; } - TransScope *scope = nullptr; + TransScopeRoot *root_scope = trans_scope_root_create(c); + TransScope *scope = &root_scope->base; for (size_t i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { AstNode *param_node = proto_node->data.fn_proto.params.at(i); @@ -2978,16 +3060,17 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { // actual function definition with body c->ptr_params.clear(); Stmt *body = fn_decl->getBody(); - AstNode *actual_body_node = trans_stmt(c, ResultUsedNo, scope, body, TransRValue); - assert(actual_body_node != skip_add_to_block_node); - if (actual_body_node == nullptr) { + AstNode *actual_body_node; + TransScope *result_scope = trans_stmt(c, scope, body, &actual_body_node); + if (result_scope == nullptr) { emit_warning(c, fn_decl->getLocation(), "unable to translate function"); return; } + assert(actual_body_node != nullptr); + assert(actual_body_node->type == NodeTypeBlock); // it worked - assert(actual_body_node->type == NodeTypeBlock); AstNode *body_node_with_param_inits = trans_create_node(c, NodeTypeBlock); for (size_t i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { @@ -3447,6 +3530,12 @@ static Buf *get_unique_name(Context *c, Buf *name, TransScope *scope) { return proposed_name; } +static TransScopeRoot *trans_scope_root_create(Context *c) { + TransScopeRoot *result = allocate(1); + result->base.id = TransScopeIdRoot; + return result; +} + static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope) { TransScopeBlock *result = allocate(1); result->base.id = TransScopeIdBlock; @@ -3472,6 +3561,16 @@ static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scop // return result; //} +static TransScopeBlock *trans_scope_block_find(TransScope *scope) { + while (scope != nullptr) { + if (scope->id == TransScopeIdBlock) { + return (TransScopeBlock *)scope; + } + scope = scope->parent; + } + return nullptr; +} + static void render_aliases(Context *c) { for (size_t i = 0; i < c->aliases.length; i += 1) { Alias *alias = &c->aliases.at(i); From 1b0e90f70b4dc26c2ba96b7b5709a3ff269bb48a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 00:58:11 -0500 Subject: [PATCH 24/34] translate-c supports switch statements --- src/translate_c.cpp | 327 +++++++++++++++++++++++++++++++------------ test/translate_c.zig | 80 +++++------ 2 files changed, 280 insertions(+), 127 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 9e6676063d..27066a08b6 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -69,6 +69,7 @@ enum TransScopeId { TransScopeIdVar, TransScopeIdBlock, TransScopeIdRoot, + TransScopeIdWhile, }; struct TransScope { @@ -79,6 +80,8 @@ struct TransScope { struct TransScopeSwitch { TransScope base; AstNode *switch_node; + uint32_t case_index; + bool found_default; }; struct TransScopeVar { @@ -96,12 +99,19 @@ struct TransScopeRoot { TransScope base; }; +struct TransScopeWhile { + TransScope base; + AstNode *node; +}; + static TransScopeRoot *trans_scope_root_create(Context *c); +static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope); static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope); static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name); -//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); +static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope); static TransScopeBlock *trans_scope_block_find(TransScope *scope); +static TransScopeSwitch *trans_scope_switch_find(TransScope *scope); static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl); static AstNode *resolve_enum_decl(Context *c, const EnumDecl *enum_decl); @@ -238,6 +248,12 @@ static AstNode *trans_create_node_addr_of(Context *c, bool is_const, bool is_vol return node; } +static AstNode *trans_create_node_bool(Context *c, bool value) { + AstNode *bool_node = trans_create_node(c, NodeTypeBoolLiteral); + bool_node->data.bool_literal.value = value; + return bool_node; +} + static AstNode *trans_create_node_str_lit_c(Context *c, Buf *buf) { AstNode *node = trans_create_node(c, NodeTypeStringLiteral); node->data.string_literal.buf = buf; @@ -965,22 +981,30 @@ static AstNode *trans_qual_type(Context *c, QualType qt, const SourceLocation &s return trans_type(c, qt.getTypePtr(), source_loc); } -static AstNode *trans_compound_stmt(Context *c, TransScope *scope, const CompoundStmt *stmt, - TransScope **out_node_scope) +static int trans_compound_stmt_inline(Context *c, TransScope *scope, const CompoundStmt *stmt, + AstNode *block_node, TransScope **out_node_scope) { - TransScopeBlock *child_scope_block = trans_scope_block_create(c, scope); - scope = &child_scope_block->base; + assert(block_node->type == NodeTypeBlock); for (CompoundStmt::const_body_iterator it = stmt->body_begin(), end_it = stmt->body_end(); it != end_it; ++it) { AstNode *child_node; scope = trans_stmt(c, scope, *it, &child_node); if (scope == nullptr) - return nullptr; + return ErrorUnexpected; if (child_node != nullptr) - child_scope_block->node->data.block.statements.append(child_node); + block_node->data.block.statements.append(child_node); } if (out_node_scope != nullptr) { - *out_node_scope = &child_scope_block->base; + *out_node_scope = scope; } + return ErrorNone; +} + +static AstNode *trans_compound_stmt(Context *c, TransScope *scope, const CompoundStmt *stmt, + TransScope **out_node_scope) +{ + TransScopeBlock *child_scope_block = trans_scope_block_create(c, scope); + if (trans_compound_stmt_inline(c, &child_scope_block->base, stmt, child_scope_block->node, out_node_scope)) + return nullptr; return child_scope_block->node; } @@ -2081,17 +2105,18 @@ static int trans_local_declaration(Context *c, TransScope *scope, const DeclStmt } static AstNode *trans_while_loop(Context *c, TransScope *scope, const WhileStmt *stmt) { - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope = trans_scope_while_create(c, scope); - while_node->data.while_expr.condition = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); - if (while_node->data.while_expr.condition == nullptr) + while_scope->node->data.while_expr.condition = trans_expr(c, ResultUsedYes, scope, stmt->getCond(), TransRValue); + if (while_scope->node->data.while_expr.condition == nullptr) return nullptr; - TransScope *body_scope = trans_stmt(c, scope, stmt->getBody(), &while_node->data.while_expr.body); + TransScope *body_scope = trans_stmt(c, &while_scope->base, stmt->getBody(), + &while_scope->node->data.while_expr.body); if (body_scope == nullptr) return nullptr; - return while_node; + return while_scope->node; } static AstNode *trans_if_statement(Context *c, TransScope *scope, const IfStmt *stmt) { @@ -2201,11 +2226,9 @@ static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, TransScope *scop } static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt *stmt) { - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope = trans_scope_while_create(c, parent_scope); - AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); - true_node->data.bool_literal.value = true; - while_node->data.while_expr.condition = true_node; + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); AstNode *body_node; TransScope *child_scope; @@ -2222,7 +2245,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: } // We call the low level function so that we can set child_scope to the scope of the generated block. - if (trans_stmt_extra(c, parent_scope, stmt->getBody(), ResultUsedNo, TransRValue, &body_node, + if (trans_stmt_extra(c, &while_scope->base, stmt->getBody(), ResultUsedNo, TransRValue, &body_node, nullptr, &child_scope)) { return nullptr; @@ -2237,7 +2260,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt // zig: a; // zig: if (!cond) break; // zig: } - TransScopeBlock *child_block_scope = trans_scope_block_create(c, parent_scope); + TransScopeBlock *child_block_scope = trans_scope_block_create(c, &while_scope->base); body_node = child_block_scope->node; AstNode *child_statement; child_scope = trans_stmt(c, &child_block_scope->base, stmt->getBody(), &child_statement); @@ -2254,89 +2277,206 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt body_node->data.block.statements.append(terminator_node); - while_node->data.while_expr.body = body_node; + while_scope->node->data.while_expr.body = body_node; - return while_node; + return while_scope->node; } -//static AstNode *trans_switch_stmt(Context *c, TransScope *scope, const SwitchStmt *stmt) { -// AstNode *switch_block_node = trans_create_node(c, NodeTypeBlock); -// AstNode *switch_node = trans_create_node(c, NodeTypeSwitchExpr); -// const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); -// if (var_decl_stmt != nullptr) { -// AstNode *vars_node = trans_stmt(c, switch_block_node, var_decl_stmt); -// if (vars_node == nullptr) -// return nullptr; -// if (vars_node != nullptr) -// switch_block_node->data.block.statements.append(vars_node); -// } -// switch_block_node->data.block.statements.append(switch_node); -// -// const Expr *cond_expr = stmt->getCond(); -// assert(cond_expr != nullptr); -// -// AstNode *expr_node = trans_expr(c, ResultUsedYes, switch_block_node, cond_expr, TransRValue); -// if (expr_node == nullptr) -// return nullptr; -// switch_node->data.switch_expr.expr = expr_node; -// -// AstNode *body_node = trans_stmt(c, switch_block_node, stmt->getBody()); -// if (body_node == nullptr) -// return nullptr; -// if (body_node != nullptr) -// switch_block_node->data.block.statements.append(body_node); -// -// return switch_block_node; -//} +static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const SwitchStmt *stmt) { + TransScopeWhile *while_scope = trans_scope_while_create(c, parent_scope); + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); + + TransScopeBlock *block_scope = trans_scope_block_create(c, &while_scope->base); + while_scope->node->data.while_expr.body = block_scope->node; + + TransScopeSwitch *switch_scope; + + const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt(); + if (var_decl_stmt == nullptr) { + switch_scope = trans_scope_switch_create(c, &block_scope->base); + } else { + AstNode *vars_node; + TransScope *var_scope = trans_stmt(c, &block_scope->base, var_decl_stmt, &vars_node); + if (var_scope == nullptr) + return nullptr; + if (vars_node != nullptr) + block_scope->node->data.block.statements.append(vars_node); + switch_scope = trans_scope_switch_create(c, var_scope); + } + block_scope->node->data.block.statements.append(switch_scope->switch_node); + + const Expr *cond_expr = stmt->getCond(); + assert(cond_expr != nullptr); + + AstNode *expr_node = trans_expr(c, ResultUsedYes, &block_scope->base, cond_expr, TransRValue); + if (expr_node == nullptr) + return nullptr; + switch_scope->switch_node->data.switch_expr.expr = expr_node; + + AstNode *body_node; + const Stmt *body_stmt = stmt->getBody(); + if (body_stmt->getStmtClass() == Stmt::CompoundStmtClass) { + if (trans_compound_stmt_inline(c, &switch_scope->base, (const CompoundStmt *)body_stmt, + block_scope->node, nullptr)) + { + return nullptr; + } + } else { + TransScope *body_scope = trans_stmt(c, &switch_scope->base, body_stmt, &body_node); + if (body_scope == nullptr) + return nullptr; + if (body_node != nullptr) + block_scope->node->data.block.statements.append(body_node); + } + + if (!switch_scope->found_default && !stmt->isAllEnumCasesCovered()) { + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + prong_node->data.switch_prong.expr = trans_create_node(c, NodeTypeBreak); + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + } + + // This is necessary if the last switch case "falls through" the end of the switch block + block_scope->node->data.block.statements.append(trans_create_node(c, NodeTypeBreak)); + + return while_scope->node; +} + +static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStmt *stmt, AstNode **out_node, + TransScope **out_scope) +{ + *out_node = nullptr; + + if (stmt->getRHS() != nullptr) { + emit_warning(c, stmt->getLocStart(), "TODO support GNU switch case a ... b extension"); + return ErrorUnexpected; + } + + TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope); + assert(switch_scope != nullptr); + + Buf *label_name = buf_sprintf("case_%" PRIu32, switch_scope->case_index); + switch_scope->case_index += 1; + + { + // Add the prong + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + AstNode *item_node = trans_expr(c, ResultUsedYes, &switch_scope->base, stmt->getLHS(), TransRValue); + if (item_node == nullptr) + return ErrorUnexpected; + prong_node->data.switch_prong.items.append(item_node); + + AstNode *goto_node = trans_create_node(c, NodeTypeGoto); + goto_node->data.goto_expr.name = label_name; + prong_node->data.switch_prong.expr = goto_node; + + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + } + + AstNode *label_node = trans_create_node(c, NodeTypeLabel); + label_node->data.label.name = label_name; + + TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); + scope_block->node->data.block.statements.append(label_node); + + AstNode *sub_stmt_node; + TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); + if (new_scope == nullptr) + return ErrorUnexpected; + if (sub_stmt_node != nullptr) + scope_block->node->data.block.statements.append(sub_stmt_node); + + *out_scope = new_scope; + return ErrorNone; +} + +static int trans_switch_default(Context *c, TransScope *parent_scope, const DefaultStmt *stmt, AstNode **out_node, + TransScope **out_scope) +{ + *out_node = nullptr; + + TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope); + assert(switch_scope != nullptr); + + Buf *label_name = buf_sprintf("default"); + + AstNode *label_node = trans_create_node(c, NodeTypeLabel); + label_node->data.label.name = label_name; + + { + // Add the prong + AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); + + AstNode *goto_node = trans_create_node(c, NodeTypeGoto); + goto_node->data.goto_expr.name = label_name; + prong_node->data.switch_prong.expr = goto_node; + + switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); + switch_scope->found_default = true; + } + + TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); + scope_block->node->data.block.statements.append(label_node); + + AstNode *sub_stmt_node; + TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); + if (new_scope == nullptr) + return ErrorUnexpected; + if (sub_stmt_node != nullptr) + scope_block->node->data.block.statements.append(sub_stmt_node); + + *out_scope = new_scope; + return ErrorNone; +} static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForStmt *stmt) { AstNode *loop_block_node; - TransScope *inner_scope; - AstNode *while_node = trans_create_node(c, NodeTypeWhileExpr); + TransScopeWhile *while_scope; + TransScope *cond_scope; const Stmt *init_stmt = stmt->getInit(); if (init_stmt == nullptr) { - loop_block_node = while_node; - inner_scope = parent_scope; + while_scope = trans_scope_while_create(c, parent_scope); + loop_block_node = while_scope->node; + cond_scope = parent_scope; } else { TransScopeBlock *child_scope = trans_scope_block_create(c, parent_scope); loop_block_node = child_scope->node; - inner_scope = &child_scope->base; AstNode *vars_node; - inner_scope = trans_stmt(c, &child_scope->base, init_stmt, &vars_node); - if (inner_scope == nullptr) + cond_scope = trans_stmt(c, &child_scope->base, init_stmt, &vars_node); + if (cond_scope == nullptr) return nullptr; if (vars_node != nullptr) child_scope->node->data.block.statements.append(vars_node); - child_scope->node->data.block.statements.append(while_node); + while_scope = trans_scope_while_create(c, cond_scope); + + child_scope->node->data.block.statements.append(while_scope->node); } const Stmt *cond_stmt = stmt->getCond(); if (cond_stmt == nullptr) { - AstNode *true_node = trans_create_node(c, NodeTypeBoolLiteral); - true_node->data.bool_literal.value = true; - while_node->data.while_expr.condition = true_node; + while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); } else { - TransScope *cond_scope = trans_stmt(c, inner_scope, cond_stmt, &while_node->data.while_expr.condition); - if (cond_scope == nullptr) + TransScope *end_cond_scope = trans_stmt(c, cond_scope, cond_stmt, + &while_scope->node->data.while_expr.condition); + if (end_cond_scope == nullptr) return nullptr; } const Stmt *inc_stmt = stmt->getInc(); if (inc_stmt != nullptr) { AstNode *inc_node; - TransScope *inc_scope = trans_stmt(c, inner_scope, inc_stmt, &inc_node); + TransScope *inc_scope = trans_stmt(c, cond_scope, inc_stmt, &inc_node); if (inc_scope == nullptr) return nullptr; - while_node->data.while_expr.continue_expr = inc_node; + while_scope->node->data.while_expr.continue_expr = inc_node; } - AstNode *child_statement; - TransScope *body_scope = trans_stmt(c, inner_scope, stmt->getBody(), &child_statement); + AstNode *body_statement; + TransScope *body_scope = trans_stmt(c, &while_scope->base, stmt->getBody(), &body_statement); if (body_scope == nullptr) return nullptr; - while_node->data.while_expr.body = child_statement; + while_scope->node->data.while_expr.body = body_statement; return loop_block_node; } @@ -2371,9 +2511,8 @@ static int wrap_stmt(AstNode **out_node, TransScope **out_scope, TransScope *in_ if (result_node == nullptr) return ErrorUnexpected; *out_node = result_node; - if (out_scope != nullptr) { + if (out_scope != nullptr) *out_scope = in_scope; - } return ErrorNone; } @@ -2456,18 +2595,13 @@ static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt, case Stmt::ParenExprClass: return wrap_stmt(out_node, out_child_scope, scope, trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue)); -// case Stmt::SwitchStmtClass: -// return wrap_stmt(out_node, out_child_scope, scope, -// trans_switch_stmt(c, scope, (const SwitchStmt *)stmt)); case Stmt::SwitchStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass"); - return ErrorUnexpected; + return wrap_stmt(out_node, out_child_scope, scope, + trans_switch_stmt(c, scope, (const SwitchStmt *)stmt)); case Stmt::CaseStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass"); - return ErrorUnexpected; + return trans_switch_case(c, scope, (const CaseStmt *)stmt, out_node, out_child_scope); case Stmt::DefaultStmtClass: - emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass"); - return ErrorUnexpected; + return trans_switch_default(c, scope, (const DefaultStmt *)stmt, out_node, out_child_scope); case Stmt::NoStmtClass: emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass"); return ErrorUnexpected; @@ -2981,7 +3115,8 @@ static AstNode *trans_expr(Context *c, ResultUsed result_used, TransScope *scope TransLRValue lrval) { AstNode *result_node; - if (trans_stmt_extra(c, scope, expr, result_used, lrval, &result_node, nullptr, nullptr)) { + TransScope *result_scope; + if (trans_stmt_extra(c, scope, expr, result_used, lrval, &result_node, &result_scope, nullptr)) { return nullptr; } return result_node; @@ -3536,6 +3671,14 @@ static TransScopeRoot *trans_scope_root_create(Context *c) { return result; } +static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope) { + TransScopeWhile *result = allocate(1); + result->base.id = TransScopeIdWhile; + result->base.parent = parent_scope; + result->node = trans_create_node(c, NodeTypeWhileExpr); + return result; +} + static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope) { TransScopeBlock *result = allocate(1); result->base.id = TransScopeIdBlock; @@ -3553,13 +3696,13 @@ static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scop return result; } -//static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) { -// TransScopeSwitch *result = allocate(1); -// result->base.id = TransScopeIdSwitch; -// result->base.parent = parent_scope; -// result->switch_node = trans_create_node(c, NodeTypeSwitchExpr); -// return result; -//} +static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) { + TransScopeSwitch *result = allocate(1); + result->base.id = TransScopeIdSwitch; + result->base.parent = parent_scope; + result->switch_node = trans_create_node(c, NodeTypeSwitchExpr); + return result; +} static TransScopeBlock *trans_scope_block_find(TransScope *scope) { while (scope != nullptr) { @@ -3571,6 +3714,16 @@ static TransScopeBlock *trans_scope_block_find(TransScope *scope) { return nullptr; } +static TransScopeSwitch *trans_scope_switch_find(TransScope *scope) { + while (scope != nullptr) { + if (scope->id == TransScopeIdSwitch) { + return (TransScopeSwitch *)scope; + } + scope = scope->parent; + } + return nullptr; +} + static void render_aliases(Context *c) { for (size_t i = 0; i < c->aliases.length; i += 1) { Alias *alias = &c->aliases.at(i); diff --git a/test/translate_c.zig b/test/translate_c.zig index b957ac4b05..e9f5e7de42 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1001,46 +1001,46 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\} ); - //cases.add("switch statement", - // \\int foo(int x) { - // \\ switch (x) { - // \\ case 1: - // \\ x += 1; - // \\ case 2: - // \\ break; - // \\ case 3: - // \\ case 4: - // \\ return x + 1; - // \\ default: - // \\ return 10; - // \\ } - // \\ return x + 13; - // \\} - //, - // \\fn foo(_x: i32) -> i32 { - // \\ var x = _x; - // \\ switch (x) { - // \\ 1 => goto switch_case_1; - // \\ 2 => goto switch_case_2; - // \\ 3 => goto switch_case_3; - // \\ 4 => goto switch_case_4; - // \\ else => goto switch_default; - // \\ } - // \\switch_case_1: - // \\ x += 1; - // \\ goto switch_case_2; - // \\switch_case_2: - // \\ goto switch_end; - // \\switch_case_3: - // \\ goto switch_case_4; - // \\switch_case_4: - // \\ return x += 1; - // \\switch_default: - // \\ return 10; - // \\switch_end: - // \\ return x + 13; - // \\} - //); + cases.add("switch statement", + \\int foo(int x) { + \\ switch (x) { + \\ case 1: + \\ x += 1; + \\ case 2: + \\ break; + \\ case 3: + \\ case 4: + \\ return x + 1; + \\ default: + \\ return 10; + \\ } + \\ return x + 13; + \\} + , + \\fn foo(_arg_x: c_int) -> c_int { + \\ var x = _arg_x; + \\ while (true) { + \\ switch (x) { + \\ 1 => goto case_0, + \\ 2 => goto case_1, + \\ 3 => goto case_2, + \\ 4 => goto case_3, + \\ else => goto default, + \\ }; + \\ case_0: + \\ x += 1; + \\ case_1: + \\ break; + \\ case_2: + \\ case_3: + \\ return x + 1; + \\ default: + \\ return 10; + \\ break; + \\ }; + \\ return x + 13; + \\} + ); } From aa2ca3f02c11c133964d793847c925e3bd131b27 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 15:58:49 -0500 Subject: [PATCH 25/34] translate-c: better way to translate switch previously `continue` would be handled incorrectly --- src/translate_c.cpp | 60 ++++++++++++++++++++++++++++---------------- test/translate_c.zig | 7 +++--- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 27066a08b6..3c9ff6e264 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -82,6 +82,7 @@ struct TransScopeSwitch { AstNode *switch_node; uint32_t case_index; bool found_default; + Buf *end_label_name; }; struct TransScopeVar { @@ -248,6 +249,18 @@ static AstNode *trans_create_node_addr_of(Context *c, bool is_const, bool is_vol return node; } +static AstNode *trans_create_node_goto(Context *c, Buf *label_name) { + AstNode *goto_node = trans_create_node(c, NodeTypeGoto); + goto_node->data.goto_expr.name = label_name; + return goto_node; +} + +static AstNode *trans_create_node_label(Context *c, Buf *label_name) { + AstNode *label_node = trans_create_node(c, NodeTypeLabel); + label_node->data.label.name = label_name; + return label_node; +} + static AstNode *trans_create_node_bool(Context *c, bool value) { AstNode *bool_node = trans_create_node(c, NodeTypeBoolLiteral); bool_node->data.bool_literal.value = value; @@ -2283,11 +2296,7 @@ static AstNode *trans_do_loop(Context *c, TransScope *parent_scope, const DoStmt } static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const SwitchStmt *stmt) { - TransScopeWhile *while_scope = trans_scope_while_create(c, parent_scope); - while_scope->node->data.while_expr.condition = trans_create_node_bool(c, true); - - TransScopeBlock *block_scope = trans_scope_block_create(c, &while_scope->base); - while_scope->node->data.while_expr.body = block_scope->node; + TransScopeBlock *block_scope = trans_scope_block_create(c, parent_scope); TransScopeSwitch *switch_scope; @@ -2305,6 +2314,10 @@ static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const Sw } block_scope->node->data.block.statements.append(switch_scope->switch_node); + // TODO avoid name collisions + Buf *end_label_name = buf_create_from_str("end"); + switch_scope->end_label_name = end_label_name; + const Expr *cond_expr = stmt->getCond(); assert(cond_expr != nullptr); @@ -2336,9 +2349,11 @@ static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const Sw } // This is necessary if the last switch case "falls through" the end of the switch block - block_scope->node->data.block.statements.append(trans_create_node(c, NodeTypeBreak)); + block_scope->node->data.block.statements.append(trans_create_node_goto(c, end_label_name)); - return while_scope->node; + block_scope->node->data.block.statements.append(trans_create_node_label(c, end_label_name)); + + return block_scope->node; } static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStmt *stmt, AstNode **out_node, @@ -2365,18 +2380,13 @@ static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStm return ErrorUnexpected; prong_node->data.switch_prong.items.append(item_node); - AstNode *goto_node = trans_create_node(c, NodeTypeGoto); - goto_node->data.goto_expr.name = label_name; - prong_node->data.switch_prong.expr = goto_node; + prong_node->data.switch_prong.expr = trans_create_node_goto(c, label_name); switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); } - AstNode *label_node = trans_create_node(c, NodeTypeLabel); - label_node->data.label.name = label_name; - TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); - scope_block->node->data.block.statements.append(label_node); + scope_block->node->data.block.statements.append(trans_create_node_label(c, label_name)); AstNode *sub_stmt_node; TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); @@ -2399,23 +2409,19 @@ static int trans_switch_default(Context *c, TransScope *parent_scope, const Defa Buf *label_name = buf_sprintf("default"); - AstNode *label_node = trans_create_node(c, NodeTypeLabel); - label_node->data.label.name = label_name; - { // Add the prong AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); - AstNode *goto_node = trans_create_node(c, NodeTypeGoto); - goto_node->data.goto_expr.name = label_name; - prong_node->data.switch_prong.expr = goto_node; + prong_node->data.switch_prong.expr = trans_create_node_goto(c, label_name); switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); switch_scope->found_default = true; } TransScopeBlock *scope_block = trans_scope_block_find(parent_scope); - scope_block->node->data.block.statements.append(label_node); + scope_block->node->data.block.statements.append(trans_create_node_label(c, label_name)); + AstNode *sub_stmt_node; TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node); @@ -2500,7 +2506,17 @@ static AstNode *trans_string_literal(Context *c, TransScope *scope, const String } static AstNode *trans_break_stmt(Context *c, TransScope *scope, const BreakStmt *stmt) { - return trans_create_node(c, NodeTypeBreak); + TransScope *cur_scope = scope; + while (cur_scope != nullptr) { + if (cur_scope->id == TransScopeIdWhile) { + return trans_create_node(c, NodeTypeBreak); + } else if (cur_scope->id == TransScopeIdSwitch) { + TransScopeSwitch *switch_scope = (TransScopeSwitch *)cur_scope; + return trans_create_node_goto(c, switch_scope->end_label_name); + } + cur_scope = cur_scope->parent; + } + zig_unreachable(); } static AstNode *trans_continue_stmt(Context *c, TransScope *scope, const ContinueStmt *stmt) { diff --git a/test/translate_c.zig b/test/translate_c.zig index e9f5e7de42..4323a867f8 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1019,7 +1019,7 @@ pub fn addCases(cases: &tests.TranslateCContext) { , \\fn foo(_arg_x: c_int) -> c_int { \\ var x = _arg_x; - \\ while (true) { + \\ { \\ switch (x) { \\ 1 => goto case_0, \\ 2 => goto case_1, @@ -1030,13 +1030,14 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ case_0: \\ x += 1; \\ case_1: - \\ break; + \\ goto end; \\ case_2: \\ case_3: \\ return x + 1; \\ default: \\ return 10; - \\ break; + \\ goto end; + \\ end: \\ }; \\ return x + 13; \\} From 9a8545d5903ff60de16f0ddab2b9a4c4c1a798a7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 16:03:56 -0500 Subject: [PATCH 26/34] translate-c: fix translation when no default switch case --- src/translate_c.cpp | 2 +- test/translate_c.zig | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 3c9ff6e264..d56673d39a 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -2344,7 +2344,7 @@ static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const Sw if (!switch_scope->found_default && !stmt->isAllEnumCasesCovered()) { AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng); - prong_node->data.switch_prong.expr = trans_create_node(c, NodeTypeBreak); + prong_node->data.switch_prong.expr = trans_create_node_goto(c, end_label_name); switch_scope->switch_node->data.switch_expr.prongs.append(prong_node); } diff --git a/test/translate_c.zig b/test/translate_c.zig index 4323a867f8..4d00307d21 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1042,6 +1042,44 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return x + 13; \\} ); + + cases.add("switch statement with no default", + \\int foo(int x) { + \\ switch (x) { + \\ case 1: + \\ x += 1; + \\ case 2: + \\ break; + \\ case 3: + \\ case 4: + \\ return x + 1; + \\ } + \\ return x + 13; + \\} + , + \\fn foo(_arg_x: c_int) -> c_int { + \\ var x = _arg_x; + \\ { + \\ switch (x) { + \\ 1 => goto case_0, + \\ 2 => goto case_1, + \\ 3 => goto case_2, + \\ 4 => goto case_3, + \\ else => goto end, + \\ }; + \\ case_0: + \\ x += 1; + \\ case_1: + \\ goto end; + \\ case_2: + \\ case_3: + \\ return x + 1; + \\ goto end; + \\ end: + \\ }; + \\ return x + 13; + \\} + ); } From 93fac5f257a80c2ca0abd30aedbeae300f6460f8 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 17:30:43 -0500 Subject: [PATCH 27/34] translate-c: support variable name shadowing --- src/translate_c.cpp | 24 +++++++++++++++++++----- test/translate_c.zig | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index d56673d39a..e3d001fe33 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -157,6 +157,19 @@ static void add_global_weak_alias(Context *c, Buf *new_name, Buf *canon_name) { alias->canon_name = canon_name; } +static Buf *trans_lookup_zig_symbol(Context *c, TransScope *scope, Buf *c_symbol_name) { + while (scope != nullptr) { + if (scope->id == TransScopeIdVar) { + TransScopeVar *var_scope = (TransScopeVar *)scope; + if (buf_eql_buf(var_scope->c_name, c_symbol_name)) { + return var_scope->zig_name; + } + } + scope = scope->parent; + } + return c_symbol_name; +} + static AstNode * trans_create_node(Context *c, NodeType id) { AstNode *node = allocate(1); node->type = id; @@ -1650,13 +1663,14 @@ static AstNode *trans_implicit_cast_expr(Context *c, TransScope *scope, const Im zig_unreachable(); } -static AstNode *trans_decl_ref_expr(Context *c, const DeclRefExpr *stmt, TransLRValue lrval) { +static AstNode *trans_decl_ref_expr(Context *c, TransScope *scope, const DeclRefExpr *stmt, TransLRValue lrval) { const ValueDecl *value_decl = stmt->getDecl(); - Buf *symbol_name = buf_create_from_str(decl_name(value_decl)); + Buf *c_symbol_name = buf_create_from_str(decl_name(value_decl)); + Buf *zig_symbol_name = trans_lookup_zig_symbol(c, scope, c_symbol_name); if (lrval == TransLValue) { - c->ptr_params.put(symbol_name, true); + c->ptr_params.put(zig_symbol_name, true); } - return trans_create_node_symbol(c, symbol_name); + return trans_create_node_symbol(c, zig_symbol_name); } static AstNode *trans_create_post_crement(Context *c, ResultUsed result_used, TransScope *scope, @@ -2562,7 +2576,7 @@ static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt, trans_implicit_cast_expr(c, scope, (const ImplicitCastExpr *)stmt)); case Stmt::DeclRefExprClass: return wrap_stmt(out_node, out_child_scope, scope, - trans_decl_ref_expr(c, (const DeclRefExpr *)stmt, lrvalue)); + trans_decl_ref_expr(c, scope, (const DeclRefExpr *)stmt, lrvalue)); case Stmt::UnaryOperatorClass: return wrap_stmt(out_node, out_child_scope, scope, trans_unary_operator(c, result_used, scope, (const UnaryOperator *)stmt)); diff --git a/test/translate_c.zig b/test/translate_c.zig index 4d00307d21..527c5831a9 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1080,6 +1080,26 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return x + 13; \\} ); + + cases.add("variable name shadowing", + \\int foo(void) { + \\ int x = 1; + \\ { + \\ int x = 2; + \\ x += 1; + \\ } + \\ return x; + \\} + , + \\pub fn foo() -> c_int { + \\ var x: c_int = 1; + \\ { + \\ var x_0: c_int = 2; + \\ x_0 += 1; + \\ }; + \\ return x; + \\} + ); } From 671183fa9a0be28851002d07ad7ddf0d3bd29b46 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 26 Nov 2017 20:05:55 -0500 Subject: [PATCH 28/34] translate-c: support pointer casting also avoid some unnecessary casts --- src/translate_c.cpp | 118 +++++++++++++++++++++++++++++-------------- test/translate_c.zig | 33 ++++++------ 2 files changed, 97 insertions(+), 54 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index e3d001fe33..b123fcb56e 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -440,6 +440,10 @@ static AstNode *trans_create_node_apint(Context *c, const llvm::APSInt &aps_int) } +static const Type *qual_type_canon(QualType qt) { + return qt.getCanonicalType().getTypePtr(); +} + static QualType get_expr_qual_type(Context *c, const Expr *expr) { // String literals in C are `char *` but they should really be `const char *`. if (expr->getStmtClass() == Stmt::ImplicitCastExprClass) { @@ -462,10 +466,7 @@ static AstNode *get_expr_type(Context *c, const Expr *expr) { return trans_qual_type(c, get_expr_qual_type(c, expr), expr->getLocStart()); } -static bool expr_types_equal(Context *c, const Expr *expr1, const Expr *expr2) { - QualType t1 = get_expr_qual_type(c, expr1); - QualType t2 = get_expr_qual_type(c, expr2); - +static bool qual_types_equal(QualType t1, QualType t2) { if (t1.isConstQualified() != t2.isConstQualified()) { return false; } @@ -482,26 +483,27 @@ static bool is_c_void_type(AstNode *node) { return (node->type == NodeTypeSymbol && buf_eql_str(node->data.symbol_expr.symbol, "c_void")); } -static AstNode* trans_c_cast(Context *c, const SourceLocation &source_location, const QualType &qt, AstNode *expr) { - // TODO: maybe widen to increase size - // TODO: maybe bitcast to change sign - // TODO: maybe truncate to reduce size - return trans_create_node_fn_call_1(c, trans_qual_type(c, qt, source_location), expr); +static bool expr_types_equal(Context *c, const Expr *expr1, const Expr *expr2) { + QualType t1 = get_expr_qual_type(c, expr1); + QualType t2 = get_expr_qual_type(c, expr2); + + return qual_types_equal(t1, t2); } -static bool qual_type_is_fn_ptr(Context *c, const QualType &qt) { - const Type *ty = qt.getTypePtr(); +static bool qual_type_is_ptr(QualType qt) { + const Type *ty = qual_type_canon(qt); + return ty->getTypeClass() == Type::Pointer; +} + +static bool qual_type_is_fn_ptr(Context *c, QualType qt) { + const Type *ty = qual_type_canon(qt); if (ty->getTypeClass() != Type::Pointer) { return false; } const PointerType *pointer_ty = static_cast(ty); QualType child_qt = pointer_ty->getPointeeType(); const Type *child_ty = child_qt.getTypePtr(); - if (child_ty->getTypeClass() != Type::Paren) { - return false; - } - const ParenType *paren_ty = static_cast(child_ty); - return paren_ty->getInnerType().getTypePtr()->getTypeClass() == Type::FunctionProto; + return child_ty->getTypeClass() == Type::FunctionProto; } static uint32_t qual_type_int_bit_width(Context *c, const QualType &qt, const SourceLocation &source_loc) { @@ -594,17 +596,26 @@ static bool qual_type_child_is_fn_proto(const QualType &qt) { return false; } -static QualType resolve_any_typedef(Context *c, QualType qt) { - const Type * ty = qt.getTypePtr(); - if (ty->getTypeClass() != Type::Typedef) - return qt; - const TypedefType *typedef_ty = static_cast(ty); - const TypedefNameDecl *typedef_decl = typedef_ty->getDecl(); - return typedef_decl->getUnderlyingType(); +static AstNode* trans_c_cast(Context *c, const SourceLocation &source_location, QualType dest_type, + QualType src_type, AstNode *expr) +{ + if (qual_types_equal(dest_type, src_type)) { + return expr; + } + if (qual_type_is_ptr(dest_type) && qual_type_is_ptr(src_type)) { + AstNode *ptr_cast_node = trans_create_node_builtin_fn_call_str(c, "ptrCast"); + ptr_cast_node->data.fn_call_expr.params.append(trans_qual_type(c, dest_type, source_location)); + ptr_cast_node->data.fn_call_expr.params.append(expr); + return ptr_cast_node; + } + // TODO: maybe widen to increase size + // TODO: maybe bitcast to change sign + // TODO: maybe truncate to reduce size + return trans_create_node_fn_call_1(c, trans_qual_type(c, dest_type, source_location), expr); } static bool c_is_signed_integer(Context *c, QualType qt) { - const Type *c_type = resolve_any_typedef(c, qt).getTypePtr(); + const Type *c_type = qual_type_canon(qt); if (c_type->getTypeClass() != Type::Builtin) return false; const BuiltinType *builtin_ty = static_cast(c_type); @@ -623,7 +634,7 @@ static bool c_is_signed_integer(Context *c, QualType qt) { } static bool c_is_unsigned_integer(Context *c, QualType qt) { - const Type *c_type = resolve_any_typedef(c, qt).getTypePtr(); + const Type *c_type = qual_type_canon(qt); if (c_type->getTypeClass() != Type::Builtin) return false; const BuiltinType *builtin_ty = static_cast(c_type); @@ -891,6 +902,11 @@ static AstNode *trans_type(Context *c, const Type *ty, const SourceLocation &sou return nullptr; } // convert c_void to actual void (only for return type) + // we do want to look at the AstNode instead of QualType, because + // if they do something like: + // typedef Foo void; + // void foo(void) -> Foo; + // we want to keep the return type AST node. if (is_c_void_type(proto_node->data.fn_proto.return_type)) { proto_node->data.fn_proto.return_type = nullptr; } @@ -1317,19 +1333,28 @@ static AstNode *trans_create_compound_assign_shift(Context *c, ResultUsed result if (rhs == nullptr) return nullptr; AstNode *coerced_rhs = trans_create_node_fn_call_1(c, rhs_type, rhs); + // operation_type(*_ref) + AstNode *operation_type_cast = trans_c_cast(c, rhs_location, + stmt->getComputationLHSType(), + stmt->getLHS()->getType(), + trans_create_node_prefix_op(c, PrefixOpDereference, + trans_create_node_symbol(c, tmp_var_name))); + + // result_type(... >> u5(rhs)) + AstNode *result_type_cast = trans_c_cast(c, rhs_location, + stmt->getComputationResultType(), + stmt->getComputationLHSType(), + trans_create_node_bin_op(c, + operation_type_cast, + bin_op, + coerced_rhs)); + + // *_ref = ... AstNode *assign_statement = trans_create_node_bin_op(c, trans_create_node_prefix_op(c, PrefixOpDereference, trans_create_node_symbol(c, tmp_var_name)), - BinOpTypeAssign, - trans_c_cast(c, rhs_location, - stmt->getComputationResultType(), - trans_create_node_bin_op(c, - trans_c_cast(c, rhs_location, - stmt->getComputationLHSType(), - trans_create_node_prefix_op(c, PrefixOpDereference, - trans_create_node_symbol(c, tmp_var_name))), - bin_op, - coerced_rhs))); + BinOpTypeAssign, result_type_cast); + child_scope->node->data.block.statements.append(assign_statement); if (result_used == ResultUsedYes) { @@ -1474,7 +1499,8 @@ static AstNode *trans_implicit_cast_expr(Context *c, TransScope *scope, const Im AstNode *target_node = trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue); if (target_node == nullptr) return nullptr; - return trans_c_cast(c, stmt->getExprLoc(), stmt->getType(), target_node); + return trans_c_cast(c, stmt->getExprLoc(), stmt->getType(), + stmt->getSubExpr()->getType(), target_node); } case CK_FunctionToPointerDecay: case CK_ArrayToPointerDecay: @@ -2177,9 +2203,23 @@ static AstNode *trans_call_expr(Context *c, ResultUsed result_used, TransScope * if (callee_raw_node == nullptr) return nullptr; - AstNode *callee_node; + AstNode *callee_node = nullptr; if (qual_type_is_fn_ptr(c, stmt->getCallee()->getType())) { - callee_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, callee_raw_node); + if (stmt->getCallee()->getStmtClass() == Stmt::ImplicitCastExprClass) { + const ImplicitCastExpr *implicit_cast = static_cast(stmt->getCallee()); + if (implicit_cast->getCastKind() == CK_FunctionToPointerDecay) { + if (implicit_cast->getSubExpr()->getStmtClass() == Stmt::DeclRefExprClass) { + const DeclRefExpr *decl_ref = static_cast(implicit_cast->getSubExpr()); + const Decl *decl = decl_ref->getFoundDecl(); + if (decl->getKind() == Decl::Function) { + callee_node = callee_raw_node; + } + } + } + } + if (callee_node == nullptr) { + callee_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, callee_raw_node); + } } else { callee_node = callee_raw_node; } @@ -2237,7 +2277,7 @@ static AstNode *trans_c_style_cast_expr(Context *c, ResultUsed result_used, Tran if (sub_expr_node == nullptr) return nullptr; - return trans_c_cast(c, stmt->getLocStart(), stmt->getType(), sub_expr_node); + return trans_c_cast(c, stmt->getLocStart(), stmt->getType(), stmt->getSubExpr()->getType(), sub_expr_node); } static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, TransScope *scope, diff --git a/test/translate_c.zig b/test/translate_c.zig index 527c5831a9..7bae341a67 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -677,12 +677,12 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ }; \\ a >>= @import("std").math.Log2Int(c_int)({ \\ const _ref = &a; - \\ (*_ref) = c_int(c_int(*_ref) >> @import("std").math.Log2Int(c_int)(1)); + \\ (*_ref) = ((*_ref) >> @import("std").math.Log2Int(c_int)(1)); \\ *_ref \\ }); \\ a <<= @import("std").math.Log2Int(c_int)({ \\ const _ref = &a; - \\ (*_ref) = c_int(c_int(*_ref) << @import("std").math.Log2Int(c_int)(1)); + \\ (*_ref) = ((*_ref) << @import("std").math.Log2Int(c_int)(1)); \\ *_ref \\ }); \\} @@ -735,12 +735,12 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ }; \\ a >>= @import("std").math.Log2Int(c_uint)({ \\ const _ref = &a; - \\ (*_ref) = c_uint(c_uint(*_ref) >> @import("std").math.Log2Int(c_uint)(1)); + \\ (*_ref) = ((*_ref) >> @import("std").math.Log2Int(c_uint)(1)); \\ *_ref \\ }); \\ a <<= @import("std").math.Log2Int(c_uint)({ \\ const _ref = &a; - \\ (*_ref) = c_uint(c_uint(*_ref) << @import("std").math.Log2Int(c_uint)(1)); + \\ (*_ref) = ((*_ref) << @import("std").math.Log2Int(c_uint)(1)); \\ *_ref \\ }); \\} @@ -878,17 +878,21 @@ pub fn addCases(cases: &tests.TranslateCContext) { cases.addC("deref function pointer", \\void foo(void) {} + \\void baz(void) {} \\void bar(void) { \\ void(*f)(void) = foo; \\ f(); \\ (*(f))(); + \\ baz(); \\} , \\export fn foo() {} + \\export fn baz() {} \\export fn bar() { \\ var f: ?extern fn() = foo; \\ (??f)(); \\ (??f)(); + \\ baz(); \\} ); @@ -1100,15 +1104,14 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return x; \\} ); + + cases.add("pointer casting", + \\float *ptrcast(int *a) { + \\ return (float *)a; + \\} + , + \\fn ptrcast(a: ?&c_int) -> ?&f32 { + \\ return @ptrCast(?&f32, a); + \\} + ); } - - - -// TODO -//float *ptrcast(int *a) { -// return (float *)a; -//} -// should translate to -// fn ptrcast(a: ?&c_int) -> ?&f32 { -// return @ptrCast(?&f32, a); -// } From 04472f57be5e91e82adf9346e71c1421725716d5 Mon Sep 17 00:00:00 2001 From: dimenus Date: Wed, 22 Nov 2017 10:01:43 -0600 Subject: [PATCH 29/34] Added support for exporting of C field expressions --- src/c_tokenizer.cpp | 7 ++++--- src/c_tokenizer.hpp | 1 + src/ir.cpp | 3 ++- src/translate_c.cpp | 30 +++++++++++++++++++++++++++--- test/translate_c.zig | 38 ++++++++++++++++++++++++++++++++++++-- 5 files changed, 70 insertions(+), 9 deletions(-) diff --git a/src/c_tokenizer.cpp b/src/c_tokenizer.cpp index 044831f72e..e5322e2b0f 100644 --- a/src/c_tokenizer.cpp +++ b/src/c_tokenizer.cpp @@ -216,9 +216,8 @@ void tokenize_c_macro(CTokenize *ctok, const uint8_t *c) { buf_append_char(&ctok->buf, '0'); break; case '.': - begin_token(ctok, CTokIdNumLitFloat); - ctok->state = CTokStateFloat; - buf_init_from_str(&ctok->buf, "0."); + begin_token(ctok, CTokIdDot); + end_token(ctok); break; case '(': begin_token(ctok, CTokIdLParen); @@ -238,6 +237,8 @@ void tokenize_c_macro(CTokenize *ctok, const uint8_t *c) { break; case CTokStateFloat: switch (*c) { + case '.': + break; case 'e': case 'E': buf_append_char(&ctok->buf, 'e'); diff --git a/src/c_tokenizer.hpp b/src/c_tokenizer.hpp index 8eea6c56c7..a3df2b94af 100644 --- a/src/c_tokenizer.hpp +++ b/src/c_tokenizer.hpp @@ -21,6 +21,7 @@ enum CTokId { CTokIdLParen, CTokIdRParen, CTokIdEOF, + CTokIdDot, }; enum CNumLitSuffix { diff --git a/src/ir.cpp b/src/ir.cpp index f632a261f6..7c15b48bee 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -6302,8 +6302,9 @@ static Buf *get_anon_type_name(CodeGen *codegen, IrExecutable *exec, const char buf_appendf(name, ")"); return name; } else { + //Note: C-imports do not have valid location information return buf_sprintf("(anonymous %s at %s:%" ZIG_PRI_usize ":%" ZIG_PRI_usize ")", kind_name, - buf_ptr(source_node->owner->path), source_node->line + 1, source_node->column + 1); + (source_node->owner->path != nullptr) ? buf_ptr(source_node->owner->path) : "(null)", source_node->line + 1, source_node->column + 1); } } } diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 27066a08b6..326cfeb633 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -3442,7 +3442,6 @@ static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl) { } const char *raw_name = decl_name(record_decl); - const char *container_kind_name; ContainerKind container_kind; if (record_decl->isUnion()) { @@ -3794,9 +3793,33 @@ static AstNode *parse_ctok(Context *c, CTokenize *ctok, size_t *tok_i) { return parse_ctok_num_lit(c, ctok, tok_i, false); case CTokIdSymbol: { - *tok_i += 1; + bool need_symbol = false; + CTokId curr_id = CTokIdSymbol; Buf *symbol_name = buf_create_from_buf(&tok->data.symbol); - return trans_create_node_symbol(c, symbol_name); + AstNode *curr_node = trans_create_node_symbol(c, symbol_name); + AstNode *parent_node = curr_node; + do { + *tok_i += 1; + CTok* curr_tok = &ctok->tokens.at(*tok_i); + if (need_symbol) { + if (curr_tok->id == CTokIdSymbol) { + symbol_name = buf_create_from_buf(&curr_tok->data.symbol); + curr_node = trans_create_node_field_access(c, parent_node, buf_create_from_buf(symbol_name)); + parent_node = curr_node; + need_symbol = false; + } else { + return nullptr; + } + } else { + if (curr_tok->id == CTokIdDot) { + need_symbol = true; + continue; + } else { + break; + } + } + } while (curr_id != CTokIdEOF); + return curr_node; } case CTokIdLParen: { @@ -3810,6 +3833,7 @@ static AstNode *parse_ctok(Context *c, CTokenize *ctok, size_t *tok_i) { *tok_i += 1; return inner_node; } + case CTokIdDot: case CTokIdEOF: case CTokIdRParen: // not able to make sense of this diff --git a/test/translate_c.zig b/test/translate_c.zig index e9f5e7de42..67c1d84310 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1041,10 +1041,44 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return x + 13; \\} ); + + cases.add("macros with field targets", + \\typedef unsigned int GLbitfield; + \\typedef void (*PFNGLCLEARPROC) (GLbitfield mask); + \\typedef void(*OpenGLProc)(void); + \\union OpenGLProcs { + \\ OpenGLProc ptr[1]; + \\ struct { + \\ PFNGLCLEARPROC Clear; + \\ } gl; + \\}; + \\extern union OpenGLProcs glProcs; + \\#define glClearUnion glProcs.gl.Clear + \\#define glClearPFN PFNGLCLEARPROC + , + \\pub const GLbitfield = c_uint; + , + \\pub const PFNGLCLEARPROC = ?extern fn(GLbitfield); + , + \\pub const OpenGLProc = ?extern fn(); + , + \\pub const union_OpenGLProcs = extern union { + \\ ptr: [1]OpenGLProc, + \\ gl: extern struct { + \\ Clear: PFNGLCLEARPROC, + \\ }, + \\}; + , + \\pub extern var glProcs: union_OpenGLProcs; + , + \\pub const glClearPFN = PFNGLCLEARPROC; + , + \\pub const glClearUnion = glProcs.gl.Clear; + , + \\pub const OpenGLProcs = union_OpenGLProcs; + ); } - - // TODO //float *ptrcast(int *a) { // return (float *)a; From 3e8fd245473a6e532fa4f9b42c51bd8257c15b3c Mon Sep 17 00:00:00 2001 From: Mason Remaley Date: Mon, 27 Nov 2017 21:00:05 -0500 Subject: [PATCH 30/34] Implements translation for the prefix not operator (#628) --- src/translate_c.cpp | 9 +++++++-- test/translate_c.zig | 10 ++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index b123fcb56e..23c3ff9821 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -1875,8 +1875,13 @@ static AstNode *trans_unary_operator(Context *c, ResultUsed result_used, TransSc } } case UO_Not: - emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_Not"); - return nullptr; + { + Expr *op_expr = stmt->getSubExpr(); + AstNode *sub_node = trans_expr(c, ResultUsedYes, scope, op_expr, TransRValue); + if (sub_node == nullptr) + return nullptr; + return trans_create_node_prefix_op(c, PrefixOpBinNot, sub_node); + } case UO_LNot: emit_warning(c, stmt->getLocStart(), "TODO handle C translation UO_LNot"); return nullptr; diff --git a/test/translate_c.zig b/test/translate_c.zig index 7bae341a67..198d813af1 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1114,4 +1114,14 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return @ptrCast(?&f32, a); \\} ); + + cases.add("bin not", + \\int foo(int x) { + \\ return ~x; + \\} + , + \\pub fn foo(x: c_int) -> c_int { + \\ return ~x; + \\} + ); } From 1ab84a27d374e666463c606dc1cd1c4972b52a74 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 28 Nov 2017 00:32:32 -0500 Subject: [PATCH 31/34] translate-c: fix sometimes getting (no file) warnings Thanks to Mason Remaley for testing the fix. --- src/translate_c.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 23c3ff9821..72ac7b3697 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -138,7 +138,7 @@ static void emit_warning(Context *c, const SourceLocation &sl, const char *forma Buf *msg = buf_vprintf(format, ap); va_end(ap); - StringRef filename = c->source_manager->getFilename(sl); + StringRef filename = c->source_manager->getFilename(c->source_manager->getSpellingLoc(sl)); const char *filename_bytes = (const char *)filename.bytes_begin(); Buf *path; if (filename_bytes) { From e745544dacc5bda010fc65e5c8b81cb3b5249223 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 28 Nov 2017 02:37:00 -0500 Subject: [PATCH 32/34] translate-c: detect macros referencing field lookup as fn calls which assert the fn ptr is non-null --- src/c_tokenizer.cpp | 1 + src/translate_c.cpp | 130 +++++++++++++++++++++++++++++-------------- test/translate_c.zig | 4 +- 3 files changed, 91 insertions(+), 44 deletions(-) diff --git a/src/c_tokenizer.cpp b/src/c_tokenizer.cpp index e5322e2b0f..6be2cf991e 100644 --- a/src/c_tokenizer.cpp +++ b/src/c_tokenizer.cpp @@ -120,6 +120,7 @@ static void begin_token(CTokenize *ctok, CTokId id) { case CTokIdLParen: case CTokIdRParen: case CTokIdEOF: + case CTokIdDot: break; } } diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 16fe4b52d1..d50cc1d315 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -23,11 +23,6 @@ using namespace clang; -struct MacroSymbol { - Buf *name; - Buf *value; -}; - struct Alias { Buf *new_name; Buf *canon_name; @@ -44,7 +39,6 @@ struct Context { HashMap global_table; SourceManager *source_manager; ZigList aliases; - ZigList macro_symbols; AstNode *source_node; bool warnings_on; @@ -351,8 +345,7 @@ static AstNode *trans_create_node_var_decl_local(Context *c, bool is_const, Buf return trans_create_node_var_decl(c, VisibModPrivate, is_const, var_name, type_node, init_node); } - -static AstNode *trans_create_node_inline_fn(Context *c, Buf *fn_name, Buf *var_name, AstNode *src_proto_node) { +static AstNode *trans_create_node_inline_fn(Context *c, Buf *fn_name, AstNode *ref_node, AstNode *src_proto_node) { AstNode *fn_def = trans_create_node(c, NodeTypeFnDef); AstNode *fn_proto = trans_create_node(c, NodeTypeFnProto); fn_proto->data.fn_proto.visib_mod = c->visib_mod; @@ -363,7 +356,7 @@ static AstNode *trans_create_node_inline_fn(Context *c, Buf *fn_name, Buf *var_n fn_def->data.fn_def.fn_proto = fn_proto; fn_proto->data.fn_proto.fn_def_node = fn_def; - AstNode *unwrap_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, trans_create_node_symbol(c, var_name)); + AstNode *unwrap_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, ref_node); AstNode *fn_call_node = trans_create_node(c, NodeTypeFnCallExpr); fn_call_node->data.fn_call_expr.fn_ref_expr = unwrap_node; @@ -3808,6 +3801,83 @@ static void render_aliases(Context *c) { } } +static AstNode *trans_lookup_ast_container_typeof(Context *c, AstNode *ref_node); + +static AstNode *trans_lookup_ast_container(Context *c, AstNode *type_node) { + if (type_node == nullptr) { + return nullptr; + } else if (type_node->type == NodeTypeContainerDecl) { + return type_node; + } else if (type_node->type == NodeTypePrefixOpExpr) { + return type_node; + } else if (type_node->type == NodeTypeSymbol) { + AstNode *existing_node = get_global(c, type_node->data.symbol_expr.symbol); + if (existing_node == nullptr) + return nullptr; + if (existing_node->type != NodeTypeVariableDeclaration) + return nullptr; + return trans_lookup_ast_container(c, existing_node->data.variable_declaration.expr); + } else if (type_node->type == NodeTypeFieldAccessExpr) { + AstNode *container_node = trans_lookup_ast_container_typeof(c, type_node->data.field_access_expr.struct_expr); + if (container_node == nullptr) + return nullptr; + if (container_node->type != NodeTypeContainerDecl) + return container_node; + + for (size_t i = 0; i < container_node->data.container_decl.fields.length; i += 1) { + AstNode *field_node = container_node->data.container_decl.fields.items[i]; + if (buf_eql_buf(field_node->data.struct_field.name, type_node->data.field_access_expr.field_name)) { + return trans_lookup_ast_container(c, field_node->data.struct_field.type); + } + } + return nullptr; + } else { + return nullptr; + } +} + +static AstNode *trans_lookup_ast_container_typeof(Context *c, AstNode *ref_node) { + if (ref_node->type == NodeTypeSymbol) { + AstNode *existing_node = get_global(c, ref_node->data.symbol_expr.symbol); + if (existing_node == nullptr) + return nullptr; + if (existing_node->type != NodeTypeVariableDeclaration) + return nullptr; + return trans_lookup_ast_container(c, existing_node->data.variable_declaration.type); + } else if (ref_node->type == NodeTypeFieldAccessExpr) { + AstNode *container_node = trans_lookup_ast_container_typeof(c, ref_node->data.field_access_expr.struct_expr); + if (container_node == nullptr) + return nullptr; + if (container_node->type != NodeTypeContainerDecl) + return container_node; + for (size_t i = 0; i < container_node->data.container_decl.fields.length; i += 1) { + AstNode *field_node = container_node->data.container_decl.fields.items[i]; + if (buf_eql_buf(field_node->data.struct_field.name, ref_node->data.field_access_expr.field_name)) { + return trans_lookup_ast_container(c, field_node->data.struct_field.type); + } + } + return nullptr; + } else { + return nullptr; + } +} + +static AstNode *trans_lookup_ast_maybe_fn(Context *c, AstNode *ref_node) { + AstNode *prefix_node = trans_lookup_ast_container_typeof(c, ref_node); + if (prefix_node == nullptr) + return nullptr; + if (prefix_node->type != NodeTypePrefixOpExpr) + return nullptr; + if (prefix_node->data.prefix_op_expr.prefix_op != PrefixOpMaybe) + return nullptr; + + AstNode *fn_proto_node = prefix_node->data.prefix_op_expr.primary_expr; + if (fn_proto_node->type != NodeTypeFnProto) + return nullptr; + + return fn_proto_node; +} + static void render_macros(Context *c) { auto it = c->macro_table.entry_iterator(); for (;;) { @@ -3815,9 +3885,16 @@ static void render_macros(Context *c) { if (!entry) break; + AstNode *proto_node; AstNode *value_node = entry->value; if (value_node->type == NodeTypeFnDef) { add_top_level_decl(c, value_node->data.fn_def.fn_proto->data.fn_proto.name, value_node); + } else if ((proto_node = trans_lookup_ast_maybe_fn(c, value_node))) { + // If a macro aliases a global variable which is a function pointer, we conclude that + // the macro is intended to represent a function that assumes the function pointer + // variable is non-null and calls it. + AstNode *inline_fn_node = trans_create_node_inline_fn(c, entry->key, value_node, proto_node); + add_top_level_decl(c, entry->key, inline_fn_node); } else { add_global_var(c, entry->key, value_node); } @@ -3944,40 +4021,8 @@ static void process_macro(Context *c, CTokenize *ctok, Buf *name, const char *ch if (buf_eql_buf(name, symbol_name)) { return; } - c->macro_symbols.append({name, symbol_name}); - } else { - c->macro_table.put(name, result_node); - } -} - -static void process_symbol_macros(Context *c) { - for (size_t i = 0; i < c->macro_symbols.length; i += 1) { - MacroSymbol ms = c->macro_symbols.at(i); - - // Check if this macro aliases another top level declaration - AstNode *existing_node = get_global(c, ms.value); - if (!existing_node || name_exists_global(c, ms.name)) - continue; - - // If a macro aliases a global variable which is a function pointer, we conclude that - // the macro is intended to represent a function that assumes the function pointer - // variable is non-null and calls it. - if (existing_node->type == NodeTypeVariableDeclaration) { - AstNode *var_type = existing_node->data.variable_declaration.type; - if (var_type != nullptr && var_type->type == NodeTypePrefixOpExpr && - var_type->data.prefix_op_expr.prefix_op == PrefixOpMaybe) - { - AstNode *fn_proto_node = var_type->data.prefix_op_expr.primary_expr; - if (fn_proto_node->type == NodeTypeFnProto) { - AstNode *inline_fn_node = trans_create_node_inline_fn(c, ms.name, ms.value, fn_proto_node); - c->macro_table.put(ms.name, inline_fn_node); - continue; - } - } - } - - add_global_var(c, ms.name, trans_create_node_symbol(c, ms.value)); } + c->macro_table.put(name, result_node); } static void process_preprocessor_entities(Context *c, ASTUnit &unit) { @@ -4194,7 +4239,6 @@ int parse_h_file(ImportTableEntry *import, ZigList *errors, const ch process_preprocessor_entities(c, *ast_unit); - process_symbol_macros(c); render_macros(c); render_aliases(c); diff --git a/test/translate_c.zig b/test/translate_c.zig index 76751e46d7..2fe4944de6 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1078,7 +1078,9 @@ pub fn addCases(cases: &tests.TranslateCContext) { , \\pub const glClearPFN = PFNGLCLEARPROC; , - \\pub const glClearUnion = glProcs.gl.Clear; + \\pub inline fn glClearUnion(arg0: GLbitfield) { + \\ (??glProcs.gl.Clear)(arg0) + \\} , \\pub const OpenGLProcs = union_OpenGLProcs; ); From 26096e79d1610218106faa75f7dbb10b4b5bbe5a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 28 Nov 2017 03:17:28 -0500 Subject: [PATCH 33/34] translate-c: fix clobbering primitive types --- src/translate_c.cpp | 3 +++ test/translate_c.zig | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index d50cc1d315..3295560af3 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -400,6 +400,9 @@ static AstNode *get_global(Context *c, Buf *name) { if (entry) return entry->value; } + if (c->codegen->primitive_type_table.maybe_get(name) != nullptr) { + return trans_create_node_symbol(c, name); + } return nullptr; } diff --git a/test/translate_c.zig b/test/translate_c.zig index 2fe4944de6..8a8d1d334b 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1162,4 +1162,14 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return ~x; \\} ); + + cases.add("primitive types included in defined symbols", + \\int foo(int u32) { + \\ return u32; + \\} + , + \\pub fn foo(u32_0: c_int) -> c_int { + \\ return u32_0; + \\} + ); } From 70662830044418fc2d637c166fc100fe72d60fcf Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 28 Nov 2017 23:44:45 -0500 Subject: [PATCH 34/34] translate-c: support const ptr initializer --- src/translate_c.cpp | 176 ++++++++++++++++++++++++++++++------------- test/translate_c.zig | 6 ++ 2 files changed, 129 insertions(+), 53 deletions(-) diff --git a/src/translate_c.cpp b/src/translate_c.cpp index 3295560af3..f5c1ef5810 100644 --- a/src/translate_c.cpp +++ b/src/translate_c.cpp @@ -28,36 +28,6 @@ struct Alias { Buf *canon_name; }; -struct Context { - ImportTableEntry *import; - ZigList *errors; - VisibMod visib_mod; - VisibMod export_visib_mod; - AstNode *root; - HashMap decl_table; - HashMap macro_table; - HashMap global_table; - SourceManager *source_manager; - ZigList aliases; - AstNode *source_node; - bool warnings_on; - - CodeGen *codegen; - ASTContext *ctx; - - HashMap ptr_params; -}; - -enum ResultUsed { - ResultUsedNo, - ResultUsedYes, -}; - -enum TransLRValue { - TransLValue, - TransRValue, -}; - enum TransScopeId { TransScopeIdSwitch, TransScopeIdVar, @@ -99,6 +69,37 @@ struct TransScopeWhile { AstNode *node; }; +struct Context { + ImportTableEntry *import; + ZigList *errors; + VisibMod visib_mod; + VisibMod export_visib_mod; + AstNode *root; + HashMap decl_table; + HashMap macro_table; + HashMap global_table; + SourceManager *source_manager; + ZigList aliases; + AstNode *source_node; + bool warnings_on; + + CodeGen *codegen; + ASTContext *ctx; + + TransScopeRoot *global_scope; + HashMap ptr_params; +}; + +enum ResultUsed { + ResultUsedNo, + ResultUsedYes, +}; + +enum TransLRValue { + TransLValue, + TransRValue, +}; + static TransScopeRoot *trans_scope_root_create(Context *c); static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope); static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope); @@ -3233,8 +3234,7 @@ static void visit_fn_decl(Context *c, const FunctionDecl *fn_decl) { return; } - TransScopeRoot *root_scope = trans_scope_root_create(c); - TransScope *scope = &root_scope->base; + TransScope *scope = &c->global_scope->base; for (size_t i = 0; i < proto_node->data.fn_proto.params.length; i += 1) { AstNode *param_node = proto_node->data.fn_proto.params.at(i); @@ -3600,6 +3600,93 @@ static AstNode *resolve_record_decl(Context *c, const RecordDecl *record_decl) { } } +static AstNode *trans_ap_value(Context *c, APValue *ap_value, QualType qt, const SourceLocation &source_loc) { + switch (ap_value->getKind()) { + case APValue::Int: + return trans_create_node_apint(c, ap_value->getInt()); + case APValue::Uninitialized: + return trans_create_node(c, NodeTypeUndefinedLiteral); + case APValue::Array: { + emit_warning(c, source_loc, "TODO add a test case for this code"); + + unsigned init_count = ap_value->getArrayInitializedElts(); + unsigned all_count = ap_value->getArraySize(); + unsigned leftover_count = all_count - init_count; + AstNode *init_node = trans_create_node(c, NodeTypeContainerInitExpr); + AstNode *arr_type_node = trans_qual_type(c, qt, source_loc); + init_node->data.container_init_expr.type = arr_type_node; + init_node->data.container_init_expr.kind = ContainerInitKindArray; + + QualType child_qt = qt.getTypePtr()->getLocallyUnqualifiedSingleStepDesugaredType(); + + for (size_t i = 0; i < init_count; i += 1) { + APValue &elem_ap_val = ap_value->getArrayInitializedElt(i); + AstNode *elem_node = trans_ap_value(c, &elem_ap_val, child_qt, source_loc); + if (elem_node == nullptr) + return nullptr; + init_node->data.container_init_expr.entries.append(elem_node); + } + if (leftover_count == 0) { + return init_node; + } + + APValue &filler_ap_val = ap_value->getArrayFiller(); + AstNode *filler_node = trans_ap_value(c, &filler_ap_val, child_qt, source_loc); + if (filler_node == nullptr) + return nullptr; + + AstNode *filler_arr_1 = trans_create_node(c, NodeTypeContainerInitExpr); + init_node->data.container_init_expr.type = arr_type_node; + init_node->data.container_init_expr.kind = ContainerInitKindArray; + init_node->data.container_init_expr.entries.append(filler_node); + + AstNode *rhs_node; + if (leftover_count == 1) { + rhs_node = filler_arr_1; + } else { + AstNode *amt_node = trans_create_node_unsigned(c, leftover_count); + rhs_node = trans_create_node_bin_op(c, filler_arr_1, BinOpTypeArrayMult, amt_node); + } + + return trans_create_node_bin_op(c, init_node, BinOpTypeArrayCat, rhs_node); + } + case APValue::LValue: { + const APValue::LValueBase lval_base = ap_value->getLValueBase(); + if (const Expr *expr = lval_base.dyn_cast()) { + return trans_expr(c, ResultUsedYes, &c->global_scope->base, expr, TransRValue); + } + //const ValueDecl *value_decl = lval_base.get(); + emit_warning(c, source_loc, "TODO handle initializer LValue ValueDecl"); + return nullptr; + } + case APValue::Float: + emit_warning(c, source_loc, "unsupported initializer value kind: Float"); + return nullptr; + case APValue::ComplexInt: + emit_warning(c, source_loc, "unsupported initializer value kind: ComplexInt"); + return nullptr; + case APValue::ComplexFloat: + emit_warning(c, source_loc, "unsupported initializer value kind: ComplexFloat"); + return nullptr; + case APValue::Vector: + emit_warning(c, source_loc, "unsupported initializer value kind: Vector"); + return nullptr; + case APValue::Struct: + emit_warning(c, source_loc, "unsupported initializer value kind: Struct"); + return nullptr; + case APValue::Union: + emit_warning(c, source_loc, "unsupported initializer value kind: Union"); + return nullptr; + case APValue::MemberPointer: + emit_warning(c, source_loc, "unsupported initializer value kind: MemberPointer"); + return nullptr; + case APValue::AddrLabelDiff: + emit_warning(c, source_loc, "unsupported initializer value kind: AddrLabelDiff"); + return nullptr; + } + zig_unreachable(); +} + static void visit_var_decl(Context *c, const VarDecl *var_decl) { Buf *name = buf_create_from_str(decl_name(var_decl)); @@ -3636,27 +3723,9 @@ static void visit_var_decl(Context *c, const VarDecl *var_decl) { "ignoring variable '%s' - unable to evaluate initializer", buf_ptr(name)); return; } - switch (ap_value->getKind()) { - case APValue::Int: - init_node = trans_create_node_apint(c, ap_value->getInt()); - break; - case APValue::Uninitialized: - init_node = trans_create_node(c, NodeTypeUndefinedLiteral); - break; - case APValue::Float: - case APValue::ComplexInt: - case APValue::ComplexFloat: - case APValue::LValue: - case APValue::Vector: - case APValue::Array: - case APValue::Struct: - case APValue::Union: - case APValue::MemberPointer: - case APValue::AddrLabelDiff: - emit_warning(c, var_decl->getLocation(), - "ignoring variable '%s' - unrecognized initializer value kind", buf_ptr(name)); - return; - } + init_node = trans_ap_value(c, ap_value, qt, var_decl->getLocation()); + if (init_node == nullptr) + return; } else { init_node = trans_create_node(c, NodeTypeUndefinedLiteral); } @@ -4101,6 +4170,7 @@ int parse_h_file(ImportTableEntry *import, ZigList *errors, const ch c->ptr_params.init(8); c->codegen = codegen; c->source_node = source_node; + c->global_scope = trans_scope_root_create(c); ZigList clang_argv = {0}; diff --git a/test/translate_c.zig b/test/translate_c.zig index 8a8d1d334b..90b99b5faf 100644 --- a/test/translate_c.zig +++ b/test/translate_c.zig @@ -1172,4 +1172,10 @@ pub fn addCases(cases: &tests.TranslateCContext) { \\ return u32_0; \\} ); + + cases.add("const ptr initializer", + \\static const char *v0 = "0.0.0"; + , + \\pub var v0: ?&const u8 = c"0.0.0"; + ); }