diff --git a/doc/langref.html.in b/doc/langref.html.in
index 818e0b5fe4..28d8745a81 100644
--- a/doc/langref.html.in
+++ b/doc/langref.html.in
@@ -3016,6 +3016,7 @@ test "switch on tagged union" {
A: u32,
C: Point,
D,
+ E: u32,
};
var a = Item{ .C = Point{ .x = 1, .y = 2 } };
@@ -3023,8 +3024,9 @@ test "switch on tagged union" {
// Switching on more complex enums is allowed.
const b = switch (a) {
// A capture group is allowed on a match, and will return the enum
- // value matched.
- Item.A => |item| item,
+ // value matched. If the payload types of both cases are the same
+ // they can be put into the same switch prong.
+ Item.A, Item.E => |item| item,
// A reference to the matched value can be obtained using `*` syntax.
Item.C => |*item| blk: {
diff --git a/src/ir.cpp b/src/ir.cpp
index 6b19ce2909..23035fa66d 100644
--- a/src/ir.cpp
+++ b/src/ir.cpp
@@ -19229,24 +19229,53 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
ZigType *enum_type = target_type->data.unionation.tag_type;
assert(enum_type != nullptr);
assert(enum_type->id == ZigTypeIdEnum);
+ assert(instruction->prongs_len > 0);
- if (instruction->prongs_len != 1) {
- return target_value_ptr;
+ IrInstruction *first_prong_value = instruction->prongs_ptr[0]->child;
+ if (type_is_invalid(first_prong_value->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *first_casted_prong_value = ir_implicit_cast(ira, first_prong_value, enum_type);
+ if (type_is_invalid(first_casted_prong_value->value.type))
+ return ira->codegen->invalid_instruction;
+
+ ConstExprValue *first_prong_val = ir_resolve_const(ira, first_casted_prong_value, UndefBad);
+ if (first_prong_val == nullptr)
+ return ira->codegen->invalid_instruction;
+
+ TypeUnionField *first_field = find_union_field_by_tag(target_type, &first_prong_val->data.x_enum_tag);
+
+ ErrorMsg *invalid_payload_msg = nullptr;
+ for (size_t prong_i = 1; prong_i < instruction->prongs_len; prong_i += 1) {
+ IrInstruction *this_prong_inst = instruction->prongs_ptr[prong_i]->child;
+ if (type_is_invalid(this_prong_inst->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *this_casted_prong_value = ir_implicit_cast(ira, this_prong_inst, enum_type);
+ if (type_is_invalid(this_casted_prong_value->value.type))
+ return ira->codegen->invalid_instruction;
+
+ ConstExprValue *this_prong = ir_resolve_const(ira, this_casted_prong_value, UndefBad);
+ if (this_prong == nullptr)
+ return ira->codegen->invalid_instruction;
+
+ TypeUnionField *payload_field = find_union_field_by_tag(target_type, &this_prong->data.x_enum_tag);
+ ZigType *payload_type = payload_field->type_entry;
+ if (first_field->type_entry != payload_type) {
+ if (invalid_payload_msg == nullptr) {
+ invalid_payload_msg = ir_add_error(ira, &instruction->base,
+ buf_sprintf("capture group with incompatible types"));
+ add_error_note(ira->codegen, invalid_payload_msg, first_prong_value->source_node,
+ buf_sprintf("type '%s' here", buf_ptr(&first_field->type_entry->name)));
+ }
+ add_error_note(ira->codegen, invalid_payload_msg, this_prong_inst->source_node,
+ buf_sprintf("type '%s' here", buf_ptr(&payload_field->type_entry->name)));
+ }
}
- IrInstruction *prong_value = instruction->prongs_ptr[0]->child;
- if (type_is_invalid(prong_value->value.type))
+ if (invalid_payload_msg != nullptr) {
return ira->codegen->invalid_instruction;
-
- IrInstruction *casted_prong_value = ir_implicit_cast(ira, prong_value, enum_type);
- if (type_is_invalid(casted_prong_value->value.type))
- return ira->codegen->invalid_instruction;
-
- ConstExprValue *prong_val = ir_resolve_const(ira, casted_prong_value, UndefBad);
- if (!prong_val)
- return ira->codegen->invalid_instruction;
-
- TypeUnionField *field = find_union_field_by_tag(target_type, &prong_val->data.x_enum_tag);
+ }
if (instr_is_comptime(target_value_ptr)) {
ConstExprValue *target_val_ptr = ir_resolve_const(ira, target_value_ptr, UndefBad);
@@ -19258,7 +19287,7 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
return ira->codegen->invalid_instruction;
IrInstruction *result = ir_const(ira, &instruction->base,
- get_pointer_to_type(ira->codegen, field->type_entry,
+ get_pointer_to_type(ira->codegen, first_field->type_entry,
target_val_ptr->type->data.pointer.is_const));
ConstExprValue *out_val = &result->value;
out_val->data.x_ptr.special = ConstPtrSpecialRef;
@@ -19268,8 +19297,8 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
}
IrInstruction *result = ir_build_union_field_ptr(&ira->new_irb,
- instruction->base.scope, instruction->base.source_node, target_value_ptr, field, false, false);
- result->value.type = get_pointer_to_type(ira->codegen, field->type_entry,
+ instruction->base.scope, instruction->base.source_node, target_value_ptr, first_field, false, false);
+ result->value.type = get_pointer_to_type(ira->codegen, first_field->type_entry,
target_value_ptr->value.type->data.pointer.is_const);
return result;
} else if (target_type->id == ZigTypeIdErrorSet) {
@@ -22977,11 +23006,11 @@ static IrInstruction *ir_analyze_instruction_mul_add(IrAnalyze *ira, IrInstructi
IrInstruction *type_value = instruction->type_value->child;
if (type_is_invalid(type_value->value.type))
return ira->codegen->invalid_instruction;
-
+
ZigType *expr_type = ir_resolve_type(ira, type_value);
if (type_is_invalid(expr_type))
return ira->codegen->invalid_instruction;
-
+
// Only allow float types, and vectors of floats.
ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
if (float_type->id != ZigTypeIdFloat) {
@@ -25082,7 +25111,7 @@ static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstruct
IrInstruction *type = instruction->type->child;
if (type_is_invalid(type->value.type))
return ira->codegen->invalid_instruction;
-
+
ZigType *expr_type = ir_resolve_type(ira, type);
if (type_is_invalid(expr_type))
return ira->codegen->invalid_instruction;
diff --git a/test/compile_errors.zig b/test/compile_errors.zig
index c411ba46f6..df4e38583c 100644
--- a/test/compile_errors.zig
+++ b/test/compile_errors.zig
@@ -2,6 +2,24 @@ const tests = @import("tests.zig");
const builtin = @import("builtin");
pub fn addCases(cases: *tests.CompileErrorContext) void {
+ cases.add(
+ "capture group on switch prong with incompatible payload types",
+ \\const Union = union(enum) {
+ \\ A: usize,
+ \\ B: isize,
+ \\};
+ \\comptime {
+ \\ var u = Union{ .A = 8 };
+ \\ switch (u) {
+ \\ .A, .B => |e| unreachable,
+ \\ }
+ \\}
+ ,
+ "tmp.zig:8:20: error: capture group with incompatible types",
+ "tmp.zig:8:9: note: type 'usize' here",
+ "tmp.zig:8:13: note: type 'isize' here",
+ );
+
cases.add(
"wrong type to @hasField",
\\export fn entry() bool {
diff --git a/test/stage1/behavior/switch.zig b/test/stage1/behavior/switch.zig
index 12e026d0ba..806f51b28e 100644
--- a/test/stage1/behavior/switch.zig
+++ b/test/stage1/behavior/switch.zig
@@ -391,3 +391,37 @@ test "switch with null and T peer types and inferred result location type" {
S.doTheTest(1);
comptime S.doTheTest(1);
}
+
+test "switch prongs with cases with identical payload types" {
+ const Union = union(enum) {
+ A: usize,
+ B: isize,
+ C: usize,
+ };
+ const S = struct {
+ fn doTheTest() void {
+ doTheSwitch1(Union{ .A = 8 });
+ doTheSwitch2(Union{ .B = -8 });
+ }
+ fn doTheSwitch1(u: Union) void {
+ switch (u) {
+ .A, .C => |e| {
+ expect(@typeOf(e) == usize);
+ expect(e == 8);
+ },
+ .B => |e| @panic("fail"),
+ }
+ }
+ fn doTheSwitch2(u: Union) void {
+ switch (u) {
+ .A, .C => |e| @panic("fail"),
+ .B => |e| {
+ expect(@typeOf(e) == isize);
+ expect(e == -8);
+ },
+ }
+ }
+ };
+ S.doTheTest();
+ comptime S.doTheTest();
+}