Sema: add @errorCast which works for both error sets and error unions

Closes #17343
This commit is contained in:
Veikka Tuominen 2023-10-01 13:16:02 +03:00
parent d8bfbbbf25
commit 63bd2bff12
14 changed files with 90 additions and 46 deletions

View File

@ -6657,7 +6657,7 @@ test "coercion from homogenous tuple to array" {
<li>{#link|@alignCast#} - make a pointer have more alignment</li>
<li>{#link|@enumFromInt#} - obtain an enum value based on its integer tag value</li>
<li>{#link|@errorFromInt#} - obtain an error code based on its integer value</li>
<li>{#link|@errSetCast#} - convert to a smaller error set</li>
<li>{#link|@errorCast#} - convert to a smaller error set</li>
<li>{#link|@floatCast#} - convert a larger float to a smaller float</li>
<li>{#link|@floatFromInt#} - convert an integer to a float value</li>
<li>{#link|@intCast#} - convert between integer types</li>
@ -8410,10 +8410,10 @@ test "main" {
</p>
{#header_close#}
{#header_open|@errSetCast#}
<pre>{#syntax#}@errSetCast(value: anytype) anytype{#endsyntax#}</pre>
{#header_open|@errorCast#}
<pre>{#syntax#}@errorCast(value: anytype) anytype{#endsyntax#}</pre>
<p>
Converts an error value from one error set to another error set. The return type is the
Converts an error set or error union value from one error set to another error set. The return type is the
inferred result type. Attempting to convert an error which is not in the destination error
set results in safety-protected {#link|Undefined Behavior#}.
</p>
@ -10257,7 +10257,7 @@ const Set2 = error{
C,
};
comptime {
_ = @as(Set2, @errSetCast(Set1.B));
_ = @as(Set2, @errorCast(Set1.B));
}
{#code_end#}
<p>At runtime:</p>
@ -10276,7 +10276,7 @@ pub fn main() void {
foo(Set1.B);
}
fn foo(set1: Set1) void {
const x: Set2 = @errSetCast(set1);
const x: Set2 = @errorCast(set1);
std.debug.print("value: {}\n", .{x});
}
{#code_end#}

View File

@ -446,7 +446,7 @@ pub const ChildProcess = struct {
// has a value greater than 0
if ((fd[0].revents & std.os.POLL.IN) != 0) {
const err_int = try readIntFd(err_pipe[0]);
return @as(SpawnError, @errSetCast(@errorFromInt(err_int)));
return @as(SpawnError, @errorCast(@errorFromInt(err_int)));
}
} else {
// Write maxInt(ErrInt) to the write end of the err_pipe. This is after
@ -459,7 +459,7 @@ pub const ChildProcess = struct {
// Here we potentially return the fork child's error from the parent
// pid.
if (err_int != maxInt(ErrInt)) {
return @as(SpawnError, @errSetCast(@errorFromInt(err_int)));
return @as(SpawnError, @errorCast(@errorFromInt(err_int)));
}
}
}

View File

@ -5419,7 +5419,7 @@ pub fn dl_iterate_phdr(
}
}.callbackC, @as(?*anyopaque, @ptrFromInt(@intFromPtr(&context))))) {
0 => return,
else => |err| return @as(Error, @errSetCast(@errorFromInt(@as(u16, @intCast(err))))), // TODO don't hardcode u16
else => |err| return @as(Error, @errorCast(@errorFromInt(@as(u16, @intCast(err))))), // TODO don't hardcode u16
}
}

View File

@ -1444,7 +1444,7 @@ fn renderBuiltinCall(
const slice = tree.tokenSlice(builtin_token);
const rewrite_two_param_cast = params.len == 2 and for ([_][]const u8{
"@bitCast",
"@errSetCast",
"@errorCast",
"@floatCast",
"@intCast",
"@ptrCast",
@ -1505,6 +1505,8 @@ fn renderBuiltinCall(
try ais.writer().writeAll("@intFromPtr");
} else if (mem.eql(u8, slice, "@fabs")) {
try ais.writer().writeAll("@abs");
} else if (mem.eql(u8, slice, "@errSetCast")) {
try ais.writer().writeAll("@errorCast");
} else {
try renderToken(ais, tree, builtin_token, .none); // @name
}

View File

@ -8454,11 +8454,11 @@ fn builtinCall(
});
return rvalue(gz, ri, result, node);
},
.err_set_cast => {
.error_cast => {
try emitDbgNode(gz, node);
const result = try gz.addExtendedPayload(.err_set_cast, Zir.Inst.BinNode{
.lhs = try ri.rl.resultTypeForCast(gz, node, "@errSetCast"),
const result = try gz.addExtendedPayload(.error_cast, Zir.Inst.BinNode{
.lhs = try ri.rl.resultTypeForCast(gz, node, "@errorCast"),
.rhs = try expr(gz, scope, .{ .rl = .none }, params[0]),
.node = gz.nodeIndexToRelative(node),
});

View File

@ -945,7 +945,7 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
.float_cast,
.int_cast,
.truncate,
.err_set_cast,
.error_cast,
.ptr_cast,
.align_cast,
.addrspace_cast,

View File

@ -43,7 +43,7 @@ pub const Tag = enum {
error_name,
error_return_trace,
int_from_error,
err_set_cast,
error_cast,
@"export",
@"extern",
fence,
@ -455,9 +455,9 @@ pub const list = list: {
},
},
.{
"@errSetCast",
"@errorCast",
.{
.tag = .err_set_cast,
.tag = .error_cast,
.eval_to_error = .always,
.param_count = 1,
},

View File

@ -1252,7 +1252,7 @@ fn analyzeBodyInner(
.wasm_memory_size => try sema.zirWasmMemorySize( block, extended),
.wasm_memory_grow => try sema.zirWasmMemoryGrow( block, extended),
.prefetch => try sema.zirPrefetch( block, extended),
.err_set_cast => try sema.zirErrSetCast( block, extended),
.error_cast => try sema.zirErrorCast( block, extended),
.await_nosuspend => try sema.zirAwaitNosuspend( block, extended),
.select => try sema.zirSelect( block, extended),
.int_from_error => try sema.zirIntFromError( block, extended),
@ -21747,17 +21747,31 @@ fn ptrFromIntVal(
};
}
fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
const mod = sema.mod;
const ip = &mod.intern_pool;
const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
const src = LazySrcLoc.nodeOffset(extra.node);
const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@errSetCast");
const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
const operand = try sema.resolveInst(extra.rhs);
const operand_ty = sema.typeOf(operand);
try sema.checkErrorSetType(block, src, dest_ty);
try sema.checkErrorSetType(block, operand_src, operand_ty);
const base_operand_ty = sema.typeOf(operand);
const dest_tag = base_dest_ty.zigTypeTag(mod);
const operand_tag = base_operand_ty.zigTypeTag(mod);
if (dest_tag != operand_tag) {
return sema.fail(block, src, "expected source and destination types to match, found '{s}' and '{s}'", .{
@tagName(operand_tag), @tagName(dest_tag),
});
} else if (dest_tag != .ErrorSet and dest_tag != .ErrorUnion) {
return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)});
}
const dest_ty, const operand_ty = if (dest_tag == .ErrorUnion) .{
base_dest_ty.errorUnionSet(mod),
base_operand_ty.errorUnionSet(mod),
} else .{
base_dest_ty,
base_operand_ty,
};
// operand must be defined since it can be an invalid error value
const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
@ -21804,8 +21818,15 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
}
if (maybe_operand_val) |val| {
if (!dest_ty.isAnyError(mod)) {
const error_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
if (!dest_ty.isAnyError(mod)) check: {
const operand_val = mod.intern_pool.indexToKey(val.toIntern());
var error_name: InternPool.NullTerminatedString = undefined;
if (dest_tag == .ErrorUnion) {
if (operand_val.error_union.val != .err_name) break :check;
error_name = operand_val.error_union.val.err_name;
} else {
error_name = operand_val.err.name;
}
if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) {
const msg = msg: {
const msg = try sema.errMsg(
@ -21822,16 +21843,29 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
}
}
return Air.internedToRef((try mod.getCoerced(val, dest_ty)).toIntern());
return Air.internedToRef((try mod.getCoerced(val, base_dest_ty)).toIntern());
}
try sema.requireRuntimeBlock(block, src, operand_src);
if (block.wantSafety() and !dest_ty.isAnyError(mod) and sema.mod.backendSupportsFeature(.error_set_has_value)) {
if (dest_tag == .ErrorUnion) {
const err_code = try sema.analyzeErrUnionCode(block, operand_src, operand);
const err_int = try block.addBitCast(Type.err_int, err_code);
const zero_u16 = Air.internedToRef(try mod.intern(.{
.int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } },
}));
const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16);
const ok = try block.addBinOp(.bit_or, has_value, is_zero);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
} else {
const err_int_inst = try block.addBitCast(Type.err_int, operand);
const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
}
return block.addBitCast(dest_ty, operand);
}
return block.addBitCast(base_dest_ty, operand);
}
fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
@ -22916,14 +22950,6 @@ fn checkIntOrVectorAllowComptime(
}
}
fn checkErrorSetType(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) CompileError!void {
const mod = sema.mod;
switch (ty.zigTypeTag(mod)) {
.ErrorSet => return,
else => return sema.fail(block, src, "expected error set type, found '{}'", .{ty.fmt(mod)}),
}
}
const SimdBinOp = struct {
len: ?usize,
/// Coerced to `result_ty`.

View File

@ -1997,9 +1997,9 @@ pub const Inst = struct {
/// Implements `@setCold`.
/// `operand` is payload index to `UnNode`.
set_cold,
/// Implements the `@errSetCast` builtin.
/// Implements the `@errorCast` builtin.
/// `operand` is payload index to `BinNode`. `lhs` is dest type, `rhs` is operand.
err_set_cast,
error_cast,
/// `operand` is payload index to `UnNode`.
await_nosuspend,
/// Implements `@breakpoint`.

View File

@ -594,7 +594,7 @@ const Writer = struct {
.builtin_extern,
.c_define,
.err_set_cast,
.error_cast,
.wasm_memory_grow,
.prefetch,
.c_va_arg,

View File

@ -228,13 +228,29 @@ const Set1 = error{ A, B };
const Set2 = error{ A, C };
fn testExplicitErrorSetCast(set1: Set1) !void {
var x = @as(Set2, @errSetCast(set1));
var x = @as(Set2, @errorCast(set1));
try expect(@TypeOf(x) == Set2);
var y = @as(Set1, @errSetCast(x));
var y = @as(Set1, @errorCast(x));
try expect(@TypeOf(y) == Set1);
try expect(y == error.A);
}
test "@errorCast on error unions" {
const S = struct {
fn doTheTest() !void {
const casted: error{Bad}!i32 = @errorCast(retErrUnion());
try expect((try casted) == 1234);
}
fn retErrUnion() anyerror!i32 {
return 1234;
}
};
try S.doTheTest();
try comptime S.doTheTest();
}
test "comptime test error for empty error set" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO

View File

@ -2,7 +2,7 @@ const Set1 = error{ A, B };
const Set2 = error{ A, C };
comptime {
var x = Set1.B;
var y: Set2 = @errSetCast(x);
var y: Set2 = @errorCast(x);
_ = y;
}

View File

@ -8,7 +8,7 @@ const Set2 = error{
};
comptime {
var x = @intFromError(Set1.B);
var y: Set2 = @errSetCast(@errorFromInt(x));
var y: Set2 = @errorCast(@errorFromInt(x));
_ = y;
}

View File

@ -14,7 +14,7 @@ pub fn main() !void {
return error.TestFailed;
}
fn foo(set1: Set1) Set2 {
return @errSetCast(set1);
return @errorCast(set1);
}
// run
// backend=llvm