From a377bf87ce6f021f958087bcf080425845a7bae6 Mon Sep 17 00:00:00 2001 From: mlugg Date: Fri, 5 May 2023 21:40:04 +0100 Subject: [PATCH 1/8] Zir: remove unnecessary switch_capture_multi instructions By indexing from the very first switch case rather than into scalar and multi cases separately, the instructions for capturing in multi cases become unnecessary, freeing up 2 ZIR tags. --- src/AstGen.zig | 13 ++----- src/Sema.zig | 16 +++------ src/Zir.zig | 92 +++++++++++------------------------------------ src/print_zir.zig | 2 -- 4 files changed, 27 insertions(+), 96 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 17cf2aae64..44dae46190 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2614,8 +2614,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .switch_cond_ref, .switch_capture, .switch_capture_ref, - .switch_capture_multi, - .switch_capture_multi_ref, .switch_capture_tag, .struct_init_empty, .struct_init, @@ -6916,15 +6914,8 @@ fn switchExpr( }, }); } else { - const is_multi_case_bits: u2 = @boolToInt(is_multi_case); - const is_ptr_bits: u2 = @boolToInt(is_ptr); - const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) { - 0b00 => .switch_capture, - 0b01 => .switch_capture_ref, - 0b10 => .switch_capture_multi, - 0b11 => .switch_capture_multi_ref, - }; - const capture_index = if (is_multi_case) multi_case_index else scalar_case_index; + const capture_tag: Zir.Inst.Tag = if (is_ptr) .switch_capture_ref else .switch_capture; + const capture_index = if (is_multi_case) scalar_cases_len + multi_case_index else scalar_case_index; capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); try astgen.instructions.append(gpa, .{ .tag = capture_tag, diff --git a/src/Sema.zig b/src/Sema.zig index cb69fa92d8..fb656a19a3 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1017,10 +1017,8 @@ fn analyzeBodyInner( .switch_block => try sema.zirSwitchBlock(block, inst), .switch_cond => try sema.zirSwitchCond(block, inst, false), .switch_cond_ref => try sema.zirSwitchCond(block, inst, true), - .switch_capture => try sema.zirSwitchCapture(block, inst, false, false), - .switch_capture_ref => try sema.zirSwitchCapture(block, inst, false, true), - .switch_capture_multi => try sema.zirSwitchCapture(block, inst, true, false), - .switch_capture_multi_ref => try sema.zirSwitchCapture(block, inst, true, true), + .switch_capture => try sema.zirSwitchCapture(block, inst, false), + .switch_capture_ref => try sema.zirSwitchCapture(block, inst, true), .switch_capture_tag => try sema.zirSwitchCaptureTag(block, inst), .type_info => try sema.zirTypeInfo(block, inst), .size_of => try sema.zirSizeOf(block, inst), @@ -10089,7 +10087,6 @@ fn zirSwitchCapture( sema: *Sema, block: *Block, inst: Zir.Inst.Index, - is_multi: bool, is_ref: bool, ) CompileError!Air.Inst.Ref { const tracy = trace(@src()); @@ -10178,12 +10175,7 @@ fn zirSwitchCapture( } } - const items = if (is_multi) - switch_extra.data.getMultiProng(sema.code, switch_extra.end, capture_info.prong_index).items - else - &[_]Zir.Inst.Ref{ - switch_extra.data.getScalarProng(sema.code, switch_extra.end, capture_info.prong_index).item, - }; + const items = switch_extra.data.getProng(sema.code, switch_extra.end, capture_info.prong_index).items; switch (operand_ty.zigTypeTag(mod)) { .Union => { @@ -10252,7 +10244,7 @@ fn zirSwitchCapture( return block.addStructFieldVal(operand, first_field_index, first_field.ty); }, .ErrorSet => { - if (is_multi) { + if (items.len > 1) { var names: Module.Fn.InferredErrorSet.NameMap = .{}; try names.ensureUnusedCapacity(sema.arena, items.len); for (items) |item| { diff --git a/src/Zir.zig b/src/Zir.zig index c3a5f8e09b..aaae5f4dcf 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -687,15 +687,6 @@ pub const Inst = struct { /// If the `prong_index` field is max int, it means this is the capture /// for the else/`_` prong. switch_capture_ref, - /// Produces the capture value for a switch prong. - /// The prong is one of the multi cases. - /// Uses the `switch_capture` field. - switch_capture_multi, - /// Produces the capture value for a switch prong. - /// The prong is one of the multi cases. - /// Result is a pointer to the value. - /// Uses the `switch_capture` field. - switch_capture_multi_ref, /// Produces the capture value for an inline switch prong tag capture. /// Uses the `un_tok` field. switch_capture_tag, @@ -1146,8 +1137,6 @@ pub const Inst = struct { .set_eval_branch_quota, .switch_capture, .switch_capture_ref, - .switch_capture_multi, - .switch_capture_multi_ref, .switch_capture_tag, .switch_block, .switch_cond, @@ -1440,8 +1429,6 @@ pub const Inst = struct { .typeof_log2_int_type, .switch_capture, .switch_capture_ref, - .switch_capture_multi, - .switch_capture_multi_ref, .switch_capture_tag, .switch_block, .switch_cond, @@ -1700,8 +1687,6 @@ pub const Inst = struct { .switch_cond_ref = .un_node, .switch_capture = .switch_capture, .switch_capture_ref = .switch_capture, - .switch_capture_multi = .switch_capture, - .switch_capture_multi_ref = .switch_capture, .switch_capture_tag = .un_tok, .array_base_ptr = .un_node, .field_base_ptr = .un_node, @@ -2735,8 +2720,8 @@ pub const Inst = struct { } }; - pub const ScalarProng = struct { - item: Ref, + pub const MultiProng = struct { + items: []const Ref, body: []const Index, }; @@ -2744,56 +2729,13 @@ pub const Inst = struct { /// change the definition of switch_capture instruction to store extra_index /// instead of prong_index. This way, Sema won't be doing O(N^2) iterations /// over the switch prongs. - pub fn getScalarProng( - self: SwitchBlock, - zir: Zir, - extra_end: usize, - prong_index: usize, - ) ScalarProng { - var extra_index: usize = extra_end; - - if (self.bits.has_multi_cases) { - extra_index += 1; - } - - if (self.bits.specialProng() != .none) { - const body_len = @truncate(u31, zir.extra[extra_index]); - extra_index += 1; - const body = zir.extra[extra_index..][0..body_len]; - extra_index += body.len; - } - - var scalar_i: usize = 0; - while (true) : (scalar_i += 1) { - const item = @intToEnum(Ref, zir.extra[extra_index]); - extra_index += 1; - const body_len = @truncate(u31, zir.extra[extra_index]); - extra_index += 1; - const body = zir.extra[extra_index..][0..body_len]; - extra_index += body.len; - - if (scalar_i < prong_index) continue; - - return .{ - .item = item, - .body = body, - }; - } - } - - pub const MultiProng = struct { - items: []const Ref, - body: []const Index, - }; - - pub fn getMultiProng( + pub fn getProng( self: SwitchBlock, zir: Zir, extra_end: usize, prong_index: usize, ) MultiProng { - // +1 for self.bits.has_multi_cases == true - var extra_index: usize = extra_end + 1; + var extra_index: usize = extra_end + @boolToInt(self.bits.has_multi_cases); if (self.bits.specialProng() != .none) { const body_len = @truncate(u31, zir.extra[extra_index]); @@ -2802,15 +2744,22 @@ pub const Inst = struct { extra_index += body.len; } - var scalar_i: usize = 0; - while (scalar_i < self.bits.scalar_cases_len) : (scalar_i += 1) { + var cur_idx: usize = 0; + while (cur_idx < self.bits.scalar_cases_len) : (cur_idx += 1) { + const items = zir.refSlice(extra_index, 1); extra_index += 1; const body_len = @truncate(u31, zir.extra[extra_index]); extra_index += 1; + const body = zir.extra[extra_index..][0..body_len]; extra_index += body_len; + if (cur_idx == prong_index) { + return .{ + .items = items, + .body = body, + }; + } } - var multi_i: u32 = 0; - while (true) : (multi_i += 1) { + while (true) : (cur_idx += 1) { const items_len = zir.extra[extra_index]; extra_index += 1; const ranges_len = zir.extra[extra_index]; @@ -2825,11 +2774,12 @@ pub const Inst = struct { const body = zir.extra[extra_index..][0..body_len]; extra_index += body_len; - if (multi_i < prong_index) continue; - return .{ - .items = items, - .body = body, - }; + if (cur_idx == prong_index) { + return .{ + .items = items, + .body = body, + }; + } } } }; diff --git a/src/print_zir.zig b/src/print_zir.zig index 6c371b8b8d..eba7f98a26 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -438,8 +438,6 @@ const Writer = struct { .switch_capture, .switch_capture_ref, - .switch_capture_multi, - .switch_capture_multi_ref, => try self.writeSwitchCapture(stream, inst), .dbg_stmt => try self.writeDbgStmt(stream, inst), From cebd80032a2dcc9f516f8183d1bfade5d1f12e45 Mon Sep 17 00:00:00 2001 From: mlugg Date: Fri, 5 May 2023 23:04:34 +0100 Subject: [PATCH 2/8] Move switch case value coercion from AstGen to Sema This is in preparation for #2473. Also fixes a bug where switching on bools allows invalid case combinations. --- src/AstGen.zig | 4 +- src/Sema.zig | 250 +++++++++++++++++++++++++++++-------------------- 2 files changed, 150 insertions(+), 104 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 44dae46190..0003457c01 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -6837,9 +6837,7 @@ fn switchExpr( const cond = try parent_gz.addUnNode(cond_tag, raw_operand, operand_node); // Sema expects a dbg_stmt immediately after switch_cond(_ref) try emitDbgStmt(parent_gz, operand_lc); - // We need the type of the operand to use as the result location for all the prong items. - const cond_ty_inst = try parent_gz.addUnNode(.typeof, cond, operand_node); - const item_ri: ResultInfo = .{ .rl = .{ .ty = cond_ty_inst } }; + const item_ri: ResultInfo = .{ .rl = .none }; // This contains the data that goes into the `extra` array for the SwitchBlock/SwitchBlockMulti, // except the first cases_nodes.len slots are a table that indexes payloads later in the array, with diff --git a/src/Sema.zig b/src/Sema.zig index fb656a19a3..1bc34beb83 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -10099,6 +10099,8 @@ fn zirSwitchCapture( const switch_info = zir_datas[capture_info.switch_inst].pl_node; const switch_extra = sema.code.extraData(Zir.Inst.SwitchBlock, switch_info.payload_index); const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_info.src_node }; + const cond = try sema.resolveInst(switch_extra.data.operand); + const cond_ty = sema.typeOf(cond); const cond_inst = Zir.refToIndex(switch_extra.data.operand).?; const cond_info = zir_datas[cond_inst].un_node; const cond_tag = sema.code.instructions.items(.tag)[cond_inst]; @@ -10175,6 +10177,8 @@ fn zirSwitchCapture( } } + // Note that these are the *uncasted* prong items. + // Also note that items from ranges are not included so this only works for non-ranged types. const items = switch_extra.data.getProng(sema.code, switch_extra.end, capture_info.prong_index).items; switch (operand_ty.zigTypeTag(mod)) { @@ -10182,7 +10186,8 @@ fn zirSwitchCapture( const union_obj = mod.typeToUnion(operand_ty).?; const first_item = try sema.resolveInst(items[0]); // Previous switch validation ensured this will succeed - const first_item_val = sema.resolveConstValue(block, .unneeded, first_item, "") catch unreachable; + const first_item_coerced = try sema.coerce(block, cond_ty, first_item, .unneeded); + const first_item_val = sema.resolveConstValue(block, .unneeded, first_item_coerced, "") catch unreachable; const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?); const first_field = union_obj.fields.values()[first_field_index]; @@ -10190,7 +10195,8 @@ fn zirSwitchCapture( for (items[1..], 0..) |item, i| { const item_ref = try sema.resolveInst(item); // Previous switch validation ensured this will succeed - const item_val = sema.resolveConstValue(block, .unneeded, item_ref, "") catch unreachable; + const item_coerced = try sema.coerce(block, cond_ty, item_ref, .unneeded); + const item_val = sema.resolveConstValue(block, .unneeded, item_coerced, "") catch unreachable; const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?; const field = union_obj.fields.values()[field_index]; @@ -10406,6 +10412,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError break :blk multi_cases_len; } else 0; + var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len); + defer case_vals.deinit(gpa); + const special_prong = extra.data.bits.specialProng(); const special: struct { body: []const Zir.Inst.Index, end: usize, is_inline: bool } = switch (special_prong) { .none => .{ .body = &.{}, .end = header_extra_index, .is_inline = false }, @@ -10491,14 +10500,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; extra_index += body_len; - try sema.validateSwitchItemEnum( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemEnum( block, seen_enum_fields, &range_set, item_ref, + operand_ty, src_node_offset, .{ .scalar = scalar_i }, - ); + )); } } { @@ -10513,15 +10523,17 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len + body_len; + try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { - try sema.validateSwitchItemEnum( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemEnum( block, seen_enum_fields, &range_set, item_ref, + operand_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } }, - ); + )); } try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); @@ -10588,13 +10600,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; extra_index += body_len; - try sema.validateSwitchItemError( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemError( block, &seen_errors, item_ref, + operand_ty, src_node_offset, .{ .scalar = scalar_i }, - ); + )); } } { @@ -10609,14 +10622,16 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len + body_len; + try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { - try sema.validateSwitchItemError( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemError( block, &seen_errors, item_ref, + operand_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } }, - ); + )); } try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); @@ -10728,13 +10743,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; extra_index += body_len; - try sema.validateSwitchItem( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemInt( block, &range_set, item_ref, + operand_ty, src_node_offset, .{ .scalar = scalar_i }, - ); + )); } } { @@ -10749,16 +10765,19 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len; + try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { - try sema.validateSwitchItem( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemInt( block, &range_set, item_ref, + operand_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } }, - ); + )); } + try case_vals.ensureUnusedCapacity(gpa, 2 * ranges_len); var range_i: u32 = 0; while (range_i < ranges_len) : (range_i += 1) { const item_first = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); @@ -10766,14 +10785,17 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const item_last = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - try sema.validateSwitchRange( + const vals = try sema.validateSwitchRange( block, &range_set, item_first, item_last, + operand_ty, src_node_offset, .{ .range = .{ .prong = multi_i, .item = range_i } }, ); + case_vals.appendAssumeCapacity(vals[0]); + case_vals.appendAssumeCapacity(vals[1]); } extra_index += body_len; @@ -10817,14 +10839,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; extra_index += body_len; - try sema.validateSwitchItemBool( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemBool( block, &true_count, &false_count, item_ref, src_node_offset, .{ .scalar = scalar_i }, - ); + )); } } { @@ -10839,15 +10861,16 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len + body_len; + try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { - try sema.validateSwitchItemBool( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemBool( block, &true_count, &false_count, item_ref, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } }, - ); + )); } try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); @@ -10899,13 +10922,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; extra_index += body_len; - try sema.validateSwitchItemSparse( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemSparse( block, &seen_values, item_ref, + operand_ty, src_node_offset, .{ .scalar = scalar_i }, - ); + )); } } { @@ -10920,14 +10944,16 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len + body_len; + try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { - try sema.validateSwitchItemSparse( + case_vals.appendAssumeCapacity(try sema.validateSwitchItemSparse( block, &seen_values, item_ref, + operand_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } }, - ); + )); } try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); @@ -10997,7 +11023,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError { var scalar_i: usize = 0; while (scalar_i < scalar_cases_len) : (scalar_i += 1) { - const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); const is_inline = sema.code.extra[extra_index] >> 31 != 0; @@ -11005,8 +11030,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const body = sema.code.extra[extra_index..][0..body_len]; extra_index += body_len; - const item = try sema.resolveInst(item_ref); - // Validation above ensured these will succeed. + const item = case_vals.items[scalar_i]; const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable; if (resolved_operand_val.eql(item_val, operand_ty, mod)) { if (is_inline) child_block.inline_case_capture = operand; @@ -11018,6 +11042,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError } { var multi_i: usize = 0; + var case_val_idx: usize = scalar_cases_len; while (multi_i < multi_cases_len) : (multi_i += 1) { const items_len = sema.code.extra[extra_index]; extra_index += 1; @@ -11025,13 +11050,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); const is_inline = sema.code.extra[extra_index] >> 31 != 0; - extra_index += 1; - const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len; + extra_index += 1 + items_len; const body = sema.code.extra[extra_index + 2 * ranges_len ..][0..body_len]; - for (items) |item_ref| { - const item = try sema.resolveInst(item_ref); + const items = case_vals.items[case_val_idx..][0..items_len]; + case_val_idx += items_len; + + for (items) |item| { // Validation above ensured these will succeed. const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable; if (resolved_operand_val.eql(item_val, operand_ty, mod)) { @@ -11044,16 +11069,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var range_i: usize = 0; while (range_i < ranges_len) : (range_i += 1) { - const item_first = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; - const item_last = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; + const range_items = case_vals.items[case_val_idx..][0..2]; + extra_index += 2; + case_val_idx += 2; // Validation above ensured these will succeed. - const first_tv = sema.resolveInstConst(&child_block, .unneeded, item_first, "") catch unreachable; - const last_tv = sema.resolveInstConst(&child_block, .unneeded, item_last, "") catch unreachable; - if ((try sema.compareAll(resolved_operand_val, .gte, first_tv.val, operand_ty)) and - (try sema.compareAll(resolved_operand_val, .lte, last_tv.val, operand_ty))) + const first_val = sema.resolveConstValue(&child_block, .unneeded, range_items[0], "") catch unreachable; + const last_val = sema.resolveConstValue(&child_block, .unneeded, range_items[1], "") catch unreachable; + if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and + (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty))) { if (is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); @@ -11115,7 +11139,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var scalar_i: usize = 0; while (scalar_i < scalar_cases_len) : (scalar_i += 1) { - const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); const is_inline = sema.code.extra[extra_index] >> 31 != 0; @@ -11130,7 +11153,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.wip_capture_scope = wip_captures.scope; case_block.inline_case_capture = .none; - const item = try sema.resolveInst(item_ref); + const item = case_vals.items[scalar_i]; if (is_inline) case_block.inline_case_capture = item; // `item` is already guaranteed to be constant known. @@ -11165,6 +11188,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError defer gpa.free(prev_then_body); var cases_len = scalar_cases_len; + var case_val_idx: usize = scalar_cases_len; var multi_i: u32 = 0; while (multi_i < multi_cases_len) : (multi_i += 1) { const items_len = sema.code.extra[extra_index]; @@ -11173,9 +11197,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); const is_inline = sema.code.extra[extra_index] >> 31 != 0; - extra_index += 1; - const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len; + extra_index += 1 + items_len; + + const items = case_vals.items[case_val_idx..][0..items_len]; + case_val_idx += items_len; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11189,14 +11214,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var range_i: u32 = 0; while (range_i < ranges_len) : (range_i += 1) { - const first_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; - const last_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; + const range_items = case_vals.items[case_val_idx..][0..2]; + extra_index += 2; + case_val_idx += 2; + + const item_first_ref = range_items[0]; + const item_last_ref = range_items[1]; - const item_first_ref = try sema.resolveInst(first_ref); var item = sema.resolveConstValue(block, .unneeded, item_first_ref, undefined) catch unreachable; - const item_last_ref = try sema.resolveInst(last_ref); const item_last = sema.resolveConstValue(block, .unneeded, item_last_ref, undefined) catch unreachable; while (item.compareScalar(.lte, item_last, operand_ty, mod)) : ({ @@ -11235,10 +11260,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError } } - for (items, 0..) |item_ref, item_i| { + for (items, 0..) |item, item_i| { cases_len += 1; - const item = try sema.resolveInst(item_ref); case_block.inline_case_capture = item; case_block.instructions.shrinkRetainingCapacity(0); @@ -11287,8 +11311,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError cases_len += 1; const analyze_body = if (union_originally) - for (items) |item_ref| { - const item = try sema.resolveInst(item_ref); + for (items) |item| { const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; const field_ty = maybe_union_ty.unionFieldType(item_val, mod); if (field_ty.zigTypeTag(mod) != .NoReturn) break true; @@ -11312,15 +11335,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError cases_extra.appendAssumeCapacity(@intCast(u32, items.len)); cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - for (items) |item_ref| { - const item = try sema.resolveInst(item_ref); + for (items) |item| { cases_extra.appendAssumeCapacity(@enumToInt(item)); } cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } else { - for (items) |item_ref| { - const item = try sema.resolveInst(item_ref); + for (items) |item| { const cmp_ok = try case_block.addBinOp(if (case_block.float_mode == .Optimized) .cmp_eq_optimized else .cmp_eq, operand, item); if (any_ok != .none) { any_ok = try case_block.addBinOp(.bool_or, any_ok, cmp_ok); @@ -11331,13 +11352,12 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var range_i: usize = 0; while (range_i < ranges_len) : (range_i += 1) { - const first_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; - const last_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); - extra_index += 1; + const range_items = case_vals.items[case_val_idx..][0..2]; + extra_index += 2; + case_val_idx += 2; - const item_first = try sema.resolveInst(first_ref); - const item_last = try sema.resolveInst(last_ref); + const item_first = range_items[0]; + const item_last = range_items[1]; // operand >= first and operand <= last const range_first_ok = try case_block.addBinOp( @@ -11696,29 +11716,46 @@ const RangeSetUnhandledIterator = struct { } }; +const ResolvedSwitchItem = struct { + ref: Air.Inst.Ref, + val: InternPool.Index, +}; fn resolveSwitchItemVal( sema: *Sema, block: *Block, item_ref: Zir.Inst.Ref, + /// Coerce `item_ref` to this type. + coerce_ty: Type, switch_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, range_expand: Module.SwitchProngSrc.RangeExpand, -) CompileError!InternPool.Index { +) CompileError!ResolvedSwitchItem { const mod = sema.mod; - const item = try sema.resolveInst(item_ref); + const uncoerced_item = try sema.resolveInst(item_ref); + // Constructing a LazySrcLoc is costly because we only have the switch AST node. // Only if we know for sure we need to report a compile error do we resolve the // full source locations. - if (sema.resolveConstLazyValue(block, .unneeded, item, "")) |val| { - return val.toIntern(); - } else |err| switch (err) { + + const item = sema.coerce(block, coerce_ty, uncoerced_item, .unneeded) catch |err| switch (err) { + error.NeededSourceLocation => { + const src = switch_prong_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, range_expand); + _ = try sema.coerce(block, coerce_ty, uncoerced_item, src); + unreachable; + }, + else => |e| return e, + }; + + const val = sema.resolveConstLazyValue(block, .unneeded, item, "") catch |err| switch (err) { error.NeededSourceLocation => { const src = switch_prong_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, range_expand); _ = try sema.resolveConstValue(block, src, item, "switch prong values must be comptime-known"); unreachable; }, else => |e| return e, - } + }; + + return .{ .ref = item, .val = val.toIntern() }; } fn validateSwitchRange( @@ -11727,31 +11764,35 @@ fn validateSwitchRange( range_set: *RangeSet, first_ref: Zir.Inst.Ref, last_ref: Zir.Inst.Ref, + operand_ty: Type, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { +) CompileError![2]Air.Inst.Ref { const mod = sema.mod; - const first = try sema.resolveSwitchItemVal(block, first_ref, src_node_offset, switch_prong_src, .first); - const last = try sema.resolveSwitchItemVal(block, last_ref, src_node_offset, switch_prong_src, .last); - if (first.toValue().compareScalar(.gt, last.toValue(), mod.intern_pool.typeOf(first).toType(), mod)) { + const first = try sema.resolveSwitchItemVal(block, first_ref, operand_ty, src_node_offset, switch_prong_src, .first); + const last = try sema.resolveSwitchItemVal(block, last_ref, operand_ty, src_node_offset, switch_prong_src, .last); + if (try first.val.toValue().compareAll(.gt, last.val.toValue(), operand_ty, mod)) { const src = switch_prong_src.resolve(mod, mod.declPtr(block.src_decl), src_node_offset, .first); return sema.fail(block, src, "range start value is greater than the end value", .{}); } - const maybe_prev_src = try range_set.add(first, last, switch_prong_src); - return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + const maybe_prev_src = try range_set.add(first.val, last.val, switch_prong_src); + try sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + return .{ first.ref, last.ref }; } -fn validateSwitchItem( +fn validateSwitchItemInt( sema: *Sema, block: *Block, range_set: *RangeSet, item_ref: Zir.Inst.Ref, + operand_ty: Type, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { - const item = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none); - const maybe_prev_src = try range_set.add(item, item, switch_prong_src); - return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); +) CompileError!Air.Inst.Ref { + const item = try sema.resolveSwitchItemVal(block, item_ref, operand_ty, src_node_offset, switch_prong_src, .none); + const maybe_prev_src = try range_set.add(item.val, item.val, switch_prong_src); + try sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + return item.ref; } fn validateSwitchItemEnum( @@ -11760,19 +11801,22 @@ fn validateSwitchItemEnum( seen_fields: []?Module.SwitchProngSrc, range_set: *RangeSet, item_ref: Zir.Inst.Ref, + operand_ty: Type, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { +) CompileError!Air.Inst.Ref { const ip = &sema.mod.intern_pool; - const item = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none); - const int = ip.indexToKey(item).enum_tag.int; - const field_index = ip.indexToKey(ip.typeOf(item)).enum_type.tagValueIndex(ip, int) orelse { + const item = try sema.resolveSwitchItemVal(block, item_ref, operand_ty, src_node_offset, switch_prong_src, .none); + const int = ip.indexToKey(item.val).enum_tag.int; + const field_index = ip.indexToKey(ip.typeOf(item.val)).enum_type.tagValueIndex(ip, int) orelse { const maybe_prev_src = try range_set.add(int, int, switch_prong_src); - return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + try sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + return item.ref; }; const maybe_prev_src = seen_fields[field_index]; seen_fields[field_index] = switch_prong_src; - return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + try sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + return item.ref; } fn validateSwitchItemError( @@ -11780,18 +11824,19 @@ fn validateSwitchItemError( block: *Block, seen_errors: *SwitchErrorSet, item_ref: Zir.Inst.Ref, + operand_ty: Type, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { +) CompileError!Air.Inst.Ref { const ip = &sema.mod.intern_pool; - const item = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none); - // TODO: Do i need to typecheck here? - const error_name = ip.indexToKey(item).err.name; + const item = try sema.resolveSwitchItemVal(block, item_ref, operand_ty, src_node_offset, switch_prong_src, .none); + const error_name = ip.indexToKey(item.val).err.name; const maybe_prev_src = if (try seen_errors.fetchPut(error_name, switch_prong_src)) |prev| prev.value else null; - return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + try sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset); + return item.ref; } fn validateSwitchDupe( @@ -11834,19 +11879,20 @@ fn validateSwitchItemBool( item_ref: Zir.Inst.Ref, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { +) CompileError!Air.Inst.Ref { const mod = sema.mod; - const item = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none); - if (item.toValue().toBool()) { + const item = try sema.resolveSwitchItemVal(block, item_ref, Type.bool, src_node_offset, switch_prong_src, .none); + if (item.val.toValue().toBool()) { true_count.* += 1; } else { false_count.* += 1; } - if (true_count.* + false_count.* > 2) { - const block_src_decl = mod.declPtr(block.src_decl); + if (true_count.* > 1 or false_count.* > 1) { + const block_src_decl = sema.mod.declPtr(block.src_decl); const src = switch_prong_src.resolve(mod, block_src_decl, src_node_offset, .none); return sema.fail(block, src, "duplicate switch value", .{}); } + return item.ref; } const ValueSrcMap = std.AutoHashMapUnmanaged(InternPool.Index, Module.SwitchProngSrc); @@ -11856,12 +11902,14 @@ fn validateSwitchItemSparse( block: *Block, seen_values: *ValueSrcMap, item_ref: Zir.Inst.Ref, + operand_ty: Type, src_node_offset: i32, switch_prong_src: Module.SwitchProngSrc, -) CompileError!void { - const item = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none); - const kv = (try seen_values.fetchPut(sema.gpa, item, switch_prong_src)) orelse return; - return sema.validateSwitchDupe(block, kv.value, switch_prong_src, src_node_offset); +) CompileError!Air.Inst.Ref { + const item = try sema.resolveSwitchItemVal(block, item_ref, operand_ty, src_node_offset, switch_prong_src, .none); + const kv = (try seen_values.fetchPut(sema.gpa, item.val, switch_prong_src)) orelse return item.ref; + try sema.validateSwitchDupe(block, kv.value, switch_prong_src, src_node_offset); + unreachable; } fn validateSwitchNoRange( From 00609e7edbbc949b12e29a8d9911d988b78d7e03 Mon Sep 17 00:00:00 2001 From: mlugg Date: Fri, 26 May 2023 04:10:51 +0100 Subject: [PATCH 3/8] Eliminate switch_capture and switch_capture_ref ZIR tags These tags are unnecessary, as this information can be more efficiently encoded within the switch_block instruction itself. We also use a neat little trick to avoid needing a dummy instruction (like is used for errdefer captures): since the switch_block itself cannot otherwise be referenced within a prong, we can repurpose its index within prongs to refer to the captured value. --- src/AstGen.zig | 68 +++-- src/Module.zig | 24 +- src/Sema.zig | 649 +++++++++++++++++++++++++++++++++------------- src/Zir.zig | 109 ++------ src/print_zir.zig | 51 ++-- 5 files changed, 565 insertions(+), 336 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 0003457c01..003677bfe5 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2612,8 +2612,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .switch_block, .switch_cond, .switch_cond_ref, - .switch_capture, - .switch_capture_ref, .switch_capture_tag, .struct_init_empty, .struct_init, @@ -6876,17 +6874,22 @@ fn switchExpr( var dbg_var_inst: Zir.Inst.Ref = undefined; var dbg_var_tag_name: ?u32 = null; var dbg_var_tag_inst: Zir.Inst.Ref = undefined; - var capture_inst: Zir.Inst.Index = 0; var tag_inst: Zir.Inst.Index = 0; var capture_val_scope: Scope.LocalVal = undefined; var tag_scope: Scope.LocalVal = undefined; + + var capture: Zir.Inst.SwitchBlock.ProngInfo.Capture = .none; + const sub_scope = blk: { const payload_token = case.payload_token orelse break :blk &case_scope.base; const ident = if (token_tags[payload_token] == .asterisk) payload_token + 1 else payload_token; + const is_ptr = ident != payload_token; + capture = if (is_ptr) .by_ref else .by_val; + const ident_slice = tree.tokenSlice(ident); var payload_sub_scope: *Scope = undefined; if (mem.eql(u8, ident_slice, "_")) { @@ -6895,46 +6898,18 @@ fn switchExpr( } payload_sub_scope = &case_scope.base; } else { - if (case_node == special_node) { - const capture_tag: Zir.Inst.Tag = if (is_ptr) - .switch_capture_ref - else - .switch_capture; - capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); - try astgen.instructions.append(gpa, .{ - .tag = capture_tag, - .data = .{ - .switch_capture = .{ - .switch_inst = switch_block, - // Max int communicates that this is the else/underscore prong. - .prong_index = std.math.maxInt(u32), - }, - }, - }); - } else { - const capture_tag: Zir.Inst.Tag = if (is_ptr) .switch_capture_ref else .switch_capture; - const capture_index = if (is_multi_case) scalar_cases_len + multi_case_index else scalar_case_index; - capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); - try astgen.instructions.append(gpa, .{ - .tag = capture_tag, - .data = .{ .switch_capture = .{ - .switch_inst = switch_block, - .prong_index = capture_index, - } }, - }); - } const capture_name = try astgen.identAsString(ident); try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice, .capture); capture_val_scope = .{ .parent = &case_scope.base, .gen_zir = &case_scope, .name = capture_name, - .inst = indexToRef(capture_inst), + .inst = indexToRef(switch_block), .token_src = payload_token, .id_cat = .capture, }; dbg_var_name = capture_name; - dbg_var_inst = indexToRef(capture_inst); + dbg_var_inst = indexToRef(switch_block); payload_sub_scope = &capture_val_scope.base; } @@ -7023,7 +6998,6 @@ fn switchExpr( case_scope.instructions_top = parent_gz.instructions.items.len; defer case_scope.unstack(); - if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst); if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst); try case_scope.addDbgBlockBegin(); if (dbg_var_name) |some| { @@ -7042,10 +7016,28 @@ fn switchExpr( } const case_slice = case_scope.instructionsSlice(); - const body_len = astgen.countBodyLenAfterFixups(case_slice); + // Since we use the switch_block instruction itself to refer to the + // capture, which will not be added to the child block, we need to + // handle ref_table manually. + const refs_len = refs: { + var n: usize = 0; + var check_inst = switch_block; + while (astgen.ref_table.get(check_inst)) |ref_inst| { + n += 1; + check_inst = ref_inst; + } + break :refs n; + }; + const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice); try payloads.ensureUnusedCapacity(gpa, body_len); - const inline_bit = @as(u32, @boolToInt(case.inline_token != null)) << 31; - payloads.items[body_len_index] = body_len | inline_bit; + payloads.items[body_len_index] = @bitCast(u32, Zir.Inst.SwitchBlock.ProngInfo{ + .body_len = @intCast(u29, body_len), + .capture = capture, + .is_inline = case.inline_token != null, + }); + if (astgen.ref_table.fetchRemove(switch_block)) |kv| { + appendPossiblyRefdBodyInst(astgen, payloads, kv.value); + } appendBodyWithFixupsArrayList(astgen, payloads, case_slice); } } @@ -7092,7 +7084,7 @@ fn switchExpr( end_index += 3 + items_len + 2 * ranges_len; } - const body_len = @truncate(u31, payloads.items[body_len_index]); + const body_len = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, payloads.items[body_len_index]).body_len; end_index += body_len; switch (strat.tag) { diff --git a/src/Module.zig b/src/Module.zig index 61f39a327a..73ab1c3277 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -5871,6 +5871,7 @@ pub const SwitchProngSrc = union(enum) { multi: Multi, range: Multi, multi_capture: u32, + special, pub const Multi = struct { prong: u32, @@ -5908,14 +5909,22 @@ pub const SwitchProngSrc = union(enum) { var scalar_i: u32 = 0; for (case_nodes) |case_node| { const case = tree.fullSwitchCase(case_node).?; - if (case.ast.values.len == 0) - continue; - if (case.ast.values.len == 1 and - node_tags[case.ast.values[0]] == .identifier and - mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_")) - { - continue; + + const is_special = special: { + if (case.ast.values.len == 0) break :special true; + if (case.ast.values.len == 1 and node_tags[case.ast.values[0]] == .identifier) { + break :special mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_"); + } + break :special false; + }; + + if (is_special) { + if (prong_src != .special) continue; + return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(case.ast.values[0]), + ); } + const is_multi = case.ast.values.len != 1 or node_tags[case.ast.values[0]] == .switch_range; @@ -5956,6 +5965,7 @@ pub const SwitchProngSrc = union(enum) { range_i += 1; } else unreachable; }, + .special => {}, } if (is_multi) { multi_i += 1; diff --git a/src/Sema.zig b/src/Sema.zig index 1bc34beb83..45920c1e92 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -277,9 +277,6 @@ pub const Block = struct { c_import_buf: ?*std.ArrayList(u8) = null, - /// type of `err` in `else => |err|` - switch_else_err_ty: ?Type = null, - /// Value for switch_capture in an inline case inline_case_capture: Air.Inst.Ref = .none, @@ -397,7 +394,6 @@ pub const Block = struct { .want_safety = parent.want_safety, .float_mode = parent.float_mode, .c_import_buf = parent.c_import_buf, - .switch_else_err_ty = parent.switch_else_err_ty, .error_return_trace_index = parent.error_return_trace_index, }; } @@ -1017,8 +1013,6 @@ fn analyzeBodyInner( .switch_block => try sema.zirSwitchBlock(block, inst), .switch_cond => try sema.zirSwitchCond(block, inst, false), .switch_cond_ref => try sema.zirSwitchCond(block, inst, true), - .switch_capture => try sema.zirSwitchCapture(block, inst, false), - .switch_capture_ref => try sema.zirSwitchCapture(block, inst, true), .switch_capture_tag => try sema.zirSwitchCaptureTag(block, inst), .type_info => try sema.zirTypeInfo(block, inst), .size_of => try sema.zirSizeOf(block, inst), @@ -10083,61 +10077,160 @@ fn zirSliceLength(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError return sema.analyzeSlice(block, src, array_ptr, start, len, sentinel, sentinel_src, ptr_src, start_src, end_src, true); } -fn zirSwitchCapture( +/// Resolve a switch prong which is determined at comptime to have no peers. Uses +/// `resolveBlockBody`. Sets up captures as needed. +fn resolveSwitchProngComptime( sema: *Sema, - block: *Block, - inst: Zir.Inst.Index, - is_ref: bool, + parent_block: *Block, + child_block: *Block, + src: LazySrcLoc, + operand: Air.Inst.Ref, + operand_ptr: Air.Inst.Ref, + prong_type: enum { normal, special }, + prong_body: []const Zir.Inst.Index, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, + raw_capture_src: Module.SwitchProngSrc, + else_error_ty: ?Type, + case_vals: []const Air.Inst.Ref, + switch_block_inst: Zir.Inst.Index, + merges: *Block.Merges, ) CompileError!Air.Inst.Ref { - const tracy = trace(@src()); - defer tracy.end(); + switch (capture) { + .none => { + return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges); + }, + .by_val, .by_ref => { + const zir_datas = sema.code.instructions.items(.data); + const switch_info = zir_datas[switch_block_inst].pl_node; + + const capture_ref = try sema.analyzeSwitchCapture( + child_block, + capture == .by_ref, + operand, + operand_ptr, + switch_info.src_node, + prong_type == .special, + raw_capture_src, + else_error_ty, + case_vals, + ); + + if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { + // This prong should be unreachable! + return Air.Inst.Ref.unreachable_value; + } + + sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref); + defer assert(sema.inst_map.remove(switch_block_inst)); + + return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges); + }, + } +} + +/// Analyze a switch prong which may have peers at runtime. Uses +/// `analyzeBodyRuntimeBreak`. Sets up captures as needed. +fn analyzeSwitchProngRuntime( + sema: *Sema, + case_block: *Block, + operand: Air.Inst.Ref, + operand_ptr: Air.Inst.Ref, + prong_type: enum { normal, special }, + prong_body: []const Zir.Inst.Index, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, + raw_capture_src: Module.SwitchProngSrc, + else_error_ty: ?Type, + case_vals: []const Air.Inst.Ref, + switch_block_inst: Zir.Inst.Index, +) CompileError!void { + switch (capture) { + .none => { + return sema.analyzeBodyRuntimeBreak(case_block, prong_body); + }, + + .by_val, .by_ref => { + const zir_datas = sema.code.instructions.items(.data); + const switch_info = zir_datas[switch_block_inst].pl_node; + + const capture_ref = try sema.analyzeSwitchCapture( + case_block, + capture == .by_ref, + operand, + operand_ptr, + switch_info.src_node, + prong_type == .special, + raw_capture_src, + else_error_ty, + case_vals, + ); + + if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { + // No need to analyze any further, the prong is unreachable + return; + } + + sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref); + defer assert(sema.inst_map.remove(switch_block_inst)); + + return sema.analyzeBodyRuntimeBreak(case_block, prong_body); + }, + } +} + +fn analyzeSwitchCapture( + sema: *Sema, + /// Must be the child block so that `inline_case_capture` is set for inline prongs. + block: *Block, + capture_byref: bool, + /// The raw switch operand value. + operand: Air.Inst.Ref, + /// Pointer to the raw switch operand. May be undefined if `capture_byref` is false. + operand_ptr: Air.Inst.Ref, + switch_node_offset: i32, + /// `true` if this is the `else` or `_` prong of a switch. + is_special_prong: bool, + /// Must use the `scalar`, `special`, or `multi_capture` union field. + raw_capture_src: Module.SwitchProngSrc, + /// If this is the `else` prong of a switch on an error set, this is the + /// type that should be assigned to the capture. If `null`, the prong should + /// be unreachable. + else_error_ty: ?Type, + /// The set of all values which can reach this prong. May be undefined if + /// the prong has `is_special_prong` or contains ranges. + case_vals: []const Air.Inst.Ref, +) CompileError!Air.Inst.Ref { const mod = sema.mod; const gpa = sema.gpa; - const zir_datas = sema.code.instructions.items(.data); - const capture_info = zir_datas[inst].switch_capture; - const switch_info = zir_datas[capture_info.switch_inst].pl_node; - const switch_extra = sema.code.extraData(Zir.Inst.SwitchBlock, switch_info.payload_index); - const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_info.src_node }; - const cond = try sema.resolveInst(switch_extra.data.operand); - const cond_ty = sema.typeOf(cond); - const cond_inst = Zir.refToIndex(switch_extra.data.operand).?; - const cond_info = zir_datas[cond_inst].un_node; - const cond_tag = sema.code.instructions.items(.tag)[cond_inst]; - const operand_is_ref = cond_tag == .switch_cond_ref; - const operand_ptr = try sema.resolveInst(cond_info.operand); - const operand_ptr_ty = sema.typeOf(operand_ptr); - const operand_ty = if (operand_is_ref) operand_ptr_ty.childType(mod) else operand_ptr_ty; + const operand_ty = sema.typeOf(operand); + const operand_ptr_ty = if (capture_byref) sema.typeOf(operand_ptr) else undefined; + const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset }; if (block.inline_case_capture != .none) { - const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, undefined) catch unreachable; - const resolved_item_val = try sema.resolveLazyValue(item_val); + const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, "") catch unreachable; if (operand_ty.zigTypeTag(mod) == .Union) { - const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(resolved_item_val, mod).?); + const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, mod).?); const union_obj = mod.typeToUnion(operand_ty).?; const field_ty = union_obj.fields.values()[field_index].ty; - if (try sema.resolveDefinedValue(block, sema.src, operand_ptr)) |union_val| { - if (is_ref) { + if (capture_byref) { + if (try sema.resolveDefinedValue(block, sema.src, operand_ptr)) |union_ptr| { const ptr_field_ty = try Type.ptr(sema.arena, mod, .{ .pointee_type = field_ty, .mutable = operand_ptr_ty.ptrIsMutable(mod), .@"volatile" = operand_ptr_ty.isVolatilePtr(mod), .@"addrspace" = operand_ptr_ty.ptrAddressSpace(mod), }); - return sema.addConstant(ptr_field_ty, (try mod.intern(.{ .ptr = .{ - .ty = ptr_field_ty.toIntern(), - .addr = .{ .field = .{ - .base = union_val.toIntern(), - .index = field_index, - } }, - } })).toValue()); + return sema.addConstant( + ptr_field_ty, + (try mod.intern(.{ .ptr = .{ + .ty = ptr_field_ty.toIntern(), + .addr = .{ .field = .{ + .base = union_ptr.toIntern(), + .index = field_index, + } }, + } })).toValue(), + ); } - return sema.addConstant( - field_ty, - mod.intern_pool.indexToKey(union_val.toIntern()).un.val.toValue(), - ); - } - if (is_ref) { const ptr_field_ty = try Type.ptr(sema.arena, mod, .{ .pointee_type = field_ty, .mutable = operand_ptr_ty.ptrIsMutable(mod), @@ -10146,29 +10239,27 @@ fn zirSwitchCapture( }); return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty); } else { - return block.addStructFieldVal(operand_ptr, field_index, field_ty); + if (try sema.resolveDefinedValue(block, sema.src, operand)) |union_val| { + const tag_and_val = mod.intern_pool.indexToKey(union_val.toIntern()).un; + return sema.addConstant(field_ty, tag_and_val.val.toValue()); + } + return block.addStructFieldVal(operand, field_index, field_ty); } - } else if (is_ref) { - return sema.addConstantMaybeRef(block, operand_ty, resolved_item_val, true); + } else if (capture_byref) { + return sema.addConstantMaybeRef(block, operand_ty, item_val, true); } else { return block.inline_case_capture; } } - const operand = if (operand_is_ref) - try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src) - else - operand_ptr; - - if (capture_info.prong_index == std.math.maxInt(@TypeOf(capture_info.prong_index))) { - // It is the else/`_` prong. - if (is_ref) { + if (is_special_prong) { + if (capture_byref) { return operand_ptr; } switch (operand_ty.zigTypeTag(mod)) { - .ErrorSet => if (block.switch_else_err_ty) |some| { - return sema.bitCast(block, some, operand, operand_src, null); + .ErrorSet => if (else_error_ty) |ty| { + return sema.bitCast(block, ty, operand, operand_src, null); } else { try block.addUnreachable(false); return Air.Inst.Ref.unreachable_value; @@ -10177,41 +10268,33 @@ fn zirSwitchCapture( } } - // Note that these are the *uncasted* prong items. - // Also note that items from ranges are not included so this only works for non-ranged types. - const items = switch_extra.data.getProng(sema.code, switch_extra.end, capture_info.prong_index).items; - switch (operand_ty.zigTypeTag(mod)) { .Union => { const union_obj = mod.typeToUnion(operand_ty).?; - const first_item = try sema.resolveInst(items[0]); - // Previous switch validation ensured this will succeed - const first_item_coerced = try sema.coerce(block, cond_ty, first_item, .unneeded); - const first_item_val = sema.resolveConstValue(block, .unneeded, first_item_coerced, "") catch unreachable; + const first_item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?); const first_field = union_obj.fields.values()[first_field_index]; - for (items[1..], 0..) |item, i| { - const item_ref = try sema.resolveInst(item); - // Previous switch validation ensured this will succeed - const item_coerced = try sema.coerce(block, cond_ty, item_ref, .unneeded); - const item_val = sema.resolveConstValue(block, .unneeded, item_coerced, "") catch unreachable; + for (case_vals[1..], 0..) |item, i| { + const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?; const field = union_obj.fields.values()[field_index]; if (!field.ty.eql(first_field.ty, mod)) { const msg = msg: { - const raw_capture_src = Module.SwitchProngSrc{ .multi_capture = capture_info.prong_index }; - const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first); + const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); errdefer msg.destroy(gpa); - const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = capture_info.prong_index, .item = 0 } }; - const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first); - const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = capture_info.prong_index, .item = 1 + @intCast(u32, i) } }; - const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first); + // This must be a multi-prong so this must be a `multi_capture` src + const multi_idx = raw_capture_src.multi_capture; + + const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; + const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); + const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; + const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)}); try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)}); break :msg msg; @@ -10220,7 +10303,7 @@ fn zirSwitchCapture( } } - if (is_ref) { + if (capture_byref) { const field_ty_ptr = try Type.ptr(sema.arena, mod, .{ .pointee_type = first_field.ty, .@"addrspace" = .generic, @@ -10250,31 +10333,35 @@ fn zirSwitchCapture( return block.addStructFieldVal(operand, first_field_index, first_field.ty); }, .ErrorSet => { - if (items.len > 1) { - var names: Module.Fn.InferredErrorSet.NameMap = .{}; - try names.ensureUnusedCapacity(sema.arena, items.len); - for (items) |item| { - const item_ref = try sema.resolveInst(item); - // Previous switch validation ensured this will succeed - const item_val = sema.resolveConstLazyValue(block, .unneeded, item_ref, "") catch unreachable; - names.putAssumeCapacityNoClobber(item_val.getErrorName(mod).unwrap().?, {}); - } - const else_error_ty = try mod.errorSetFromUnsortedNames(names.keys()); - - return sema.bitCast(block, else_error_ty, operand, operand_src, null); - } else { - const item_ref = try sema.resolveInst(items[0]); - // Previous switch validation ensured this will succeed - const item_val = sema.resolveConstLazyValue(block, .unneeded, item_ref, "") catch unreachable; + if (capture_byref) { + const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); + return sema.fail( + block, + capture_src, + "error set cannot be captured by reference", + .{}, + ); + } + if (case_vals.len == 1) { + const item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; const item_ty = try mod.singleErrorSetType(item_val.getErrorName(mod).unwrap().?); return sema.bitCast(block, item_ty, operand, operand_src, null); } + + var names: Module.Fn.InferredErrorSet.NameMap = .{}; + try names.ensureUnusedCapacity(sema.arena, case_vals.len); + for (case_vals) |err| { + const err_val = sema.resolveConstValue(block, .unneeded, err, "") catch unreachable; + names.putAssumeCapacityNoClobber(err_val.getErrorName(mod).unwrap().?, {}); + } + const error_ty = try mod.errorSetFromUnsortedNames(names.keys()); + return sema.bitCast(block, error_ty, operand, operand_src, null); }, else => { - // In this case the capture value is just the passed-through value of the - // switch condition. - if (is_ref) { + // In this case the capture value is just the passed-through value + // of the switch condition. + if (capture_byref) { return operand_ptr; } else { return operand; @@ -10415,28 +10502,42 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len); defer case_vals.deinit(gpa); + const Special = struct { + body: []const Zir.Inst.Index, + end: usize, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, + is_inline: bool, + }; + const special_prong = extra.data.bits.specialProng(); - const special: struct { body: []const Zir.Inst.Index, end: usize, is_inline: bool } = switch (special_prong) { - .none => .{ .body = &.{}, .end = header_extra_index, .is_inline = false }, + const special: Special = switch (special_prong) { + .none => .{ .body = &.{}, .end = header_extra_index, .capture = .none, .is_inline = false }, .under, .@"else" => blk: { - const body_len = @truncate(u31, sema.code.extra[header_extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[header_extra_index]); const extra_body_start = header_extra_index + 1; break :blk .{ - .body = sema.code.extra[extra_body_start..][0..body_len], - .end = extra_body_start + body_len, - .is_inline = sema.code.extra[header_extra_index] >> 31 != 0, + .body = sema.code.extra[extra_body_start..][0..info.body_len], + .end = extra_body_start + info.body_len, + .capture = info.capture, + .is_inline = info.is_inline, }; }, }; - const maybe_union_ty = blk: { + const raw_operand: struct { val: Air.Inst.Ref, ptr: Air.Inst.Ref } = blk: { const zir_tags = sema.code.instructions.items(.tag); const zir_data = sema.code.instructions.items(.data); const cond_index = Zir.refToIndex(extra.data.operand).?; - const raw_operand = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable; - const target_ty = sema.typeOf(raw_operand); - break :blk if (zir_tags[cond_index] == .switch_cond_ref) target_ty.childType(mod) else target_ty; + const raw = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable; + if (zir_tags[cond_index] == .switch_cond_ref) { + const val = try sema.analyzeLoad(block, src, raw, operand_src); + break :blk .{ .val = val, .ptr = raw }; + } else { + break :blk .{ .val = raw, .ptr = undefined }; + } }; + + const maybe_union_ty = sema.typeOf(raw_operand.val); const union_originally = maybe_union_ty.zigTypeTag(mod) == .Union; // Duplicate checking variables later also used for `inline else`. @@ -10496,9 +10597,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - extra_index += 1; - extra_index += body_len; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); + extra_index += 1 + info.body_len; case_vals.appendAssumeCapacity(try sema.validateSwitchItemEnum( block, @@ -10518,10 +10618,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len + body_len; + extra_index += items_len + info.body_len; try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { @@ -10596,9 +10696,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - extra_index += 1; - extra_index += body_len; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); + extra_index += 1 + info.body_len; case_vals.appendAssumeCapacity(try sema.validateSwitchItemError( block, @@ -10617,10 +10716,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len + body_len; + extra_index += items_len + info.body_len; try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { @@ -10694,7 +10793,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError .dbg_block_end, .dbg_stmt, .dbg_var_val, - .switch_capture, .ret_type, .as_node, .ret_node, @@ -10739,9 +10837,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - extra_index += 1; - extra_index += body_len; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); + extra_index += 1 + info.body_len; case_vals.appendAssumeCapacity(try sema.validateSwitchItemInt( block, @@ -10760,7 +10857,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len; @@ -10798,7 +10895,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_vals.appendAssumeCapacity(vals[1]); } - extra_index += body_len; + extra_index += info.body_len; } } @@ -10835,9 +10932,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - extra_index += 1; - extra_index += body_len; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); + extra_index += 1 + info.body_len; case_vals.appendAssumeCapacity(try sema.validateSwitchItemBool( block, @@ -10856,10 +10952,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len + body_len; + extra_index += items_len + info.body_len; try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { @@ -10918,9 +11014,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; - extra_index += body_len; + extra_index += info.body_len; case_vals.appendAssumeCapacity(try sema.validateSwitchItemSparse( block, @@ -10939,10 +11035,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); - extra_index += items_len + body_len; + extra_index += items_len + info.body_len; try case_vals.ensureUnusedCapacity(gpa, items.len); for (items, 0..) |item_ref, item_i| { @@ -11006,7 +11102,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError .is_comptime = block.is_comptime, .comptime_reason = block.comptime_reason, .is_typeof = block.is_typeof, - .switch_else_err_ty = else_error_ty, .c_import_buf = block.c_import_buf, .runtime_cond = block.runtime_cond, .runtime_loop = block.runtime_loop, @@ -11024,19 +11119,31 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var scalar_i: usize = 0; while (scalar_i < scalar_cases_len) : (scalar_i += 1) { extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - const is_inline = sema.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; - const body = sema.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = sema.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; const item = case_vals.items[scalar_i]; - const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable; - if (resolved_operand_val.eql(item_val, operand_ty, mod)) { - if (is_inline) child_block.inline_case_capture = operand; - + const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable; + if (operand_val.eql(item_val, operand_ty, sema.mod)) { + if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveBlockBody(block, src, &child_block, body, inst, merges); + return sema.resolveSwitchProngComptime( + block, + &child_block, + src, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .scalar = @intCast(u32, scalar_i) }, + else_error_ty, + &.{item}, + inst, + merges, + ); } } } @@ -11048,22 +11155,34 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - const is_inline = sema.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1 + items_len; - const body = sema.code.extra[extra_index + 2 * ranges_len ..][0..body_len]; + const body = sema.code.extra[extra_index + 2 * ranges_len ..][0..info.body_len]; const items = case_vals.items[case_val_idx..][0..items_len]; case_val_idx += items_len; for (items) |item| { // Validation above ensured these will succeed. - const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable; - if (resolved_operand_val.eql(item_val, operand_ty, mod)) { - if (is_inline) child_block.inline_case_capture = operand; - + const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable; + if (operand_val.eql(item_val, operand_ty, sema.mod)) { + if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveBlockBody(block, src, &child_block, body, inst, merges); + return sema.resolveSwitchProngComptime( + block, + &child_block, + src, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = @intCast(u32, multi_i) }, + else_error_ty, + items, + inst, + merges, + ); } } @@ -11079,13 +11198,27 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty))) { - if (is_inline) child_block.inline_case_capture = operand; + if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveBlockBody(block, src, &child_block, body, inst, merges); + return sema.resolveSwitchProngComptime( + block, + &child_block, + src, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = @intCast(u32, multi_i) }, + else_error_ty, + undefined, + inst, + merges, + ); } } - extra_index += body_len; + extra_index += info.body_len; } } if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand); @@ -11093,7 +11226,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (empty_enum) { return Air.Inst.Ref.void_value; } - return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges); + + return sema.resolveSwitchProngComptime( + block, + &child_block, + src, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + undefined, + inst, + merges, + ); } if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) { @@ -11113,7 +11261,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const ok = try block.addUnOp(.is_named_enum_value, operand); try sema.addSafetyCheck(block, ok, .corrupt_switch); } - return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges); + + return sema.resolveSwitchProngComptime( + block, + &child_block, + src, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + undefined, + inst, + merges, + ); } if (child_block.is_comptime) { @@ -11140,11 +11303,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError var scalar_i: usize = 0; while (scalar_i < scalar_cases_len) : (scalar_i += 1) { extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - const is_inline = sema.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1; - const body = sema.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = sema.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; var wip_captures = try WipCaptureScope.init(gpa, child_block.wip_capture_scope); defer wip_captures.deinit(); @@ -11154,7 +11316,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.inline_case_capture = .none; const item = case_vals.items[scalar_i]; - if (is_inline) case_block.inline_case_capture = item; + if (info.is_inline) case_block.inline_case_capture = item; // `item` is already guaranteed to be constant known. const analyze_body = if (union_originally) blk: { @@ -11166,7 +11328,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else if (analyze_body) { - try sema.analyzeBodyRuntimeBreak(&case_block, body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .scalar = @intCast(u32, scalar_i) }, + else_error_ty, + &.{item}, + inst, + ); } else { _ = try case_block.addNoOp(.unreach); } @@ -11195,8 +11368,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError extra_index += 1; const ranges_len = sema.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, sema.code.extra[extra_index]); - const is_inline = sema.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]); extra_index += 1 + items_len; const items = case_vals.items[case_val_idx..][0..items_len]; @@ -11207,9 +11379,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.inline_case_capture = .none; // Generate all possible cases as scalar prongs. - if (is_inline) { + if (info.is_inline) { const body_start = extra_index + 2 * ranges_len; - const body = sema.code.extra[body_start..][0..body_len]; + const body = sema.code.extra[body_start..][0..info.body_len]; var emit_bb = false; var range_i: u32 = 0; @@ -11250,7 +11422,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError }; emit_bb = true; - try sema.analyzeBodyRuntimeBreak(&case_block, body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = multi_i }, + else_error_ty, + undefined, + inst, + ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len @@ -11286,7 +11469,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError emit_bb = true; if (analyze_body) { - try sema.analyzeBodyRuntimeBreak(&case_block, body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = multi_i }, + else_error_ty, + &.{item}, + inst, + ); } else { _ = try case_block.addNoOp(.unreach); } @@ -11298,7 +11492,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } - extra_index += body_len; + extra_index += info.body_len; continue; } @@ -11319,12 +11513,23 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError else true; - const body = sema.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = sema.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else if (analyze_body) { - try sema.analyzeBodyRuntimeBreak(&case_block, body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = multi_i }, + else_error_ty, + items, + inst, + ); } else { _ = try case_block.addNoOp(.unreach); } @@ -11397,12 +11602,23 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = wip_captures.scope; - const body = sema.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = sema.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else { - try sema.analyzeBodyRuntimeBreak(&case_block, body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .normal, + body, + info.capture, + .{ .multi_capture = multi_i }, + else_error_ty, + items, + inst, + ); } try wip_captures.finalize(); @@ -11461,7 +11677,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError emit_bb = true; if (analyze_body) { - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + &.{item_ref}, + inst, + ); } else { _ = try case_block.addNoOp(.unreach); } @@ -11497,7 +11724,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + &.{item_ref}, + inst, + ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len @@ -11520,7 +11758,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + &.{item_ref}, + inst, + ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len @@ -11540,7 +11789,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + &.{Air.Inst.Ref.bool_true}, + inst, + ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len @@ -11558,7 +11818,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + &.{Air.Inst.Ref.bool_false}, + inst, + ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len @@ -11601,7 +11872,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError { // nothing to do here } else if (special.body.len != 0 and analyze_body and !special.is_inline) { - try sema.analyzeBodyRuntimeBreak(&case_block, special.body); + try sema.analyzeSwitchProngRuntime( + &case_block, + raw_operand.val, + raw_operand.ptr, + .special, + special.body, + special.capture, + .special, + else_error_ty, + undefined, + inst, + ); } else { // We still need a terminator in this block, but we have proven // that it is unreachable. @@ -11746,7 +12028,7 @@ fn resolveSwitchItemVal( else => |e| return e, }; - const val = sema.resolveConstLazyValue(block, .unneeded, item, "") catch |err| switch (err) { + const maybe_lazy = sema.resolveConstValue(block, .unneeded, item, "") catch |err| switch (err) { error.NeededSourceLocation => { const src = switch_prong_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, range_expand); _ = try sema.resolveConstValue(block, src, item, "switch prong values must be comptime-known"); @@ -11755,7 +12037,12 @@ fn resolveSwitchItemVal( else => |e| return e, }; - return .{ .ref = item, .val = val.toIntern() }; + const val = try sema.resolveLazyValue(maybe_lazy); + const new_item = if (val.toIntern() != maybe_lazy.toIntern()) blk: { + break :blk try sema.addConstant(coerce_ty, val); + } else item; + + return .{ .ref = new_item, .val = val.toIntern() }; } fn validateSwitchRange( diff --git a/src/Zir.zig b/src/Zir.zig index aaae5f4dcf..2a6ce2f047 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -676,17 +676,6 @@ pub const Inst = struct { /// what will be switched on. /// Uses the `un_node` union field. switch_cond_ref, - /// Produces the capture value for a switch prong. - /// Uses the `switch_capture` field. - /// If the `prong_index` field is max int, it means this is the capture - /// for the else/`_` prong. - switch_capture, - /// Produces the capture value for a switch prong. - /// Result is a pointer to the value. - /// Uses the `switch_capture` field. - /// If the `prong_index` field is max int, it means this is the capture - /// for the else/`_` prong. - switch_capture_ref, /// Produces the capture value for an inline switch prong tag capture. /// Uses the `un_tok` field. switch_capture_tag, @@ -1135,8 +1124,6 @@ pub const Inst = struct { .typeof_log2_int_type, .resolve_inferred_alloc, .set_eval_branch_quota, - .switch_capture, - .switch_capture_ref, .switch_capture_tag, .switch_block, .switch_cond, @@ -1427,8 +1414,6 @@ pub const Inst = struct { .slice_length, .import, .typeof_log2_int_type, - .switch_capture, - .switch_capture_ref, .switch_capture_tag, .switch_block, .switch_cond, @@ -1685,8 +1670,6 @@ pub const Inst = struct { .switch_block = .pl_node, .switch_cond = .un_node, .switch_cond_ref = .un_node, - .switch_capture = .switch_capture, - .switch_capture_ref = .switch_capture, .switch_capture_tag = .un_tok, .array_base_ptr = .un_node, .field_base_ptr = .un_node, @@ -2254,10 +2237,6 @@ pub const Inst = struct { operand: Ref, payload_index: u32, }, - switch_capture: struct { - switch_inst: Index, - prong_index: u32, - }, dbg_stmt: LineColumn, /// Used for unary operators which reference an inst, /// with an AST node source location. @@ -2327,7 +2306,6 @@ pub const Inst = struct { bool_br, @"unreachable", @"break", - switch_capture, dbg_stmt, inst_node, str_op, @@ -2667,25 +2645,29 @@ pub const Inst = struct { /// 0. multi_cases_len: u32 // If has_multi_cases is set. /// 1. else_body { // If has_else or has_under is set. - /// body_len: u32, - /// body member Index for every body_len + /// info: ProngInfo, + /// body member Index for every info.body_len /// } /// 2. scalar_cases: { // for every scalar_cases_len /// item: Ref, - /// body_len: u32, - /// body member Index for every body_len + /// info: ProngInfo, + /// body member Index for every info.body_len /// } /// 3. multi_cases: { // for every multi_cases_len /// items_len: u32, /// ranges_len: u32, - /// body_len: u32, + /// info: ProngInfo, /// item: Ref // for every items_len /// ranges: { // for every ranges_len /// item_first: Ref, /// item_last: Ref, /// } - /// body member Index for every body_len + /// body member Index for every info.body_len /// } + /// + /// When analyzing a case body, the switch instruction itself refers to the + /// captured payload. Whether this is captured by reference or by value + /// depends on whether the `byref` bit is set for the corresponding body. pub const SwitchBlock = struct { /// This is always a `switch_cond` or `switch_cond_ref` instruction. /// If it is a `switch_cond_ref` instruction, bits.is_ref is always true. @@ -2697,6 +2679,19 @@ pub const Inst = struct { operand: Ref, bits: Bits, + /// These are stored in trailing data in `extra` for each prong. + pub const ProngInfo = packed struct(u32) { + body_len: u29, + capture: Capture, + is_inline: bool, + + pub const Capture = enum(u2) { + none, + by_val, + by_ref, + }; + }; + pub const Bits = packed struct { /// If true, one or more prongs have multiple items. has_multi_cases: bool, @@ -2724,64 +2719,6 @@ pub const Inst = struct { items: []const Ref, body: []const Index, }; - - /// TODO performance optimization: instead of having this helper method - /// change the definition of switch_capture instruction to store extra_index - /// instead of prong_index. This way, Sema won't be doing O(N^2) iterations - /// over the switch prongs. - pub fn getProng( - self: SwitchBlock, - zir: Zir, - extra_end: usize, - prong_index: usize, - ) MultiProng { - var extra_index: usize = extra_end + @boolToInt(self.bits.has_multi_cases); - - if (self.bits.specialProng() != .none) { - const body_len = @truncate(u31, zir.extra[extra_index]); - extra_index += 1; - const body = zir.extra[extra_index..][0..body_len]; - extra_index += body.len; - } - - var cur_idx: usize = 0; - while (cur_idx < self.bits.scalar_cases_len) : (cur_idx += 1) { - const items = zir.refSlice(extra_index, 1); - extra_index += 1; - const body_len = @truncate(u31, zir.extra[extra_index]); - extra_index += 1; - const body = zir.extra[extra_index..][0..body_len]; - extra_index += body_len; - if (cur_idx == prong_index) { - return .{ - .items = items, - .body = body, - }; - } - } - while (true) : (cur_idx += 1) { - const items_len = zir.extra[extra_index]; - extra_index += 1; - const ranges_len = zir.extra[extra_index]; - extra_index += 1; - const body_len = @truncate(u31, zir.extra[extra_index]); - extra_index += 1; - const items = zir.refSlice(extra_index, items_len); - extra_index += items_len; - // Each range has a start and an end. - extra_index += 2 * ranges_len; - - const body = zir.extra[extra_index..][0..body_len]; - extra_index += body_len; - - if (cur_idx == prong_index) { - return .{ - .items = items, - .body = body, - }; - } - } - } }; pub const Field = struct { diff --git a/src/print_zir.zig b/src/print_zir.zig index eba7f98a26..203c7daaf3 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -436,10 +436,6 @@ const Writer = struct { .@"unreachable" => try self.writeUnreachable(stream, inst), - .switch_capture, - .switch_capture_ref, - => try self.writeSwitchCapture(stream, inst), - .dbg_stmt => try self.writeDbgStmt(stream, inst), .dbg_block_begin, @@ -1913,15 +1909,20 @@ const Writer = struct { else => break :else_prong, }; - const body_len = @truncate(u31, self.code.extra[extra_index]); - const inline_text = if (self.code.extra[extra_index] >> 31 != 0) "inline " else ""; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]); + const capture_text = switch (info.capture) { + .none => "", + .by_val => "by_val ", + .by_ref => "by_ref ", + }; + const inline_text = if (info.is_inline) "inline " else ""; extra_index += 1; - const body = self.code.extra[extra_index..][0..body_len]; + const body = self.code.extra[extra_index..][0..info.body_len]; extra_index += body.len; try stream.writeAll(",\n"); try stream.writeByteNTimes(' ', self.indent); - try stream.print("{s}{s} => ", .{ inline_text, prong_name }); + try stream.print("{s}{s}{s} => ", .{ capture_text, inline_text, prong_name }); try self.writeBracedBody(stream, body); } @@ -1931,15 +1932,19 @@ const Writer = struct { while (scalar_i < scalar_cases_len) : (scalar_i += 1) { const item_ref = @intToEnum(Zir.Inst.Ref, self.code.extra[extra_index]); extra_index += 1; - const body_len = @truncate(u31, self.code.extra[extra_index]); - const is_inline = self.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]); extra_index += 1; - const body = self.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = self.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; try stream.writeAll(",\n"); try stream.writeByteNTimes(' ', self.indent); - if (is_inline) try stream.writeAll("inline "); + switch (info.capture) { + .none => {}, + .by_val => try stream.writeAll("by_val "), + .by_ref => try stream.writeAll("by_ref "), + } + if (info.is_inline) try stream.writeAll("inline "); try self.writeInstRef(stream, item_ref); try stream.writeAll(" => "); try self.writeBracedBody(stream, body); @@ -1952,15 +1957,19 @@ const Writer = struct { extra_index += 1; const ranges_len = self.code.extra[extra_index]; extra_index += 1; - const body_len = @truncate(u31, self.code.extra[extra_index]); - const is_inline = self.code.extra[extra_index] >> 31 != 0; + const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]); extra_index += 1; const items = self.code.refSlice(extra_index, items_len); extra_index += items_len; try stream.writeAll(",\n"); try stream.writeByteNTimes(' ', self.indent); - if (is_inline) try stream.writeAll("inline "); + switch (info.capture) { + .none => {}, + .by_val => try stream.writeAll("by_val "), + .by_ref => try stream.writeAll("by_ref "), + } + if (info.is_inline) try stream.writeAll("inline "); for (items, 0..) |item_ref, item_i| { if (item_i != 0) try stream.writeAll(", "); @@ -1982,8 +1991,8 @@ const Writer = struct { try self.writeInstRef(stream, item_last); } - const body = self.code.extra[extra_index..][0..body_len]; - extra_index += body_len; + const body = self.code.extra[extra_index..][0..info.body_len]; + extra_index += info.body_len; try stream.writeAll(" => "); try self.writeBracedBody(stream, body); } @@ -2435,12 +2444,6 @@ const Writer = struct { try self.writeSrc(stream, src); } - fn writeSwitchCapture(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void { - const inst_data = self.code.instructions.items(.data)[inst].switch_capture; - try self.writeInstIndex(stream, inst_data.switch_inst); - try stream.print(", {d})", .{inst_data.prong_index}); - } - fn writeDbgStmt(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void { const inst_data = self.code.instructions.items(.data)[inst].dbg_stmt; try stream.print("{d}, {d})", .{ inst_data.line + 1, inst_data.column + 1 }); From ec27524da9b4200dc9ea39285e9c4c30cad28a98 Mon Sep 17 00:00:00 2001 From: mlugg Date: Sat, 27 May 2023 06:52:25 +0100 Subject: [PATCH 4/8] Sema: minor refactor to switch prong analysis --- src/Sema.zig | 599 ++++++++++++++++++++++----------------------------- 1 file changed, 260 insertions(+), 339 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 45920c1e92..ec905e72b8 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -10077,298 +10077,284 @@ fn zirSliceLength(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError return sema.analyzeSlice(block, src, array_ptr, start, len, sentinel, sentinel_src, ptr_src, start_src, end_src, true); } -/// Resolve a switch prong which is determined at comptime to have no peers. Uses -/// `resolveBlockBody`. Sets up captures as needed. -fn resolveSwitchProngComptime( +/// Holds common data used when analyzing or resolving switch prong bodies, +/// including setting up captures. +const SwitchProngAnalysis = struct { sema: *Sema, + /// The block containing the `switch_block` itself. parent_block: *Block, - child_block: *Block, - src: LazySrcLoc, + /// The raw switch operand value (*not* the condition). Always defined. operand: Air.Inst.Ref, + /// May be `undefined` if no prong has a by-ref capture. operand_ptr: Air.Inst.Ref, - prong_type: enum { normal, special }, - prong_body: []const Zir.Inst.Index, - capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, - raw_capture_src: Module.SwitchProngSrc, + /// If this switch is on an error set, this is the type to assign to the + /// `else` prong. If `null`, the prong should be unreachable. else_error_ty: ?Type, - case_vals: []const Air.Inst.Ref, + /// The index of the `switch_block` instruction itself. switch_block_inst: Zir.Inst.Index, - merges: *Block.Merges, -) CompileError!Air.Inst.Ref { - switch (capture) { - .none => { - return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges); - }, - .by_val, .by_ref => { - const zir_datas = sema.code.instructions.items(.data); - const switch_info = zir_datas[switch_block_inst].pl_node; + /// Resolve a switch prong which is determined at comptime to have no peers. + /// Uses `resolveBlockBody`. Sets up captures as needed. + fn resolveProngComptime( + spa: SwitchProngAnalysis, + child_block: *Block, + prong_type: enum { normal, special }, + prong_body: []const Zir.Inst.Index, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, + /// Must use the `scalar`, `special`, or `multi_capture` union field. + raw_capture_src: Module.SwitchProngSrc, + /// The set of all values which can reach this prong. May be undefined + /// if the prong is special or contains ranges. + case_vals: []const Air.Inst.Ref, + merges: *Block.Merges, + ) CompileError!Air.Inst.Ref { + const sema = spa.sema; + const src = sema.code.instructions.items(.data)[spa.switch_block_inst].pl_node.src(); + switch (capture) { + .none => { + return sema.resolveBlockBody(spa.parent_block, src, child_block, prong_body, spa.switch_block_inst, merges); + }, - const capture_ref = try sema.analyzeSwitchCapture( - child_block, - capture == .by_ref, - operand, - operand_ptr, - switch_info.src_node, - prong_type == .special, - raw_capture_src, - else_error_ty, - case_vals, - ); + .by_val, .by_ref => { + const capture_ref = try spa.analyzeCapture( + child_block, + capture == .by_ref, + prong_type == .special, + raw_capture_src, + case_vals, + ); - if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { - // This prong should be unreachable! - return Air.Inst.Ref.unreachable_value; - } + if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { + // This prong should be unreachable! + return Air.Inst.Ref.unreachable_value; + } - sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref); - defer assert(sema.inst_map.remove(switch_block_inst)); + sema.inst_map.putAssumeCapacity(spa.switch_block_inst, capture_ref); + defer assert(sema.inst_map.remove(spa.switch_block_inst)); - return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges); - }, + return sema.resolveBlockBody(spa.parent_block, src, child_block, prong_body, spa.switch_block_inst, merges); + }, + } } -} -/// Analyze a switch prong which may have peers at runtime. Uses -/// `analyzeBodyRuntimeBreak`. Sets up captures as needed. -fn analyzeSwitchProngRuntime( - sema: *Sema, - case_block: *Block, - operand: Air.Inst.Ref, - operand_ptr: Air.Inst.Ref, - prong_type: enum { normal, special }, - prong_body: []const Zir.Inst.Index, - capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, - raw_capture_src: Module.SwitchProngSrc, - else_error_ty: ?Type, - case_vals: []const Air.Inst.Ref, - switch_block_inst: Zir.Inst.Index, -) CompileError!void { - switch (capture) { - .none => { - return sema.analyzeBodyRuntimeBreak(case_block, prong_body); - }, + /// Analyze a switch prong which may have peers at runtime. + /// Uses `analyzeBodyRuntimeBreak`. Sets up captures as needed. + fn analyzeProngRuntime( + spa: SwitchProngAnalysis, + case_block: *Block, + prong_type: enum { normal, special }, + prong_body: []const Zir.Inst.Index, + capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, + /// Must use the `scalar`, `special`, or `multi_capture` union field. + raw_capture_src: Module.SwitchProngSrc, + /// The set of all values which can reach this prong. May be undefined + /// if the prong is special or contains ranges. + case_vals: []const Air.Inst.Ref, + ) CompileError!void { + const sema = spa.sema; + switch (capture) { + .none => { + return sema.analyzeBodyRuntimeBreak(case_block, prong_body); + }, - .by_val, .by_ref => { - const zir_datas = sema.code.instructions.items(.data); - const switch_info = zir_datas[switch_block_inst].pl_node; + .by_val, .by_ref => { + const capture_ref = try spa.analyzeCapture( + case_block, + capture == .by_ref, + prong_type == .special, + raw_capture_src, + case_vals, + ); - const capture_ref = try sema.analyzeSwitchCapture( - case_block, - capture == .by_ref, - operand, - operand_ptr, - switch_info.src_node, - prong_type == .special, - raw_capture_src, - else_error_ty, - case_vals, - ); + if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { + // No need to analyze any further, the prong is unreachable + return; + } - if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { - // No need to analyze any further, the prong is unreachable - return; - } + sema.inst_map.putAssumeCapacity(spa.switch_block_inst, capture_ref); + defer assert(sema.inst_map.remove(spa.switch_block_inst)); - sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref); - defer assert(sema.inst_map.remove(switch_block_inst)); - - return sema.analyzeBodyRuntimeBreak(case_block, prong_body); - }, + return sema.analyzeBodyRuntimeBreak(case_block, prong_body); + }, + } } -} -fn analyzeSwitchCapture( - sema: *Sema, - /// Must be the child block so that `inline_case_capture` is set for inline prongs. - block: *Block, - capture_byref: bool, - /// The raw switch operand value. - operand: Air.Inst.Ref, - /// Pointer to the raw switch operand. May be undefined if `capture_byref` is false. - operand_ptr: Air.Inst.Ref, - switch_node_offset: i32, - /// `true` if this is the `else` or `_` prong of a switch. - is_special_prong: bool, - /// Must use the `scalar`, `special`, or `multi_capture` union field. - raw_capture_src: Module.SwitchProngSrc, - /// If this is the `else` prong of a switch on an error set, this is the - /// type that should be assigned to the capture. If `null`, the prong should - /// be unreachable. - else_error_ty: ?Type, - /// The set of all values which can reach this prong. May be undefined if - /// the prong has `is_special_prong` or contains ranges. - case_vals: []const Air.Inst.Ref, -) CompileError!Air.Inst.Ref { - const mod = sema.mod; - const gpa = sema.gpa; - const operand_ty = sema.typeOf(operand); - const operand_ptr_ty = if (capture_byref) sema.typeOf(operand_ptr) else undefined; - const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset }; + fn analyzeCapture( + spa: SwitchProngAnalysis, + block: *Block, + capture_byref: bool, + is_special_prong: bool, + raw_capture_src: Module.SwitchProngSrc, + case_vals: []const Air.Inst.Ref, + ) CompileError!Air.Inst.Ref { + const sema = spa.sema; + const mod = sema.mod; - if (block.inline_case_capture != .none) { - const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, "") catch unreachable; - if (operand_ty.zigTypeTag(mod) == .Union) { - const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, mod).?); - const union_obj = mod.typeToUnion(operand_ty).?; - const field_ty = union_obj.fields.values()[field_index].ty; - if (capture_byref) { - if (try sema.resolveDefinedValue(block, sema.src, operand_ptr)) |union_ptr| { + const zir_datas = sema.code.instructions.items(.data); + const switch_node_offset = zir_datas[spa.switch_block_inst].pl_node.src_node; + + const operand_ty = sema.typeOf(spa.operand); + const operand_ptr_ty = if (capture_byref) sema.typeOf(spa.operand_ptr) else undefined; + const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset }; + + if (block.inline_case_capture != .none) { + const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, "") catch unreachable; + if (operand_ty.zigTypeTag(mod) == .Union) { + const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, mod).?); + const union_obj = mod.typeToUnion(operand_ty).?; + const field_ty = union_obj.fields.values()[field_index].ty; + if (capture_byref) { const ptr_field_ty = try Type.ptr(sema.arena, mod, .{ .pointee_type = field_ty, .mutable = operand_ptr_ty.ptrIsMutable(mod), .@"volatile" = operand_ptr_ty.isVolatilePtr(mod), .@"addrspace" = operand_ptr_ty.ptrAddressSpace(mod), }); - return sema.addConstant( - ptr_field_ty, - (try mod.intern(.{ .ptr = .{ - .ty = ptr_field_ty.toIntern(), - .addr = .{ .field = .{ - .base = union_ptr.toIntern(), - .index = field_index, - } }, - } })).toValue(), - ); + if (try sema.resolveDefinedValue(block, sema.src, spa.operand_ptr)) |union_ptr| { + return sema.addConstant( + ptr_field_ty, + (try mod.intern(.{ .ptr = .{ + .ty = ptr_field_ty.toIntern(), + .addr = .{ .field = .{ + .base = union_ptr.toIntern(), + .index = field_index, + } }, + } })).toValue(), + ); + } + return block.addStructFieldPtr(spa.operand_ptr, field_index, ptr_field_ty); + } else { + if (try sema.resolveDefinedValue(block, sema.src, spa.operand)) |union_val| { + const tag_and_val = mod.intern_pool.indexToKey(union_val.toIntern()).un; + return sema.addConstant(field_ty, tag_and_val.val.toValue()); + } + return block.addStructFieldVal(spa.operand, field_index, field_ty); } - const ptr_field_ty = try Type.ptr(sema.arena, mod, .{ - .pointee_type = field_ty, - .mutable = operand_ptr_ty.ptrIsMutable(mod), - .@"volatile" = operand_ptr_ty.isVolatilePtr(mod), - .@"addrspace" = operand_ptr_ty.ptrAddressSpace(mod), - }); - return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty); + } else if (capture_byref) { + return sema.addConstantMaybeRef(block, operand_ty, item_val, true); } else { - if (try sema.resolveDefinedValue(block, sema.src, operand)) |union_val| { - const tag_and_val = mod.intern_pool.indexToKey(union_val.toIntern()).un; - return sema.addConstant(field_ty, tag_and_val.val.toValue()); - } - return block.addStructFieldVal(operand, field_index, field_ty); + return block.inline_case_capture; } - } else if (capture_byref) { - return sema.addConstantMaybeRef(block, operand_ty, item_val, true); - } else { - return block.inline_case_capture; } - } - if (is_special_prong) { - if (capture_byref) { - return operand_ptr; + if (is_special_prong) { + if (capture_byref) { + return spa.operand_ptr; + } + + switch (operand_ty.zigTypeTag(mod)) { + .ErrorSet => if (spa.else_error_ty) |ty| { + return sema.bitCast(block, ty, spa.operand, operand_src, null); + } else { + try block.addUnreachable(false); + return Air.Inst.Ref.unreachable_value; + }, + else => return spa.operand, + } } switch (operand_ty.zigTypeTag(mod)) { - .ErrorSet => if (else_error_ty) |ty| { - return sema.bitCast(block, ty, operand, operand_src, null); - } else { - try block.addUnreachable(false); - return Air.Inst.Ref.unreachable_value; - }, - else => return operand, - } - } + .Union => { + const union_obj = mod.typeToUnion(operand_ty).?; + const first_item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; - switch (operand_ty.zigTypeTag(mod)) { - .Union => { - const union_obj = mod.typeToUnion(operand_ty).?; - const first_item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; + const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?); + const first_field = union_obj.fields.values()[first_field_index]; - const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?); - const first_field = union_obj.fields.values()[first_field_index]; + for (case_vals[1..], 0..) |item, i| { + const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; - for (case_vals[1..], 0..) |item, i| { - const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; + const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?; + const field = union_obj.fields.values()[field_index]; + if (!field.ty.eql(first_field.ty, mod)) { + const msg = msg: { + const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); - const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?; - const field = union_obj.fields.values()[field_index]; - if (!field.ty.eql(first_field.ty, mod)) { - const msg = msg: { - const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); + const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); + errdefer msg.destroy(sema.gpa); - const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); - errdefer msg.destroy(gpa); + // This must be a multi-prong so this must be a `multi_capture` src + const multi_idx = raw_capture_src.multi_capture; - // This must be a multi-prong so this must be a `multi_capture` src - const multi_idx = raw_capture_src.multi_capture; - - const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; - const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); - const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; - const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); - try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)}); - try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)}); - break :msg msg; - }; - return sema.failWithOwnedErrorMsg(msg); + const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; + const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); + const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; + const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); + try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)}); + try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)}); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } } - } - if (capture_byref) { - const field_ty_ptr = try Type.ptr(sema.arena, mod, .{ - .pointee_type = first_field.ty, - .@"addrspace" = .generic, - .mutable = operand_ptr_ty.ptrIsMutable(mod), - }); + if (capture_byref) { + const field_ty_ptr = try Type.ptr(sema.arena, mod, .{ + .pointee_type = first_field.ty, + .@"addrspace" = .generic, + .mutable = operand_ptr_ty.ptrIsMutable(mod), + }); - if (try sema.resolveDefinedValue(block, operand_src, operand_ptr)) |op_ptr_val| { - return sema.addConstant(field_ty_ptr, (try mod.intern(.{ .ptr = .{ - .ty = field_ty_ptr.toIntern(), - .addr = .{ .field = .{ - .base = op_ptr_val.toIntern(), - .index = first_field_index, - } }, - } })).toValue()); + if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { + return sema.addConstant(field_ty_ptr, (try mod.intern(.{ .ptr = .{ + .ty = field_ty_ptr.toIntern(), + .addr = .{ .field = .{ + .base = op_ptr_val.toIntern(), + .index = first_field_index, + } }, + } })).toValue()); + } + try sema.requireRuntimeBlock(block, operand_src, null); + return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr); + } + + if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| { + return sema.addConstant( + first_field.ty, + mod.intern_pool.indexToKey(operand_val.toIntern()).un.val.toValue(), + ); } try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldPtr(operand_ptr, first_field_index, field_ty_ptr); - } + return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + }, + .ErrorSet => { + if (capture_byref) { + const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); + return sema.fail( + block, + capture_src, + "error set cannot be captured by reference", + .{}, + ); + } - if (try sema.resolveDefinedValue(block, operand_src, operand)) |operand_val| { - return sema.addConstant( - first_field.ty, - mod.intern_pool.indexToKey(operand_val.toIntern()).un.val.toValue(), - ); - } - try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldVal(operand, first_field_index, first_field.ty); - }, - .ErrorSet => { - if (capture_byref) { - const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); - return sema.fail( - block, - capture_src, - "error set cannot be captured by reference", - .{}, - ); - } + if (case_vals.len == 1) { + const item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; + const item_ty = try mod.singleErrorSetType(item_val.getErrorName(mod).unwrap().?); + return sema.bitCast(block, item_ty, spa.operand, operand_src, null); + } - if (case_vals.len == 1) { - const item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable; - const item_ty = try mod.singleErrorSetType(item_val.getErrorName(mod).unwrap().?); - return sema.bitCast(block, item_ty, operand, operand_src, null); - } - - var names: Module.Fn.InferredErrorSet.NameMap = .{}; - try names.ensureUnusedCapacity(sema.arena, case_vals.len); - for (case_vals) |err| { - const err_val = sema.resolveConstValue(block, .unneeded, err, "") catch unreachable; - names.putAssumeCapacityNoClobber(err_val.getErrorName(mod).unwrap().?, {}); - } - const error_ty = try mod.errorSetFromUnsortedNames(names.keys()); - return sema.bitCast(block, error_ty, operand, operand_src, null); - }, - else => { - // In this case the capture value is just the passed-through value - // of the switch condition. - if (capture_byref) { - return operand_ptr; - } else { - return operand; - } - }, + var names: Module.Fn.InferredErrorSet.NameMap = .{}; + try names.ensureUnusedCapacity(sema.arena, case_vals.len); + for (case_vals) |err| { + const err_val = sema.resolveConstValue(block, .unneeded, err, "") catch unreachable; + names.putAssumeCapacityNoClobber(err_val.getErrorName(mod).unwrap().?, {}); + } + const error_ty = try mod.errorSetFromUnsortedNames(names.keys()); + return sema.bitCast(block, error_ty, spa.operand, operand_src, null); + }, + else => { + // In this case the capture value is just the passed-through value + // of the switch condition. + if (capture_byref) { + return spa.operand_ptr; + } else { + return spa.operand; + } + }, + } } -} +}; fn zirSwitchCaptureTag(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const mod = sema.mod; @@ -11075,6 +11061,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError }), } + const spa: SwitchProngAnalysis = .{ + .sema = sema, + .parent_block = block, + .operand = raw_operand.val, + .operand_ptr = raw_operand.ptr, + .else_error_ty = else_error_ty, + .switch_block_inst = inst, + }; + const block_inst = @intCast(Air.Inst.Index, sema.air_instructions.len); try sema.air_instructions.append(gpa, .{ .tag = .block, @@ -11129,19 +11124,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (operand_val.eql(item_val, operand_ty, sema.mod)) { if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveSwitchProngComptime( - block, + return spa.resolveProngComptime( &child_block, - src, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .scalar = @intCast(u32, scalar_i) }, - else_error_ty, &.{item}, - inst, merges, ); } @@ -11168,19 +11157,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (operand_val.eql(item_val, operand_ty, sema.mod)) { if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveSwitchProngComptime( - block, + return spa.resolveProngComptime( &child_block, - src, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = @intCast(u32, multi_i) }, - else_error_ty, items, - inst, merges, ); } @@ -11200,19 +11183,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError { if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); - return sema.resolveSwitchProngComptime( - block, + return spa.resolveProngComptime( &child_block, - src, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = @intCast(u32, multi_i) }, - else_error_ty, - undefined, - inst, + undefined, // case_vals may be undefined for ranges merges, ); } @@ -11227,19 +11204,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError return Air.Inst.Ref.void_value; } - return sema.resolveSwitchProngComptime( - block, + return spa.resolveProngComptime( &child_block, - src, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, - undefined, - inst, + undefined, // case_vals may be undefined for special prongs merges, ); } @@ -11262,19 +11233,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError try sema.addSafetyCheck(block, ok, .corrupt_switch); } - return sema.resolveSwitchProngComptime( - block, + return spa.resolveProngComptime( &child_block, - src, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, - undefined, - inst, + undefined, // case_vals may be undefined for special prongs merges, ); } @@ -11328,17 +11293,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else if (analyze_body) { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .scalar = @intCast(u32, scalar_i) }, - else_error_ty, &.{item}, - inst, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11422,17 +11383,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError }; emit_bb = true; - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = multi_i }, - else_error_ty, - undefined, - inst, + undefined, // case_vals may be undefined for ranges ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); @@ -11469,17 +11426,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError emit_bb = true; if (analyze_body) { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = multi_i }, - else_error_ty, &.{item}, - inst, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11518,17 +11471,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else if (analyze_body) { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = multi_i }, - else_error_ty, items, - inst, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11607,17 +11556,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) { // nothing to do here } else { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .normal, body, info.capture, .{ .multi_capture = multi_i }, - else_error_ty, items, - inst, ); } @@ -11677,17 +11622,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError emit_bb = true; if (analyze_body) { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, &.{item_ref}, - inst, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11724,17 +11665,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, &.{item_ref}, - inst, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); @@ -11758,17 +11695,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, &.{item_ref}, - inst, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); @@ -11789,17 +11722,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, &.{Air.Inst.Ref.bool_true}, - inst, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); @@ -11818,17 +11747,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src); emit_bb = true; - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, &.{Air.Inst.Ref.bool_false}, - inst, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); @@ -11872,17 +11797,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError { // nothing to do here } else if (special.body.len != 0 and analyze_body and !special.is_inline) { - try sema.analyzeSwitchProngRuntime( + try spa.analyzeProngRuntime( &case_block, - raw_operand.val, - raw_operand.ptr, .special, special.body, special.capture, .special, - else_error_ty, - undefined, - inst, + undefined, // case_vals may be undefined for special prongs ); } else { // We still need a terminator in this block, but we have proven From 39510cc7d1f6035e16aaa4f97991abbaddef463c Mon Sep 17 00:00:00 2001 From: mlugg Date: Sat, 27 May 2023 07:29:55 +0100 Subject: [PATCH 5/8] Eliminate switch_capture_tag ZIR instruction This is a follow-up to a previous commit which eliminated switch_capture and switch_capture_ref. All captures are now handled directly by `switch_block`, which has also eliminated some unnecessary Block data in Sema. --- src/AstGen.zig | 62 ++++++++++++---- src/Sema.zig | 182 +++++++++++++++++++++++++++++++--------------- src/Zir.zig | 27 ++++--- src/print_zir.zig | 14 +++- 4 files changed, 194 insertions(+), 91 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index 003677bfe5..ae2c191285 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2612,7 +2612,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .switch_block, .switch_cond, .switch_cond_ref, - .switch_capture_tag, .struct_init_empty, .struct_init, .struct_init_ref, @@ -2956,7 +2955,7 @@ fn deferStmt( try gz.astgen.instructions.append(gz.astgen.gpa, .{ .tag = .extended, .data = .{ .extended = .{ - .opcode = .errdefer_err_code, + .opcode = .value_placeholder, .small = undefined, .operand = undefined, } }, @@ -6711,6 +6710,7 @@ fn switchExpr( // for the following variables, make note of the special prong AST node index, // and bail out with a compile error if there are multiple special prongs present. var any_payload_is_ref = false; + var any_has_tag_capture = false; var scalar_cases_len: u32 = 0; var multi_cases_len: u32 = 0; var inline_cases_len: u32 = 0; @@ -6721,8 +6721,12 @@ fn switchExpr( for (case_nodes) |case_node| { const case = tree.fullSwitchCase(case_node).?; if (case.payload_token) |payload_token| { - if (token_tags[payload_token] == .asterisk) { + const ident = if (token_tags[payload_token] == .asterisk) blk: { any_payload_is_ref = true; + break :blk payload_token + 1; + } else payload_token; + if (token_tags[ident + 1] == .comma) { + any_has_tag_capture = true; } } // Check for else/`_` prong. @@ -6861,6 +6865,20 @@ fn switchExpr( var case_scope = parent_gz.makeSubBlock(&block_scope.base); case_scope.instructions_top = GenZir.unstacked_top; + // If any prong has an inline tag capture, allocate a shared dummy instruction for it + const tag_inst = if (any_has_tag_capture) tag_inst: { + const inst = @intCast(Zir.Inst.Index, astgen.instructions.len); + try astgen.instructions.append(astgen.gpa, .{ + .tag = .extended, + .data = .{ .extended = .{ + .opcode = .value_placeholder, + .small = undefined, + .operand = undefined, + } }, // TODO rename opcode + }); + break :tag_inst inst; + } else undefined; + // In this pass we generate all the item and prong expressions. var multi_case_index: u32 = 0; var scalar_case_index: u32 = 0; @@ -6874,7 +6892,7 @@ fn switchExpr( var dbg_var_inst: Zir.Inst.Ref = undefined; var dbg_var_tag_name: ?u32 = null; var dbg_var_tag_inst: Zir.Inst.Ref = undefined; - var tag_inst: Zir.Inst.Index = 0; + var has_tag_capture = false; var capture_val_scope: Scope.LocalVal = undefined; var tag_scope: Scope.LocalVal = undefined; @@ -6925,14 +6943,9 @@ fn switchExpr( } const tag_name = try astgen.identAsString(tag_token); try astgen.detectLocalShadowing(payload_sub_scope, tag_name, tag_token, tag_slice, .@"switch tag capture"); - tag_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); - try astgen.instructions.append(gpa, .{ - .tag = .switch_capture_tag, - .data = .{ .un_tok = .{ - .operand = cond, - .src_tok = case_scope.tokenIndexToRelative(tag_token), - } }, - }); + + assert(any_has_tag_capture); + has_tag_capture = true; tag_scope = .{ .parent = payload_sub_scope, @@ -6998,7 +7011,6 @@ fn switchExpr( case_scope.instructions_top = parent_gz.instructions.items.len; defer case_scope.unstack(); - if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst); try case_scope.addDbgBlockBegin(); if (dbg_var_name) |some| { try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_inst); @@ -7018,7 +7030,8 @@ fn switchExpr( const case_slice = case_scope.instructionsSlice(); // Since we use the switch_block instruction itself to refer to the // capture, which will not be added to the child block, we need to - // handle ref_table manually. + // handle ref_table manually, and the same for the inline tag + // capture instruction. const refs_len = refs: { var n: usize = 0; var check_inst = switch_block; @@ -7026,18 +7039,31 @@ fn switchExpr( n += 1; check_inst = ref_inst; } + if (has_tag_capture) { + check_inst = tag_inst; + while (astgen.ref_table.get(check_inst)) |ref_inst| { + n += 1; + check_inst = ref_inst; + } + } break :refs n; }; const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice); try payloads.ensureUnusedCapacity(gpa, body_len); payloads.items[body_len_index] = @bitCast(u32, Zir.Inst.SwitchBlock.ProngInfo{ - .body_len = @intCast(u29, body_len), + .body_len = @intCast(u28, body_len), .capture = capture, .is_inline = case.inline_token != null, + .has_tag_capture = has_tag_capture, }); if (astgen.ref_table.fetchRemove(switch_block)) |kv| { appendPossiblyRefdBodyInst(astgen, payloads, kv.value); } + if (has_tag_capture) { + if (astgen.ref_table.fetchRemove(tag_inst)) |kv| { + appendPossiblyRefdBodyInst(astgen, payloads, kv.value); + } + } appendBodyWithFixupsArrayList(astgen, payloads, case_slice); } } @@ -7046,6 +7072,7 @@ fn switchExpr( try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.SwitchBlock).Struct.fields.len + @boolToInt(multi_cases_len != 0) + + @boolToInt(any_has_tag_capture) + payloads.items.len - case_table_end); const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.SwitchBlock{ @@ -7054,6 +7081,7 @@ fn switchExpr( .has_multi_cases = multi_cases_len != 0, .has_else = special_prong == .@"else", .has_under = special_prong == .under, + .any_has_tag_capture = any_has_tag_capture, .scalar_cases_len = @intCast(Zir.Inst.SwitchBlock.Bits.ScalarCasesLen, scalar_cases_len), }, }); @@ -7062,6 +7090,10 @@ fn switchExpr( astgen.extra.appendAssumeCapacity(multi_cases_len); } + if (any_has_tag_capture) { + astgen.extra.appendAssumeCapacity(tag_inst); + } + const zir_datas = astgen.instructions.items(.data); const zir_tags = astgen.instructions.items(.tag); diff --git a/src/Sema.zig b/src/Sema.zig index ec905e72b8..723eb9b01a 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -277,9 +277,6 @@ pub const Block = struct { c_import_buf: ?*std.ArrayList(u8) = null, - /// Value for switch_capture in an inline case - inline_case_capture: Air.Inst.Ref = .none, - const ComptimeReason = union(enum) { c_import: struct { block: *Block, @@ -1013,7 +1010,6 @@ fn analyzeBodyInner( .switch_block => try sema.zirSwitchBlock(block, inst), .switch_cond => try sema.zirSwitchCond(block, inst, false), .switch_cond_ref => try sema.zirSwitchCond(block, inst, true), - .switch_capture_tag => try sema.zirSwitchCaptureTag(block, inst), .type_info => try sema.zirTypeInfo(block, inst), .size_of => try sema.zirSizeOf(block, inst), .bit_size_of => try sema.zirBitSizeOf(block, inst), @@ -1217,7 +1213,7 @@ fn analyzeBodyInner( i += 1; continue; }, - .errdefer_err_code => unreachable, // never appears in a body + .value_placeholder => unreachable, // never appears in a body }; }, @@ -10092,6 +10088,9 @@ const SwitchProngAnalysis = struct { else_error_ty: ?Type, /// The index of the `switch_block` instruction itself. switch_block_inst: Zir.Inst.Index, + /// The dummy index into which inline tag captures should be placed. May be + /// undefined if no prong has a tag capture. + tag_capture_inst: Zir.Inst.Index, /// Resolve a switch prong which is determined at comptime to have no peers. /// Uses `resolveBlockBody`. Sets up captures as needed. @@ -10106,10 +10105,23 @@ const SwitchProngAnalysis = struct { /// The set of all values which can reach this prong. May be undefined /// if the prong is special or contains ranges. case_vals: []const Air.Inst.Ref, + /// The inline capture of this prong. If this is not an inline prong, + /// this is `.none`. + inline_case_capture: Air.Inst.Ref, + /// Whether this prong has an inline tag capture. If `true`, then + /// `inline_case_capture` cannot be `.none`. + has_tag_capture: bool, merges: *Block.Merges, ) CompileError!Air.Inst.Ref { const sema = spa.sema; const src = sema.code.instructions.items(.data)[spa.switch_block_inst].pl_node.src(); + + if (has_tag_capture) { + const tag_ref = try spa.analyzeTagCapture(child_block, raw_capture_src, inline_case_capture); + sema.inst_map.putAssumeCapacity(spa.tag_capture_inst, tag_ref); + } + defer if (has_tag_capture) assert(sema.inst_map.remove(spa.tag_capture_inst)); + switch (capture) { .none => { return sema.resolveBlockBody(spa.parent_block, src, child_block, prong_body, spa.switch_block_inst, merges); @@ -10122,6 +10134,7 @@ const SwitchProngAnalysis = struct { prong_type == .special, raw_capture_src, case_vals, + inline_case_capture, ); if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { @@ -10150,8 +10163,21 @@ const SwitchProngAnalysis = struct { /// The set of all values which can reach this prong. May be undefined /// if the prong is special or contains ranges. case_vals: []const Air.Inst.Ref, + /// The inline capture of this prong. If this is not an inline prong, + /// this is `.none`. + inline_case_capture: Air.Inst.Ref, + /// Whether this prong has an inline tag capture. If `true`, then + /// `inline_case_capture` cannot be `.none`. + has_tag_capture: bool, ) CompileError!void { const sema = spa.sema; + + if (has_tag_capture) { + const tag_ref = try spa.analyzeTagCapture(case_block, raw_capture_src, inline_case_capture); + sema.inst_map.putAssumeCapacity(spa.tag_capture_inst, tag_ref); + } + defer if (has_tag_capture) assert(sema.inst_map.remove(spa.tag_capture_inst)); + switch (capture) { .none => { return sema.analyzeBodyRuntimeBreak(case_block, prong_body); @@ -10164,6 +10190,7 @@ const SwitchProngAnalysis = struct { prong_type == .special, raw_capture_src, case_vals, + inline_case_capture, ); if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) { @@ -10179,6 +10206,33 @@ const SwitchProngAnalysis = struct { } } + fn analyzeTagCapture( + spa: SwitchProngAnalysis, + block: *Block, + raw_capture_src: Module.SwitchProngSrc, + inline_case_capture: Air.Inst.Ref, + ) CompileError!Air.Inst.Ref { + const sema = spa.sema; + const mod = sema.mod; + const operand_ty = sema.typeOf(spa.operand); + if (operand_ty.zigTypeTag(mod) != .Union) { + const zir_datas = sema.code.instructions.items(.data); + const switch_node_offset = zir_datas[spa.switch_block_inst].pl_node.src_node; + const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); + const msg = msg: { + const msg = try sema.errMsg(block, capture_src, "cannot capture tag of non-union type '{}'", .{ + operand_ty.fmt(mod), + }); + errdefer msg.destroy(sema.gpa); + try sema.addDeclaredHereNote(msg, operand_ty); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } + assert(inline_case_capture != .none); + return inline_case_capture; + } + fn analyzeCapture( spa: SwitchProngAnalysis, block: *Block, @@ -10186,6 +10240,7 @@ const SwitchProngAnalysis = struct { is_special_prong: bool, raw_capture_src: Module.SwitchProngSrc, case_vals: []const Air.Inst.Ref, + inline_case_capture: Air.Inst.Ref, ) CompileError!Air.Inst.Ref { const sema = spa.sema; const mod = sema.mod; @@ -10197,8 +10252,8 @@ const SwitchProngAnalysis = struct { const operand_ptr_ty = if (capture_byref) sema.typeOf(spa.operand_ptr) else undefined; const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset }; - if (block.inline_case_capture != .none) { - const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, "") catch unreachable; + if (inline_case_capture != .none) { + const item_val = sema.resolveConstValue(block, .unneeded, inline_case_capture, "") catch unreachable; if (operand_ty.zigTypeTag(mod) == .Union) { const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, mod).?); const union_obj = mod.typeToUnion(operand_ty).?; @@ -10233,7 +10288,7 @@ const SwitchProngAnalysis = struct { } else if (capture_byref) { return sema.addConstantMaybeRef(block, operand_ty, item_val, true); } else { - return block.inline_case_capture; + return inline_case_capture; } } @@ -10356,34 +10411,6 @@ const SwitchProngAnalysis = struct { } }; -fn zirSwitchCaptureTag(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { - const mod = sema.mod; - const zir_datas = sema.code.instructions.items(.data); - const inst_data = zir_datas[inst].un_tok; - const src = inst_data.src(); - - const switch_tag = sema.code.instructions.items(.tag)[Zir.refToIndex(inst_data.operand).?]; - const is_ref = switch_tag == .switch_cond_ref; - const cond_data = zir_datas[Zir.refToIndex(inst_data.operand).?].un_node; - const operand_ptr = try sema.resolveInst(cond_data.operand); - const operand_ptr_ty = sema.typeOf(operand_ptr); - const operand_ty = if (is_ref) operand_ptr_ty.childType(mod) else operand_ptr_ty; - - if (operand_ty.zigTypeTag(mod) != .Union) { - const msg = msg: { - const msg = try sema.errMsg(block, src, "cannot capture tag of non-union type '{}'", .{ - operand_ty.fmt(mod), - }); - errdefer msg.destroy(sema.gpa); - try sema.addDeclaredHereNote(msg, operand_ty); - break :msg msg; - }; - return sema.failWithOwnedErrorMsg(msg); - } - - return block.inline_case_capture; -} - fn zirSwitchCond( sema: *Sema, block: *Block, @@ -10485,6 +10512,16 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError break :blk multi_cases_len; } else 0; + const tag_capture_inst: Zir.Inst.Index = if (extra.data.bits.any_has_tag_capture) blk: { + const tag_capture_inst = sema.code.extra[header_extra_index]; + header_extra_index += 1; + // SwitchProngAnalysis wants inst_map to have space for the tag capture. + // Note that the normal capture is referred to via the switch block + // index, which there is already necessarily space for. + try sema.inst_map.ensureSpaceForInstructions(gpa, &.{tag_capture_inst}); + break :blk tag_capture_inst; + } else undefined; + var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len); defer case_vals.deinit(gpa); @@ -10493,11 +10530,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError end: usize, capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, is_inline: bool, + has_tag_capture: bool, }; const special_prong = extra.data.bits.specialProng(); const special: Special = switch (special_prong) { - .none => .{ .body = &.{}, .end = header_extra_index, .capture = .none, .is_inline = false }, + .none => .{ + .body = &.{}, + .end = header_extra_index, + .capture = .none, + .is_inline = false, + .has_tag_capture = false, + }, .under, .@"else" => blk: { const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[header_extra_index]); const extra_body_start = header_extra_index + 1; @@ -10506,6 +10550,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError .end = extra_body_start + info.body_len, .capture = info.capture, .is_inline = info.is_inline, + .has_tag_capture = info.has_tag_capture, }; }, }; @@ -11068,6 +11113,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError .operand_ptr = raw_operand.ptr, .else_error_ty = else_error_ty, .switch_block_inst = inst, + .tag_capture_inst = tag_capture_inst, }; const block_inst = @intCast(Air.Inst.Index, sema.air_instructions.len); @@ -11122,7 +11168,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const item = case_vals.items[scalar_i]; const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable; if (operand_val.eql(item_val, operand_ty, sema.mod)) { - if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); return spa.resolveProngComptime( &child_block, @@ -11131,6 +11176,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .scalar = @intCast(u32, scalar_i) }, &.{item}, + if (info.is_inline) operand else .none, + info.has_tag_capture, merges, ); } @@ -11155,7 +11202,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError // Validation above ensured these will succeed. const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable; if (operand_val.eql(item_val, operand_ty, sema.mod)) { - if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); return spa.resolveProngComptime( &child_block, @@ -11164,6 +11210,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = @intCast(u32, multi_i) }, items, + if (info.is_inline) operand else .none, + info.has_tag_capture, merges, ); } @@ -11181,7 +11229,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty))) { - if (info.is_inline) child_block.inline_case_capture = operand; if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand); return spa.resolveProngComptime( &child_block, @@ -11190,6 +11237,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = @intCast(u32, multi_i) }, undefined, // case_vals may be undefined for ranges + if (info.is_inline) operand else .none, + info.has_tag_capture, merges, ); } @@ -11199,7 +11248,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError } } if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand); - if (special.is_inline) child_block.inline_case_capture = operand; if (empty_enum) { return Air.Inst.Ref.void_value; } @@ -11211,6 +11259,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, undefined, // case_vals may be undefined for special prongs + if (special.is_inline) operand else .none, + special.has_tag_capture, merges, ); } @@ -11240,6 +11290,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, undefined, // case_vals may be undefined for special prongs + .none, + false, merges, ); } @@ -11278,10 +11330,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = wip_captures.scope; - case_block.inline_case_capture = .none; const item = case_vals.items[scalar_i]; - if (info.is_inline) case_block.inline_case_capture = item; // `item` is already guaranteed to be constant known. const analyze_body = if (union_originally) blk: { @@ -11300,6 +11350,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .scalar = @intCast(u32, scalar_i) }, &.{item}, + if (info.is_inline) item else .none, + info.has_tag_capture, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11337,7 +11389,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; - case_block.inline_case_capture = .none; // Generate all possible cases as scalar prongs. if (info.is_inline) { @@ -11367,7 +11418,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError cases_len += 1; const item_ref = try sema.addConstant(operand_ty, item); - case_block.inline_case_capture = item_ref; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11390,12 +11440,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = multi_i }, undefined, // case_vals may be undefined for ranges + item_ref, + info.has_tag_capture, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(item_ref)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } } @@ -11403,8 +11455,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError for (items, 0..) |item, item_i| { cases_len += 1; - case_block.inline_case_capture = item; - case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11433,6 +11483,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = multi_i }, &.{item}, + item, + info.has_tag_capture, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11441,7 +11493,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(item)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } @@ -11478,6 +11530,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = multi_i }, items, + .none, + false, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11563,6 +11617,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError info.capture, .{ .multi_capture = multi_i }, items, + .none, + false, ); } @@ -11608,7 +11664,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const item_val = try mod.enumValueFieldIndex(operand_ty, @intCast(u32, i)); const item_ref = try sema.addConstant(operand_ty, item_val); - case_block.inline_case_capture = item_ref; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11629,6 +11684,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, &.{item_ref}, + item_ref, + special.has_tag_capture, ); } else { _ = try case_block.addNoOp(.unreach); @@ -11637,7 +11694,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(item_ref)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } }, @@ -11657,7 +11714,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError .name = error_name, } }); const item_ref = try sema.addConstant(operand_ty, item_val.toValue()); - case_block.inline_case_capture = item_ref; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11672,12 +11728,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, &.{item_ref}, + item_ref, + special.has_tag_capture, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(item_ref)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } }, @@ -11687,7 +11745,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError cases_len += 1; const item_ref = try sema.addConstant(operand_ty, cur.toValue()); - case_block.inline_case_capture = item_ref; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11702,19 +11759,20 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, &.{item_ref}, + item_ref, + special.has_tag_capture, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(item_ref)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } }, .Bool => { if (true_count == 0) { cases_len += 1; - case_block.inline_case_capture = Air.Inst.Ref.bool_true; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11729,17 +11787,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, &.{Air.Inst.Ref.bool_true}, + Air.Inst.Ref.bool_true, + special.has_tag_capture, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(Air.Inst.Ref.bool_true)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } if (false_count == 0) { cases_len += 1; - case_block.inline_case_capture = Air.Inst.Ref.bool_false; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; @@ -11754,12 +11813,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, &.{Air.Inst.Ref.bool_false}, + Air.Inst.Ref.bool_false, + special.has_tag_capture, ); try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); cases_extra.appendAssumeCapacity(1); // items_len cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); - cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture)); + cases_extra.appendAssumeCapacity(@enumToInt(Air.Inst.Ref.bool_false)); cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); } }, @@ -11773,7 +11834,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = wip_captures.scope; - case_block.inline_case_capture = .none; if (mod.backendSupportsFeature(.is_named_enum_value) and special.body.len != 0 and block.wantSafety() and operand_ty.zigTypeTag(mod) == .Enum and (!operand_ty.isNonexhaustiveEnum(mod) or union_originally)) @@ -11804,6 +11864,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError special.capture, .special, undefined, // case_vals may be undefined for special prongs + .none, + false, ); } else { // We still need a terminator in this block, but we have proven diff --git a/src/Zir.zig b/src/Zir.zig index 2a6ce2f047..ba883314e8 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -676,9 +676,6 @@ pub const Inst = struct { /// what will be switched on. /// Uses the `un_node` union field. switch_cond_ref, - /// Produces the capture value for an inline switch prong tag capture. - /// Uses the `un_tok` field. - switch_capture_tag, /// Given a /// *A returns *A /// *E!A returns *A @@ -1124,7 +1121,6 @@ pub const Inst = struct { .typeof_log2_int_type, .resolve_inferred_alloc, .set_eval_branch_quota, - .switch_capture_tag, .switch_block, .switch_cond, .switch_cond_ref, @@ -1414,7 +1410,6 @@ pub const Inst = struct { .slice_length, .import, .typeof_log2_int_type, - .switch_capture_tag, .switch_block, .switch_cond, .switch_cond_ref, @@ -1670,7 +1665,6 @@ pub const Inst = struct { .switch_block = .pl_node, .switch_cond = .un_node, .switch_cond_ref = .un_node, - .switch_capture_tag = .un_tok, .array_base_ptr = .un_node, .field_base_ptr = .un_node, .validate_array_init_ty = .pl_node, @@ -1996,9 +1990,10 @@ pub const Inst = struct { /// Implements the `@inComptime` builtin. /// `operand` is `src_node: i32`. in_comptime, - /// Used as a placeholder for the capture of an `errdefer`. - /// This is replaced by Sema with the captured value. - errdefer_err_code, + /// Used as a placeholder instruction which is just a dummy index for Sema to replace + /// with a specific value. For instance, this is used for the capture of an `errdefer`. + /// This should never appear in a body. + value_placeholder, pub const InstData = struct { opcode: Extended, @@ -2644,16 +2639,17 @@ pub const Inst = struct { }; /// 0. multi_cases_len: u32 // If has_multi_cases is set. - /// 1. else_body { // If has_else or has_under is set. + /// 1. tag_capture_inst: u32 // If any_has_tag_capture is set. Index of instruction prongs use to refer to the inline tag capture. + /// 2. else_body { // If has_else or has_under is set. /// info: ProngInfo, /// body member Index for every info.body_len /// } - /// 2. scalar_cases: { // for every scalar_cases_len + /// 3. scalar_cases: { // for every scalar_cases_len /// item: Ref, /// info: ProngInfo, /// body member Index for every info.body_len /// } - /// 3. multi_cases: { // for every multi_cases_len + /// 4. multi_cases: { // for every multi_cases_len /// items_len: u32, /// ranges_len: u32, /// info: ProngInfo, @@ -2681,9 +2677,10 @@ pub const Inst = struct { /// These are stored in trailing data in `extra` for each prong. pub const ProngInfo = packed struct(u32) { - body_len: u29, + body_len: u28, capture: Capture, is_inline: bool, + has_tag_capture: bool, pub const Capture = enum(u2) { none, @@ -2699,9 +2696,11 @@ pub const Inst = struct { has_else: bool, /// If true, there is an underscore prong. This is mutually exclusive with `has_else`. has_under: bool, + /// If true, at least one prong has an inline tag capture. + any_has_tag_capture: bool, scalar_cases_len: ScalarCasesLen, - pub const ScalarCasesLen = u29; + pub const ScalarCasesLen = u28; pub fn specialProng(bits: Bits) SpecialProng { const has_else: u2 = @boolToInt(bits.has_else); diff --git a/src/print_zir.zig b/src/print_zir.zig index 203c7daaf3..0148cf89ec 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -235,7 +235,6 @@ const Writer = struct { .ref, .ret_implicit, .closure_capture, - .switch_capture_tag, => try self.writeUnTok(stream, inst), .bool_br_and, @@ -463,7 +462,7 @@ const Writer = struct { .breakpoint, .c_va_start, .in_comptime, - .errdefer_err_code, + .value_placeholder, => try self.writeExtNode(stream, extended), .builtin_src => { @@ -1897,8 +1896,19 @@ const Writer = struct { break :blk multi_cases_len; } else 0; + const tag_capture_inst: Zir.Inst.Index = if (extra.data.bits.any_has_tag_capture) blk: { + const tag_capture_inst = self.code.extra[extra_index]; + extra_index += 1; + break :blk tag_capture_inst; + } else undefined; + try self.writeInstRef(stream, extra.data.operand); + if (extra.data.bits.any_has_tag_capture) { + try stream.writeAll(", tag_capture="); + try self.writeInstIndex(stream, tag_capture_inst); + } + self.indent += 2; else_prong: { From 85e94fed1e058720460560823ac09d7b64e49b97 Mon Sep 17 00:00:00 2001 From: mlugg Date: Sat, 27 May 2023 13:50:50 +0100 Subject: [PATCH 6/8] Eliminate switch_cond[_ref] ZIR tags This finishes the process of consolidating switch expressions in ZIR into as simple and compact a representation as is possible. There are now just two ZIR tags dedicated to switch expressions: switch_block and switch_block_ref, with the latter being for an operand passed by reference. --- src/AstGen.zig | 14 ++++++------ src/Autodoc.zig | 25 ---------------------- src/Sema.zig | 54 ++++++++++++++++++----------------------------- src/Zir.zig | 32 +++++++++------------------- src/print_zir.zig | 6 +++--- 5 files changed, 40 insertions(+), 91 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index ae2c191285..86bd0fd4f4 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2610,8 +2610,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .slice_length, .import, .switch_block, - .switch_cond, - .switch_cond_ref, + .switch_block_ref, .struct_init_empty, .struct_init, .struct_init_ref, @@ -6835,10 +6834,6 @@ fn switchExpr( const operand_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column }; const raw_operand = try expr(parent_gz, scope, operand_ri, operand_node); - const cond_tag: Zir.Inst.Tag = if (any_payload_is_ref) .switch_cond_ref else .switch_cond; - const cond = try parent_gz.addUnNode(cond_tag, raw_operand, operand_node); - // Sema expects a dbg_stmt immediately after switch_cond(_ref) - try emitDbgStmt(parent_gz, operand_lc); const item_ri: ResultInfo = .{ .rl = .none }; // This contains the data that goes into the `extra` array for the SwitchBlock/SwitchBlockMulti, @@ -6858,8 +6853,11 @@ fn switchExpr( block_scope.instructions_top = GenZir.unstacked_top; block_scope.setBreakResultInfo(ri); + // Sema expects a dbg_stmt immediately before switch_block(_ref) + try emitDbgStmt(parent_gz, operand_lc); // This gets added to the parent block later, after the item expressions. - const switch_block = try parent_gz.makeBlockInst(.switch_block, switch_node); + const switch_tag: Zir.Inst.Tag = if (any_payload_is_ref) .switch_block_ref else .switch_block; + const switch_block = try parent_gz.makeBlockInst(switch_tag, switch_node); // We re-use this same scope for all cases, including the special prong, if any. var case_scope = parent_gz.makeSubBlock(&block_scope.base); @@ -7076,7 +7074,7 @@ fn switchExpr( payloads.items.len - case_table_end); const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.SwitchBlock{ - .operand = cond, + .operand = raw_operand, .bits = Zir.Inst.SwitchBlock.Bits{ .has_multi_cases = multi_cases_len != 0, .has_else = special_prong == .@"else", diff --git a/src/Autodoc.zig b/src/Autodoc.zig index 1cdb768311..055b8fd989 100644 --- a/src/Autodoc.zig +++ b/src/Autodoc.zig @@ -1993,31 +1993,6 @@ fn walkInstruction( .expr = .{ .switchIndex = switch_index }, }; }, - .switch_cond => { - const un_node = data[inst_index].un_node; - const operand = try self.walkRef( - file, - parent_scope, - parent_src, - un_node.operand, - need_type, - ); - const operand_index = self.exprs.items.len; - try self.exprs.append(self.arena, operand.expr); - - // const ast_index = self.ast_nodes.items.len; - // const sep = "=" ** 200; - // log.debug("{s}", .{sep}); - // log.debug("SWITCH COND", .{}); - // log.debug("ast index = {}", .{ast_index}); - // log.debug("ast previous = {}", .{self.ast_nodes.items[ast_index - 1]}); - // log.debug("{s}", .{sep}); - - return DocData.WalkResult{ - .typeRef = operand.typeRef, - .expr = .{ .typeOf = operand_index }, - }; - }, .typeof => { const un_node = data[inst_index].un_node; diff --git a/src/Sema.zig b/src/Sema.zig index 723eb9b01a..40251a849b 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1007,9 +1007,8 @@ fn analyzeBodyInner( .slice_start => try sema.zirSliceStart(block, inst), .slice_length => try sema.zirSliceLength(block, inst), .str => try sema.zirStr(block, inst), - .switch_block => try sema.zirSwitchBlock(block, inst), - .switch_cond => try sema.zirSwitchCond(block, inst, false), - .switch_cond_ref => try sema.zirSwitchCond(block, inst, true), + .switch_block => try sema.zirSwitchBlock(block, inst, false), + .switch_block_ref => try sema.zirSwitchBlock(block, inst, true), .type_info => try sema.zirTypeInfo(block, inst), .size_of => try sema.zirSizeOf(block, inst), .bit_size_of => try sema.zirBitSizeOf(block, inst), @@ -10411,23 +10410,14 @@ const SwitchProngAnalysis = struct { } }; -fn zirSwitchCond( +fn switchCond( sema: *Sema, block: *Block, - inst: Zir.Inst.Index, - is_ref: bool, + src: LazySrcLoc, + operand: Air.Inst.Ref, ) CompileError!Air.Inst.Ref { const mod = sema.mod; - const inst_data = sema.code.instructions.items(.data)[inst].un_node; - const src = inst_data.src(); - const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = inst_data.src_node }; - const operand_ptr = try sema.resolveInst(inst_data.operand); - const operand = if (is_ref) - try sema.analyzeLoad(block, src, operand_ptr, operand_src) - else - operand_ptr; const operand_ty = sema.typeOf(operand); - switch (operand_ty.zigTypeTag(mod)) { .Type, .Void, @@ -10484,7 +10474,7 @@ fn zirSwitchCond( const SwitchErrorSet = std.AutoHashMap(InternPool.NullTerminatedString, Module.SwitchProngSrc); -fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { +fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_ref: bool) CompileError!Air.Inst.Ref { const tracy = trace(@src()); defer tracy.end(); @@ -10498,10 +10488,21 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const special_prong_src: LazySrcLoc = .{ .node_offset_switch_special_prong = src_node_offset }; const extra = sema.code.extraData(Zir.Inst.SwitchBlock, inst_data.payload_index); - const operand = try sema.resolveInst(extra.data.operand); - // AstGen guarantees that the instruction immediately following - // switch_cond(_ref) is a dbg_stmt - const cond_dbg_node_index = Zir.refToIndex(extra.data.operand).? + 1; + const raw_operand: struct { val: Air.Inst.Ref, ptr: Air.Inst.Ref } = blk: { + const maybe_ptr = try sema.resolveInst(extra.data.operand); + if (operand_is_ref) { + const val = try sema.analyzeLoad(block, src, maybe_ptr, operand_src); + break :blk .{ .val = val, .ptr = maybe_ptr }; + } else { + break :blk .{ .val = maybe_ptr, .ptr = undefined }; + } + }; + + const operand = try sema.switchCond(block, src, raw_operand.val); + + // AstGen guarantees that the instruction immediately preceding + // switch_block(_ref) is a dbg_stmt + const cond_dbg_node_index = inst - 1; var header_extra_index: usize = extra.end; @@ -10555,19 +10556,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError }, }; - const raw_operand: struct { val: Air.Inst.Ref, ptr: Air.Inst.Ref } = blk: { - const zir_tags = sema.code.instructions.items(.tag); - const zir_data = sema.code.instructions.items(.data); - const cond_index = Zir.refToIndex(extra.data.operand).?; - const raw = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable; - if (zir_tags[cond_index] == .switch_cond_ref) { - const val = try sema.analyzeLoad(block, src, raw, operand_src); - break :blk .{ .val = val, .ptr = raw }; - } else { - break :blk .{ .val = raw, .ptr = undefined }; - } - }; - const maybe_union_ty = sema.typeOf(raw_operand.val); const union_originally = maybe_union_ty.zigTypeTag(mod) == .Union; diff --git a/src/Zir.zig b/src/Zir.zig index ba883314e8..81b02bb469 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -667,15 +667,9 @@ pub const Inst = struct { /// A switch expression. Uses the `pl_node` union field. /// AST node is the switch, payload is `SwitchBlock`. switch_block, - /// Produces the value that will be switched on. For example, for - /// integers, it returns the integer with no modifications. For tagged unions, it - /// returns the active enum tag. - /// Uses the `un_node` union field. - switch_cond, - /// Same as `switch_cond`, except the input operand is a pointer to - /// what will be switched on. - /// Uses the `un_node` union field. - switch_cond_ref, + /// A switch expression. Uses the `pl_node` union field. + /// AST node is the switch, payload is `SwitchBlock`. Operand is a pointer. + switch_block_ref, /// Given a /// *A returns *A /// *E!A returns *A @@ -1122,8 +1116,7 @@ pub const Inst = struct { .resolve_inferred_alloc, .set_eval_branch_quota, .switch_block, - .switch_cond, - .switch_cond_ref, + .switch_block_ref, .array_base_ptr, .field_base_ptr, .validate_array_init_ty, @@ -1411,8 +1404,7 @@ pub const Inst = struct { .import, .typeof_log2_int_type, .switch_block, - .switch_cond, - .switch_cond_ref, + .switch_block_ref, .array_base_ptr, .field_base_ptr, .struct_init_empty, @@ -1663,8 +1655,7 @@ pub const Inst = struct { .err_union_code_ptr = .un_node, .enum_literal = .str_tok, .switch_block = .pl_node, - .switch_cond = .un_node, - .switch_cond_ref = .un_node, + .switch_block_ref = .pl_node, .array_base_ptr = .un_node, .field_base_ptr = .un_node, .validate_array_init_ty = .pl_node, @@ -2665,13 +2656,10 @@ pub const Inst = struct { /// captured payload. Whether this is captured by reference or by value /// depends on whether the `byref` bit is set for the corresponding body. pub const SwitchBlock = struct { - /// This is always a `switch_cond` or `switch_cond_ref` instruction. - /// If it is a `switch_cond_ref` instruction, bits.is_ref is always true. - /// If it is a `switch_cond` instruction, bits.is_ref is always false. - /// Both `switch_cond` and `switch_cond_ref` return a value, not a pointer, - /// that is useful for the case items, but cannot be used for capture values. - /// For the capture values, Sema is expected to find the operand of this operand - /// and use that. + /// The operand passed to the `switch` expression. If this is a + /// `switch_block`, this is the operand value; if `switch_block_ref` it + /// is a pointer to the operand. `switch_block_ref` is always used if + /// any prong has a byref capture. operand: Ref, bits: Bits, diff --git a/src/print_zir.zig b/src/print_zir.zig index 0148cf89ec..35050f1728 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -222,8 +222,6 @@ const Writer = struct { .bit_reverse, .@"resume", .@"await", - .switch_cond, - .switch_cond_ref, .array_base_ptr, .field_base_ptr, .validate_struct_init_ty, @@ -388,7 +386,9 @@ const Writer = struct { .error_set_decl_anon => try self.writeErrorSetDecl(stream, inst, .anon), .error_set_decl_func => try self.writeErrorSetDecl(stream, inst, .func), - .switch_block => try self.writeSwitchBlock(stream, inst), + .switch_block, + .switch_block_ref, + => try self.writeSwitchBlock(stream, inst), .field_ptr, .field_ptr_init, From bcb673d94ac09ec381e5b1ad1edf64b603ac68f1 Mon Sep 17 00:00:00 2001 From: mlugg Date: Sun, 28 May 2023 01:45:15 +0100 Subject: [PATCH 7/8] Sema: resolve union payload switch captures with peer type resolution This is a bit harder than it seems at first glance. Actually resolving the type is the easy part: the interesting thing is actually getting the capture value. We split this into three cases: * If all payload types are the same (as is required in status quo), we can just do what we already do: get the first field value. * If all payloads are in-memory coercible to the resolved type, we still fetch the first field, but we also emit a `bitcast` to convert to the resolved type. * Otherwise, we need to handle each case separately. We emit a nested `switch_br` which, for each possible case, gets the corresponding union field, and coerces it to the resolved type. As an optimization, the inner switch's 'else' prong is used for any peer which is in-memory coercible to the target type, and the bitcast approach described above is used. Pointer captures have the additional constraint that all payload types must be in-memory coercible to the resolved type. Resolves: #2812 --- src/Sema.zig | 299 +++++++++++++++--- test/behavior/switch.zig | 68 ++++ .../switch_capture_incompatible_types.zig | 27 ++ 3 files changed, 352 insertions(+), 42 deletions(-) create mode 100644 test/cases/compile_errors/switch_capture_incompatible_types.zig diff --git a/src/Sema.zig b/src/Sema.zig index 40251a849b..51acda3810 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2392,6 +2392,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError { return error.AnalysisFail; } +/// Given an ErrorMsg, modify its message and source location to the given values, turning the +/// original message into a note. Notes on the original message are preserved as further notes. +/// Reference trace is preserved. +fn reparentOwnedErrorMsg( + sema: *Sema, + block: *Block, + src: LazySrcLoc, + msg: *Module.ErrorMsg, + comptime format: []const u8, + args: anytype, +) !void { + const mod = sema.mod; + const src_decl = mod.declPtr(block.src_decl); + const resolved_src = src.toSrcLoc(src_decl, mod); + const msg_str = try std.fmt.allocPrint(mod.gpa, format, args); + + const orig_notes = msg.notes.len; + msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1); + std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]); + msg.notes[0] = .{ + .src_loc = msg.src_loc, + .msg = msg.msg, + }; + + msg.src_loc = resolved_src; + msg.msg = msg_str; +} + const align_ty = Type.u29; fn analyzeAsAlign( @@ -10082,6 +10110,8 @@ const SwitchProngAnalysis = struct { operand: Air.Inst.Ref, /// May be `undefined` if no prong has a by-ref capture. operand_ptr: Air.Inst.Ref, + /// The switch condition value. For unions, `operand` is the union and `cond` is its tag. + cond: Air.Inst.Ref, /// If this switch is on an error set, this is the type to assign to the /// `else` prong. If `null`, the prong should be unreachable. else_error_ty: ?Type, @@ -10315,61 +10345,245 @@ const SwitchProngAnalysis = struct { const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?); const first_field = union_obj.fields.values()[first_field_index]; - for (case_vals[1..], 0..) |item, i| { + const field_tys = try sema.arena.alloc(Type, case_vals.len); + for (case_vals, field_tys) |item, *field_ty| { const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; - - const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?; - const field = union_obj.fields.values()[field_index]; - if (!field.ty.eql(first_field.ty, mod)) { - const msg = msg: { - const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); - - const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); - errdefer msg.destroy(sema.gpa); - - // This must be a multi-prong so this must be a `multi_capture` src - const multi_idx = raw_capture_src.multi_capture; - - const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; - const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); - const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; - const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first); - try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)}); - try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)}); - break :msg msg; - }; - return sema.failWithOwnedErrorMsg(msg); - } + const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?); + field_ty.* = union_obj.fields.values()[field_idx].ty; } + // Fast path: if all the operands are the same type already, we don't need to hit + // PTR! This will also allow us to emit simpler code. + const same_types = for (field_tys[1..]) |field_ty| { + if (!field_ty.eql(field_tys[0], sema.mod)) break false; + } else true; + + const capture_ty = if (same_types) field_tys[0] else capture_ty: { + // We need values to run PTR on, so make a bunch of undef constants. + const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len); + for (dummy_captures, field_tys) |*dummy, field_ty| { + dummy.* = try sema.addConstUndef(field_ty); + } + + const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len); + @memset(case_srcs, .unneeded); + + break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) { + error.NeededSourceLocation => { + // This must be a multi-prong so this must be a `multi_capture` src + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + for (case_srcs, 0..) |*case_src, i| { + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + case_src.* = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none); + } + const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none); + _ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) { + error.AnalysisFail => { + const msg = sema.err orelse return error.AnalysisFail; + try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{}); + return error.AnalysisFail; + }, + else => |e| return e, + }; + unreachable; + }, + else => |e| return e, + }; + }; + + // By-reference captures have some further restrictions which make them easier to emit if (capture_byref) { - const field_ty_ptr = try Type.ptr(sema.arena, mod, .{ - .pointee_type = first_field.ty, - .@"addrspace" = .generic, - .mutable = operand_ptr_ty.ptrIsMutable(mod), + const operand_ptr_info = operand_ptr_ty.ptrInfo(mod); + const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = capture_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! }); - if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { - return sema.addConstant(field_ty_ptr, (try mod.intern(.{ .ptr = .{ - .ty = field_ty_ptr.toIntern(), - .addr = .{ .field = .{ - .base = op_ptr_val.toIntern(), - .index = first_field_index, - } }, - } })).toValue()); + // By-ref captures of hetereogeneous types are only allowed if each field + // pointer type is in-memory coercible to the capture pointer type. + if (!same_types) { + for (field_tys, 0..) |field_ty, i| { + const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = field_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! + }); + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none); + const msg = msg: { + const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); + errdefer msg.destroy(sema.gpa); + try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{ + field_ty.fmt(sema.mod), + capture_ty.fmt(sema.mod), + }); + try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{}); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } + } } + + if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { + if (op_ptr_val.isUndef(mod)) return sema.addConstUndef(capture_ptr_ty); + return sema.addConstant( + capture_ptr_ty, + (try mod.intern(.{ .ptr = .{ + .ty = capture_ptr_ty.toIntern(), + .addr = .{ .field = .{ + .base = op_ptr_val.toIntern(), + .index = first_field_index, + } }, + } })).toValue(), + ); + } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr); + return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty); } if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| { - return sema.addConstant( - first_field.ty, - mod.intern_pool.indexToKey(operand_val.toIntern()).un.val.toValue(), - ); + if (operand_val.isUndef(mod)) return sema.addConstUndef(capture_ty); + const union_val = mod.intern_pool.indexToKey(operand_val.toIntern()).un; + if (union_val.tag.toValue().isUndef(mod)) return sema.addConstUndef(capture_ty); + const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag.toValue(), sema.mod).?); + const field_ty = union_obj.fields.values()[active_field_idx].ty; + const uncoerced = try sema.addConstant(field_ty, union_val.val.toValue()); + return sema.coerce(block, capture_ty, uncoerced, operand_src); } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + + if (same_types) { + return block.addStructFieldVal(spa.operand, first_field_index, capture_ty); + } + + // We may have to emit a switch block which coerces the operand to the capture type. + // If we can, try to avoid that using in-memory coercions. + const first_non_imc = in_mem: { + for (field_tys, 0..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + break :in_mem i; + } + } + // All fields are in-memory coercible to the resolved type! + // Just take the first field and bitcast the result. + const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + return block.addBitCast(capture_ty, uncoerced); + }; + + // By-val capture with heterogeneous types which are not all in-memory coercible to + // the resolved capture type. We finally have to fall back to the ugly method. + + // However, let's first track which operands are in-memory coercible. There may well + // be several, and we can squash all of these cases into the same switch prong using + // a simple bitcast. We'll make this the 'else' prong. + + var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len); + in_mem_coercible.unset(first_non_imc); + { + const next = first_non_imc + 1; + for (field_tys[next..], next..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + in_mem_coercible.unset(i); + } + } + } + + const capture_block_inst = try block.addInstAsIndex(.{ + .tag = .block, + .data = .{ + .ty_pl = .{ + .ty = try sema.addType(capture_ty), + .payload = undefined, // updated below + }, + }, + }); + + const prong_count = field_tys.len - in_mem_coercible.count(); + + const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts + var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra); + defer cases_extra.deinit(); + + { + // Non-bitcast cases + var it = in_mem_coercible.iterator(.{ .kind = .unset }); + while (it.next()) |idx| { + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]); + const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) { + error.NeededSourceLocation => { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } }; + const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none); + _ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src); + unreachable; + }, + else => |e| return e, + }; + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len); + cases_extra.appendAssumeCapacity(1); // items_len + cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len + cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item + cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body + } + } + const else_body_len = len: { + // 'else' prong uses a bitcast + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const first_imc = in_mem_coercible.findFirstSet().?; + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]); + const coerced = try coerce_block.addBitCast(capture_ty, uncoerced); + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.appendSlice(coerce_block.instructions.items); + break :len coerce_block.instructions.items.len; + }; + + try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len + + cases_extra.items.len + + @typeInfo(Air.Block).Struct.fields.len + + 1); + + const switch_br_inst = @intCast(u32, sema.air_instructions.len); + try sema.air_instructions.append(sema.gpa, .{ + .tag = .switch_br, + .data = .{ .pl_op = .{ + .operand = spa.cond, + .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ + .cases_len = @intCast(u32, prong_count), + .else_body_len = @intCast(u32, else_body_len), + }), + } }, + }); + sema.air_extra.appendSliceAssumeCapacity(cases_extra.items); + + // Set up block body + sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = 1, + }); + sema.air_extra.appendAssumeCapacity(switch_br_inst); + + return Air.indexToRef(capture_block_inst); }, .ErrorSet => { if (capture_byref) { @@ -11099,6 +11313,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .parent_block = block, .operand = raw_operand.val, .operand_ptr = raw_operand.ptr, + .cond = operand, .else_error_ty = else_error_ty, .switch_block_inst = inst, .tag_capture_inst = tag_capture_inst, diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index 3f6cd37298..72a36c9883 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const assert = std.debug.assert; const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; @@ -717,3 +718,70 @@ test "comptime inline switch" { try expectEqual(u32, value); } + +test "switch capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const U = union(enum) { + a: u32, + b: u64, + fn innerVal(u: @This()) u64 { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(u64, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(u64, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch capture peer type resolution for in-memory coercible payloads" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: @This()) c_int { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(c_int, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(c_int, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch pointer capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: *@This()) *c_int { + switch (u.*) { + .a, .b => |*ptr| return ptr, + } + } + }; + + var ua: U = .{ .a = 100 }; + var ub: U = .{ .b = 200 }; + + ua.innerVal().* = 111; + ub.innerVal().* = 222; + + try expectEqual(U{ .a = 111 }, ua); + try expectEqual(U{ .b = 222 }, ub); +} diff --git a/test/cases/compile_errors/switch_capture_incompatible_types.zig b/test/cases/compile_errors/switch_capture_incompatible_types.zig new file mode 100644 index 0000000000..b6de7d5bf5 --- /dev/null +++ b/test/cases/compile_errors/switch_capture_incompatible_types.zig @@ -0,0 +1,27 @@ +export fn f() void { + const U = union(enum) { a: u32, b: *u8 }; + var u: U = undefined; + switch (u) { + .a, .b => |val| _ = val, + } +} + +export fn g() void { + const U = union(enum) { a: u64, b: u32 }; + var u: U = undefined; + switch (u) { + .a, .b => |*ptr| _ = ptr, + } +} + +// error +// backend=stage2 +// target=native +// +// :5:20: error: capture group with incompatible types +// :5:20: note: incompatible types: 'u32' and '*u8' +// :5:10: note: type 'u32' here +// :5:14: note: type '*u8' here +// :13:20: error: capture group with incompatible types +// :13:14: note: pointer type child 'u32' cannot cast into resolved pointer type child 'u64' +// :13:20: note: this coercion is only possible when capturing by value From 42dc7539c5b0a39e9b64c5ad92757945b0ca05ad Mon Sep 17 00:00:00 2001 From: mlugg Date: Sun, 28 May 2023 04:31:56 +0100 Subject: [PATCH 8/8] Fix bad source locations in switch capture errors To do this, I expanded SwitchProngSrc a bit. Several of the tags there aren't actually used by any current errors, but they're there for consistency and if we ever need them. Also delete a now-redundant test and fix another. --- src/Module.zig | 142 ++++++++++++------ src/Sema.zig | 32 ++-- ..._prong_with_incompatible_payload_types.zig | 21 --- .../switch_on_union_with_no_attached_enum.zig | 4 +- 4 files changed, 120 insertions(+), 79 deletions(-) delete mode 100644 test/cases/compile_errors/capture_group_on_switch_prong_with_incompatible_payload_types.zig diff --git a/src/Module.zig b/src/Module.zig index 73ab1c3277..d60f3919a5 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -2471,12 +2471,23 @@ pub const SrcLoc = struct { } } else unreachable; }, - .node_offset_switch_prong_capture => |node_off| { + .node_offset_switch_prong_capture, + .node_offset_switch_prong_tag_capture, + => |node_off| { const tree = try src_loc.file_scope.getTree(gpa); const case_node = src_loc.declRelativeToNodeIndex(node_off); const case = tree.fullSwitchCase(case_node).?; - const start_tok = case.payload_token.?; const token_tags = tree.tokens.items(.tag); + const start_tok = switch (src_loc.lazy) { + .node_offset_switch_prong_capture => case.payload_token.?, + .node_offset_switch_prong_tag_capture => blk: { + var tok = case.payload_token.?; + if (token_tags[tok] == .asterisk) tok += 1; + tok += 2; // skip over comma + break :blk tok; + }, + else => unreachable, + }; const end_tok = switch (token_tags[start_tok]) { .asterisk => start_tok + 1, else => start_tok, @@ -2957,6 +2968,9 @@ pub const LazySrcLoc = union(enum) { /// The source location points to the capture of a switch_prong. /// The Decl is determined contextually. node_offset_switch_prong_capture: i32, + /// The source location points to the tag capture of a switch_prong. + /// The Decl is determined contextually. + node_offset_switch_prong_tag_capture: i32, /// The source location points to the align expr of a function type /// expression, found by taking this AST node index offset from the containing /// Decl AST node, which points to a function type AST node. Next, navigate to @@ -3130,6 +3144,7 @@ pub const LazySrcLoc = union(enum) { .node_offset_switch_special_prong, .node_offset_switch_range, .node_offset_switch_prong_capture, + .node_offset_switch_prong_tag_capture, .node_offset_fn_type_align, .node_offset_fn_type_addrspace, .node_offset_fn_type_section, @@ -5867,11 +5882,26 @@ fn lockAndClearFileCompileError(mod: *Module, file: *File) void { } pub const SwitchProngSrc = union(enum) { + /// The item for a scalar prong. scalar: u32, + /// A given single item for a multi prong. multi: Multi, + /// A given range item for a multi prong. range: Multi, - multi_capture: u32, + /// The item for the special prong. special, + /// The main capture for a scalar prong. + scalar_capture: u32, + /// The main capture for a multi prong. + multi_capture: u32, + /// The main capture for the special prong. + special_capture, + /// The tag capture for a scalar prong. + scalar_tag_capture: u32, + /// The tag capture for a multi prong. + multi_tag_capture: u32, + /// The tag capture for the special prong. + special_tag_capture, pub const Multi = struct { prong: u32, @@ -5887,6 +5917,7 @@ pub const SwitchProngSrc = union(enum) { mod: *Module, decl: *Decl, switch_node_offset: i32, + /// Ignored if `prong_src` is not `.range` range_expand: RangeExpand, ) LazySrcLoc { @setCold(true); @@ -5907,7 +5938,7 @@ pub const SwitchProngSrc = union(enum) { var multi_i: u32 = 0; var scalar_i: u32 = 0; - for (case_nodes) |case_node| { + const case_node = for (case_nodes) |case_node| { const case = tree.fullSwitchCase(case_node).?; const is_special = special: { @@ -5919,60 +5950,85 @@ pub const SwitchProngSrc = union(enum) { }; if (is_special) { - if (prong_src != .special) continue; - return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(case.ast.values[0]), - ); + switch (prong_src) { + .special, .special_capture, .special_tag_capture => break case_node, + else => continue, + } } const is_multi = case.ast.values.len != 1 or node_tags[case.ast.values[0]] == .switch_range; switch (prong_src) { - .scalar => |i| if (!is_multi and i == scalar_i) return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(case.ast.values[0]), - ), - .multi_capture => |i| if (is_multi and i == multi_i) { - return LazySrcLoc{ .node_offset_switch_prong_capture = decl.nodeIndexToRelative(case_node) }; - }, - .multi => |s| if (is_multi and s.prong == multi_i) { - var item_i: u32 = 0; - for (case.ast.values) |item_node| { - if (node_tags[item_node] == .switch_range) continue; + .scalar, + .scalar_capture, + .scalar_tag_capture, + => |i| if (!is_multi and i == scalar_i) break case_node, - if (item_i == s.item) return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(item_node), - ); - item_i += 1; - } else unreachable; - }, - .range => |s| if (is_multi and s.prong == multi_i) { - var range_i: u32 = 0; - for (case.ast.values) |range| { - if (node_tags[range] != .switch_range) continue; + .multi_capture, + .multi_tag_capture, + => |i| if (is_multi and i == multi_i) break case_node, - if (range_i == s.item) switch (range_expand) { - .none => return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(range), - ), - .first => return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(node_datas[range].lhs), - ), - .last => return LazySrcLoc.nodeOffset( - decl.nodeIndexToRelative(node_datas[range].rhs), - ), - }; - range_i += 1; - } else unreachable; - }, - .special => {}, + .multi, + .range, + => |m| if (is_multi and m.prong == multi_i) break case_node, + + .special, + .special_capture, + .special_tag_capture, + => {}, } + if (is_multi) { multi_i += 1; } else { scalar_i += 1; } } else unreachable; + + const case = tree.fullSwitchCase(case_node).?; + + switch (prong_src) { + .scalar, .special => return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(case.ast.values[0]), + ), + .multi => |m| { + var item_i: u32 = 0; + for (case.ast.values) |item_node| { + if (node_tags[item_node] == .switch_range) continue; + if (item_i == m.item) return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(item_node), + ); + item_i += 1; + } + unreachable; + }, + .range => |m| { + var range_i: u32 = 0; + for (case.ast.values) |range| { + if (node_tags[range] != .switch_range) continue; + if (range_i == m.item) switch (range_expand) { + .none => return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(range), + ), + .first => return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(node_datas[range].lhs), + ), + .last => return LazySrcLoc.nodeOffset( + decl.nodeIndexToRelative(node_datas[range].rhs), + ), + }; + range_i += 1; + } + unreachable; + }, + .scalar_capture, .multi_capture, .special_capture => { + return .{ .node_offset_switch_prong_capture = decl.nodeIndexToRelative(case_node) }; + }, + .scalar_tag_capture, .multi_tag_capture, .special_tag_capture => { + return .{ .node_offset_switch_prong_tag_capture = decl.nodeIndexToRelative(case_node) }; + }, + } } }; diff --git a/src/Sema.zig b/src/Sema.zig index 51acda3810..7c4968b0e5 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -10129,7 +10129,7 @@ const SwitchProngAnalysis = struct { prong_type: enum { normal, special }, prong_body: []const Zir.Inst.Index, capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, - /// Must use the `scalar`, `special`, or `multi_capture` union field. + /// Must use the `scalar_capture`, `special_capture`, or `multi_capture` union field. raw_capture_src: Module.SwitchProngSrc, /// The set of all values which can reach this prong. May be undefined /// if the prong is special or contains ranges. @@ -10247,7 +10247,13 @@ const SwitchProngAnalysis = struct { if (operand_ty.zigTypeTag(mod) != .Union) { const zir_datas = sema.code.instructions.items(.data); const switch_node_offset = zir_datas[spa.switch_block_inst].pl_node.src_node; - const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); + const raw_tag_capture_src: Module.SwitchProngSrc = switch (raw_capture_src) { + .scalar_capture => |i| .{ .scalar_tag_capture = i }, + .multi_capture => |i| .{ .multi_tag_capture = i }, + .special_capture => .special_tag_capture, + else => unreachable, + }; + const capture_src = raw_tag_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none); const msg = msg: { const msg = try sema.errMsg(block, capture_src, "cannot capture tag of non-union type '{}'", .{ operand_ty.fmt(mod), @@ -10712,7 +10718,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r } }; - const operand = try sema.switchCond(block, src, raw_operand.val); + const operand = try sema.switchCond(block, operand_src, raw_operand.val); // AstGen guarantees that the instruction immediately preceding // switch_block(_ref) is a dbg_stmt @@ -11377,7 +11383,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .normal, body, info.capture, - .{ .scalar = @intCast(u32, scalar_i) }, + .{ .scalar_capture = @intCast(u32, scalar_i) }, &.{item}, if (info.is_inline) operand else .none, info.has_tag_capture, @@ -11460,7 +11466,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, undefined, // case_vals may be undefined for special prongs if (special.is_inline) operand else .none, special.has_tag_capture, @@ -11491,7 +11497,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, undefined, // case_vals may be undefined for special prongs .none, false, @@ -11551,7 +11557,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .normal, body, info.capture, - .{ .scalar = @intCast(u32, scalar_i) }, + .{ .scalar_capture = @intCast(u32, scalar_i) }, &.{item}, if (info.is_inline) item else .none, info.has_tag_capture, @@ -11885,7 +11891,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, &.{item_ref}, item_ref, special.has_tag_capture, @@ -11929,7 +11935,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, &.{item_ref}, item_ref, special.has_tag_capture, @@ -11960,7 +11966,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, &.{item_ref}, item_ref, special.has_tag_capture, @@ -11988,7 +11994,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, &.{Air.Inst.Ref.bool_true}, Air.Inst.Ref.bool_true, special.has_tag_capture, @@ -12014,7 +12020,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, &.{Air.Inst.Ref.bool_false}, Air.Inst.Ref.bool_false, special.has_tag_capture, @@ -12065,7 +12071,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .special, special.body, special.capture, - .special, + .special_capture, undefined, // case_vals may be undefined for special prongs .none, false, diff --git a/test/cases/compile_errors/capture_group_on_switch_prong_with_incompatible_payload_types.zig b/test/cases/compile_errors/capture_group_on_switch_prong_with_incompatible_payload_types.zig deleted file mode 100644 index cff9a58bc6..0000000000 --- a/test/cases/compile_errors/capture_group_on_switch_prong_with_incompatible_payload_types.zig +++ /dev/null @@ -1,21 +0,0 @@ -const Union = union(enum) { - A: usize, - B: isize, -}; -comptime { - var u = Union{ .A = 8 }; - switch (u) { - .A, .B => |e| { - _ = e; - unreachable; - }, - } -} - -// error -// backend=stage2 -// target=native -// -// :8:20: error: capture group with incompatible types -// :8:10: note: type 'usize' here -// :8:14: note: type 'isize' here diff --git a/test/cases/compile_errors/switch_on_union_with_no_attached_enum.zig b/test/cases/compile_errors/switch_on_union_with_no_attached_enum.zig index 4d8742d32e..d9bc2abb91 100644 --- a/test/cases/compile_errors/switch_on_union_with_no_attached_enum.zig +++ b/test/cases/compile_errors/switch_on_union_with_no_attached_enum.zig @@ -4,12 +4,12 @@ const Payload = union { C: bool, }; export fn entry() void { - const a = Payload { .A = 1234 }; + const a = Payload{ .A = 1234 }; foo(&a); } fn foo(a: *const Payload) void { switch (a.*) { - Payload.A => {}, + .A => {}, else => unreachable, } }