diff --git a/src/all_types.hpp b/src/all_types.hpp index c9f905118f..f13a5577d2 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -398,6 +398,7 @@ struct LazyValueErrUnionType { IrAnalyze *ira; IrInstruction *err_set_type; IrInstruction *payload_type; + Buf *type_name; }; struct ConstExprValue { @@ -2407,6 +2408,7 @@ enum IrInstructionId { IrInstructionIdPhi, IrInstructionIdUnOp, IrInstructionIdBinOp, + IrInstructionIdMergeErrSets, IrInstructionIdLoadPtr, IrInstructionIdLoadPtrGen, IrInstructionIdStorePtr, @@ -2713,7 +2715,6 @@ enum IrBinOp { IrBinOpRemMod, IrBinOpArrayCat, IrBinOpArrayMult, - IrBinOpMergeErrorSets, }; struct IrInstructionBinOp { @@ -2725,6 +2726,14 @@ struct IrInstructionBinOp { bool safety_check_on; }; +struct IrInstructionMergeErrSets { + IrInstruction base; + + IrInstruction *op1; + IrInstruction *op2; + Buf *type_name; +}; + struct IrInstructionLoadPtr { IrInstruction base; @@ -3633,6 +3642,7 @@ struct IrInstructionErrorUnion { IrInstruction *err_set; IrInstruction *payload; + Buf *type_name; }; struct IrInstructionAtomicRmw { diff --git a/src/codegen.cpp b/src/codegen.cpp index 3dbd0b8538..c7acdc992a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2776,7 +2776,6 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable, case IrBinOpArrayCat: case IrBinOpArrayMult: case IrBinOpRemUnspecified: - case IrBinOpMergeErrorSets: zig_unreachable(); case IrBinOpBoolOr: return LLVMBuildOr(g->builder, op1_value, op2_value, ""); @@ -6040,6 +6039,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, case IrInstructionIdAllocaGen: case IrInstructionIdAwaitSrc: case IrInstructionIdSplatSrc: + case IrInstructionIdMergeErrSets: zig_unreachable(); case IrInstructionIdDeclVarGen: diff --git a/src/ir.cpp b/src/ir.cpp index f29870e039..ad5ffcd951 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -198,6 +198,8 @@ static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNo IrInstruction *union_type, IrInstruction *field_name, AstNode *expr_node, LVal lval, ResultLoc *parent_result_loc); static void ir_reset_result(ResultLoc *result_loc); +static Buf *get_anon_type_name(CodeGen *codegen, IrExecutable *exec, const char *kind_name, + Scope *scope, AstNode *source_node, Buf *out_bare_name); static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *const_val) { assert(get_src_ptr_type(const_val->type) != nullptr); @@ -469,6 +471,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionBinOp *) { return IrInstructionIdBinOp; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionMergeErrSets *) { + return IrInstructionIdMergeErrSets; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionExport *) { return IrInstructionIdExport; } @@ -1290,6 +1296,20 @@ static IrInstruction *ir_build_bin_op(IrBuilder *irb, Scope *scope, AstNode *sou return &bin_op_instruction->base; } +static IrInstruction *ir_build_merge_err_sets(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *op1, IrInstruction *op2, Buf *type_name) +{ + IrInstructionMergeErrSets *merge_err_sets_instruction = ir_build_instruction(irb, scope, source_node); + merge_err_sets_instruction->op1 = op1; + merge_err_sets_instruction->op2 = op2; + merge_err_sets_instruction->type_name = type_name; + + ir_ref_instruction(op1, irb->current_basic_block); + ir_ref_instruction(op2, irb->current_basic_block); + + return &merge_err_sets_instruction->base; +} + static IrInstruction *ir_build_var_ptr_x(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigVar *var, ScopeFnDef *crossed_fndef_scope) { @@ -3894,6 +3914,20 @@ static IrInstruction *ir_gen_bin_op_id(IrBuilder *irb, Scope *scope, AstNode *no return ir_build_bin_op(irb, scope, node, op_id, op1, op2, true); } +static IrInstruction *ir_gen_merge_err_sets(IrBuilder *irb, Scope *scope, AstNode *node) { + IrInstruction *op1 = ir_gen_node(irb, node->data.bin_op_expr.op1, scope); + IrInstruction *op2 = ir_gen_node(irb, node->data.bin_op_expr.op2, scope); + + if (op1 == irb->codegen->invalid_instruction || op2 == irb->codegen->invalid_instruction) + return irb->codegen->invalid_instruction; + + // TODO only pass type_name when the || operator is the top level AST node in the var decl expr + Buf bare_name = BUF_INIT; + Buf *type_name = get_anon_type_name(irb->codegen, irb->exec, "error", scope, node, &bare_name); + + return ir_build_merge_err_sets(irb, scope, node, op1, op2, type_name); +} + static IrInstruction *ir_gen_assign(IrBuilder *irb, Scope *scope, AstNode *node) { IrInstruction *lvalue = ir_gen_node_extra(irb, node->data.bin_op_expr.op1, scope, LValPtr, nullptr); if (lvalue == irb->codegen->invalid_instruction) @@ -3913,6 +3947,19 @@ static IrInstruction *ir_gen_assign(IrBuilder *irb, Scope *scope, AstNode *node) return ir_build_const_void(irb, scope, node); } +static IrInstruction *ir_gen_assign_merge_err_sets(IrBuilder *irb, Scope *scope, AstNode *node) { + IrInstruction *lvalue = ir_gen_node_extra(irb, node->data.bin_op_expr.op1, scope, LValPtr, nullptr); + if (lvalue == irb->codegen->invalid_instruction) + return lvalue; + IrInstruction *op1 = ir_build_load_ptr(irb, scope, node->data.bin_op_expr.op1, lvalue); + IrInstruction *op2 = ir_gen_node(irb, node->data.bin_op_expr.op2, scope); + if (op2 == irb->codegen->invalid_instruction) + return op2; + IrInstruction *result = ir_build_merge_err_sets(irb, scope, node, op1, op2, nullptr); + ir_build_store_ptr(irb, scope, node, lvalue, result); + return ir_build_const_void(irb, scope, node); +} + static IrInstruction *ir_gen_assign_op(IrBuilder *irb, Scope *scope, AstNode *node, IrBinOp op_id) { IrInstruction *lvalue = ir_gen_node_extra(irb, node->data.bin_op_expr.op1, scope, LValPtr, nullptr); if (lvalue == irb->codegen->invalid_instruction) @@ -4153,7 +4200,7 @@ static IrInstruction *ir_gen_bin_op(IrBuilder *irb, Scope *scope, AstNode *node, case BinOpTypeAssignBitOr: return ir_lval_wrap(irb, scope, ir_gen_assign_op(irb, scope, node, IrBinOpBinOr), lval, result_loc); case BinOpTypeAssignMergeErrorSets: - return ir_lval_wrap(irb, scope, ir_gen_assign_op(irb, scope, node, IrBinOpMergeErrorSets), lval, result_loc); + return ir_lval_wrap(irb, scope, ir_gen_assign_merge_err_sets(irb, scope, node), lval, result_loc); case BinOpTypeBoolOr: return ir_lval_wrap(irb, scope, ir_gen_bool_or(irb, scope, node), lval, result_loc); case BinOpTypeBoolAnd: @@ -4201,7 +4248,7 @@ static IrInstruction *ir_gen_bin_op(IrBuilder *irb, Scope *scope, AstNode *node, case BinOpTypeArrayMult: return ir_lval_wrap(irb, scope, ir_gen_bin_op_id(irb, scope, node, IrBinOpArrayMult), lval, result_loc); case BinOpTypeMergeErrorSets: - return ir_lval_wrap(irb, scope, ir_gen_bin_op_id(irb, scope, node, IrBinOpMergeErrorSets), lval, result_loc); + return ir_lval_wrap(irb, scope, ir_gen_merge_err_sets(irb, scope, node), lval, result_loc); case BinOpTypeUnwrapOptional: return ir_gen_orelse(irb, scope, node, lval, result_loc); case BinOpTypeErrorUnion: @@ -7859,7 +7906,9 @@ static IrInstruction *ir_gen_container_decl(IrBuilder *irb, Scope *parent_scope, } // errors should be populated with set1's values -static ZigType *get_error_set_union(CodeGen *g, ErrorTableEntry **errors, ZigType *set1, ZigType *set2) { +static ZigType *get_error_set_union(CodeGen *g, ErrorTableEntry **errors, ZigType *set1, ZigType *set2, + Buf *type_name) +{ assert(set1->id == ZigTypeIdErrorSet); assert(set2->id == ZigTypeIdErrorSet); @@ -7867,8 +7916,12 @@ static ZigType *get_error_set_union(CodeGen *g, ErrorTableEntry **errors, ZigTyp err_set_type->size_in_bits = g->builtin_types.entry_global_error_set->size_in_bits; err_set_type->abi_align = g->builtin_types.entry_global_error_set->abi_align; err_set_type->abi_size = g->builtin_types.entry_global_error_set->abi_size; - buf_resize(&err_set_type->name, 0); - buf_appendf(&err_set_type->name, "error{"); + if (type_name == nullptr) { + buf_resize(&err_set_type->name, 0); + buf_appendf(&err_set_type->name, "error{"); + } else { + buf_init_from_buf(&err_set_type->name, type_name); + } for (uint32_t i = 0, count = set1->data.error_set.err_count; i < count; i += 1) { assert(errors[set1->data.error_set.errors[i]->value] == set1->data.error_set.errors[i]); @@ -7885,21 +7938,27 @@ static ZigType *get_error_set_union(CodeGen *g, ErrorTableEntry **errors, ZigTyp err_set_type->data.error_set.err_count = count; err_set_type->data.error_set.errors = allocate(count); + bool need_comma = false; for (uint32_t i = 0; i < set1->data.error_set.err_count; i += 1) { ErrorTableEntry *error_entry = set1->data.error_set.errors[i]; - buf_appendf(&err_set_type->name, "%s,", buf_ptr(&error_entry->name)); + if (type_name == nullptr) { + const char *comma = need_comma ? "," : ""; + need_comma = true; + buf_appendf(&err_set_type->name, "%s%s", comma, buf_ptr(&error_entry->name)); + } err_set_type->data.error_set.errors[i] = error_entry; } uint32_t index = set1->data.error_set.err_count; - bool need_comma = false; for (uint32_t i = 0; i < set2->data.error_set.err_count; i += 1) { ErrorTableEntry *error_entry = set2->data.error_set.errors[i]; if (errors[error_entry->value] == nullptr) { errors[error_entry->value] = error_entry; - const char *comma = need_comma ? "," : ""; - need_comma = true; - buf_appendf(&err_set_type->name, "%s%s", comma, buf_ptr(&error_entry->name)); + if (type_name == nullptr) { + const char *comma = need_comma ? "," : ""; + need_comma = true; + buf_appendf(&err_set_type->name, "%s%s", comma, buf_ptr(&error_entry->name)); + } err_set_type->data.error_set.errors[index] = error_entry; index += 1; } @@ -7907,7 +7966,9 @@ static ZigType *get_error_set_union(CodeGen *g, ErrorTableEntry **errors, ZigTyp assert(index == count); assert(count != 0); - buf_appendf(&err_set_type->name, "}"); + if (type_name == nullptr) { + buf_appendf(&err_set_type->name, "}"); + } return err_set_type; @@ -9967,7 +10028,7 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT } // neither of them are supersets. so we invent a new error set type that is a union of both of them - err_set_type = get_error_set_union(ira->codegen, errors, cur_type, err_set_type); + err_set_type = get_error_set_union(ira->codegen, errors, cur_type, err_set_type, nullptr); assert(errors != nullptr); continue; } else if (cur_type->id == ZigTypeIdErrorUnion) { @@ -10018,7 +10079,7 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT } // not a subset. invent new error set type, union of both of them - err_set_type = get_error_set_union(ira->codegen, errors, cur_err_set_type, err_set_type); + err_set_type = get_error_set_union(ira->codegen, errors, cur_err_set_type, err_set_type, nullptr); prev_inst = cur_inst; assert(errors != nullptr); continue; @@ -10074,7 +10135,7 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT continue; } // not a subset. invent new error set type, union of both of them - err_set_type = get_error_set_union(ira->codegen, errors, err_set_type, cur_type); + err_set_type = get_error_set_union(ira->codegen, errors, err_set_type, cur_type, nullptr); assert(errors != nullptr); continue; } @@ -10160,7 +10221,7 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT continue; } - err_set_type = get_error_set_union(ira->codegen, errors, cur_err_set_type, prev_err_set_type); + err_set_type = get_error_set_union(ira->codegen, errors, cur_err_set_type, prev_err_set_type, nullptr); continue; } } @@ -10286,7 +10347,7 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT update_errors_helper(ira->codegen, &errors, &errors_count); - err_set_type = get_error_set_union(ira->codegen, errors, err_set_type, cur_err_set_type); + err_set_type = get_error_set_union(ira->codegen, errors, err_set_type, cur_err_set_type, nullptr); } prev_inst = cur_inst; continue; @@ -13795,7 +13856,6 @@ static ErrorMsg *ir_eval_math_op_scalar(IrAnalyze *ira, IrInstruction *source_in case IrBinOpArrayCat: case IrBinOpArrayMult: case IrBinOpRemUnspecified: - case IrBinOpMergeErrorSets: zig_unreachable(); case IrBinOpBinOr: assert(is_int); @@ -14102,7 +14162,6 @@ static bool ok_float_op(IrBinOp op) { case IrBinOpRemUnspecified: case IrBinOpArrayCat: case IrBinOpArrayMult: - case IrBinOpMergeErrorSets: return false; } zig_unreachable(); @@ -14603,7 +14662,9 @@ static IrInstruction *ir_analyze_array_mult(IrAnalyze *ira, IrInstructionBinOp * return result; } -static IrInstruction *ir_analyze_merge_error_sets(IrAnalyze *ira, IrInstructionBinOp *instruction) { +static IrInstruction *ir_analyze_instruction_merge_err_sets(IrAnalyze *ira, + IrInstructionMergeErrSets *instruction) +{ ZigType *op1_type = ir_resolve_error_set_type(ira, &instruction->base, instruction->op1->child); if (type_is_invalid(op1_type)) return ira->codegen->invalid_instruction; @@ -14632,12 +14693,13 @@ static IrInstruction *ir_analyze_merge_error_sets(IrAnalyze *ira, IrInstructionB assert(errors[error_entry->value] == nullptr); errors[error_entry->value] = error_entry; } - ZigType *result_type = get_error_set_union(ira->codegen, errors, op1_type, op2_type); + ZigType *result_type = get_error_set_union(ira->codegen, errors, op1_type, op2_type, instruction->type_name); free(errors); return ir_const_type(ira, &instruction->base, result_type); } + static IrInstruction *ir_analyze_instruction_bin_op(IrAnalyze *ira, IrInstructionBinOp *bin_op_instruction) { IrBinOp op_id = bin_op_instruction->op_id; switch (op_id) { @@ -14679,8 +14741,6 @@ static IrInstruction *ir_analyze_instruction_bin_op(IrAnalyze *ira, IrInstructio return ir_analyze_array_cat(ira, bin_op_instruction); case IrBinOpArrayMult: return ir_analyze_array_mult(ira, bin_op_instruction); - case IrBinOpMergeErrorSets: - return ir_analyze_merge_error_sets(ira, bin_op_instruction); } zig_unreachable(); } @@ -25945,6 +26005,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_un_op(ira, (IrInstructionUnOp *)instruction); case IrInstructionIdBinOp: return ir_analyze_instruction_bin_op(ira, (IrInstructionBinOp *)instruction); + case IrInstructionIdMergeErrSets: + return ir_analyze_instruction_merge_err_sets(ira, (IrInstructionMergeErrSets *)instruction); case IrInstructionIdDeclVarSrc: return ir_analyze_instruction_decl_var(ira, (IrInstructionDeclVarSrc *)instruction); case IrInstructionIdLoadPtr: @@ -26370,6 +26432,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdPhi: case IrInstructionIdUnOp: case IrInstructionIdBinOp: + case IrInstructionIdMergeErrSets: case IrInstructionIdLoadPtr: case IrInstructionIdConst: case IrInstructionIdCast: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index aae65d50a9..ecd8248d69 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -70,6 +70,8 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) { return "UnOp"; case IrInstructionIdBinOp: return "BinOp"; + case IrInstructionIdMergeErrSets: + return "MergeErrSets"; case IrInstructionIdLoadPtr: return "LoadPtr"; case IrInstructionIdLoadPtrGen: @@ -497,8 +499,6 @@ static const char *ir_bin_op_id_str(IrBinOp op_id) { return "++"; case IrBinOpArrayMult: return "**"; - case IrBinOpMergeErrorSets: - return "||"; } zig_unreachable(); } @@ -535,6 +535,15 @@ static void ir_print_bin_op(IrPrint *irp, IrInstructionBinOp *bin_op_instruction } } +static void ir_print_merge_err_sets(IrPrint *irp, IrInstructionMergeErrSets *instruction) { + ir_print_other_instruction(irp, instruction->op1); + fprintf(irp->f, " || "); + ir_print_other_instruction(irp, instruction->op2); + if (instruction->type_name != nullptr) { + fprintf(irp->f, " // name=%s", buf_ptr(instruction->type_name)); + } +} + static void ir_print_decl_var_src(IrPrint *irp, IrInstructionDeclVarSrc *decl_var_instruction) { const char *var_or_const = decl_var_instruction->var->gen_is_const ? "const" : "var"; const char *name = decl_var_instruction->var->name; @@ -1974,6 +1983,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool case IrInstructionIdBinOp: ir_print_bin_op(irp, (IrInstructionBinOp *)instruction); break; + case IrInstructionIdMergeErrSets: + ir_print_merge_err_sets(irp, (IrInstructionMergeErrSets *)instruction); + break; case IrInstructionIdDeclVarSrc: ir_print_decl_var_src(irp, (IrInstructionDeclVarSrc *)instruction); break;