From 6534f2ef4f8161f4121326f19bc3cf89324f62c5 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 17 Oct 2021 14:50:47 -0700 Subject: [PATCH] stage2: implement error wrapping * Sema: fix returned operands not coercing to the function return type in some cases. - When returning an error or an error union from a function with an inferred error set, it will now populate the inferred error set. - Implement error set coercion for the common case of inferred error set to inferred error set, without forcing a full resolution. * LLVM backend: update instruction lowering that handles error unions to respect `isByRef`. - Also implement `wrap_err_union_err`. --- src/Module.zig | 4 +++ src/Sema.zig | 77 ++++++++++++++++++++++++++--------------- src/codegen/llvm.zig | 48 +++++++++++++++++-------- src/type.zig | 44 +++++++++++++++++++++-- test/behavior/error.zig | 18 ++++++++++ 5 files changed, 146 insertions(+), 45 deletions(-) diff --git a/src/Module.zig b/src/Module.zig index a42ec3c2e1..f52e1c8ef7 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -782,6 +782,10 @@ pub const ErrorSet = struct { /// The length is given by `names_len`. names_ptr: [*]const []const u8, + pub fn names(self: ErrorSet) []const []const u8 { + return self.names_ptr[0..self.names_len]; + } + pub fn srcLoc(self: ErrorSet) SrcLoc { return .{ .file_scope = self.owner_decl.getFileScope(), diff --git a/src/Sema.zig b/src/Sema.zig index 229ae054b2..cfc541453a 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -4845,6 +4845,8 @@ fn funcCommon( const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{ .func = new_func, .map = .{}, + .functions = .{}, + .is_anyerror = false, }); break :blk try Type.Tag.error_union.create(sema.arena, .{ .error_set = error_set_ty, @@ -8466,19 +8468,13 @@ fn zirRetErrValue( const err_name = inst_data.get(sema.code); const src = inst_data.src(); - // Add the error tag to the inferred error set of the in-scope function. - if (sema.fn_ret_ty.zigTypeTag() == .ErrorUnion) { - if (sema.fn_ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| { - _ = try payload.data.map.getOrPut(sema.gpa, err_name); - } - } // Return the error code from the function. const kv = try sema.mod.getErrorValue(err_name); const result_inst = try sema.addConstant( try Type.Tag.error_set_single.create(sema.arena, kv.key), try Value.Tag.@"error".create(sema.arena, .{ .name = kv.key }), ); - return sema.analyzeRet(block, result_inst, src, true); + return sema.analyzeRet(block, result_inst, src); } fn zirRetCoerce( @@ -8493,7 +8489,7 @@ fn zirRetCoerce( const operand = sema.resolveInst(inst_data.operand); const src = inst_data.src(); - return sema.analyzeRet(block, operand, src, true); + return sema.analyzeRet(block, operand, src); } fn zirRetNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index { @@ -8504,11 +8500,7 @@ fn zirRetNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir const operand = sema.resolveInst(inst_data.operand); const src = inst_data.src(); - // TODO: we pass false here for the `need_coercion` boolean, but I'm pretty sure we need - // to remove this parameter entirely. Observe the problem by looking at the incorrect compile - // error that occurs when a behavior test case being executed at comptime fails, e.g. - // `test { comptime foo(); } fn foo() { try expect(false); }` - return sema.analyzeRet(block, operand, src, false); + return sema.analyzeRet(block, operand, src); } fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index { @@ -8521,7 +8513,7 @@ fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir if (block.is_comptime or block.inlining != null) { const operand = try sema.analyzeLoad(block, src, ret_ptr, src); - return sema.analyzeRet(block, operand, src, false); + return sema.analyzeRet(block, operand, src); } try sema.requireRuntimeBlock(block, src); _ = try block.addUnOp(.ret_load, ret_ptr); @@ -8533,12 +8525,25 @@ fn analyzeRet( block: *Block, uncasted_operand: Air.Inst.Ref, src: LazySrcLoc, - need_coercion: bool, ) CompileError!Zir.Inst.Index { - const operand = if (!need_coercion) - uncasted_operand - else - try sema.coerce(block, sema.fn_ret_ty, uncasted_operand, src); + // Special case for returning an error to an inferred error set; we need to + // add the error tag to the inferred error set of the in-scope function, so + // that the coercion below works correctly. + if (sema.fn_ret_ty.zigTypeTag() == .ErrorUnion) { + if (sema.fn_ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| { + const op_ty = sema.typeOf(uncasted_operand); + switch (op_ty.zigTypeTag()) { + .ErrorSet => { + try payload.data.addErrorSet(sema.gpa, op_ty); + }, + .ErrorUnion => { + try payload.data.addErrorSet(sema.gpa, op_ty.errorUnionSet()); + }, + else => {}, + } + } + } + const operand = try sema.coerce(block, sema.fn_ret_ty, uncasted_operand, src); if (block.inlining) |inlining| { if (block.is_comptime) { @@ -11605,14 +11610,30 @@ fn coerce( // T to E!T or E to E!T return sema.wrapErrorUnion(block, dest_ty, inst, inst_src); }, - .ErrorSet => { - // Coercion to `anyerror`. - // TODO If the dest type tag is not `anyerror` it still could - // resolve to anyerror. `dest_ty` needs to have inferred error set resolution - // happen before this check. - if (dest_ty.tag() == .anyerror and inst_ty.zigTypeTag() == .ErrorSet) { - return sema.coerceErrSetToAnyError(block, inst, inst_src); - } + .ErrorSet => switch (inst_ty.zigTypeTag()) { + .ErrorSet => { + // Coercion to `anyerror`. Note that this check can return false positives + // in case the error sets did not get resolved. + if (dest_ty.isAnyError()) { + return sema.coerceCompatibleErrorSets(block, inst, inst_src); + } + // If both are inferred error sets of functions, and + // the dest includes the source function, the coercion is OK. + // This check is important because it works without forcing a full resolution + // of inferred error sets. + if (inst_ty.castTag(.error_set_inferred)) |src_payload| { + if (dest_ty.castTag(.error_set_inferred)) |dst_payload| { + const src_func = src_payload.data.func; + const dst_func = dst_payload.data.func; + + if (src_func == dst_func or dst_payload.data.functions.contains(src_func)) { + return sema.coerceCompatibleErrorSets(block, inst, inst_src); + } + } + } + // TODO full error set resolution and compare sets by names. + }, + else => {}, }, .Union => switch (inst_ty.zigTypeTag()) { .Enum, .EnumLiteral => return sema.coerceEnumToUnion(block, dest_ty, dest_ty_src, inst, inst_src), @@ -12245,7 +12266,7 @@ fn coerceVectorToArray( return block.addTyOp(.bitcast, array_ty, vector); } -fn coerceErrSetToAnyError( +fn coerceCompatibleErrorSets( sema: *Sema, block: *Block, err_set: Air.Inst.Ref, diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 4349fce4d9..026046b29f 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -2301,8 +2301,7 @@ pub const FuncGen = struct { op: llvm.IntPredicate, operand_is_ptr: bool, ) !?*const llvm.Value { - if (self.liveness.isUnused(inst)) - return null; + if (self.liveness.isUnused(inst)) return null; const un_op = self.air.instructions.items(.data)[inst].un_op; const operand = try self.resolveInst(un_op); @@ -2363,22 +2362,16 @@ pub const FuncGen = struct { inst: Air.Inst.Index, operand_is_ptr: bool, ) !?*const llvm.Value { - if (self.liveness.isUnused(inst)) - return null; + if (self.liveness.isUnused(inst)) return null; const ty_op = self.air.instructions.items(.data)[inst].ty_op; const operand = try self.resolveInst(ty_op.operand); const err_union_ty = self.air.typeOf(ty_op.operand); const payload_ty = err_union_ty.errorUnionPayload(); - - if (!payload_ty.hasCodeGenBits()) { - return null; - } - - if (operand_is_ptr) { + if (!payload_ty.hasCodeGenBits()) return null; + if (operand_is_ptr or isByRef(payload_ty)) { return self.builder.buildStructGEP(operand, 1, ""); } - return self.builder.buildExtractValue(operand, 1, ""); } @@ -2400,7 +2393,7 @@ pub const FuncGen = struct { return self.builder.buildLoad(operand, ""); } - if (operand_is_ptr) { + if (operand_is_ptr or isByRef(payload_ty)) { const err_field_ptr = self.builder.buildStructGEP(operand, 0, ""); return self.builder.buildLoad(err_field_ptr, ""); } @@ -2469,10 +2462,35 @@ pub const FuncGen = struct { } fn airWrapErrUnionErr(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { - if (self.liveness.isUnused(inst)) - return null; + if (self.liveness.isUnused(inst)) return null; - return self.todo("implement llvm codegen for 'airWrapErrUnionErr'", .{}); + const ty_op = self.air.instructions.items(.data)[inst].ty_op; + const err_un_ty = self.air.typeOfIndex(inst); + const payload_ty = err_un_ty.errorUnionPayload(); + const operand = try self.resolveInst(ty_op.operand); + if (!payload_ty.hasCodeGenBits()) { + return operand; + } + const err_un_llvm_ty = try self.dg.llvmType(err_un_ty); + if (isByRef(err_un_ty)) { + const result_ptr = self.buildAlloca(err_un_llvm_ty); + const err_ptr = self.builder.buildStructGEP(result_ptr, 0, ""); + _ = self.builder.buildStore(operand, err_ptr); + const payload_ptr = self.builder.buildStructGEP(result_ptr, 1, ""); + var ptr_ty_payload: Type.Payload.ElemType = .{ + .base = .{ .tag = .single_mut_pointer }, + .data = payload_ty, + }; + const payload_ptr_ty = Type.initPayload(&ptr_ty_payload.base); + // TODO store undef to payload_ptr + _ = payload_ptr; + _ = payload_ptr_ty; + return result_ptr; + } + + const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), operand, 0, ""); + // TODO set payload bytes to undef + return partial; } fn airMin(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { diff --git a/src/type.zig b/src/type.zig index be3fc1df67..31f6d57ad4 100644 --- a/src/type.zig +++ b/src/type.zig @@ -2619,6 +2619,17 @@ pub const Type = extern union { }; } + /// Returns true if it is an error set that includes anyerror, false otherwise. + /// Note that the result may be a false negative if the type did not get error set + /// resolution prior to this call. + pub fn isAnyError(ty: Type) bool { + return switch (ty.tag()) { + .anyerror => true, + .error_set_inferred => ty.castTag(.error_set_inferred).?.data.is_anyerror, + else => false, + }; + } + /// Asserts the type is an array or vector. pub fn arrayLen(ty: Type) u64 { return switch (ty.tag()) { @@ -3871,10 +3882,39 @@ pub const Type = extern union { pub const base_tag = Tag.error_set_inferred; base: Payload = Payload{ .tag = base_tag }, - data: struct { + data: Data, + + pub const Data = struct { func: *Module.Fn, + /// Direct additions to the inferred error set via `return error.Foo;`. map: std.StringHashMapUnmanaged(void), - }, + /// Other functions with inferred error sets which this error set includes. + functions: std.AutoHashMapUnmanaged(*Module.Fn, void), + is_anyerror: bool, + + pub fn addErrorSet(self: *Data, gpa: *Allocator, err_set_ty: Type) !void { + switch (err_set_ty.tag()) { + .error_set => { + const names = err_set_ty.castTag(.error_set).?.data.names(); + for (names) |name| { + try self.map.put(gpa, name, {}); + } + }, + .error_set_single => { + const name = err_set_ty.castTag(.error_set_single).?.data; + try self.map.put(gpa, name, {}); + }, + .error_set_inferred => { + const func = err_set_ty.castTag(.error_set_inferred).?.data.func; + try self.functions.put(gpa, func, {}); + }, + .anyerror => { + self.is_anyerror = true; + }, + else => unreachable, + } + } + }; }; pub const Pointer = struct { diff --git a/test/behavior/error.zig b/test/behavior/error.zig index fe647ee7c5..6edb973e36 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -31,3 +31,21 @@ test "empty error union" { const x = error{} || error{}; _ = x; } + +pub fn foo() anyerror!i32 { + const x = try bar(); + return x + 1; +} + +pub fn bar() anyerror!i32 { + return 13; +} + +pub fn baz() anyerror!i32 { + const y = foo() catch 1234; + return y + 1; +} + +test "error wrapping" { + try expect((baz() catch unreachable) == 15); +}