zig-wgpu/src/GpuCompute.zig

172 lines
5.6 KiB
Zig

const c = @import("utils.zig").c;
const sv = @import("utils.zig").sv;
const GpuAllocator = @import("GpuAllocator.zig");
const GpuBuffer = @import("GpuBuffer.zig");
const GpuDevice = @import("GpuDevice.zig");
pub const Binding = struct {
/// Element size in bytes for this binding. E.g. @sizeOf(f32).
/// If 0, no element-based size validation is performed for this buffer.
element_size: u32 = 0,
};
pub const ComputeDef = struct {
bindings: []const Binding,
workgroup_size: u32 = 256,
max_workgroups: u32 = 65535,
/// If true, automatically adds a Uniform Buffer containing `elements_count` as a `u32`
/// to the next available binding slot.
append_info_buffer: bool = true,
};
pip: c.WGPUComputePipeline,
gloc: GpuAllocator,
def: ComputeDef,
pub fn init(gloc: GpuAllocator, wgsl: []const u8, def: ComputeDef) !@This() {
var wgsl_src = c.WGPUShaderSourceWGSL{
.chain = .{ .sType = c.WGPUSType_ShaderSourceWGSL },
.code = sv(wgsl),
};
const shader = c.wgpuDeviceCreateShaderModule(gloc.device.device, &.{
.nextInChain = @ptrCast(&wgsl_src),
}) orelse return error.Shader;
defer c.wgpuShaderModuleRelease(shader);
const pip = try gloc.allocComputePipeline(.{ .compute = .{ .module = shader, .entryPoint = sv("main") } });
return .{
.gloc = gloc,
.pip = pip,
.def = def,
};
}
pub fn deinit(self: @This()) void {
self.gloc.freeComputePipeline(self.pip);
}
/// Execute the compute pass with arbitrary buffer bindings via a tuple.
/// Example: `try proc.run(gloc, .{ buf_a, buf_b, buf_out });`
pub fn run(
self: @This(),
gloc: GpuAllocator,
args: anytype,
) !void {
const type_info = @typeInfo(@TypeOf(args));
if (type_info != .@"struct" or !type_info.@"struct".is_tuple)
@compileError("Expected a tuple of GpuBuffers for args. E.g. .{ buf_a, buf_b }");
const fields = type_info.@"struct".fields;
if (fields.len != self.def.bindings.len)
return error.InvalidArgumentCount;
var elements_count: u32 = 0;
// Infer elements_count from the first arg with a defined element_size
inline for (fields, 0..) |field, i| {
if (elements_count == 0) {
const buf = @field(args, field.name);
const el_size = self.def.bindings[i].element_size;
if (el_size > 0) {
elements_count = @intCast(buf.size / el_size);
}
}
}
// Validate runtime buffer sizes before dispatching
inline for (fields, 0..) |field, i| {
const buf = @field(args, field.name);
const el_size = self.def.bindings[i].element_size;
if (el_size > 0) {
const expected_min_bytes = @as(u64, elements_count) * el_size;
if (buf.size < expected_min_bytes)
return error.BufferTooSmall;
}
}
var entries_buf: [32]c.WGPUBindGroupEntry = undefined;
var entry_count: usize = 0;
// Unpack tuple into WebGPU BindGroupEntries
inline for (fields, 0..) |field, i| {
const buf = @field(args, field.name);
if (@TypeOf(buf) != GpuBuffer) {
@compileError("All arguments in the tuple must be of type GpuBuffer");
}
entries_buf[entry_count] = .{
.binding = @intCast(i),
.buffer = buf.raw,
.offset = 0,
.size = buf.size, // Size exposes the fully allocated length
};
entry_count += 1;
}
// Optional uniform dispatch buffer appended at the end
var info_buf: ?GpuBuffer = null;
defer if (info_buf) |b| b.deinit();
if (self.def.append_info_buffer) {
info_buf = try GpuBuffer.init(
gloc,
@sizeOf(u32),
.initMany(&.{ .Uniform, .CopyDst }),
);
c.wgpuQueueWriteBuffer(gloc.device.queue, info_buf.?.raw, 0, &elements_count, @sizeOf(u32));
entries_buf[entry_count] = .{
.binding = @intCast(entry_count),
.buffer = info_buf.?.raw,
.offset = 0,
.size = @sizeOf(u32),
};
entry_count += 1;
}
const entries = entries_buf[0..entry_count];
try submitPass(gloc, self.pip, entries, elements_count, self.def.workgroup_size, self.def.max_workgroups);
}
fn submitPass(
gloc: GpuAllocator,
pipeline: c.WGPUComputePipeline,
entries: []const c.WGPUBindGroupEntry,
n: usize,
workgroup_size: u32,
max_workgroups: u32,
) !void {
if (n == 0) return;
const bgl = c.wgpuComputePipelineGetBindGroupLayout(pipeline, 0);
defer c.wgpuBindGroupLayoutRelease(bgl);
const bg = c.wgpuDeviceCreateBindGroup(gloc.device.device, &.{
.layout = bgl,
.entries = entries.ptr,
.entryCount = entries.len,
}) orelse return error.BindGroup;
defer c.wgpuBindGroupRelease(bg);
const enc = c.wgpuDeviceCreateCommandEncoder(gloc.device.device, null) orelse return error.Encoder;
const pass = c.wgpuCommandEncoderBeginComputePass(enc, null);
c.wgpuComputePassEncoderSetPipeline(pass, pipeline);
c.wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, null);
const desired_workgroups = ceilDiv(n, workgroup_size);
const dispatch_count = @min(desired_workgroups, max_workgroups);
c.wgpuComputePassEncoderDispatchWorkgroups(pass, @intCast(dispatch_count), 1, 1);
c.wgpuComputePassEncoderEnd(pass);
c.wgpuComputePassEncoderRelease(pass);
const cmd = c.wgpuCommandEncoderFinish(enc, null);
defer c.wgpuCommandEncoderRelease(enc);
defer c.wgpuCommandBufferRelease(cmd);
c.wgpuQueueSubmit(gloc.device.queue, 1, &cmd);
}
fn ceilDiv(n: usize, d: usize) usize {
return (n + d - 1) / d;
}