From c2e66d9bab396a69514ec7c3c41fb0404e542f21 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 7 Jul 2021 20:47:21 -0700 Subject: [PATCH] stage2: basic inferred error set support * Inferred error sets are stored in the return Type of the function, owned by the Module.Fn. So it cleans up that memory in deinit(). * Sema: update the inferred error set in zirRetErrValue - Update relevant code in wrapErrorUnion * C backend: improve some some instructions to take advantage of liveness analysis to avoid being emitted when unused. * C backend: when an error union has a payload type with no runtime bits, emit the error union as the same type as the error set. --- src/Module.zig | 15 ++++++++-- src/Sema.zig | 25 ++++++++++++----- src/codegen/c.zig | 67 ++++++++++++++++++++++++++++++++++++++------- src/type.zig | 7 +++-- test/stage2/cbe.zig | 20 ++++++++++++++ 5 files changed, 113 insertions(+), 21 deletions(-) diff --git a/src/Module.zig b/src/Module.zig index c48440ccc2..8ae184a377 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -777,8 +777,19 @@ pub const Fn = struct { } pub fn deinit(func: *Fn, gpa: *Allocator) void { - _ = func; - _ = gpa; + if (func.getInferredErrorSet()) |map| { + map.deinit(gpa); + } + } + + pub fn getInferredErrorSet(func: *Fn) ?*std.StringHashMapUnmanaged(void) { + const ret_ty = func.owner_decl.ty.fnReturnType(); + if (ret_ty.zigTypeTag() == .ErrorUnion) { + if (ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| { + return &payload.data.map; + } + } + return null; } }; diff --git a/src/Sema.zig b/src/Sema.zig index 86e5f59af6..d7ce9fdf4f 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3139,7 +3139,10 @@ fn funcCommon( } const return_type = if (!inferred_error_set) bare_return_type else blk: { - const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, new_func); + const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{ + .func = new_func, + .map = .{}, + }); break :blk try Type.Tag.error_union.create(sema.arena, .{ .error_set = error_set_ty, .payload = bare_return_type, @@ -5424,12 +5427,8 @@ fn zirRetErrValue( // Add the error tag to the inferred error set of the in-scope function. if (sema.func) |func| { - const fn_ty = func.owner_decl.ty; - const fn_ret_ty = fn_ty.fnReturnType(); - if (fn_ret_ty.zigTypeTag() == .ErrorUnion and - fn_ret_ty.errorUnionSet().tag() == .error_set_inferred) - { - return sema.mod.fail(&block.base, src, "TODO: Sema.zirRetErrValue", .{}); + if (func.getInferredErrorSet()) |map| { + _ = try map.getOrPut(sema.gpa, err_name); } } // Return the error code from the function. @@ -7535,6 +7534,18 @@ fn wrapErrorUnion(sema: *Sema, block: *Scope.Block, dest_type: Type, inst: *Inst ); } }, + .error_set_inferred => { + const expected_name = val.castTag(.@"error").?.data.name; + const map = &err_union.data.error_set.castTag(.error_set_inferred).?.data.map; + if (!map.contains(expected_name)) { + return sema.mod.fail( + &block.base, + inst.src, + "expected type '{}', found type '{}'", + .{ err_union.data.error_set, inst.ty }, + ); + } + }, else => unreachable, } diff --git a/src/codegen/c.zig b/src/codegen/c.zig index db0e910643..3aaf559802 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -360,6 +360,12 @@ pub const DeclGen = struct { const error_type = t.errorUnionSet(); const payload_type = t.errorUnionChild(); const data = val.castTag(.error_union).?.data; + + if (!payload_type.hasCodeGenBits()) { + // We use the error type directly as the type. + return dg.renderValue(writer, error_type, data); + } + try writer.writeByte('('); try dg.renderType(writer, t); try writer.writeAll("){"); @@ -604,6 +610,10 @@ pub const DeclGen = struct { const child_type = t.errorUnionChild(); const err_set_type = t.errorUnionSet(); + if (!child_type.hasCodeGenBits()) { + return dg.renderType(w, err_set_type); + } + var buffer = std.ArrayList(u8).init(dg.typedefs.allocator); defer buffer.deinit(); const bw = buffer.writer(); @@ -613,7 +623,7 @@ pub const DeclGen = struct { try bw.writeAll(" payload; uint16_t error; } "); const name_index = buffer.items.len; if (err_set_type.castTag(.error_set_inferred)) |inf_err_set_payload| { - const func = inf_err_set_payload.data; + const func = inf_err_set_payload.data.func; try bw.print("zig_E_{s};\n", .{func.owner_decl.name}); } else { try bw.print("zig_E_{s}_{s};\n", .{ @@ -895,10 +905,10 @@ pub fn genBody(o: *Object, body: ir.Body) error{ AnalysisFail, OutOfMemory }!voi .ref => try genRef(o, inst.castTag(.ref).?), .struct_field_ptr => try genStructFieldPtr(o, inst.castTag(.struct_field_ptr).?), - .is_err => try genIsErr(o, inst.castTag(.is_err).?, "", "!="), - .is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", "=="), - .is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "[0]", "!="), - .is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "[0]", "=="), + .is_err => try genIsErr(o, inst.castTag(.is_err).?, "", ".", "!="), + .is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", ".", "=="), + .is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "*", "->", "!="), + .is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "*", "->", "=="), .unwrap_errunion_payload => try genUnwrapErrUnionPay(o, inst.castTag(.unwrap_errunion_payload).?), .unwrap_errunion_err => try genUnwrapErrUnionErr(o, inst.castTag(.unwrap_errunion_err).?), @@ -1384,9 +1394,25 @@ fn genStructFieldPtr(o: *Object, inst: *Inst.StructFieldPtr) !CValue { // *(E!T) -> E NOT *E fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue { + if (inst.base.isUnused()) + return CValue.none; + const writer = o.writer(); const operand = try o.resolveInst(inst.operand); + const payload_ty = inst.operand.ty.errorUnionChild(); + if (!payload_ty.hasCodeGenBits()) { + if (inst.operand.ty.zigTypeTag() == .Pointer) { + const local = try o.allocLocal(inst.base.ty, .Const); + try writer.writeAll(" = *"); + try o.writeCValue(writer, operand); + try writer.writeAll(";\n"); + return local; + } else { + return operand; + } + } + const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else "."; const local = try o.allocLocal(inst.base.ty, .Const); @@ -1396,10 +1422,19 @@ fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue { try writer.print("){s}error;\n", .{maybe_deref}); return local; } + fn genUnwrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue { + if (inst.base.isUnused()) + return CValue.none; + const writer = o.writer(); const operand = try o.resolveInst(inst.operand); + const payload_ty = inst.operand.ty.errorUnionChild(); + if (!payload_ty.hasCodeGenBits()) { + return CValue.none; + } + const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else "."; const maybe_addrof = if (inst.base.ty.zigTypeTag() == .Pointer) "&" else ""; @@ -1448,14 +1483,26 @@ fn genWrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue { return local; } -fn genIsErr(o: *Object, inst: *Inst.UnOp, deref_suffix: []const u8, op_str: []const u8) !CValue { +fn genIsErr( + o: *Object, + inst: *Inst.UnOp, + deref_prefix: [*:0]const u8, + deref_suffix: [*:0]const u8, + op_str: [*:0]const u8, +) !CValue { const writer = o.writer(); const operand = try o.resolveInst(inst.operand); - const local = try o.allocLocal(Type.initTag(.bool), .Const); - try writer.writeAll(" = ("); - try o.writeCValue(writer, operand); - try writer.print("){s}.error {s} 0;\n", .{ deref_suffix, op_str }); + const payload_ty = inst.operand.ty.errorUnionChild(); + if (!payload_ty.hasCodeGenBits()) { + try writer.print(" = {s}", .{deref_prefix}); + try o.writeCValue(writer, operand); + try writer.print(" {s} 0;\n", .{op_str}); + } else { + try writer.writeAll(" = "); + try o.writeCValue(writer, operand); + try writer.print("{s}error {s} 0;\n", .{ deref_suffix, op_str }); + } return local; } diff --git a/src/type.zig b/src/type.zig index f9385e90bc..e8f0998332 100644 --- a/src/type.zig +++ b/src/type.zig @@ -1041,7 +1041,7 @@ pub const Type = extern union { return writer.writeAll(std.mem.spanZ(error_set.owner_decl.name)); }, .error_set_inferred => { - const func = ty.castTag(.error_set_inferred).?.data; + const func = ty.castTag(.error_set_inferred).?.data.func; return writer.print("(inferred error set of {s})", .{func.owner_decl.name}); }, .error_set_single => { @@ -3154,7 +3154,10 @@ pub const Type = extern union { pub const base_tag = Tag.error_set_inferred; base: Payload = Payload{ .tag = base_tag }, - data: *Module.Fn, + data: struct { + func: *Module.Fn, + map: std.StringHashMapUnmanaged(void), + }, }; pub const Pointer = struct { diff --git a/test/stage2/cbe.zig b/test/stage2/cbe.zig index a064995c13..cbe24d3ec3 100644 --- a/test/stage2/cbe.zig +++ b/test/stage2/cbe.zig @@ -804,6 +804,26 @@ pub fn addCases(ctx: *TestContext) !void { }); } + { + var case = ctx.exeFromCompiledC("inferred error sets", .{}); + + case.addCompareOutput( + \\pub export fn main() c_int { + \\ if (foo()) |_| { + \\ @panic("test fail"); + \\ } else |err| { + \\ if (err != error.ItBroke) { + \\ @panic("test fail"); + \\ } + \\ } + \\ return 0; + \\} + \\fn foo() !void { + \\ return error.ItBroke; + \\} + , ""); + } + ctx.h("simple header", linux_x64, \\export fn start() void{} ,