Sema: implement comparison analysis for non-numeric types

This commit is contained in:
Andrew Kelley 2021-04-07 12:15:05 -07:00
parent d9c25ec672
commit 18119aae30
3 changed files with 94 additions and 9 deletions

View File

@ -3776,9 +3776,13 @@ fn zirCmp(
const tracy = trace(@src());
defer tracy.end();
const mod = sema.mod;
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
const extra = sema.code.extraData(zir.Inst.Bin, inst_data.payload_index).data;
const src: LazySrcLoc = inst_data.src();
const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
const lhs = try sema.resolveInst(extra.lhs);
const rhs = try sema.resolveInst(extra.rhs);
@ -3790,7 +3794,7 @@ fn zirCmp(
const rhs_ty_tag = rhs.ty.zigTypeTag();
if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
// null == null, null != null
return sema.mod.constBool(sema.arena, src, op == .eq);
return mod.constBool(sema.arena, src, op == .eq);
} else if (is_equality_cmp and
((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
@ -3801,23 +3805,23 @@ fn zirCmp(
} else if (is_equality_cmp and
((lhs_ty_tag == .Null and rhs.ty.isCPtr()) or (rhs_ty_tag == .Null and lhs.ty.isCPtr())))
{
return sema.mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
} else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
const non_null_type = if (lhs_ty_tag == .Null) rhs.ty else lhs.ty;
return sema.mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
} else if (is_equality_cmp and
((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
(rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
{
return sema.mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
} else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
if (!is_equality_cmp) {
return sema.mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
}
if (rhs.value()) |rval| {
if (lhs.value()) |lval| {
// TODO optimisation oppurtunity: evaluate if std.mem.eql is faster with the names, or calling to Module.getErrorValue to get the values and then compare them is faster
return sema.mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
return mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
}
}
try sema.requireRuntimeBlock(block, src);
@ -3829,11 +3833,30 @@ fn zirCmp(
return sema.cmpNumeric(block, src, lhs, rhs, op);
} else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
if (!is_equality_cmp) {
return sema.mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
}
return sema.mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
return mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
}
return sema.mod.fail(&block.base, src, "TODO implement more cmp analysis", .{});
const instructions = &[_]*Inst{ lhs, rhs };
const resolved_type = try sema.resolvePeerTypes(block, src, instructions);
if (!resolved_type.isSelfComparable(is_equality_cmp)) {
return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type});
}
const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
try sema.requireRuntimeBlock(block, src); // TODO try to do it at comptime
const bool_type = Type.initTag(.bool); // TODO handle vectors
const tag: Inst.Tag = switch (op) {
.lt => .cmp_lt,
.lte => .cmp_lte,
.eq => .cmp_eq,
.gte => .cmp_gte,
.gt => .cmp_gt,
.neq => .cmp_neq,
};
return block.addBinOp(src, bool_type, tag, casted_lhs, casted_rhs);
}
fn zirTypeof(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {

View File

@ -107,6 +107,42 @@ pub const Type = extern union {
}
}
pub fn isSelfComparable(ty: Type, is_equality_cmp: bool) bool {
return switch (ty.zigTypeTag()) {
.Int,
.Float,
.ComptimeFloat,
.ComptimeInt,
.Vector, // TODO some vectors require is_equality_cmp==true
=> true,
.Bool,
.Type,
.Void,
.ErrorSet,
.Fn,
.BoundFn,
.Opaque,
.AnyFrame,
.Enum,
.EnumLiteral,
=> is_equality_cmp,
.NoReturn,
.Array,
.Struct,
.Undefined,
.Null,
.ErrorUnion,
.Union,
.Frame,
=> false,
.Pointer => is_equality_cmp or ty.isCPtr(),
.Optional => is_equality_cmp and ty.isAbiPtr(),
};
}
pub fn initTag(comptime small_tag: Tag) Type {
comptime assert(@enumToInt(small_tag) < Tag.no_payload_count);
return .{ .tag_if_small_enough = @enumToInt(small_tag) };
@ -1583,6 +1619,11 @@ pub const Type = extern union {
}
}
/// Returns whether the type is represented as a pointer in the ABI.
pub fn isAbiPtr(self: Type) bool {
@panic("TODO implement this");
}
/// Asserts that the type is an error union.
pub fn errorUnionChild(self: Type) Type {
return switch (self.tag()) {

View File

@ -536,6 +536,27 @@ pub fn addCases(ctx: *TestContext) !void {
, "");
}
{
var case = ctx.exeFromCompiledC("enums", .{});
case.addCompareOutput(
\\const Number = enum { One, Two, Three };
\\
\\export fn main() c_int {
\\ var number1 = Number.One;
\\ var number2: Number = .Two;
\\ const number3 = @intToEnum(Number, 2);
\\ if (number1 == number2) return 1;
\\ if (number2 == number3) return 1;
\\ if (@enumToInt(number1) != 0) return 1;
\\ if (@enumToInt(number2) != 1) return 1;
\\ if (@enumToInt(number3) != 2) return 1;
\\ var x: Number = .Two;
\\ if (number2 != x) return 1;
\\ return 0;
\\}
, "");
}
ctx.c("empty start function", linux_x64,
\\export fn _start() noreturn {
\\ unreachable;