diff --git a/src/GpuAllocator.zig b/src/GpuAllocator.zig index 506a314..780d7a1 100644 --- a/src/GpuAllocator.zig +++ b/src/GpuAllocator.zig @@ -2,6 +2,13 @@ const std = @import("std"); const sh = @import("shaders.zig"); const c = @import("c.zig").c; +pub const GpuConfig = struct { + /// Absolute max footprint of a single Tensor buffer in bytes. + max_tensor_buffer_bytes: u64, + /// Absolute max slice size readable inside a single compute binding in bytes. + max_tensor_binding_bytes: u64, +}; + const GpuAllocator = @This(); cpu_allocator: std.mem.Allocator, @@ -9,6 +16,7 @@ instance: c.WGPUInstance, adapter: c.WGPUAdapter, device: c.WGPUDevice, queue: c.WGPUQueue, +config: GpuConfig, tracked_buffers: std.AutoHashMap(c.WGPUBuffer, void), @@ -32,20 +40,42 @@ pub fn init(cpu_allocator: std.mem.Allocator) !GpuAllocator { const adapter = ctx.adapter orelse return error.NoAdapter; errdefer c.wgpuAdapterRelease(adapter); + // --- QUERY HARDWARE LIMITS --- + var supported_limits = std.mem.zeroes(c.WGPULimits); + supported_limits.nextInChain = null; + + // Fetch what your physical graphic card can actually handle + if (c.wgpuAdapterGetLimits(adapter, &supported_limits) != 1) return error.FailedToGetAdapterLimits; + + const device_descriptor = c.WGPUDeviceDescriptor{ + .nextInChain = null, + .label = sv("TensorCompilerDevice"), + .requiredFeatureCount = 0, + .requiredFeatures = null, + .requiredLimits = &supported_limits, + }; + _ = c.wgpuAdapterRequestDevice( adapter, - null, + &device_descriptor, .{ .callback = onDevice, .userdata1 = &ctx }, ); c.wgpuInstanceProcessEvents(instance); const device = ctx.device orelse return error.NoDevice; + // Package configurations into the struct + const config = GpuConfig{ + .max_tensor_buffer_bytes = supported_limits.maxBufferSize, + .max_tensor_binding_bytes = supported_limits.maxStorageBufferBindingSize, + }; + return .{ .cpu_allocator = cpu_allocator, .instance = instance, .adapter = adapter, .device = device, .queue = c.wgpuDeviceGetQueue(device), + .config = config, .tracked_buffers = .init(cpu_allocator), }; } diff --git a/src/Mat.zig b/src/Mat.zig index 1ddc2ee..4d00acd 100644 --- a/src/Mat.zig +++ b/src/Mat.zig @@ -146,18 +146,30 @@ fn dispatch2in1out( bytes: u64, n: usize, ) !void { - const bgl = c.wgpuComputePipelineGetBindGroupLayout(pipeline, 0); - defer c.wgpuBindGroupLayoutRelease(bgl); + // 1. Create a 4-byte Uniform buffer to hold the u32 size + const info_buf = try GpuBuffer.init( + gloc, + @sizeOf(u32), + c.WGPUBufferUsage_Uniform | c.WGPUBufferUsage_CopyDst, + ); + defer info_buf.deinit(); // Clean up immediately after the pass submits + // 2. Cast the usize 'n' to a u32 and write it to the GPU queue + const size_payload: u32 = @intCast(n); + c.wgpuQueueWriteBuffer(gloc.queue, info_buf.raw, 0, &size_payload, @sizeOf(u32)); + + // 3. Create the 4 entries matching your WGSL @binding() tags const entries = [_]c.WGPUBindGroupEntry{ .{ .binding = 0, .buffer = buf_a.raw, .offset = 0, .size = bytes }, .{ .binding = 1, .buffer = buf_b.raw, .offset = 0, .size = bytes }, .{ .binding = 2, .buffer = buf_out.raw, .offset = 0, .size = bytes }, + .{ .binding = 3, .buffer = info_buf.raw, .offset = 0, .size = @sizeOf(u32) }, // <--- The 4th binding! }; + try submitPass(gloc, pipeline, &entries, n); } -/// Create bind group, encode pass, submit. workgroup_size=64. +/// Create bind group, encode pass, submit. fn submitPass( gloc: *GpuAllocator, pipeline: c.WGPUComputePipeline, @@ -179,7 +191,14 @@ fn submitPass( const pass = c.wgpuCommandEncoderBeginComputePass(enc, null); c.wgpuComputePassEncoderSetPipeline(pass, pipeline); c.wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, null); - c.wgpuComputePassEncoderDispatchWorkgroups(pass, @intCast(ceilDiv(n, 256)), 1, 1); + + const WORKGROUP_SIZE = 256; + const MAX_WORKGROUPS = 65535; + + const desired_workgroups = ceilDiv(n, WORKGROUP_SIZE); + const dispatch_count = @min(desired_workgroups, MAX_WORKGROUPS); + + c.wgpuComputePassEncoderDispatchWorkgroups(pass, @intCast(dispatch_count), 1, 1); c.wgpuComputePassEncoderEnd(pass); c.wgpuComputePassEncoderRelease(pass); diff --git a/src/main.zig b/src/main.zig index 8c38c89..1223ff5 100644 --- a/src/main.zig +++ b/src/main.zig @@ -7,7 +7,19 @@ pub fn main(init: std.process.Init) !void { defer gloc.deinit(); // Define the sizes you want to benchmark - const sizes = [_]usize{ 1, 1024, 4096, 16384, 65536, 262144, 1024 * 1024, 4 * 1024 * 1024 }; + const sizes = [_]usize{ + 1, + 1024, + 4096, + 16384, + 65536, + 262144, + 1024 * 1024, + 4 * 1024 * 1024, + 4 * 4 * 1024 * 1024, + 4 * 4 * 4 * 1024 * 1024, + 1024 * 1024 * 1024, + }; // Print table header std.debug.print("\n| Element Count | Size (MB) | Time (ms) | Time (ns) |\n", .{}); @@ -51,11 +63,10 @@ pub fn main(init: std.process.Init) !void { defer allocator.free(out_scaled); try sum.read(&gloc, out_sum); - try scaled.read(&gloc, out_scaled); const duration = start.durationTo(std.Io.Clock.awake.now(init.io)); const ns = duration.toNanoseconds(); - const ms = @as(f64, @floatFromInt(ns)) / 1_000_000.0; + const ms = duration.toMilliseconds(); const mb = @as(f64, @floatFromInt(size * @sizeOf(f32))) / (1024.0 * 1024.0); // Print table row diff --git a/src/shaders.zig b/src/shaders.zig index da489fd..4e21cac 100644 --- a/src/shaders.zig +++ b/src/shaders.zig @@ -1,16 +1,6 @@ -pub const SHADER_ADD = - \\@group(0) @binding(0) var a : array; - \\@group(0) @binding(1) var b : array; - \\@group(0) @binding(2) var out : array; - \\ - \\@compute @workgroup_size(256) - \\fn main(@builtin(global_invocation_id) gid : vec3) { - \\ let i = gid.x; - \\ if (i < arrayLength(&out)) { - \\ out[i] = a[i] + b[i]; - \\ } - \\} -; +const std = @import("std"); + +pub const SHADER_ADD = @embedFile("shaders/add.wgsl"); pub const SHADER_SCALE = \\struct Uniforms { scalar : f32 } diff --git a/src/shaders/add.wgsl b/src/shaders/add.wgsl new file mode 100644 index 0000000..58407fc --- /dev/null +++ b/src/shaders/add.wgsl @@ -0,0 +1,26 @@ +@group(0) @binding(0) var A: array; +@group(0) @binding(1) var B: array; +@group(0) @binding(2) var C: array; + +struct TensorInfo { + size: u32, +}; +@group(0) @binding(3) var info: TensorInfo; + +@compute @workgroup_size(256) +fn main( + @builtin(global_invocation_id) global_id : vec3, + @builtin(num_workgroups) num_workgroups: vec3 +) { + // 1. Calculate the total number of threads across the entire grid + let total_threads = num_workgroups.x * 256u; + + // 2. Start at this thread's unique global ID + var index = global_id.x; + + // 3. Stride through the tensor elements + while (index < info.size) { + C[index] = A[index] + B[index]; + index += total_threads; // Jump forward by the total thread count + } +}