From 1b000b90c9a7abde3aeacf29cef73a877da237e1 Mon Sep 17 00:00:00 2001 From: mlugg Date: Fri, 30 Aug 2024 20:29:27 +0100 Subject: [PATCH] Air: direct representation of ranges in switch cases This commit modifies the representation of the AIR `switch_br` instruction to represent ranges in cases. Previously, Sema emitted different AIR in the case of a range, where the `else` branch of the `switch_br` contained a simple `cond_br` for each such case which did a simple range check (`x > a and x < b`). Not only does this add complexity to Sema, which we would like to minimize, but it also gets in the way of the implementation of #8220. That proposal turns certain `switch` statements into a looping construct, and for optimization purposes, we want to lower this to AIR fairly directly (i.e. without involving a `loop` instruction). That means we would ideally like a single instruction to represent the entire `switch` statement, so that we can dispatch back to it with a different operand as in #8220. This is not really possible to do correctly under the status quo system. This commit implements lowering of this new `switch_br` usage in the LLVM and C backends. The C backend just turns any case containing ranges entirely into conditionals, as before. The LLVM backend is a little smarter, and puts scalar items into the `switch` instruction, only using conditionals for the range cases (which direct to the same bb). All remaining self-hosted backends are temporarily regressed in the presence of switch range cases. This functionality will be restored for at least the x86_64 backend before merge. --- src/Air.zig | 12 +- src/Air/types_resolved.zig | 4 + src/Sema.zig | 370 ++++++++++++++--------------------- src/arch/aarch64/CodeGen.zig | 2 + src/arch/arm/CodeGen.zig | 1 + src/arch/riscv64/CodeGen.zig | 2 + src/arch/wasm/CodeGen.zig | 2 + src/arch/x86_64/CodeGen.zig | 2 + src/codegen/c.zig | 58 ++++-- src/codegen/llvm.zig | 56 +++++- src/codegen/spirv.zig | 1 + src/print_air.zig | 6 + 12 files changed, 268 insertions(+), 248 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index d64fe2983d..9db68da89f 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -1143,10 +1143,12 @@ pub const SwitchBr = struct { else_body_len: u32, /// Trailing: - /// * item: Inst.Ref // for each `items_len`. - /// * instruction index for each `body_len`. + /// * item: Inst.Ref // for each `items_len` + /// * { range_start: Inst.Ref, range_end: Inst.Ref } // for each `ranges_len` + /// * body_inst: Inst.Index // for each `body_len` pub const Case = struct { items_len: u32, + ranges_len: u32, body_len: u32, }; }; @@ -1862,6 +1864,10 @@ pub const UnwrappedSwitch = struct { var extra_index = extra.end; const items: []const Inst.Ref = @ptrCast(it.air.extra[extra_index..][0..extra.data.items_len]); extra_index += items.len; + // TODO: ptrcast from []const Inst.Ref to []const [2]Inst.Ref when supported + const ranges_ptr: [*]const [2]Inst.Ref = @ptrCast(it.air.extra[extra_index..]); + const ranges: []const [2]Inst.Ref = ranges_ptr[0..extra.data.ranges_len]; + extra_index += ranges.len * 2; const body: []const Inst.Index = @ptrCast(it.air.extra[extra_index..][0..extra.data.body_len]); extra_index += body.len; it.extra_index = @intCast(extra_index); @@ -1869,6 +1875,7 @@ pub const UnwrappedSwitch = struct { return .{ .idx = idx, .items = items, + .ranges = ranges, .body = body, }; } @@ -1881,6 +1888,7 @@ pub const UnwrappedSwitch = struct { pub const Case = struct { idx: u32, items: []const Inst.Ref, + ranges: []const [2]Inst.Ref, body: []const Inst.Index, }; }; diff --git a/src/Air/types_resolved.zig b/src/Air/types_resolved.zig index e60f5ef311..3de5aeb3e3 100644 --- a/src/Air/types_resolved.zig +++ b/src/Air/types_resolved.zig @@ -386,6 +386,10 @@ fn checkBody(air: Air, body: []const Air.Inst.Index, zcu: *Zcu) bool { var it = switch_br.iterateCases(); while (it.next()) |case| { for (case.items) |item| if (!checkRef(item, zcu)) return false; + for (case.ranges) |range| { + if (!checkRef(range[0], zcu)) return false; + if (!checkRef(range[1], zcu)) return false; + } if (!checkBody(air, case.body, zcu)) return false; } if (!checkBody(air, it.elseBody(), zcu)) return false; diff --git a/src/Sema.zig b/src/Sema.zig index a72c749f2e..aafd605bc8 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -11353,9 +11353,14 @@ const SwitchProngAnalysis = struct { const coerced = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src); _ = try coerce_block.addBr(capture_block_inst, coerced); - try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(coerce_block.instructions.items.len)); // body_len + try cases_extra.ensureUnusedCapacity(@typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + coerce_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(coerce_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(case_vals[idx])); // item cases_extra.appendSliceAssumeCapacity(@ptrCast(coerce_block.instructions.items)); // body } @@ -12578,21 +12583,18 @@ fn analyzeSwitchRuntimeBlock( }; try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } - var is_first = true; - var prev_cond_br: Air.Inst.Index = undefined; - var prev_hint: std.builtin.BranchHint = undefined; - var first_else_body: []const Air.Inst.Index = &.{}; - defer gpa.free(first_else_body); - var prev_then_body: []const Air.Inst.Index = &.{}; - defer gpa.free(prev_then_body); - var cases_len = scalar_cases_len; var case_val_idx: usize = scalar_cases_len; var multi_i: u32 = 0; @@ -12602,31 +12604,27 @@ fn analyzeSwitchRuntimeBlock( const ranges_len = sema.code.extra[extra_index]; extra_index += 1; const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]); - extra_index += 1 + items_len; + extra_index += 1 + items_len + 2 * ranges_len; const items = case_vals.items[case_val_idx..][0..items_len]; case_val_idx += items_len; + // TODO: @ptrCast slice once Sema supports it + const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(case_vals.items[case_val_idx..]))[0..ranges_len]; + case_val_idx += ranges_len * 2; + + const body = sema.code.bodySlice(extra_index, info.body_len); + extra_index += info.body_len; case_block.instructions.shrinkRetainingCapacity(0); case_block.error_return_trace_index = child_block.error_return_trace_index; // Generate all possible cases as scalar prongs. if (info.is_inline) { - const body_start = extra_index + 2 * ranges_len; - const body = sema.code.bodySlice(body_start, info.body_len); var emit_bb = false; - var range_i: u32 = 0; - while (range_i < ranges_len) : (range_i += 1) { - const range_items = case_vals.items[case_val_idx..][0..2]; - extra_index += 2; - case_val_idx += 2; - - const item_first_ref = range_items[0]; - const item_last_ref = range_items[1]; - - var item = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, item_first_ref, undefined) catch unreachable; - const item_last = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, item_last_ref, undefined) catch unreachable; + for (ranges, 0..) |range_items, range_i| { + var item = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, range_items[0], undefined) catch unreachable; + const item_last = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, range_items[1], undefined) catch unreachable; while (item.compareScalar(.lte, item_last, operand_ty, zcu)) : ({ // Previous validation has resolved any possible lazy values. @@ -12664,9 +12662,14 @@ fn analyzeSwitchRuntimeBlock( ); try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); @@ -12713,134 +12716,39 @@ fn analyzeSwitchRuntimeBlock( }; try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } - extra_index += info.body_len; continue; } - var any_ok: Air.Inst.Ref = .none; - - // If there are any ranges, we have to put all the items into the - // else prong. Otherwise, we can take advantage of multiple items - // mapping to the same body. - if (ranges_len == 0) { - cases_len += 1; - - const analyze_body = if (union_originally) - for (items) |item| { - const item_val = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, item, undefined) catch unreachable; - const field_ty = maybe_union_ty.unionFieldType(item_val, zcu).?; - if (field_ty.zigTypeTag(zcu) != .noreturn) break true; - } else false - else - true; - - const body = sema.code.bodySlice(extra_index, info.body_len); - extra_index += info.body_len; - const prong_hint: std.builtin.BranchHint = if (err_set and - try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) - h: { - // nothing to do here. weight against error branch - break :h .unlikely; - } else if (analyze_body) h: { - break :h try spa.analyzeProngRuntime( - &case_block, - .normal, - body, - info.capture, - child_block.src(.{ .switch_capture = .{ - .switch_node_offset = switch_node_offset, - .case_idx = .{ .kind = .multi, .index = @intCast(multi_i) }, - } }), - items, - .none, - false, - ); - } else h: { - _ = try case_block.addNoOp(.unreach); - break :h .none; - }; - - try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 2 + items.len + - case_block.instructions.items.len); - - cases_extra.appendAssumeCapacity(@intCast(items.len)); - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + cases_len += 1; + const analyze_body = if (union_originally) for (items) |item| { - cases_extra.appendAssumeCapacity(@intFromEnum(item)); - } + const item_val = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, item, undefined) catch unreachable; + const field_ty = maybe_union_ty.unionFieldType(item_val, zcu).?; + if (field_ty.zigTypeTag(zcu) != .noreturn) break true; + } else false + else + true; - cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); - } else { - for (items) |item| { - const cmp_ok = try case_block.addBinOp(if (case_block.float_mode == .optimized) .cmp_eq_optimized else .cmp_eq, operand, item); - if (any_ok != .none) { - any_ok = try case_block.addBinOp(.bool_or, any_ok, cmp_ok); - } else { - any_ok = cmp_ok; - } - } - - var range_i: usize = 0; - while (range_i < ranges_len) : (range_i += 1) { - const range_items = case_vals.items[case_val_idx..][0..2]; - extra_index += 2; - case_val_idx += 2; - - const item_first = range_items[0]; - const item_last = range_items[1]; - - // operand >= first and operand <= last - const range_first_ok = try case_block.addBinOp( - if (case_block.float_mode == .optimized) .cmp_gte_optimized else .cmp_gte, - operand, - item_first, - ); - const range_last_ok = try case_block.addBinOp( - if (case_block.float_mode == .optimized) .cmp_lte_optimized else .cmp_lte, - operand, - item_last, - ); - const range_ok = try case_block.addBinOp( - .bool_and, - range_first_ok, - range_last_ok, - ); - if (any_ok != .none) { - any_ok = try case_block.addBinOp(.bool_or, any_ok, range_ok); - } else { - any_ok = range_ok; - } - } - - const new_cond_br = try case_block.addInstAsIndex(.{ .tag = .cond_br, .data = .{ - .pl_op = .{ - .operand = any_ok, - .payload = undefined, - }, - } }); - var cond_body = try case_block.instructions.toOwnedSlice(gpa); - defer gpa.free(cond_body); - - case_block.instructions.shrinkRetainingCapacity(0); - case_block.error_return_trace_index = child_block.error_return_trace_index; - - const body = sema.code.bodySlice(extra_index, info.body_len); - extra_index += info.body_len; - const prong_hint: std.builtin.BranchHint = if (err_set and - try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) - h: { - // nothing to do here. weight against error branch - break :h .unlikely; - } else try spa.analyzeProngRuntime( + const prong_hint: std.builtin.BranchHint = if (err_set and + try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) + h: { + // nothing to do here. weight against error branch + break :h .unlikely; + } else if (analyze_body) h: { + break :h try spa.analyzeProngRuntime( &case_block, .normal, body, @@ -12853,40 +12761,36 @@ fn analyzeSwitchRuntimeBlock( .none, false, ); + } else h: { + _ = try case_block.addNoOp(.unreach); + break :h .none; + }; - if (is_first) { - is_first = false; - first_else_body = cond_body; - cond_body = &.{}; - } else { - try sema.air_extra.ensureUnusedCapacity( - gpa, - @typeInfo(Air.CondBr).@"struct".fields.len + prev_then_body.len + cond_body.len, - ); + try branch_hints.append(gpa, prong_hint); - sema.air_instructions.items(.data)[@intFromEnum(prev_cond_br)].pl_op.payload = sema.addExtraAssumeCapacity(Air.CondBr{ - .then_body_len = @intCast(prev_then_body.len), - .else_body_len = @intCast(cond_body.len), - .branch_hints = .{ - .true = prev_hint, - .false = .none, - // Code coverage is desired for error handling. - .then_cov = .poi, - .else_cov = .poi, - }, - }); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(prev_then_body)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cond_body)); - } - gpa.free(prev_then_body); - prev_then_body = try case_block.instructions.toOwnedSlice(gpa); - prev_cond_br = new_cond_br; - prev_hint = prong_hint; + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + items.len + 2 * ranges_len + + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = @intCast(items.len), + .ranges_len = @intCast(ranges_len), + .body_len = @intCast(case_block.instructions.items.len), + })); + + for (items) |item| { + cases_extra.appendAssumeCapacity(@intFromEnum(item)); } + for (ranges) |range| { + cases_extra.appendSliceAssumeCapacity(&.{ + @intFromEnum(range[0]), + @intFromEnum(range[1]), + }); + } + + cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } - var final_else_body: []const Air.Inst.Index = &.{}; - if (special.body.len != 0 or !is_first or case_block.wantSafety()) { + const else_body: []const Air.Inst.Index = if (special.body.len != 0 or case_block.wantSafety()) else_body: { var emit_bb = false; if (special.is_inline) switch (operand_ty.zigTypeTag(zcu)) { .@"enum" => { @@ -12933,9 +12837,14 @@ fn analyzeSwitchRuntimeBlock( }; try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -12979,9 +12888,14 @@ fn analyzeSwitchRuntimeBlock( ); try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13014,9 +12928,14 @@ fn analyzeSwitchRuntimeBlock( ); try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13046,9 +12965,14 @@ fn analyzeSwitchRuntimeBlock( ); try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(Air.Inst.Ref.bool_true)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13076,9 +13000,14 @@ fn analyzeSwitchRuntimeBlock( ); try branch_hints.append(gpa, prong_hint); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len + + 1 + // `item`, no ranges + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(Air.Inst.Ref.bool_false)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13142,41 +13071,22 @@ fn analyzeSwitchRuntimeBlock( break :h .cold; }; - if (is_first) { - try branch_hints.append(gpa, else_hint); - final_else_body = case_block.instructions.items; - } else { - try branch_hints.append(gpa, .none); // we have the range conditionals first - try sema.air_extra.ensureUnusedCapacity(gpa, prev_then_body.len + - @typeInfo(Air.CondBr).@"struct".fields.len + case_block.instructions.items.len); - - sema.air_instructions.items(.data)[@intFromEnum(prev_cond_br)].pl_op.payload = sema.addExtraAssumeCapacity(Air.CondBr{ - .then_body_len = @intCast(prev_then_body.len), - .else_body_len = @intCast(case_block.instructions.items.len), - .branch_hints = .{ - .true = prev_hint, - .false = else_hint, - .then_cov = .poi, - .else_cov = .poi, - }, - }); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(prev_then_body)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); - final_else_body = first_else_body; - } - } else { + try branch_hints.append(gpa, else_hint); + break :else_body case_block.instructions.items; + } else else_body: { try branch_hints.append(gpa, .none); - } + break :else_body &.{}; + }; assert(branch_hints.items.len == cases_len + 1); try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr).@"struct".fields.len + - cases_extra.items.len + final_else_body.len + + cases_extra.items.len + else_body.len + (std.math.divCeil(usize, branch_hints.items.len, 10) catch unreachable)); // branch hints const payload_index = sema.addExtraAssumeCapacity(Air.SwitchBr{ .cases_len = @intCast(cases_len), - .else_body_len = @intCast(final_else_body.len), + .else_body_len = @intCast(else_body.len), }); { @@ -13195,7 +13105,7 @@ fn analyzeSwitchRuntimeBlock( } } sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cases_extra.items)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(final_else_body)); + sema.air_extra.appendSliceAssumeCapacity(@ptrCast(else_body)); return try child_block.addInst(.{ .tag = .switch_br, @@ -37386,15 +37296,21 @@ pub fn addExtra(sema: *Sema, extra: anytype) Allocator.Error!u32 { } pub fn addExtraAssumeCapacity(sema: *Sema, extra: anytype) u32 { - const fields = std.meta.fields(@TypeOf(extra)); const result: u32 = @intCast(sema.air_extra.items.len); - inline for (fields) |field| { - sema.air_extra.appendAssumeCapacity(switch (field.type) { - u32 => @field(extra, field.name), - i32, Air.CondBr.BranchHints => @bitCast(@field(extra, field.name)), - Air.Inst.Ref, InternPool.Index => @intFromEnum(@field(extra, field.name)), + sema.air_extra.appendSliceAssumeCapacity(&payloadToExtraItems(extra)); + return result; +} + +fn payloadToExtraItems(data: anytype) [@typeInfo(@TypeOf(data)).@"struct".fields.len]u32 { + const fields = @typeInfo(@TypeOf(data)).@"struct".fields; + var result: [fields.len]u32 = undefined; + inline for (&result, fields) |*val, field| { + val.* = switch (field.type) { + u32 => @field(data, field.name), + i32, Air.CondBr.BranchHints => @bitCast(@field(data, field.name)), + Air.Inst.Ref, InternPool.Index => @intFromEnum(@field(data, field.name)), else => @compileError("bad field type: " ++ @typeName(field.type)), - }); + }; } return result; } diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index c61e540c4a..fd5a1a5367 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -5105,6 +5105,8 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void { var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return self.fail("TODO: switch with ranges", .{}); + // For every item, we compare it to condition and branch into // the prong if they are equal. After we compared to all // items, we branch into the next prong (or if no other prongs diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index d693c06ec9..d88029bd6c 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -5053,6 +5053,7 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void { var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return self.fail("TODO: switch with ranges", .{}); // For every item, we compare it to condition and branch into // the prong if they are equal. After we compared to all // items, we branch into the next prong (or if no other prongs diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index a70618a394..d48a99c8d7 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -5681,6 +5681,8 @@ fn airSwitchBr(func: *Func, inst: Air.Inst.Index) !void { var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return func.fail("TODO: switch with ranges", .{}); + var relocs = try func.gpa.alloc(Mir.Inst.Index, case.items.len); defer func.gpa.free(relocs); diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index a3a51c7233..7fe8676234 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -4064,6 +4064,8 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return func.fail("TODO: switch with ranges", .{}); + const values = try func.gpa.alloc(CaseValue, case.items.len); errdefer func.gpa.free(values); diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 1f13c8ae7d..136ad3ca4c 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -13695,6 +13695,8 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return self.fail("TODO: switch with ranges", .{}); + var relocs = try self.gpa.alloc(Mir.Inst.Index, case.items.len); defer self.gpa.free(relocs); diff --git a/src/codegen/c.zig b/src/codegen/c.zig index f3b8c7e72a..3c97d74cd4 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -5017,12 +5017,13 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.cases_len + 1); defer gpa.free(liveness.deaths); - // On the final iteration we do not need to fix any state. This is because, like in the `else` - // branch of a `cond_br`, our parent has to do it for this entire body anyway. - const last_case_i = switch_br.cases_len - @intFromBool(switch_br.else_body_len == 0); - + var any_range_cases = false; var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) { + any_range_cases = true; + continue; + } for (case.items) |item| { try f.object.indent_writer.insertNewline(); try writer.writeAll("case "); @@ -5041,29 +5042,56 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { } try writer.writeByte(' '); - if (case.idx != last_case_i) { - try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false); - } else { - for (liveness.deaths[case.idx]) |death| { - try die(f, inst, death.toRef()); - } - try genBody(f, case.body); - } + try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false); // The case body must be noreturn so we don't need to insert a break. } const else_body = it.elseBody(); try f.object.indent_writer.insertNewline(); + + try writer.writeAll("default: "); + if (any_range_cases) { + // We will iterate the cases again to handle those with ranges, and generate + // code using conditions rather than switch cases for such cases. + it = switch_br.iterateCases(); + while (it.next()) |case| { + if (case.ranges.len == 0) continue; // handled above + + try writer.writeAll("if ("); + for (case.items, 0..) |item, item_i| { + if (item_i != 0) try writer.writeAll(" || "); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" == "); + try f.object.dg.renderValue(writer, (try f.air.value(item, pt)).?, .Other); + } + for (case.ranges, 0..) |range, range_i| { + if (case.items.len != 0 or range_i != 0) try writer.writeAll(" || "); + // "(x >= lower && x <= upper)" + try writer.writeByte('('); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" >= "); + try f.object.dg.renderValue(writer, (try f.air.value(range[0], pt)).?, .Other); + try writer.writeAll(" && "); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" <= "); + try f.object.dg.renderValue(writer, (try f.air.value(range[1], pt)).?, .Other); + try writer.writeByte(')'); + } + try writer.writeAll(") "); + try genBodyResolveState(f, inst, liveness.deaths[case.idx], case.body, false); + } + } + if (else_body.len > 0) { - // Note that this must be the last case (i.e. the `last_case_i` case was not hit above) + // Note that this must be the last case, so we do not need to use `genBodyResolveState` since + // the parent block will do it (because the case body is noreturn). for (liveness.deaths[liveness.deaths.len - 1]) |death| { try die(f, inst, death.toRef()); } - try writer.writeAll("default: "); try genBody(f, else_body); } else { - try writer.writeAll("default: zig_unreachable();"); + try writer.writeAll("zig_unreachable();"); } try f.object.indent_writer.insertNewline(); diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index dc8996afda..25cf2e8dcc 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -6230,7 +6230,15 @@ pub const FuncGen = struct { const cond = try self.resolveInst(switch_br.operand); - const else_block = try self.wip.block(1, "Default"); + // This is not necessarily the actual `else` prong; it first contains conditionals + // for any range cases. It's just the `else` of the LLVM switch. + const llvm_else_block = try self.wip.block(1, "Default"); + + const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.cases_len); + defer self.gpa.free(case_blocks); + // We set incoming as 0 for now, and increment it as we construct the switch. + for (case_blocks) |*b| b.* = try self.wip.block(0, "Case"); + const llvm_usize = try o.lowerType(Type.usize); const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder)) try self.wip.cast(.ptrtoint, cond, llvm_usize, "") @@ -6294,12 +6302,17 @@ pub const FuncGen = struct { break :weights @enumFromInt(@intFromEnum(tuple)); }; - var wip_switch = try self.wip.@"switch"(cond_int, else_block, llvm_cases_len, weights); + var wip_switch = try self.wip.@"switch"(cond_int, llvm_else_block, llvm_cases_len, weights); defer wip_switch.finish(&self.wip); var it = switch_br.iterateCases(); + var any_ranges = false; while (it.next()) |case| { - const case_block = try self.wip.block(@intCast(case.items.len), "Case"); + if (case.ranges.len > 0) any_ranges = true; + const case_block = case_blocks[case.idx]; + case_block.ptr(&self.wip).incoming += @intCast(case.items.len); + // Handle scalar items, and generate the block. + // We'll generate conditionals for the ranges later on. for (case.items) |item| { const llvm_item = (try self.resolveInst(item)).toConst().?; const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder)) @@ -6314,7 +6327,42 @@ pub const FuncGen = struct { } const else_body = it.elseBody(); - self.wip.cursor = .{ .block = else_block }; + self.wip.cursor = .{ .block = llvm_else_block }; + if (any_ranges) { + const cond_ty = self.typeOf(switch_br.operand); + // Add conditionals for the ranges, directing to the relevant bb. + // We don't need to consider `cold` branch hints since that information is stored + // in the target bb body, but we do care about likely/unlikely/unpredictable. + it = switch_br.iterateCases(); + while (it.next()) |case| { + if (case.ranges.len == 0) continue; + const case_block = case_blocks[case.idx]; + const hint = switch_br.getHint(case.idx); + case_block.ptr(&self.wip).incoming += 1; + const next_else_block = try self.wip.block(1, "Default"); + var range_cond: ?Builder.Value = null; + for (case.ranges) |range| { + const llvm_min = try self.resolveInst(range[0]); + const llvm_max = try self.resolveInst(range[1]); + const cond_part = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), + try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), + "", + ); + if (range_cond) |prev| { + range_cond = try self.wip.bin(.@"or", prev, cond_part, ""); + } else range_cond = cond_part; + } + _ = try self.wip.brCond(range_cond.?, case_block, next_else_block, switch (hint) { + .none, .cold => .none, + .unpredictable => .unpredictable, + .likely => .then_likely, + .unlikely => .else_likely, + }); + self.wip.cursor = .{ .block = next_else_block }; + } + } if (switch_br.getElseHint() == .cold) _ = try self.wip.callIntrinsicAssumeCold(); if (else_body.len != 0) { try self.genBodyDebugScope(null, else_body, .poi); diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 345e80a23c..b9aa777786 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -6211,6 +6211,7 @@ const NavGen = struct { var num_conditions: u32 = 0; var it = switch_br.iterateCases(); while (it.next()) |case| { + if (case.ranges.len > 0) return self.todo("switch with ranges", .{}); num_conditions += @intCast(case.items.len); } break :blk num_conditions; diff --git a/src/print_air.zig b/src/print_air.zig index 227f362c39..acb5eda07e 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -864,6 +864,12 @@ const Writer = struct { if (item_i != 0) try s.writeAll(", "); try w.writeInstRef(s, item, false); } + for (case.ranges, 0..) |range, range_i| { + if (range_i != 0 or case.items.len != 0) try s.writeAll(", "); + try w.writeInstRef(s, range[0], false); + try s.writeAll("..."); + try w.writeInstRef(s, range[1], false); + } try s.writeAll("] "); const hint = switch_br.getHint(case.idx); if (hint != .none) {