stage2: basic inferred error set support

* Inferred error sets are stored in the return Type of the function,
   owned by the Module.Fn. So it cleans up that memory in deinit().
 * Sema: update the inferred error set in zirRetErrValue
   - Update relevant code in wrapErrorUnion
 * C backend: improve some some instructions to take advantage of
   liveness analysis to avoid being emitted when unused.
 * C backend: when an error union has a payload type with no runtime
   bits, emit the error union as the same type as the error set.
This commit is contained in:
Andrew Kelley 2021-07-07 20:47:21 -07:00
parent 5c8bd443d9
commit c2e66d9bab
5 changed files with 113 additions and 21 deletions

View File

@ -777,8 +777,19 @@ pub const Fn = struct {
}
pub fn deinit(func: *Fn, gpa: *Allocator) void {
_ = func;
_ = gpa;
if (func.getInferredErrorSet()) |map| {
map.deinit(gpa);
}
}
pub fn getInferredErrorSet(func: *Fn) ?*std.StringHashMapUnmanaged(void) {
const ret_ty = func.owner_decl.ty.fnReturnType();
if (ret_ty.zigTypeTag() == .ErrorUnion) {
if (ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
return &payload.data.map;
}
}
return null;
}
};

View File

@ -3139,7 +3139,10 @@ fn funcCommon(
}
const return_type = if (!inferred_error_set) bare_return_type else blk: {
const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, new_func);
const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{
.func = new_func,
.map = .{},
});
break :blk try Type.Tag.error_union.create(sema.arena, .{
.error_set = error_set_ty,
.payload = bare_return_type,
@ -5424,12 +5427,8 @@ fn zirRetErrValue(
// Add the error tag to the inferred error set of the in-scope function.
if (sema.func) |func| {
const fn_ty = func.owner_decl.ty;
const fn_ret_ty = fn_ty.fnReturnType();
if (fn_ret_ty.zigTypeTag() == .ErrorUnion and
fn_ret_ty.errorUnionSet().tag() == .error_set_inferred)
{
return sema.mod.fail(&block.base, src, "TODO: Sema.zirRetErrValue", .{});
if (func.getInferredErrorSet()) |map| {
_ = try map.getOrPut(sema.gpa, err_name);
}
}
// Return the error code from the function.
@ -7535,6 +7534,18 @@ fn wrapErrorUnion(sema: *Sema, block: *Scope.Block, dest_type: Type, inst: *Inst
);
}
},
.error_set_inferred => {
const expected_name = val.castTag(.@"error").?.data.name;
const map = &err_union.data.error_set.castTag(.error_set_inferred).?.data.map;
if (!map.contains(expected_name)) {
return sema.mod.fail(
&block.base,
inst.src,
"expected type '{}', found type '{}'",
.{ err_union.data.error_set, inst.ty },
);
}
},
else => unreachable,
}

View File

@ -360,6 +360,12 @@ pub const DeclGen = struct {
const error_type = t.errorUnionSet();
const payload_type = t.errorUnionChild();
const data = val.castTag(.error_union).?.data;
if (!payload_type.hasCodeGenBits()) {
// We use the error type directly as the type.
return dg.renderValue(writer, error_type, data);
}
try writer.writeByte('(');
try dg.renderType(writer, t);
try writer.writeAll("){");
@ -604,6 +610,10 @@ pub const DeclGen = struct {
const child_type = t.errorUnionChild();
const err_set_type = t.errorUnionSet();
if (!child_type.hasCodeGenBits()) {
return dg.renderType(w, err_set_type);
}
var buffer = std.ArrayList(u8).init(dg.typedefs.allocator);
defer buffer.deinit();
const bw = buffer.writer();
@ -613,7 +623,7 @@ pub const DeclGen = struct {
try bw.writeAll(" payload; uint16_t error; } ");
const name_index = buffer.items.len;
if (err_set_type.castTag(.error_set_inferred)) |inf_err_set_payload| {
const func = inf_err_set_payload.data;
const func = inf_err_set_payload.data.func;
try bw.print("zig_E_{s};\n", .{func.owner_decl.name});
} else {
try bw.print("zig_E_{s}_{s};\n", .{
@ -895,10 +905,10 @@ pub fn genBody(o: *Object, body: ir.Body) error{ AnalysisFail, OutOfMemory }!voi
.ref => try genRef(o, inst.castTag(.ref).?),
.struct_field_ptr => try genStructFieldPtr(o, inst.castTag(.struct_field_ptr).?),
.is_err => try genIsErr(o, inst.castTag(.is_err).?, "", "!="),
.is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", "=="),
.is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "[0]", "!="),
.is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "[0]", "=="),
.is_err => try genIsErr(o, inst.castTag(.is_err).?, "", ".", "!="),
.is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", ".", "=="),
.is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "*", "->", "!="),
.is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "*", "->", "=="),
.unwrap_errunion_payload => try genUnwrapErrUnionPay(o, inst.castTag(.unwrap_errunion_payload).?),
.unwrap_errunion_err => try genUnwrapErrUnionErr(o, inst.castTag(.unwrap_errunion_err).?),
@ -1384,9 +1394,25 @@ fn genStructFieldPtr(o: *Object, inst: *Inst.StructFieldPtr) !CValue {
// *(E!T) -> E NOT *E
fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue {
if (inst.base.isUnused())
return CValue.none;
const writer = o.writer();
const operand = try o.resolveInst(inst.operand);
const payload_ty = inst.operand.ty.errorUnionChild();
if (!payload_ty.hasCodeGenBits()) {
if (inst.operand.ty.zigTypeTag() == .Pointer) {
const local = try o.allocLocal(inst.base.ty, .Const);
try writer.writeAll(" = *");
try o.writeCValue(writer, operand);
try writer.writeAll(";\n");
return local;
} else {
return operand;
}
}
const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else ".";
const local = try o.allocLocal(inst.base.ty, .Const);
@ -1396,10 +1422,19 @@ fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue {
try writer.print("){s}error;\n", .{maybe_deref});
return local;
}
fn genUnwrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue {
if (inst.base.isUnused())
return CValue.none;
const writer = o.writer();
const operand = try o.resolveInst(inst.operand);
const payload_ty = inst.operand.ty.errorUnionChild();
if (!payload_ty.hasCodeGenBits()) {
return CValue.none;
}
const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else ".";
const maybe_addrof = if (inst.base.ty.zigTypeTag() == .Pointer) "&" else "";
@ -1448,14 +1483,26 @@ fn genWrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue {
return local;
}
fn genIsErr(o: *Object, inst: *Inst.UnOp, deref_suffix: []const u8, op_str: []const u8) !CValue {
fn genIsErr(
o: *Object,
inst: *Inst.UnOp,
deref_prefix: [*:0]const u8,
deref_suffix: [*:0]const u8,
op_str: [*:0]const u8,
) !CValue {
const writer = o.writer();
const operand = try o.resolveInst(inst.operand);
const local = try o.allocLocal(Type.initTag(.bool), .Const);
try writer.writeAll(" = (");
try o.writeCValue(writer, operand);
try writer.print("){s}.error {s} 0;\n", .{ deref_suffix, op_str });
const payload_ty = inst.operand.ty.errorUnionChild();
if (!payload_ty.hasCodeGenBits()) {
try writer.print(" = {s}", .{deref_prefix});
try o.writeCValue(writer, operand);
try writer.print(" {s} 0;\n", .{op_str});
} else {
try writer.writeAll(" = ");
try o.writeCValue(writer, operand);
try writer.print("{s}error {s} 0;\n", .{ deref_suffix, op_str });
}
return local;
}

View File

@ -1041,7 +1041,7 @@ pub const Type = extern union {
return writer.writeAll(std.mem.spanZ(error_set.owner_decl.name));
},
.error_set_inferred => {
const func = ty.castTag(.error_set_inferred).?.data;
const func = ty.castTag(.error_set_inferred).?.data.func;
return writer.print("(inferred error set of {s})", .{func.owner_decl.name});
},
.error_set_single => {
@ -3154,7 +3154,10 @@ pub const Type = extern union {
pub const base_tag = Tag.error_set_inferred;
base: Payload = Payload{ .tag = base_tag },
data: *Module.Fn,
data: struct {
func: *Module.Fn,
map: std.StringHashMapUnmanaged(void),
},
};
pub const Pointer = struct {

View File

@ -804,6 +804,26 @@ pub fn addCases(ctx: *TestContext) !void {
});
}
{
var case = ctx.exeFromCompiledC("inferred error sets", .{});
case.addCompareOutput(
\\pub export fn main() c_int {
\\ if (foo()) |_| {
\\ @panic("test fail");
\\ } else |err| {
\\ if (err != error.ItBroke) {
\\ @panic("test fail");
\\ }
\\ }
\\ return 0;
\\}
\\fn foo() !void {
\\ return error.ItBroke;
\\}
, "");
}
ctx.h("simple header", linux_x64,
\\export fn start() void{}
,