spirv: switch_br lowering

Implements lowering switch statements in the SPIR-V backend.
This commit is contained in:
Robin Voetter 2022-11-27 16:22:01 +01:00
parent 205d928b24
commit e443b1bed7
No known key found for this signature in database
GPG Key ID: E755662F227CB468
2 changed files with 122 additions and 0 deletions

View File

@ -887,11 +887,13 @@ pub const DeclGen = struct {
.breakpoint => return,
.cond_br => return self.airCondBr(inst),
.constant => unreachable,
.const_ty => unreachable,
.dbg_stmt => return self.airDbgStmt(inst),
.loop => return self.airLoop(inst),
.ret => return self.airRet(inst),
.ret_load => return self.airRetLoad(inst),
.store => return self.airStore(inst),
.switch_br => return self.airSwitchBr(inst),
.unreach => return self.airUnreach(),
.assembly => try self.airAssembly(inst),
@ -1679,6 +1681,121 @@ pub const DeclGen = struct {
});
}
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
const target = self.getTarget();
const pl_op = self.air.instructions.items(.data)[inst].pl_op;
const cond = try self.resolve(pl_op.operand);
const cond_ty = self.air.typeOf(pl_op.operand);
const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload);
const cond_words: u32 = switch (cond_ty.zigTypeTag()) {
.Int => blk: {
const bits = cond_ty.intInfo(target).bits;
const backing_bits = self.backingIntBits(bits) orelse {
return self.todo("implement composite int switch", .{});
};
break :blk if (backing_bits <= 32) 1 else 2;
},
.Enum => blk: {
var buffer: Type.Payload.Bits = undefined;
const int_ty = cond_ty.intTagType(&buffer);
const int_info = int_ty.intInfo(target);
const backing_bits = self.backingIntBits(int_info.bits) orelse {
return self.todo("implement composite int switch", .{});
};
break :blk if (backing_bits <= 32) 1 else 2;
},
else => return self.todo("implement switch for type {s}", .{@tagName(cond_ty.zigTypeTag())}), // TODO: Figure out which types apply here, and work around them as we can only do integers.
};
const num_cases = switch_br.data.cases_len;
// Compute the total number of arms that we need.
// Zig switches are grouped by condition, so we need to loop through all of them
const num_conditions = blk: {
var extra_index: usize = switch_br.end;
var case_i: u32 = 0;
var num_conditions: u32 = 0;
while (case_i < num_cases) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
const case_body = self.air.extra[case.end + case.data.items_len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len;
num_conditions += case.data.items_len;
}
break :blk num_conditions;
};
// First, pre-allocate the labels for the cases.
const first_case_label = self.spv.allocIds(num_cases);
// We always need the default case - if zig has none, we will generate unreachable there.
const default = self.spv.allocId();
// Emit the instruction before generating the blocks.
try self.func.body.emitRaw(self.spv.gpa, .OpSwitch, 2 + (cond_words + 1) * num_conditions);
self.func.body.writeOperand(IdRef, cond);
self.func.body.writeOperand(IdRef, default.toRef());
// Emit each of the cases
{
var extra_index: usize = switch_br.end;
var case_i: u32 = 0;
while (case_i < num_cases) : (case_i += 1) {
// SPIR-V needs a literal here, which' width depends on the case condition.
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]);
const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len;
const label = IdRef{ .id = first_case_label.id + case_i };
for (items) |item| {
const value = self.air.value(item) orelse {
return self.todo("switch on runtime value???", .{});
};
const int_val = switch (cond_ty.zigTypeTag()) {
.Int => if (cond_ty.isSignedInt()) @bitCast(u64, value.toSignedInt()) else value.toUnsignedInt(target),
.Enum => blk: {
var int_buffer: Value.Payload.U64 = undefined;
// TODO: figure out of cond_ty is correct (something with enum literals)
break :blk value.enumToInt(cond_ty, &int_buffer).toUnsignedInt(target); // TODO: composite integer constants
},
else => unreachable,
};
const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) {
1 => .{ .uint32 = @intCast(u32, int_val) },
2 => .{ .uint64 = int_val },
else => unreachable,
};
self.func.body.writeOperand(spec.LiteralContextDependentNumber, int_lit);
self.func.body.writeOperand(IdRef, label);
}
}
}
// Now, finally, we can start emitting each of the cases.
var extra_index: usize = switch_br.end;
var case_i: u32 = 0;
while (case_i < num_cases) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]);
const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len;
const label = IdResult{ .id = first_case_label.id + case_i };
try self.beginSpvBlock(label);
try self.genBody(case_body);
}
const else_body = self.air.extra[extra_index..][0..switch_br.data.else_body_len];
try self.beginSpvBlock(default);
if (else_body.len != 0) {
try self.genBody(else_body);
} else {
try self.func.body.emit(self.spv.gpa, .OpUnreachable, {});
}
}
fn airUnreach(self: *DeclGen) !void {
try self.func.body.emit(self.spv.gpa, .OpUnreachable, {});
}

View File

@ -132,6 +132,11 @@ pub fn allocId(self: *Module) spec.IdResult {
return .{ .id = self.next_result_id };
}
pub fn allocIds(self: *Module, n: u32) spec.IdResult {
defer self.next_result_id += n;
return .{ .id = self.next_result_id };
}
pub fn idBound(self: Module) Word {
return self.next_result_id;
}