diff --git a/lib/std/zig/parse.zig b/lib/std/zig/parse.zig index 04b45ad20d..db56cef21e 100644 --- a/lib/std/zig/parse.zig +++ b/lib/std/zig/parse.zig @@ -3100,7 +3100,7 @@ const Parser = struct { return identifier; } - /// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr + /// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr /// SwitchCase /// <- SwitchItem (COMMA SwitchItem)* COMMA? /// / KEYWORD_else @@ -3123,7 +3123,7 @@ const Parser = struct { } } const arrow_token = try p.expectToken(.equal_angle_bracket_right); - _ = try p.parsePtrPayload(); + _ = try p.parsePtrIndexPayload(); const items = p.scratch.items[scratch_top..]; switch (items.len) { diff --git a/lib/std/zig/parser_test.zig b/lib/std/zig/parser_test.zig index 0e1817ffab..4e155df6d8 100644 --- a/lib/std/zig/parser_test.zig +++ b/lib/std/zig/parser_test.zig @@ -3276,6 +3276,8 @@ test "zig fmt: switch" { \\ switch (u) { \\ Union.Int => |int| {}, \\ Union.Float => |*float| unreachable, + \\ 1 => |a, b| unreachable, + \\ 2 => |*a, b| unreachable, \\ } \\} \\ diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index db4a092b2a..ab009f8390 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -1541,13 +1541,17 @@ fn renderSwitchCase( if (switch_case.payload_token) |payload_token| { try renderToken(ais, tree, payload_token - 1, .none); // pipe + const ident = payload_token + @boolToInt(token_tags[payload_token] == .asterisk); if (token_tags[payload_token] == .asterisk) { try renderToken(ais, tree, payload_token, .none); // asterisk - try renderToken(ais, tree, payload_token + 1, .none); // identifier - try renderToken(ais, tree, payload_token + 2, pre_target_space); // pipe + } + try renderToken(ais, tree, ident, .none); // identifier + if (token_tags[ident + 1] == .comma) { + try renderToken(ais, tree, ident + 1, .space); // , + try renderToken(ais, tree, ident + 2, .none); // identifier + try renderToken(ais, tree, ident + 3, pre_target_space); // pipe } else { - try renderToken(ais, tree, payload_token, .none); // identifier - try renderToken(ais, tree, payload_token + 1, pre_target_space); // pipe + try renderToken(ais, tree, ident + 1, pre_target_space); // pipe } } diff --git a/src/AstGen.zig b/src/AstGen.zig index 06557f900b..1920ccacfb 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -2373,6 +2373,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .switch_capture_ref, .switch_capture_multi, .switch_capture_multi_ref, + .switch_capture_tag, .struct_init_empty, .struct_init, .struct_init_ref, @@ -6378,8 +6379,12 @@ fn switchExpr( var dbg_var_name: ?u32 = null; var dbg_var_inst: Zir.Inst.Ref = undefined; + var dbg_var_tag_name: ?u32 = null; + var dbg_var_tag_inst: Zir.Inst.Ref = undefined; var capture_inst: Zir.Inst.Index = 0; + var tag_inst: Zir.Inst.Index = 0; var capture_val_scope: Scope.LocalVal = undefined; + var tag_scope: Scope.LocalVal = undefined; const sub_scope = blk: { const payload_token = case.payload_token orelse break :blk &case_scope.base; const ident = if (token_tags[payload_token] == .asterisk) @@ -6387,59 +6392,96 @@ fn switchExpr( else payload_token; const is_ptr = ident != payload_token; - if (mem.eql(u8, tree.tokenSlice(ident), "_")) { + const ident_slice = tree.tokenSlice(ident); + var payload_sub_scope: *Scope = undefined; + if (mem.eql(u8, ident_slice, "_")) { if (is_ptr) { return astgen.failTok(payload_token, "pointer modifier invalid on discard", .{}); } - break :blk &case_scope.base; - } - if (case_node == special_node) { - const capture_tag: Zir.Inst.Tag = if (is_ptr) - .switch_capture_ref - else - .switch_capture; - capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); - try astgen.instructions.append(gpa, .{ - .tag = capture_tag, - .data = .{ - .switch_capture = .{ - .switch_inst = switch_block, - // Max int communicates that this is the else/underscore prong. - .prong_index = std.math.maxInt(u32), - }, - }, - }); + payload_sub_scope = &case_scope.base; } else { - const is_multi_case_bits: u2 = @boolToInt(is_multi_case); - const is_ptr_bits: u2 = @boolToInt(is_ptr); - const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) { - 0b00 => .switch_capture, - 0b01 => .switch_capture_ref, - 0b10 => .switch_capture_multi, - 0b11 => .switch_capture_multi_ref, + if (case_node == special_node) { + const capture_tag: Zir.Inst.Tag = if (is_ptr) + .switch_capture_ref + else + .switch_capture; + capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); + try astgen.instructions.append(gpa, .{ + .tag = capture_tag, + .data = .{ + .switch_capture = .{ + .switch_inst = switch_block, + // Max int communicates that this is the else/underscore prong. + .prong_index = std.math.maxInt(u32), + }, + }, + }); + } else { + const is_multi_case_bits: u2 = @boolToInt(is_multi_case); + const is_ptr_bits: u2 = @boolToInt(is_ptr); + const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) { + 0b00 => .switch_capture, + 0b01 => .switch_capture_ref, + 0b10 => .switch_capture_multi, + 0b11 => .switch_capture_multi_ref, + }; + const capture_index = if (is_multi_case) multi_case_index else scalar_case_index; + capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); + try astgen.instructions.append(gpa, .{ + .tag = capture_tag, + .data = .{ .switch_capture = .{ + .switch_inst = switch_block, + .prong_index = capture_index, + } }, + }); + } + const capture_name = try astgen.identAsString(ident); + try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice); + capture_val_scope = .{ + .parent = &case_scope.base, + .gen_zir = &case_scope, + .name = capture_name, + .inst = indexToRef(capture_inst), + .token_src = payload_token, + .id_cat = .@"capture", }; - const capture_index = if (is_multi_case) multi_case_index else scalar_case_index; - capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); - try astgen.instructions.append(gpa, .{ - .tag = capture_tag, - .data = .{ .switch_capture = .{ - .switch_inst = switch_block, - .prong_index = capture_index, - } }, - }); + dbg_var_name = capture_name; + dbg_var_inst = indexToRef(capture_inst); + payload_sub_scope = &capture_val_scope.base; } - const capture_name = try astgen.identAsString(ident); - capture_val_scope = .{ - .parent = &case_scope.base, + + const tag_token = if (token_tags[ident + 1] == .comma) + ident + 2 + else + break :blk payload_sub_scope; + const tag_slice = tree.tokenSlice(tag_token); + if (mem.eql(u8, tag_slice, "_")) { + return astgen.failTok(tag_token, "discard of tag capture; omit it instead", .{}); + } else if (case.inline_token == null) { + return astgen.failTok(tag_token, "tag capture on non-inline prong", .{}); + } + const tag_name = try astgen.identAsString(tag_token); + try astgen.detectLocalShadowing(payload_sub_scope, tag_name, tag_token, tag_slice); + tag_inst = @intCast(Zir.Inst.Index, astgen.instructions.len); + try astgen.instructions.append(gpa, .{ + .tag = .switch_capture_tag, + .data = .{ .un_tok = .{ + .operand = cond, + .src_tok = case_scope.tokenIndexToRelative(tag_token), + } }, + }); + + tag_scope = .{ + .parent = payload_sub_scope, .gen_zir = &case_scope, - .name = capture_name, - .inst = indexToRef(capture_inst), - .token_src = payload_token, - .id_cat = .@"capture", + .name = tag_name, + .inst = indexToRef(tag_inst), + .token_src = tag_token, + .id_cat = .@"switch tag capture", }; - dbg_var_name = capture_name; - dbg_var_inst = indexToRef(capture_inst); - break :blk &capture_val_scope.base; + dbg_var_tag_name = tag_name; + dbg_var_tag_inst = indexToRef(tag_inst); + break :blk &tag_scope.base; }; const header_index = @intCast(u32, payloads.items.len); @@ -6494,10 +6536,14 @@ fn switchExpr( defer case_scope.unstack(); if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst); + if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst); try case_scope.addDbgBlockBegin(); if (dbg_var_name) |some| { try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_inst); } + if (dbg_var_tag_name) |some| { + try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_tag_inst); + } const case_result = try expr(&case_scope, sub_scope, block_scope.break_result_loc, case.ast.target_expr); try checkUsed(parent_gz, &case_scope.base, sub_scope); try case_scope.addDbgBlockEnd(); @@ -10073,6 +10119,7 @@ const Scope = struct { @"local constant", @"local variable", @"loop index capture", + @"switch tag capture", @"capture", }; diff --git a/src/Sema.zig b/src/Sema.zig index d27c0095ad..211135c744 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -799,6 +799,7 @@ fn analyzeBodyInner( .switch_capture_ref => try sema.zirSwitchCapture(block, inst, false, true), .switch_capture_multi => try sema.zirSwitchCapture(block, inst, true, false), .switch_capture_multi_ref => try sema.zirSwitchCapture(block, inst, true, true), + .switch_capture_tag => try sema.zirSwitchCaptureTag(block, inst), .type_info => try sema.zirTypeInfo(block, inst), .size_of => try sema.zirSizeOf(block, inst), .bit_size_of => try sema.zirBitSizeOf(block, inst), @@ -9164,6 +9165,33 @@ fn zirSwitchCapture( } } +fn zirSwitchCaptureTag(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { + const zir_datas = sema.code.instructions.items(.data); + const inst_data = zir_datas[inst].un_tok; + const src = inst_data.src(); + + const switch_tag = sema.code.instructions.items(.tag)[Zir.refToIndex(inst_data.operand).?]; + const is_ref = switch_tag == .switch_cond_ref; + const cond_data = zir_datas[Zir.refToIndex(inst_data.operand).?].un_node; + const operand_ptr = try sema.resolveInst(cond_data.operand); + const operand_ptr_ty = sema.typeOf(operand_ptr); + const operand_ty = if (is_ref) operand_ptr_ty.childType() else operand_ptr_ty; + + if (operand_ty.zigTypeTag() != .Union) { + const msg = msg: { + const msg = try sema.errMsg(block, src, "cannot capture tag of non-union type '{}'", .{ + operand_ty.fmt(sema.mod), + }); + errdefer msg.destroy(sema.gpa); + try sema.addDeclaredHereNote(msg, operand_ty); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } + + return block.inline_case_capture; +} + fn zirSwitchCond( sema: *Sema, block: *Block, diff --git a/src/Zir.zig b/src/Zir.zig index 3ce0b3c81a..add8bad801 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -683,6 +683,9 @@ pub const Inst = struct { /// Result is a pointer to the value. /// Uses the `switch_capture` field. switch_capture_multi_ref, + /// Produces the capture value for an inline switch prong tag capture. + /// Uses the `un_tok` field. + switch_capture_tag, /// Given a /// *A returns *A /// *E!A returns *A @@ -1128,6 +1131,7 @@ pub const Inst = struct { .switch_capture_ref, .switch_capture_multi, .switch_capture_multi_ref, + .switch_capture_tag, .switch_block, .switch_cond, .switch_cond_ref, @@ -1422,6 +1426,7 @@ pub const Inst = struct { .switch_capture_ref, .switch_capture_multi, .switch_capture_multi_ref, + .switch_capture_tag, .switch_block, .switch_cond, .switch_cond_ref, @@ -1681,6 +1686,7 @@ pub const Inst = struct { .switch_capture_ref = .switch_capture, .switch_capture_multi = .switch_capture, .switch_capture_multi_ref = .switch_capture, + .switch_capture_tag = .un_tok, .array_base_ptr = .un_node, .field_base_ptr = .un_node, .validate_array_init_ty = .pl_node, diff --git a/src/arch/x86_64/Emit.zig b/src/arch/x86_64/Emit.zig index 0cdc7a4c5f..c1c00d8303 100644 --- a/src/arch/x86_64/Emit.zig +++ b/src/arch/x86_64/Emit.zig @@ -2159,7 +2159,7 @@ const RegisterOrMemory = union(enum) { /// Returns size in bits. fn size(reg_or_mem: RegisterOrMemory) u64 { return switch (reg_or_mem) { - .register => |reg| reg.size(), + .register => |register| register.size(), .memory => |memory| memory.size(), }; } diff --git a/src/print_zir.zig b/src/print_zir.zig index fcd447f707..d383664c16 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -237,6 +237,7 @@ const Writer = struct { .ret_tok, .ensure_err_payload_void, .closure_capture, + .switch_capture_tag, => try self.writeUnTok(stream, inst), .bool_br_and, diff --git a/src/stage1/parser.cpp b/src/stage1/parser.cpp index b7cb6cb297..ec02e6fa8b 100644 --- a/src/stage1/parser.cpp +++ b/src/stage1/parser.cpp @@ -2306,17 +2306,17 @@ static Optional ast_parse_ptr_index_payload(ParseContext *pc) { return Optional::some(res); } -// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr +// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr static AstNode *ast_parse_switch_prong(ParseContext *pc) { AstNode *res = ast_parse_switch_case(pc); if (res == nullptr) return nullptr; expect_token(pc, TokenIdFatArrow); - Optional opt_payload = ast_parse_ptr_payload(pc); + Optional opt_payload = ast_parse_ptr_index_payload(pc); AstNode *expr = ast_expect(pc, ast_parse_assign_expr); - PtrPayload payload; + PtrIndexPayload payload; assert(res->type == NodeTypeSwitchProng); res->data.switch_prong.expr = expr; if (opt_payload.unwrap(&payload)) { diff --git a/test/behavior/inline_switch.zig b/test/behavior/inline_switch.zig index d7863f8444..11157f20fd 100644 --- a/test/behavior/inline_switch.zig +++ b/test/behavior/inline_switch.zig @@ -47,11 +47,21 @@ test "inline switch unions" { var x: U = .a; switch (x) { - inline .a, .b => |aorb| { - try expect(@TypeOf(aorb) == void or @TypeOf(aorb) == u2); + inline .a, .b => |aorb, tag| { + if (tag == .a) { + try expect(@TypeOf(aorb) == void); + } else { + try expect(tag == .b); + try expect(@TypeOf(aorb) == u2); + } }, - inline .c, .d => |cord| { - try expect(@TypeOf(cord) == u3 or @TypeOf(cord) == u4); + inline .c, .d => |cord, tag| { + if (tag == .c) { + try expect(@TypeOf(cord) == u3); + } else { + try expect(tag == .d); + try expect(@TypeOf(cord) == u4); + } }, } } diff --git a/test/cases/compile_errors/invalid_tag_capture.zig b/test/cases/compile_errors/invalid_tag_capture.zig new file mode 100644 index 0000000000..2cb9135792 --- /dev/null +++ b/test/cases/compile_errors/invalid_tag_capture.zig @@ -0,0 +1,15 @@ +const E = enum { a, b, c, d }; +pub export fn entry() void { + var x: E = .a; + switch (x) { + inline .a, .b => |aorb, d| @compileLog(aorb, d), + inline .c, .d => |*cord| @compileLog(cord), + } +} + +// error +// backend=stage2 +// target=native +// +// :5:33: error: cannot capture tag of non-union type 'tmp.E' +// :1:11: note: enum declared here diff --git a/test/cases/compile_errors/tag_capture_on_non_inline_prong.zig b/test/cases/compile_errors/tag_capture_on_non_inline_prong.zig new file mode 100644 index 0000000000..b525aa4db3 --- /dev/null +++ b/test/cases/compile_errors/tag_capture_on_non_inline_prong.zig @@ -0,0 +1,14 @@ +const E = enum { a, b, c, d }; +pub export fn entry() void { + var x: E = .a; + switch (x) { + .a, .b => |aorb, d| @compileLog(aorb, d), + inline .c, .d => |*cord| @compileLog(cord), + } +} + +// error +// backend=stage2 +// target=native +// +// :5:26: error: tag capture on non-inline prong