From 6e78c007dff96de98c44c52da890cdae3d6e1389 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Jul 2021 17:40:30 -0700 Subject: [PATCH] Sema: improved AIR when one operand of bool cmp is known When doing `x == true` or `x == false` it is now lowered as either a no-op or a not, respectively, rather than a cmp instruction. This commit also extracts a zirCmpEq function out from zirCmp, reducing the amount of branching (on is_equality_cmp) in both functions. --- src/Sema.zig | 226 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 147 insertions(+), 79 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 97a6f8323e..4fa59c4744 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -193,12 +193,12 @@ pub fn analyzeBody( .call_compile_time => try sema.zirCall(block, inst, .compile_time, false), .call_nosuspend => try sema.zirCall(block, inst, .no_async, false), .call_async => try sema.zirCall(block, inst, .async_kw, false), - .cmp_eq => try sema.zirCmp(block, inst, .eq), - .cmp_gt => try sema.zirCmp(block, inst, .gt), - .cmp_gte => try sema.zirCmp(block, inst, .gte), .cmp_lt => try sema.zirCmp(block, inst, .lt), .cmp_lte => try sema.zirCmp(block, inst, .lte), - .cmp_neq => try sema.zirCmp(block, inst, .neq), + .cmp_eq => try sema.zirCmpEq(block, inst, .eq, .cmp_eq), + .cmp_gte => try sema.zirCmp(block, inst, .gte), + .cmp_gt => try sema.zirCmp(block, inst, .gt), + .cmp_neq => try sema.zirCmpEq(block, inst, .neq, .cmp_neq), .coerce_result_ptr => try sema.zirCoerceResultPtr(block, inst), .decl_ref => try sema.zirDeclRef(block, inst), .decl_val => try sema.zirDeclVal(block, inst), @@ -5040,6 +5040,97 @@ fn zirAsm( return asm_air; } +/// Only called for equality operators. See also `zirCmp`. +fn zirCmpEq( + sema: *Sema, + block: *Scope.Block, + inst: Zir.Inst.Index, + op: std.math.CompareOperator, + air_tag: Air.Inst.Tag, +) CompileError!Air.Inst.Ref { + const tracy = trace(@src()); + defer tracy.end(); + + const mod = sema.mod; + const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data; + const src: LazySrcLoc = inst_data.src(); + const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node }; + const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node }; + const lhs = sema.resolveInst(extra.lhs); + const rhs = sema.resolveInst(extra.rhs); + + const lhs_ty = sema.typeOf(lhs); + const rhs_ty = sema.typeOf(rhs); + const lhs_ty_tag = lhs_ty.zigTypeTag(); + const rhs_ty_tag = rhs_ty.zigTypeTag(); + if (lhs_ty_tag == .Null and rhs_ty_tag == .Null) { + // null == null, null != null + if (op == .eq) { + return Air.Inst.Ref.bool_true; + } else { + return Air.Inst.Ref.bool_false; + } + } + if (((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or + rhs_ty_tag == .Null and lhs_ty_tag == .Optional)) + { + // comparing null with optionals + const opt_operand = if (lhs_ty_tag == .Optional) lhs else rhs; + return sema.analyzeIsNull(block, src, opt_operand, op == .neq); + } + if (((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr()))) { + return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{}); + } + if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) { + const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty; + return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type}); + } + if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or + (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union))) + { + return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{}); + } + if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) { + const runtime_src: LazySrcLoc = src: { + if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| { + if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| { + if (lval.isUndef() or rval.isUndef()) { + return sema.addConstUndef(Type.initTag(.bool)); + } + // TODO optimisation opportunity: evaluate if mem.eql is faster with the names, + // or calling to Module.getErrorValue to get the values and then compare them is + // faster. + const lhs_name = lval.castTag(.@"error").?.data.name; + const rhs_name = rval.castTag(.@"error").?.data.name; + if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) { + return Air.Inst.Ref.bool_true; + } else { + return Air.Inst.Ref.bool_false; + } + } else { + break :src rhs_src; + } + } else { + break :src lhs_src; + } + }; + try sema.requireRuntimeBlock(block, runtime_src); + return block.addBinOp(air_tag, lhs, rhs); + } + if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) { + const lhs_as_type = try sema.analyzeAsType(block, lhs_src, lhs); + const rhs_as_type = try sema.analyzeAsType(block, rhs_src, rhs); + if (lhs_as_type.eql(rhs_as_type) == (op == .eq)) { + return Air.Inst.Ref.bool_true; + } else { + return Air.Inst.Ref.bool_false; + } + } + return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true); +} + +/// Only called for non-equality operators. See also `zirCmpEq`. fn zirCmp( sema: *Sema, block: *Scope.Block, @@ -5049,8 +5140,6 @@ fn zirCmp( const tracy = trace(@src()); defer tracy.end(); - const mod = sema.mod; - const inst_data = sema.code.instructions.items(.data)[inst].pl_node; const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data; const src: LazySrcLoc = inst_data.src(); @@ -5058,87 +5147,34 @@ fn zirCmp( const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node }; const lhs = sema.resolveInst(extra.lhs); const rhs = sema.resolveInst(extra.rhs); + return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, false); +} - const is_equality_cmp = switch (op) { - .eq, .neq => true, - else => false, - }; +fn analyzeCmp( + sema: *Sema, + block: *Scope.Block, + src: LazySrcLoc, + lhs: Air.Inst.Ref, + rhs: Air.Inst.Ref, + op: std.math.CompareOperator, + lhs_src: LazySrcLoc, + rhs_src: LazySrcLoc, + is_equality_cmp: bool, +) CompileError!Air.Inst.Ref { const lhs_ty = sema.typeOf(lhs); const rhs_ty = sema.typeOf(rhs); - const lhs_ty_tag = lhs_ty.zigTypeTag(); - const rhs_ty_tag = rhs_ty.zigTypeTag(); - if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) { - // null == null, null != null - if (op == .eq) { - return Air.Inst.Ref.bool_true; - } else { - return Air.Inst.Ref.bool_false; - } - } else if (is_equality_cmp and - ((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or - rhs_ty_tag == .Null and lhs_ty_tag == .Optional)) - { - // comparing null with optionals - const opt_operand = if (lhs_ty_tag == .Optional) lhs else rhs; - return sema.analyzeIsNull(block, src, opt_operand, op == .neq); - } else if (is_equality_cmp and - ((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr()))) - { - return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{}); - } else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) { - const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty; - return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type}); - } else if (is_equality_cmp and - ((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or - (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union))) - { - return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{}); - } else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) { - if (!is_equality_cmp) { - return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)}); - } - if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| { - if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| { - if (lval.isUndef() or rval.isUndef()) { - return sema.addConstUndef(Type.initTag(.bool)); - } - // TODO optimisation opportunity: evaluate if mem.eql is faster with the names, - // or calling to Module.getErrorValue to get the values and then compare them is - // faster. - const lhs_name = lval.castTag(.@"error").?.data.name; - const rhs_name = rval.castTag(.@"error").?.data.name; - if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) { - return Air.Inst.Ref.bool_true; - } else { - return Air.Inst.Ref.bool_false; - } - } - } - try sema.requireRuntimeBlock(block, src); - const tag: Air.Inst.Tag = if (op == .eq) .cmp_eq else .cmp_neq; - return block.addBinOp(tag, lhs, rhs); - } else if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) { + if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) { // This operation allows any combination of integer and float types, regardless of the // signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for // numeric types. return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src); - } else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) { - if (!is_equality_cmp) { - return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)}); - } - const lhs_as_type = try sema.analyzeAsType(block, lhs_src, lhs); - const rhs_as_type = try sema.analyzeAsType(block, rhs_src, rhs); - if (lhs_as_type.eql(rhs_as_type) == (op == .eq)) { - return Air.Inst.Ref.bool_true; - } else { - return Air.Inst.Ref.bool_false; - } } - const instructions = &[_]Air.Inst.Ref{ lhs, rhs }; const resolved_type = try sema.resolvePeerTypes(block, src, instructions); if (!resolved_type.isSelfComparable(is_equality_cmp)) { - return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type}); + return sema.mod.fail(&block.base, src, "{s} operator not allowed for type '{}'", .{ + @tagName(op), resolved_type, + }); } const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src); @@ -5146,19 +5182,31 @@ fn zirCmp( const runtime_src: LazySrcLoc = src: { if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| { + if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type); if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| { - if (lhs_val.isUndef() or rhs_val.isUndef()) { - return sema.addConstUndef(resolved_type); - } + if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type); + if (lhs_val.compare(op, rhs_val, resolved_type)) { return Air.Inst.Ref.bool_true; } else { return Air.Inst.Ref.bool_false; } } else { + if (resolved_type.zigTypeTag() == .Bool) { + // We can lower bool eq/neq more efficiently. + return sema.runtimeBoolCmp(block, op, casted_rhs, lhs_val.toBool(), rhs_src); + } break :src rhs_src; } } else { + // For bools, we still check the other operand, because we can lower + // bool eq/neq more efficiently. + if (resolved_type.zigTypeTag() == .Bool) { + if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| { + if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type); + return sema.runtimeBoolCmp(block, op, casted_lhs, rhs_val.toBool(), lhs_src); + } + } break :src lhs_src; } }; @@ -5176,6 +5224,26 @@ fn zirCmp( return block.addBinOp(tag, casted_lhs, casted_rhs); } +/// cmp_eq (x, false) => not(x) +/// cmp_eq (x, true ) => x +/// cmp_neq(x, false) => x +/// cmp_neq(x, true ) => not(x) +fn runtimeBoolCmp( + sema: *Sema, + block: *Scope.Block, + op: std.math.CompareOperator, + lhs: Air.Inst.Ref, + rhs: bool, + runtime_src: LazySrcLoc, +) CompileError!Air.Inst.Ref { + if ((op == .neq) == rhs) { + try sema.requireRuntimeBlock(block, runtime_src); + return block.addTyOp(.not, Type.initTag(.bool), lhs); + } else { + return lhs; + } +} + fn zirSizeOf(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[inst].un_node; const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };