sema: refactor error set switch logic

This commit is contained in:
dweiller 2023-11-22 14:59:02 +11:00
parent 4136097566
commit b784f64a6e

View File

@ -11212,16 +11212,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len);
defer case_vals.deinit(gpa);
const Special = struct {
body: []const Zir.Inst.Index,
end: usize,
capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
is_inline: bool,
has_tag_capture: bool,
};
const special_prong = extra.data.bits.specialProng();
const special: Special = switch (special_prong) {
const special: SpecialProng = switch (special_prong) {
.none => .{
.body = &.{},
.end = header_extra_index,
@ -11401,150 +11393,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
);
}
},
.ErrorSet => {
var extra_index: usize = special.end;
{
var scalar_i: u32 = 0;
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
const item_ref: Zir.Inst.Ref = @enumFromInt(sema.code.extra[extra_index]);
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1 + info.body_len;
case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
block,
&seen_errors,
item_ref,
operand_ty,
src_node_offset,
.{ .scalar = scalar_i },
));
}
}
{
var multi_i: u32 = 0;
while (multi_i < multi_cases_len) : (multi_i += 1) {
const items_len = sema.code.extra[extra_index];
extra_index += 1;
const ranges_len = sema.code.extra[extra_index];
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1;
const items = sema.code.refSlice(extra_index, items_len);
extra_index += items_len + info.body_len;
try case_vals.ensureUnusedCapacity(gpa, items.len);
for (items, 0..) |item_ref, item_i| {
case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
block,
&seen_errors,
item_ref,
operand_ty,
src_node_offset,
.{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } },
));
}
try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset);
}
}
switch (try sema.resolveInferredErrorSetTy(block, src, operand_ty.toIntern())) {
.anyerror_type => {
if (special_prong != .@"else") {
return sema.fail(
block,
src,
"else prong required when switching on type 'anyerror'",
.{},
);
}
else_error_ty = Type.anyerror;
},
else => |err_set_ty_index| else_validation: {
const error_names = ip.indexToKey(err_set_ty_index).error_set_type.names;
var maybe_msg: ?*Module.ErrorMsg = null;
errdefer if (maybe_msg) |msg| msg.destroy(sema.gpa);
for (error_names.get(ip)) |error_name| {
if (!seen_errors.contains(error_name) and special_prong != .@"else") {
const msg = maybe_msg orelse blk: {
maybe_msg = try sema.errMsg(
block,
src,
"switch must handle all possibilities",
.{},
);
break :blk maybe_msg.?;
};
try sema.errNote(
block,
src,
msg,
"unhandled error value: 'error.{}'",
.{error_name.fmt(ip)},
);
}
}
if (maybe_msg) |msg| {
maybe_msg = null;
try sema.addDeclaredHereNote(msg, operand_ty);
return sema.failWithOwnedErrorMsg(block, msg);
}
if (special_prong == .@"else" and
seen_errors.count() == error_names.len)
{
// In order to enable common patterns for generic code allow simple else bodies
// else => unreachable,
// else => return,
// else => |e| return e,
// even if all the possible errors were already handled.
const tags = sema.code.instructions.items(.tag);
for (special.body) |else_inst| switch (tags[@intFromEnum(else_inst)]) {
.dbg_block_begin,
.dbg_block_end,
.dbg_stmt,
.dbg_var_val,
.ret_type,
.as_node,
.ret_node,
.@"unreachable",
.@"defer",
.defer_err_code,
.err_union_code,
.ret_err_value_code,
.restore_err_ret_index,
.is_non_err,
.ret_is_non_err,
.condbr,
=> {},
else => break,
} else break :else_validation;
return sema.fail(
block,
special_prong_src,
"unreachable else prong; all cases already handled",
.{},
);
}
var names: InferredErrorSet.NameMap = .{};
try names.ensureUnusedCapacity(sema.arena, error_names.len);
for (error_names.get(ip)) |error_name| {
if (seen_errors.contains(error_name)) continue;
names.putAssumeCapacityNoClobber(error_name, {});
}
// No need to keep the hash map metadata correct; here we
// extract the (sorted) keys only.
else_error_ty = try mod.errorSetFromUnsortedNames(names.keys());
},
}
},
.ErrorSet => else_error_ty = try validateErrSetSwitch(
sema,
block,
&seen_errors,
&case_vals,
operand_ty,
inst_data,
scalar_cases_len,
multi_cases_len,
.{ .body = special.body, .end = special.end, .src = special_prong_src },
special_prong == .@"else",
),
.Int, .ComptimeInt => {
var extra_index: usize = special.end;
{
@ -11840,114 +11700,19 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
defer merges.deinit(gpa);
if (try sema.resolveDefinedValue(&child_block, src, operand)) |operand_val| {
const resolved_operand_val = try sema.resolveLazyValue(operand_val);
var extra_index: usize = special.end;
{
var scalar_i: usize = 0;
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1;
const body = sema.code.bodySlice(extra_index, info.body_len);
extra_index += info.body_len;
const item = case_vals.items[scalar_i];
const item_val = sema.resolveConstDefinedValue(&child_block, .unneeded, item, undefined) catch unreachable;
if (operand_val.eql(item_val, operand_ty, sema.mod)) {
if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
return spa.resolveProngComptime(
&child_block,
.normal,
body,
info.capture,
.{ .scalar_capture = @intCast(scalar_i) },
&.{item},
if (info.is_inline) operand else .none,
info.has_tag_capture,
merges,
);
}
}
}
{
var multi_i: usize = 0;
var case_val_idx: usize = scalar_cases_len;
while (multi_i < multi_cases_len) : (multi_i += 1) {
const items_len = sema.code.extra[extra_index];
extra_index += 1;
const ranges_len = sema.code.extra[extra_index];
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1 + items_len;
const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len);
const items = case_vals.items[case_val_idx..][0..items_len];
case_val_idx += items_len;
for (items) |item| {
// Validation above ensured these will succeed.
const item_val = sema.resolveConstDefinedValue(&child_block, .unneeded, item, undefined) catch unreachable;
if (operand_val.eql(item_val, operand_ty, sema.mod)) {
if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
return spa.resolveProngComptime(
&child_block,
.normal,
body,
info.capture,
.{ .multi_capture = @intCast(multi_i) },
items,
if (info.is_inline) operand else .none,
info.has_tag_capture,
merges,
);
}
}
var range_i: usize = 0;
while (range_i < ranges_len) : (range_i += 1) {
const range_items = case_vals.items[case_val_idx..][0..2];
extra_index += 2;
case_val_idx += 2;
// Validation above ensured these will succeed.
const first_val = sema.resolveConstDefinedValue(&child_block, .unneeded, range_items[0], undefined) catch unreachable;
const last_val = sema.resolveConstDefinedValue(&child_block, .unneeded, range_items[1], undefined) catch unreachable;
if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and
(try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty)))
{
if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
return spa.resolveProngComptime(
&child_block,
.normal,
body,
info.capture,
.{ .multi_capture = @intCast(multi_i) },
undefined, // case_vals may be undefined for ranges
if (info.is_inline) operand else .none,
info.has_tag_capture,
merges,
);
}
}
extra_index += info.body_len;
}
}
if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand);
if (empty_enum) {
return .void_value;
}
return spa.resolveProngComptime(
return resolveSwitchComptime(
sema,
spa,
&child_block,
.special,
special.body,
special.capture,
.special_capture,
undefined, // case_vals may be undefined for special prongs
if (special.is_inline) operand else .none,
special.has_tag_capture,
merges,
operand,
operand_val,
operand_ty,
special,
case_vals,
scalar_cases_len,
multi_cases_len,
err_set,
empty_enum,
);
}
@ -12593,6 +12358,140 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
return sema.analyzeBlockBody(block, src, &child_block, merges);
}
const SpecialProng = struct {
body: []const Zir.Inst.Index,
end: usize,
capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
is_inline: bool,
has_tag_capture: bool,
};
fn resolveSwitchComptime(
sema: *Sema,
spa: SwitchProngAnalysis,
child_block: *Block,
cond_operand: Air.Inst.Ref,
operand_val: Value,
operand_ty: Type,
special: SpecialProng,
case_vals: std.ArrayListUnmanaged(Air.Inst.Ref),
scalar_cases_len: u32,
multi_cases_len: u32,
err_set: bool,
empty_enum: bool,
) CompileError!Air.Inst.Ref {
const merges = &child_block.label.?.merges;
const resolved_operand_val = try sema.resolveLazyValue(operand_val);
var extra_index: usize = special.end;
{
var scalar_i: usize = 0;
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1;
const body = sema.code.bodySlice(extra_index, info.body_len);
extra_index += info.body_len;
const item = case_vals.items[scalar_i];
const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable;
if (operand_val.eql(item_val, operand_ty, sema.mod)) {
if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
return spa.resolveProngComptime(
child_block,
.normal,
body,
info.capture,
.{ .scalar_capture = @intCast(scalar_i) },
&.{item},
if (info.is_inline) cond_operand else .none,
info.has_tag_capture,
merges,
);
}
}
}
{
var multi_i: usize = 0;
var case_val_idx: usize = scalar_cases_len;
while (multi_i < multi_cases_len) : (multi_i += 1) {
const items_len = sema.code.extra[extra_index];
extra_index += 1;
const ranges_len = sema.code.extra[extra_index];
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1 + items_len;
const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len);
const items = case_vals.items[case_val_idx..][0..items_len];
case_val_idx += items_len;
for (items) |item| {
// Validation above ensured these will succeed.
const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable;
if (operand_val.eql(item_val, operand_ty, sema.mod)) {
if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
return spa.resolveProngComptime(
child_block,
.normal,
body,
info.capture,
.{ .multi_capture = @intCast(multi_i) },
items,
if (info.is_inline) cond_operand else .none,
info.has_tag_capture,
merges,
);
}
}
var range_i: usize = 0;
while (range_i < ranges_len) : (range_i += 1) {
const range_items = case_vals.items[case_val_idx..][0..2];
extra_index += 2;
case_val_idx += 2;
// Validation above ensured these will succeed.
const first_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[0], undefined) catch unreachable;
const last_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[1], undefined) catch unreachable;
if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and
(try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty)))
{
if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
return spa.resolveProngComptime(
child_block,
.normal,
body,
info.capture,
.{ .multi_capture = @intCast(multi_i) },
undefined, // case_vals may be undefined for ranges
if (info.is_inline) cond_operand else .none,
info.has_tag_capture,
merges,
);
}
}
extra_index += info.body_len;
}
}
if (err_set) try sema.maybeErrorUnwrapComptime(child_block, special.body, cond_operand);
if (empty_enum) {
return .void_value;
}
return spa.resolveProngComptime(
child_block,
.special,
special.body,
special.capture,
.special_capture,
undefined, // case_vals may be undefined for special prongs
if (special.is_inline) cond_operand else .none,
special.has_tag_capture,
merges,
);
}
const RangeSetUnhandledIterator = struct {
mod: *Module,
cur: ?InternPool.Index,
@ -12710,6 +12609,168 @@ fn resolveSwitchItemVal(
return .{ .ref = new_item, .val = val.toIntern() };
}
fn validateErrSetSwitch(
sema: *Sema,
block: *Block,
seen_errors: *SwitchErrorSet,
case_vals: *std.ArrayListUnmanaged(Air.Inst.Ref),
operand_ty: Type,
inst_data: std.meta.FieldType(Zir.Inst.Data, .pl_node),
scalar_cases_len: u32,
multi_cases_len: u32,
else_case: struct { body: []const Zir.Inst.Index, end: usize, src: LazySrcLoc },
has_else: bool,
) CompileError!?Type {
const gpa = sema.gpa;
const mod = sema.mod;
const ip = &mod.intern_pool;
const src_node_offset = inst_data.src_node;
const src = inst_data.src();
var extra_index: usize = else_case.end;
{
var scalar_i: u32 = 0;
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
const item_ref: Zir.Inst.Ref = @enumFromInt(sema.code.extra[extra_index]);
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1 + info.body_len;
case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
block,
seen_errors,
item_ref,
operand_ty,
src_node_offset,
.{ .scalar = scalar_i },
));
}
}
{
var multi_i: u32 = 0;
while (multi_i < multi_cases_len) : (multi_i += 1) {
const items_len = sema.code.extra[extra_index];
extra_index += 1;
const ranges_len = sema.code.extra[extra_index];
extra_index += 1;
const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
extra_index += 1;
const items = sema.code.refSlice(extra_index, items_len);
extra_index += items_len + info.body_len;
try case_vals.ensureUnusedCapacity(gpa, items.len);
for (items, 0..) |item_ref, item_i| {
case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
block,
seen_errors,
item_ref,
operand_ty,
src_node_offset,
.{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } },
));
}
try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset);
}
}
switch (try sema.resolveInferredErrorSetTy(block, src, operand_ty.toIntern())) {
.anyerror_type => {
if (!has_else) {
return sema.fail(
block,
src,
"else prong required when switching on type 'anyerror'",
.{},
);
}
return Type.anyerror;
},
else => |err_set_ty_index| else_validation: {
const error_names = ip.indexToKey(err_set_ty_index).error_set_type.names;
var maybe_msg: ?*Module.ErrorMsg = null;
errdefer if (maybe_msg) |msg| msg.destroy(sema.gpa);
for (error_names.get(ip)) |error_name| {
if (!seen_errors.contains(error_name) and !has_else) {
const msg = maybe_msg orelse blk: {
maybe_msg = try sema.errMsg(
block,
src,
"switch must handle all possibilities",
.{},
);
break :blk maybe_msg.?;
};
try sema.errNote(
block,
src,
msg,
"unhandled error value: 'error.{}'",
.{error_name.fmt(ip)},
);
}
}
if (maybe_msg) |msg| {
maybe_msg = null;
try sema.addDeclaredHereNote(msg, operand_ty);
return sema.failWithOwnedErrorMsg(block, msg);
}
if (has_else and seen_errors.count() == error_names.len) {
// In order to enable common patterns for generic code allow simple else bodies
// else => unreachable,
// else => return,
// else => |e| return e,
// even if all the possible errors were already handled.
const tags = sema.code.instructions.items(.tag);
for (else_case.body) |else_inst| switch (tags[@intFromEnum(else_inst)]) {
.dbg_block_begin,
.dbg_block_end,
.dbg_stmt,
.dbg_var_val,
.ret_type,
.as_node,
.ret_node,
.@"unreachable",
.@"defer",
.defer_err_code,
.err_union_code,
.ret_err_value_code,
.restore_err_ret_index,
.is_non_err,
.ret_is_non_err,
.condbr,
=> {},
else => break,
} else break :else_validation;
return sema.fail(
block,
else_case.src,
"unreachable else prong; all cases already handled",
.{},
);
}
var names: InferredErrorSet.NameMap = .{};
try names.ensureUnusedCapacity(sema.arena, error_names.len);
for (error_names.get(ip)) |error_name| {
if (seen_errors.contains(error_name)) continue;
names.putAssumeCapacityNoClobber(error_name, {});
}
// No need to keep the hash map metadata correct; here we
// extract the (sorted) keys only.
return try mod.errorSetFromUnsortedNames(names.keys());
},
}
return null;
}
fn validateSwitchRange(
sema: *Sema,
block: *Block,