diff --git a/src/GpuAllocator.zig b/src/GpuAllocator.zig index 780d7a1..b3fb604 100644 --- a/src/GpuAllocator.zig +++ b/src/GpuAllocator.zig @@ -20,9 +20,9 @@ config: GpuConfig, tracked_buffers: std.AutoHashMap(c.WGPUBuffer, void), -// Lazily created, cached for lifetime of allocator -_pip_add: c.WGPUComputePipeline = null, -_pip_scale: c.WGPUComputePipeline = null, +pipelines: struct { + add: c.WGPUComputePipeline, +}, pub fn init(cpu_allocator: std.mem.Allocator) !GpuAllocator { const instance = c.wgpuCreateInstance( @@ -77,12 +77,15 @@ pub fn init(cpu_allocator: std.mem.Allocator) !GpuAllocator { .queue = c.wgpuDeviceGetQueue(device), .config = config, .tracked_buffers = .init(cpu_allocator), + .pipelines = .{ + .add = try buildPipeline(device, sh.SHADER_ADD), + }, }; } pub fn deinit(self: *GpuAllocator) void { - if (self._pip_add) |p| c.wgpuComputePipelineRelease(p); - if (self._pip_scale) |p| c.wgpuComputePipelineRelease(p); + inline for (@typeInfo(@TypeOf(self.pipelines)).@"struct".fields) |field| + c.wgpuComputePipelineRelease(@field(self.pipelines, field.name)); var it = self.tracked_buffers.keyIterator(); while (it.next()) |buf_ptr| { @@ -132,18 +135,6 @@ pub fn makeBuffer( }) orelse error.BufferAlloc; } -pub fn pipAdd(self: *GpuAllocator) !c.WGPUComputePipeline { - if (self._pip_add == null) - self._pip_add = try buildPipeline(self.device, sh.SHADER_ADD); - return self._pip_add.?; -} - -pub fn pipScale(self: *GpuAllocator) !c.WGPUComputePipeline { - if (self._pip_scale == null) - self._pip_scale = try buildPipeline(self.device, sh.SHADER_SCALE); - return self._pip_scale.?; -} - /// Poll until GPU work completes. Use after submit if you need CPU sync. pub fn poll(self: *GpuAllocator) void { _ = c.wgpuDevicePoll(self.device, 1, null); diff --git a/src/Mat.zig b/src/Mat.zig index bb27152..e81d4ed 100644 --- a/src/Mat.zig +++ b/src/Mat.zig @@ -57,8 +57,7 @@ pub fn add(self: Mat, gloc: *GpuAllocator, other: Mat) !Mat { const result = try Mat.zeros(gloc, self.rows, self.cols); errdefer result.deinit(); - const pipeline = try gloc.pipAdd(); - try dispatch2in1out(gloc, pipeline, self.buf, other.buf, result.buf, self.byteSize()); + try dispatch2in1out(gloc, gloc.pipelines.add, self.buf, other.buf, result.buf, self.byteSize()); return result; } diff --git a/src/lib.zig b/src/lib.zig index 6f0564c..58c2bfd 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1 +1,2 @@ pub const GpuAllocator = @import("GpuAllocator.zig"); +pub const GpuBuffer = @import("GpuBuffer.zig");