diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index bfd2c2da30..b6ccbec0b7 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -887,11 +887,13 @@ pub const DeclGen = struct { .breakpoint => return, .cond_br => return self.airCondBr(inst), .constant => unreachable, + .const_ty => unreachable, .dbg_stmt => return self.airDbgStmt(inst), .loop => return self.airLoop(inst), .ret => return self.airRet(inst), .ret_load => return self.airRetLoad(inst), .store => return self.airStore(inst), + .switch_br => return self.airSwitchBr(inst), .unreach => return self.airUnreach(), .assembly => try self.airAssembly(inst), @@ -1679,6 +1681,121 @@ pub const DeclGen = struct { }); } + fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void { + const target = self.getTarget(); + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const cond = try self.resolve(pl_op.operand); + const cond_ty = self.air.typeOf(pl_op.operand); + const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload); + + const cond_words: u32 = switch (cond_ty.zigTypeTag()) { + .Int => blk: { + const bits = cond_ty.intInfo(target).bits; + const backing_bits = self.backingIntBits(bits) orelse { + return self.todo("implement composite int switch", .{}); + }; + break :blk if (backing_bits <= 32) 1 else 2; + }, + .Enum => blk: { + var buffer: Type.Payload.Bits = undefined; + const int_ty = cond_ty.intTagType(&buffer); + const int_info = int_ty.intInfo(target); + const backing_bits = self.backingIntBits(int_info.bits) orelse { + return self.todo("implement composite int switch", .{}); + }; + break :blk if (backing_bits <= 32) 1 else 2; + }, + else => return self.todo("implement switch for type {s}", .{@tagName(cond_ty.zigTypeTag())}), // TODO: Figure out which types apply here, and work around them as we can only do integers. + }; + + const num_cases = switch_br.data.cases_len; + + // Compute the total number of arms that we need. + // Zig switches are grouped by condition, so we need to loop through all of them + const num_conditions = blk: { + var extra_index: usize = switch_br.end; + var case_i: u32 = 0; + var num_conditions: u32 = 0; + while (case_i < num_cases) : (case_i += 1) { + const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + const case_body = self.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]; + extra_index = case.end + case.data.items_len + case_body.len; + num_conditions += case.data.items_len; + } + break :blk num_conditions; + }; + + // First, pre-allocate the labels for the cases. + const first_case_label = self.spv.allocIds(num_cases); + // We always need the default case - if zig has none, we will generate unreachable there. + const default = self.spv.allocId(); + + // Emit the instruction before generating the blocks. + try self.func.body.emitRaw(self.spv.gpa, .OpSwitch, 2 + (cond_words + 1) * num_conditions); + self.func.body.writeOperand(IdRef, cond); + self.func.body.writeOperand(IdRef, default.toRef()); + + // Emit each of the cases + { + var extra_index: usize = switch_br.end; + var case_i: u32 = 0; + while (case_i < num_cases) : (case_i += 1) { + // SPIR-V needs a literal here, which' width depends on the case condition. + const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]); + const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; + extra_index = case.end + case.data.items_len + case_body.len; + + const label = IdRef{ .id = first_case_label.id + case_i }; + + for (items) |item| { + const value = self.air.value(item) orelse { + return self.todo("switch on runtime value???", .{}); + }; + const int_val = switch (cond_ty.zigTypeTag()) { + .Int => if (cond_ty.isSignedInt()) @bitCast(u64, value.toSignedInt()) else value.toUnsignedInt(target), + .Enum => blk: { + var int_buffer: Value.Payload.U64 = undefined; + // TODO: figure out of cond_ty is correct (something with enum literals) + break :blk value.enumToInt(cond_ty, &int_buffer).toUnsignedInt(target); // TODO: composite integer constants + }, + else => unreachable, + }; + const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) { + 1 => .{ .uint32 = @intCast(u32, int_val) }, + 2 => .{ .uint64 = int_val }, + else => unreachable, + }; + self.func.body.writeOperand(spec.LiteralContextDependentNumber, int_lit); + self.func.body.writeOperand(IdRef, label); + } + } + } + + // Now, finally, we can start emitting each of the cases. + var extra_index: usize = switch_br.end; + var case_i: u32 = 0; + while (case_i < num_cases) : (case_i += 1) { + const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]); + const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; + extra_index = case.end + case.data.items_len + case_body.len; + + const label = IdResult{ .id = first_case_label.id + case_i }; + + try self.beginSpvBlock(label); + try self.genBody(case_body); + } + + const else_body = self.air.extra[extra_index..][0..switch_br.data.else_body_len]; + try self.beginSpvBlock(default); + if (else_body.len != 0) { + try self.genBody(else_body); + } else { + try self.func.body.emit(self.spv.gpa, .OpUnreachable, {}); + } + } + fn airUnreach(self: *DeclGen) !void { try self.func.body.emit(self.spv.gpa, .OpUnreachable, {}); } diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index f6c4cd735e..ab9d0588ca 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -132,6 +132,11 @@ pub fn allocId(self: *Module) spec.IdResult { return .{ .id = self.next_result_id }; } +pub fn allocIds(self: *Module, n: u32) spec.IdResult { + defer self.next_result_id += n; + return .{ .id = self.next_result_id }; +} + pub fn idBound(self: Module) Word { return self.next_result_id; }