Sema: implement switch validation for enums

This commit is contained in:
Andrew Kelley 2021-04-07 16:39:10 -07:00
parent ccdba774c8
commit 8f28e26e7a
3 changed files with 238 additions and 25 deletions

View File

@ -2729,6 +2729,8 @@ fn analyzeSwitch(
src_node_offset: i32,
) InnerError!*Inst {
const gpa = sema.gpa;
const mod = sema.mod;
const special: struct { body: []const zir.Inst.Index, end: usize } = switch (special_prong) {
.none => .{ .body = &.{}, .end = extra_end },
.under, .@"else" => blk: {
@ -2748,14 +2750,14 @@ fn analyzeSwitch(
// Validate usage of '_' prongs.
if (special_prong == .under and !operand.ty.isExhaustiveEnum()) {
const msg = msg: {
const msg = try sema.mod.errMsg(
const msg = try mod.errMsg(
&block.base,
src,
"'_' prong only allowed when switching on non-exhaustive enums",
.{},
);
errdefer msg.destroy(gpa);
try sema.mod.errNote(
try mod.errNote(
&block.base,
special_prong_src,
msg,
@ -2764,14 +2766,121 @@ fn analyzeSwitch(
);
break :msg msg;
};
return sema.mod.failWithOwnedErrorMsg(&block.base, msg);
return mod.failWithOwnedErrorMsg(&block.base, msg);
}
// Validate for duplicate items, missing else prong, and invalid range.
switch (operand.ty.zigTypeTag()) {
.Enum => return sema.mod.fail(&block.base, src, "TODO validate switch .Enum", .{}),
.ErrorSet => return sema.mod.fail(&block.base, src, "TODO validate switch .ErrorSet", .{}),
.Union => return sema.mod.fail(&block.base, src, "TODO validate switch .Union", .{}),
.Enum => {
var seen_fields = try gpa.alloc(?AstGen.SwitchProngSrc, operand.ty.enumFieldCount());
defer gpa.free(seen_fields);
var extra_index: usize = special.end;
{
var scalar_i: u32 = 0;
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
const item_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
extra_index += 1;
const body_len = sema.code.extra[extra_index];
extra_index += 1;
const body = sema.code.extra[extra_index..][0..body_len];
extra_index += body_len;
try sema.validateSwitchItemEnum(
block,
seen_fields,
item_ref,
src_node_offset,
.{ .scalar = scalar_i },
);
}
}
{
var multi_i: u32 = 0;
while (multi_i < multi_cases_len) : (multi_i += 1) {
const items_len = sema.code.extra[extra_index];
extra_index += 1;
const ranges_len = sema.code.extra[extra_index];
extra_index += 1;
const body_len = sema.code.extra[extra_index];
extra_index += 1;
const items = sema.code.refSlice(extra_index, items_len);
extra_index += items_len + body_len;
for (items) |item_ref, item_i| {
try sema.validateSwitchItemEnum(
block,
seen_fields,
item_ref,
src_node_offset,
.{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } },
);
}
try sema.validateSwitchNoRange(block, ranges_len, operand.ty, src_node_offset);
}
}
const all_tags_handled = for (seen_fields) |seen_src| {
if (seen_src == null) break false;
} else true;
switch (special_prong) {
.none => {
if (!all_tags_handled) {
const msg = msg: {
const msg = try mod.errMsg(
&block.base,
src,
"switch must handle all possibilities",
.{},
);
errdefer msg.destroy(sema.gpa);
try mod.errNoteNonLazy(
operand.ty.declSrcLoc(),
msg,
"enum '{}' declared here",
.{operand.ty},
);
for (seen_fields) |seen_src, i| {
if (seen_src != null) continue;
const field_name = operand.ty.enumFieldName(i);
// TODO have this point to the tag decl instead of here
try mod.errNote(
&block.base,
src,
msg,
"unhandled enumeration value: '{s}",
.{field_name},
);
}
break :msg msg;
};
return mod.failWithOwnedErrorMsg(&block.base, msg);
}
},
.under => {
if (all_tags_handled) return mod.fail(
&block.base,
special_prong_src,
"unreachable '_' prong; all cases already handled",
.{},
);
},
.@"else" => {
if (all_tags_handled) return mod.fail(
&block.base,
special_prong_src,
"unreachable else prong; all cases already handled",
.{},
);
},
}
},
.ErrorSet => return mod.fail(&block.base, src, "TODO validate switch .ErrorSet", .{}),
.Union => return mod.fail(&block.base, src, "TODO validate switch .Union", .{}),
.Int, .ComptimeInt => {
var range_set = RangeSet.init(gpa);
defer range_set.deinit();
@ -2844,11 +2953,11 @@ fn analyzeSwitch(
var arena = std.heap.ArenaAllocator.init(gpa);
defer arena.deinit();
const min_int = try operand.ty.minInt(&arena, sema.mod.getTarget());
const max_int = try operand.ty.maxInt(&arena, sema.mod.getTarget());
const min_int = try operand.ty.minInt(&arena, mod.getTarget());
const max_int = try operand.ty.maxInt(&arena, mod.getTarget());
if (try range_set.spans(min_int, max_int)) {
if (special_prong == .@"else") {
return sema.mod.fail(
return mod.fail(
&block.base,
special_prong_src,
"unreachable else prong; all cases already handled",
@ -2859,7 +2968,7 @@ fn analyzeSwitch(
}
}
if (special_prong != .@"else") {
return sema.mod.fail(
return mod.fail(
&block.base,
src,
"switch must handle all possibilities",
@ -2922,7 +3031,7 @@ fn analyzeSwitch(
switch (special_prong) {
.@"else" => {
if (true_count + false_count == 2) {
return sema.mod.fail(
return mod.fail(
&block.base,
src,
"unreachable else prong; all cases already handled",
@ -2932,7 +3041,7 @@ fn analyzeSwitch(
},
.under, .none => {
if (true_count + false_count < 2) {
return sema.mod.fail(
return mod.fail(
&block.base,
src,
"switch must handle all possibilities",
@ -2944,7 +3053,7 @@ fn analyzeSwitch(
},
.EnumLiteral, .Void, .Fn, .Pointer, .Type => {
if (special_prong != .@"else") {
return sema.mod.fail(
return mod.fail(
&block.base,
src,
"else prong required when switching on type '{}'",
@ -3016,7 +3125,7 @@ fn analyzeSwitch(
.AnyFrame,
.ComptimeFloat,
.Float,
=> return sema.mod.fail(&block.base, operand_src, "invalid switch operand type '{}'", .{
=> return mod.fail(&block.base, operand_src, "invalid switch operand type '{}'", .{
operand.ty,
}),
}
@ -3291,7 +3400,7 @@ fn resolveSwitchItemVal(
switch_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
range_expand: AstGen.SwitchProngSrc.RangeExpand,
) InnerError!Value {
) InnerError!TypedValue {
const item = try sema.resolveInst(item_ref);
// We have to avoid the other helper functions here because we cannot construct a LazySrcLoc
// because we only have the switch AST node. Only if we know for sure we need to report
@ -3301,7 +3410,7 @@ fn resolveSwitchItemVal(
const src = switch_prong_src.resolve(block.src_decl, switch_node_offset, range_expand);
return sema.failWithUseOfUndef(block, src);
}
return val;
return TypedValue{ .ty = item.ty, .val = val };
}
const src = switch_prong_src.resolve(block.src_decl, switch_node_offset, range_expand);
return sema.failWithNeededComptime(block, src);
@ -3316,8 +3425,8 @@ fn validateSwitchRange(
src_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
) InnerError!void {
const first_val = try sema.resolveSwitchItemVal(block, first_ref, src_node_offset, switch_prong_src, .first);
const last_val = try sema.resolveSwitchItemVal(block, last_ref, src_node_offset, switch_prong_src, .last);
const first_val = (try sema.resolveSwitchItemVal(block, first_ref, src_node_offset, switch_prong_src, .first)).val;
const last_val = (try sema.resolveSwitchItemVal(block, last_ref, src_node_offset, switch_prong_src, .last)).val;
const maybe_prev_src = try range_set.add(first_val, last_val, switch_prong_src);
return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
}
@ -3330,11 +3439,46 @@ fn validateSwitchItem(
src_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
) InnerError!void {
const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
const maybe_prev_src = try range_set.add(item_val, item_val, switch_prong_src);
return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
}
fn validateSwitchItemEnum(
sema: *Sema,
block: *Scope.Block,
seen_fields: []?AstGen.SwitchProngSrc,
item_ref: zir.Inst.Ref,
src_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
) InnerError!void {
const mod = sema.mod;
const item_tv = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
const field_index = item_tv.ty.enumTagFieldIndex(item_tv.val) orelse {
const msg = msg: {
const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
const msg = try mod.errMsg(
&block.base,
src,
"enum '{}' has no tag with value '{}'",
.{ item_tv.ty, item_tv.val },
);
errdefer msg.destroy(sema.gpa);
try mod.errNoteNonLazy(
item_tv.ty.declSrcLoc(),
msg,
"enum declared here",
.{},
);
break :msg msg;
};
return mod.failWithOwnedErrorMsg(&block.base, msg);
};
const maybe_prev_src = seen_fields[field_index];
seen_fields[field_index] = switch_prong_src;
return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
}
fn validateSwitchDupe(
sema: *Sema,
block: *Scope.Block,
@ -3343,17 +3487,18 @@ fn validateSwitchDupe(
src_node_offset: i32,
) InnerError!void {
const prev_prong_src = maybe_prev_src orelse return;
const mod = sema.mod;
const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
const prev_src = prev_prong_src.resolve(block.src_decl, src_node_offset, .none);
const msg = msg: {
const msg = try sema.mod.errMsg(
const msg = try mod.errMsg(
&block.base,
src,
"duplicate switch value",
.{},
);
errdefer msg.destroy(sema.gpa);
try sema.mod.errNote(
try mod.errNote(
&block.base,
prev_src,
msg,
@ -3362,7 +3507,7 @@ fn validateSwitchDupe(
);
break :msg msg;
};
return sema.mod.failWithOwnedErrorMsg(&block.base, msg);
return mod.failWithOwnedErrorMsg(&block.base, msg);
}
fn validateSwitchItemBool(
@ -3374,7 +3519,7 @@ fn validateSwitchItemBool(
src_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
) InnerError!void {
const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
if (item_val.toBool()) {
true_count.* += 1;
} else {
@ -3396,7 +3541,7 @@ fn validateSwitchItemSparse(
src_node_offset: i32,
switch_prong_src: AstGen.SwitchProngSrc,
) InnerError!void {
const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
const entry = (try seen_values.fetchPut(item_val, switch_prong_src)) orelse return;
return sema.validateSwitchDupe(block, entry.value, switch_prong_src, src_node_offset);
}

View File

@ -2126,6 +2126,34 @@ pub const Type = extern union {
};
}
pub fn enumFieldCount(ty: Type) usize {
switch (ty.tag()) {
.enum_full, .enum_nonexhaustive => {
const enum_full = ty.cast(Payload.EnumFull).?.data;
return enum_full.fields.count();
},
.enum_simple => {
const enum_simple = ty.castTag(.enum_simple).?.data;
return enum_simple.fields.count();
},
else => unreachable,
}
}
pub fn enumFieldName(ty: Type, field_index: usize) []const u8 {
switch (ty.tag()) {
.enum_full, .enum_nonexhaustive => {
const enum_full = ty.cast(Payload.EnumFull).?.data;
return enum_full.fields.entries.items[field_index].key;
},
.enum_simple => {
const enum_simple = ty.castTag(.enum_simple).?.data;
return enum_simple.fields.entries.items[field_index].key;
},
else => unreachable,
}
}
pub fn enumFieldIndex(ty: Type, field_name: []const u8) ?usize {
switch (ty.tag()) {
.enum_full, .enum_nonexhaustive => {
@ -2140,6 +2168,42 @@ pub const Type = extern union {
}
}
/// Asserts `ty` is an enum. `enum_tag` can either be `enum_field_index` or
/// an integer which represents the enum value. Returns the field index in
/// declaration order, or `null` if `enum_tag` does not match any field.
pub fn enumTagFieldIndex(ty: Type, enum_tag: Value) ?usize {
if (enum_tag.castTag(.enum_field_index)) |payload| {
return @as(usize, payload.data);
}
const S = struct {
fn fieldWithRange(int_val: Value, end: usize) ?usize {
if (int_val.compareWithZero(.lt)) return null;
var end_payload: Value.Payload.U64 = .{
.base = .{ .tag = .int_u64 },
.data = end,
};
const end_val = Value.initPayload(&end_payload.base);
if (int_val.compare(.gte, end_val)) return null;
return int_val.toUnsignedInt();
}
};
switch (ty.tag()) {
.enum_full, .enum_nonexhaustive => {
const enum_full = ty.cast(Payload.EnumFull).?.data;
if (enum_full.values.count() == 0) {
return S.fieldWithRange(enum_tag, enum_full.fields.count());
} else {
return enum_full.values.getIndex(enum_tag);
}
},
.enum_simple => {
const enum_simple = ty.castTag(.enum_simple).?.data;
return S.fieldWithRange(enum_tag, enum_simple.fields.count());
},
else => unreachable,
}
}
pub fn declSrcLoc(ty: Type) Module.SrcLoc {
switch (ty.tag()) {
.enum_full, .enum_nonexhaustive => {

View File

@ -552,7 +552,11 @@ pub fn addCases(ctx: *TestContext) !void {
\\ if (@enumToInt(number3) != 2) return 1;
\\ var x: Number = .Two;
\\ if (number2 != x) return 1;
\\ return 0;
\\ switch (x) {
\\ .One => return 1,
\\ .Two => return 0,
\\ number3 => return 2,
\\ }
\\}
, "");
}