mirror of
https://github.com/ziglang/zig.git
synced 2025-12-30 18:13:19 +00:00
Sema: improved AIR when one operand of bool cmp is known
When doing `x == true` or `x == false` it is now lowered as either a no-op or a not, respectively, rather than a cmp instruction. This commit also extracts a zirCmpEq function out from zirCmp, reducing the amount of branching (on is_equality_cmp) in both functions.
This commit is contained in:
parent
507dc1f2e7
commit
6e78c007df
226
src/Sema.zig
226
src/Sema.zig
@ -193,12 +193,12 @@ pub fn analyzeBody(
|
||||
.call_compile_time => try sema.zirCall(block, inst, .compile_time, false),
|
||||
.call_nosuspend => try sema.zirCall(block, inst, .no_async, false),
|
||||
.call_async => try sema.zirCall(block, inst, .async_kw, false),
|
||||
.cmp_eq => try sema.zirCmp(block, inst, .eq),
|
||||
.cmp_gt => try sema.zirCmp(block, inst, .gt),
|
||||
.cmp_gte => try sema.zirCmp(block, inst, .gte),
|
||||
.cmp_lt => try sema.zirCmp(block, inst, .lt),
|
||||
.cmp_lte => try sema.zirCmp(block, inst, .lte),
|
||||
.cmp_neq => try sema.zirCmp(block, inst, .neq),
|
||||
.cmp_eq => try sema.zirCmpEq(block, inst, .eq, .cmp_eq),
|
||||
.cmp_gte => try sema.zirCmp(block, inst, .gte),
|
||||
.cmp_gt => try sema.zirCmp(block, inst, .gt),
|
||||
.cmp_neq => try sema.zirCmpEq(block, inst, .neq, .cmp_neq),
|
||||
.coerce_result_ptr => try sema.zirCoerceResultPtr(block, inst),
|
||||
.decl_ref => try sema.zirDeclRef(block, inst),
|
||||
.decl_val => try sema.zirDeclVal(block, inst),
|
||||
@ -5040,6 +5040,97 @@ fn zirAsm(
|
||||
return asm_air;
|
||||
}
|
||||
|
||||
/// Only called for equality operators. See also `zirCmp`.
|
||||
fn zirCmpEq(
|
||||
sema: *Sema,
|
||||
block: *Scope.Block,
|
||||
inst: Zir.Inst.Index,
|
||||
op: std.math.CompareOperator,
|
||||
air_tag: Air.Inst.Tag,
|
||||
) CompileError!Air.Inst.Ref {
|
||||
const tracy = trace(@src());
|
||||
defer tracy.end();
|
||||
|
||||
const mod = sema.mod;
|
||||
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
|
||||
const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
|
||||
const src: LazySrcLoc = inst_data.src();
|
||||
const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
|
||||
const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
|
||||
const lhs = sema.resolveInst(extra.lhs);
|
||||
const rhs = sema.resolveInst(extra.rhs);
|
||||
|
||||
const lhs_ty = sema.typeOf(lhs);
|
||||
const rhs_ty = sema.typeOf(rhs);
|
||||
const lhs_ty_tag = lhs_ty.zigTypeTag();
|
||||
const rhs_ty_tag = rhs_ty.zigTypeTag();
|
||||
if (lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
|
||||
// null == null, null != null
|
||||
if (op == .eq) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
if (((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
|
||||
rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
|
||||
{
|
||||
// comparing null with optionals
|
||||
const opt_operand = if (lhs_ty_tag == .Optional) lhs else rhs;
|
||||
return sema.analyzeIsNull(block, src, opt_operand, op == .neq);
|
||||
}
|
||||
if (((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr()))) {
|
||||
return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
|
||||
}
|
||||
if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
|
||||
const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
|
||||
return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
|
||||
}
|
||||
if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
|
||||
(rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
|
||||
{
|
||||
return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
|
||||
}
|
||||
if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
|
||||
const runtime_src: LazySrcLoc = src: {
|
||||
if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| {
|
||||
if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| {
|
||||
if (lval.isUndef() or rval.isUndef()) {
|
||||
return sema.addConstUndef(Type.initTag(.bool));
|
||||
}
|
||||
// TODO optimisation opportunity: evaluate if mem.eql is faster with the names,
|
||||
// or calling to Module.getErrorValue to get the values and then compare them is
|
||||
// faster.
|
||||
const lhs_name = lval.castTag(.@"error").?.data.name;
|
||||
const rhs_name = rval.castTag(.@"error").?.data.name;
|
||||
if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
} else {
|
||||
break :src rhs_src;
|
||||
}
|
||||
} else {
|
||||
break :src lhs_src;
|
||||
}
|
||||
};
|
||||
try sema.requireRuntimeBlock(block, runtime_src);
|
||||
return block.addBinOp(air_tag, lhs, rhs);
|
||||
}
|
||||
if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
|
||||
const lhs_as_type = try sema.analyzeAsType(block, lhs_src, lhs);
|
||||
const rhs_as_type = try sema.analyzeAsType(block, rhs_src, rhs);
|
||||
if (lhs_as_type.eql(rhs_as_type) == (op == .eq)) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true);
|
||||
}
|
||||
|
||||
/// Only called for non-equality operators. See also `zirCmpEq`.
|
||||
fn zirCmp(
|
||||
sema: *Sema,
|
||||
block: *Scope.Block,
|
||||
@ -5049,8 +5140,6 @@ fn zirCmp(
|
||||
const tracy = trace(@src());
|
||||
defer tracy.end();
|
||||
|
||||
const mod = sema.mod;
|
||||
|
||||
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
|
||||
const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
|
||||
const src: LazySrcLoc = inst_data.src();
|
||||
@ -5058,87 +5147,34 @@ fn zirCmp(
|
||||
const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
|
||||
const lhs = sema.resolveInst(extra.lhs);
|
||||
const rhs = sema.resolveInst(extra.rhs);
|
||||
return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, false);
|
||||
}
|
||||
|
||||
const is_equality_cmp = switch (op) {
|
||||
.eq, .neq => true,
|
||||
else => false,
|
||||
};
|
||||
fn analyzeCmp(
|
||||
sema: *Sema,
|
||||
block: *Scope.Block,
|
||||
src: LazySrcLoc,
|
||||
lhs: Air.Inst.Ref,
|
||||
rhs: Air.Inst.Ref,
|
||||
op: std.math.CompareOperator,
|
||||
lhs_src: LazySrcLoc,
|
||||
rhs_src: LazySrcLoc,
|
||||
is_equality_cmp: bool,
|
||||
) CompileError!Air.Inst.Ref {
|
||||
const lhs_ty = sema.typeOf(lhs);
|
||||
const rhs_ty = sema.typeOf(rhs);
|
||||
const lhs_ty_tag = lhs_ty.zigTypeTag();
|
||||
const rhs_ty_tag = rhs_ty.zigTypeTag();
|
||||
if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
|
||||
// null == null, null != null
|
||||
if (op == .eq) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
} else if (is_equality_cmp and
|
||||
((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
|
||||
rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
|
||||
{
|
||||
// comparing null with optionals
|
||||
const opt_operand = if (lhs_ty_tag == .Optional) lhs else rhs;
|
||||
return sema.analyzeIsNull(block, src, opt_operand, op == .neq);
|
||||
} else if (is_equality_cmp and
|
||||
((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr())))
|
||||
{
|
||||
return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
|
||||
} else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
|
||||
const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
|
||||
return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
|
||||
} else if (is_equality_cmp and
|
||||
((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
|
||||
(rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
|
||||
{
|
||||
return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
|
||||
} else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
|
||||
if (!is_equality_cmp) {
|
||||
return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
|
||||
}
|
||||
if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| {
|
||||
if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| {
|
||||
if (lval.isUndef() or rval.isUndef()) {
|
||||
return sema.addConstUndef(Type.initTag(.bool));
|
||||
}
|
||||
// TODO optimisation opportunity: evaluate if mem.eql is faster with the names,
|
||||
// or calling to Module.getErrorValue to get the values and then compare them is
|
||||
// faster.
|
||||
const lhs_name = lval.castTag(.@"error").?.data.name;
|
||||
const rhs_name = rval.castTag(.@"error").?.data.name;
|
||||
if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
}
|
||||
try sema.requireRuntimeBlock(block, src);
|
||||
const tag: Air.Inst.Tag = if (op == .eq) .cmp_eq else .cmp_neq;
|
||||
return block.addBinOp(tag, lhs, rhs);
|
||||
} else if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) {
|
||||
if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) {
|
||||
// This operation allows any combination of integer and float types, regardless of the
|
||||
// signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for
|
||||
// numeric types.
|
||||
return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src);
|
||||
} else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
|
||||
if (!is_equality_cmp) {
|
||||
return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
|
||||
}
|
||||
const lhs_as_type = try sema.analyzeAsType(block, lhs_src, lhs);
|
||||
const rhs_as_type = try sema.analyzeAsType(block, rhs_src, rhs);
|
||||
if (lhs_as_type.eql(rhs_as_type) == (op == .eq)) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
|
||||
const resolved_type = try sema.resolvePeerTypes(block, src, instructions);
|
||||
if (!resolved_type.isSelfComparable(is_equality_cmp)) {
|
||||
return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type});
|
||||
return sema.mod.fail(&block.base, src, "{s} operator not allowed for type '{}'", .{
|
||||
@tagName(op), resolved_type,
|
||||
});
|
||||
}
|
||||
|
||||
const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
|
||||
@ -5146,19 +5182,31 @@ fn zirCmp(
|
||||
|
||||
const runtime_src: LazySrcLoc = src: {
|
||||
if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
|
||||
if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type);
|
||||
if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| {
|
||||
if (lhs_val.isUndef() or rhs_val.isUndef()) {
|
||||
return sema.addConstUndef(resolved_type);
|
||||
}
|
||||
if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type);
|
||||
|
||||
if (lhs_val.compare(op, rhs_val, resolved_type)) {
|
||||
return Air.Inst.Ref.bool_true;
|
||||
} else {
|
||||
return Air.Inst.Ref.bool_false;
|
||||
}
|
||||
} else {
|
||||
if (resolved_type.zigTypeTag() == .Bool) {
|
||||
// We can lower bool eq/neq more efficiently.
|
||||
return sema.runtimeBoolCmp(block, op, casted_rhs, lhs_val.toBool(), rhs_src);
|
||||
}
|
||||
break :src rhs_src;
|
||||
}
|
||||
} else {
|
||||
// For bools, we still check the other operand, because we can lower
|
||||
// bool eq/neq more efficiently.
|
||||
if (resolved_type.zigTypeTag() == .Bool) {
|
||||
if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| {
|
||||
if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type);
|
||||
return sema.runtimeBoolCmp(block, op, casted_lhs, rhs_val.toBool(), lhs_src);
|
||||
}
|
||||
}
|
||||
break :src lhs_src;
|
||||
}
|
||||
};
|
||||
@ -5176,6 +5224,26 @@ fn zirCmp(
|
||||
return block.addBinOp(tag, casted_lhs, casted_rhs);
|
||||
}
|
||||
|
||||
/// cmp_eq (x, false) => not(x)
|
||||
/// cmp_eq (x, true ) => x
|
||||
/// cmp_neq(x, false) => x
|
||||
/// cmp_neq(x, true ) => not(x)
|
||||
fn runtimeBoolCmp(
|
||||
sema: *Sema,
|
||||
block: *Scope.Block,
|
||||
op: std.math.CompareOperator,
|
||||
lhs: Air.Inst.Ref,
|
||||
rhs: bool,
|
||||
runtime_src: LazySrcLoc,
|
||||
) CompileError!Air.Inst.Ref {
|
||||
if ((op == .neq) == rhs) {
|
||||
try sema.requireRuntimeBlock(block, runtime_src);
|
||||
return block.addTyOp(.not, Type.initTag(.bool), lhs);
|
||||
} else {
|
||||
return lhs;
|
||||
}
|
||||
}
|
||||
|
||||
fn zirSizeOf(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
|
||||
const inst_data = sema.code.instructions.items(.data)[inst].un_node;
|
||||
const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user