Removed GpuBuffer limits

This commit is contained in:
adrien 2026-05-16 00:26:07 +02:00
parent 90a7cf946f
commit cfc1069309
5 changed files with 97 additions and 21 deletions

View File

@ -2,6 +2,13 @@ const std = @import("std");
const sh = @import("shaders.zig"); const sh = @import("shaders.zig");
const c = @import("c.zig").c; const c = @import("c.zig").c;
pub const GpuConfig = struct {
/// Absolute max footprint of a single Tensor buffer in bytes.
max_tensor_buffer_bytes: u64,
/// Absolute max slice size readable inside a single compute binding in bytes.
max_tensor_binding_bytes: u64,
};
const GpuAllocator = @This(); const GpuAllocator = @This();
cpu_allocator: std.mem.Allocator, cpu_allocator: std.mem.Allocator,
@ -9,6 +16,7 @@ instance: c.WGPUInstance,
adapter: c.WGPUAdapter, adapter: c.WGPUAdapter,
device: c.WGPUDevice, device: c.WGPUDevice,
queue: c.WGPUQueue, queue: c.WGPUQueue,
config: GpuConfig,
tracked_buffers: std.AutoHashMap(c.WGPUBuffer, void), tracked_buffers: std.AutoHashMap(c.WGPUBuffer, void),
@ -32,20 +40,42 @@ pub fn init(cpu_allocator: std.mem.Allocator) !GpuAllocator {
const adapter = ctx.adapter orelse return error.NoAdapter; const adapter = ctx.adapter orelse return error.NoAdapter;
errdefer c.wgpuAdapterRelease(adapter); errdefer c.wgpuAdapterRelease(adapter);
// --- QUERY HARDWARE LIMITS ---
var supported_limits = std.mem.zeroes(c.WGPULimits);
supported_limits.nextInChain = null;
// Fetch what your physical graphic card can actually handle
if (c.wgpuAdapterGetLimits(adapter, &supported_limits) != 1) return error.FailedToGetAdapterLimits;
const device_descriptor = c.WGPUDeviceDescriptor{
.nextInChain = null,
.label = sv("TensorCompilerDevice"),
.requiredFeatureCount = 0,
.requiredFeatures = null,
.requiredLimits = &supported_limits,
};
_ = c.wgpuAdapterRequestDevice( _ = c.wgpuAdapterRequestDevice(
adapter, adapter,
null, &device_descriptor,
.{ .callback = onDevice, .userdata1 = &ctx }, .{ .callback = onDevice, .userdata1 = &ctx },
); );
c.wgpuInstanceProcessEvents(instance); c.wgpuInstanceProcessEvents(instance);
const device = ctx.device orelse return error.NoDevice; const device = ctx.device orelse return error.NoDevice;
// Package configurations into the struct
const config = GpuConfig{
.max_tensor_buffer_bytes = supported_limits.maxBufferSize,
.max_tensor_binding_bytes = supported_limits.maxStorageBufferBindingSize,
};
return .{ return .{
.cpu_allocator = cpu_allocator, .cpu_allocator = cpu_allocator,
.instance = instance, .instance = instance,
.adapter = adapter, .adapter = adapter,
.device = device, .device = device,
.queue = c.wgpuDeviceGetQueue(device), .queue = c.wgpuDeviceGetQueue(device),
.config = config,
.tracked_buffers = .init(cpu_allocator), .tracked_buffers = .init(cpu_allocator),
}; };
} }

View File

@ -146,18 +146,30 @@ fn dispatch2in1out(
bytes: u64, bytes: u64,
n: usize, n: usize,
) !void { ) !void {
const bgl = c.wgpuComputePipelineGetBindGroupLayout(pipeline, 0); // 1. Create a 4-byte Uniform buffer to hold the u32 size
defer c.wgpuBindGroupLayoutRelease(bgl); const info_buf = try GpuBuffer.init(
gloc,
@sizeOf(u32),
c.WGPUBufferUsage_Uniform | c.WGPUBufferUsage_CopyDst,
);
defer info_buf.deinit(); // Clean up immediately after the pass submits
// 2. Cast the usize 'n' to a u32 and write it to the GPU queue
const size_payload: u32 = @intCast(n);
c.wgpuQueueWriteBuffer(gloc.queue, info_buf.raw, 0, &size_payload, @sizeOf(u32));
// 3. Create the 4 entries matching your WGSL @binding() tags
const entries = [_]c.WGPUBindGroupEntry{ const entries = [_]c.WGPUBindGroupEntry{
.{ .binding = 0, .buffer = buf_a.raw, .offset = 0, .size = bytes }, .{ .binding = 0, .buffer = buf_a.raw, .offset = 0, .size = bytes },
.{ .binding = 1, .buffer = buf_b.raw, .offset = 0, .size = bytes }, .{ .binding = 1, .buffer = buf_b.raw, .offset = 0, .size = bytes },
.{ .binding = 2, .buffer = buf_out.raw, .offset = 0, .size = bytes }, .{ .binding = 2, .buffer = buf_out.raw, .offset = 0, .size = bytes },
.{ .binding = 3, .buffer = info_buf.raw, .offset = 0, .size = @sizeOf(u32) }, // <--- The 4th binding!
}; };
try submitPass(gloc, pipeline, &entries, n); try submitPass(gloc, pipeline, &entries, n);
} }
/// Create bind group, encode pass, submit. workgroup_size=64. /// Create bind group, encode pass, submit.
fn submitPass( fn submitPass(
gloc: *GpuAllocator, gloc: *GpuAllocator,
pipeline: c.WGPUComputePipeline, pipeline: c.WGPUComputePipeline,
@ -179,7 +191,14 @@ fn submitPass(
const pass = c.wgpuCommandEncoderBeginComputePass(enc, null); const pass = c.wgpuCommandEncoderBeginComputePass(enc, null);
c.wgpuComputePassEncoderSetPipeline(pass, pipeline); c.wgpuComputePassEncoderSetPipeline(pass, pipeline);
c.wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, null); c.wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, null);
c.wgpuComputePassEncoderDispatchWorkgroups(pass, @intCast(ceilDiv(n, 256)), 1, 1);
const WORKGROUP_SIZE = 256;
const MAX_WORKGROUPS = 65535;
const desired_workgroups = ceilDiv(n, WORKGROUP_SIZE);
const dispatch_count = @min(desired_workgroups, MAX_WORKGROUPS);
c.wgpuComputePassEncoderDispatchWorkgroups(pass, @intCast(dispatch_count), 1, 1);
c.wgpuComputePassEncoderEnd(pass); c.wgpuComputePassEncoderEnd(pass);
c.wgpuComputePassEncoderRelease(pass); c.wgpuComputePassEncoderRelease(pass);

View File

@ -7,7 +7,19 @@ pub fn main(init: std.process.Init) !void {
defer gloc.deinit(); defer gloc.deinit();
// Define the sizes you want to benchmark // Define the sizes you want to benchmark
const sizes = [_]usize{ 1, 1024, 4096, 16384, 65536, 262144, 1024 * 1024, 4 * 1024 * 1024 }; const sizes = [_]usize{
1,
1024,
4096,
16384,
65536,
262144,
1024 * 1024,
4 * 1024 * 1024,
4 * 4 * 1024 * 1024,
4 * 4 * 4 * 1024 * 1024,
1024 * 1024 * 1024,
};
// Print table header // Print table header
std.debug.print("\n| Element Count | Size (MB) | Time (ms) | Time (ns) |\n", .{}); std.debug.print("\n| Element Count | Size (MB) | Time (ms) | Time (ns) |\n", .{});
@ -51,11 +63,10 @@ pub fn main(init: std.process.Init) !void {
defer allocator.free(out_scaled); defer allocator.free(out_scaled);
try sum.read(&gloc, out_sum); try sum.read(&gloc, out_sum);
try scaled.read(&gloc, out_scaled);
const duration = start.durationTo(std.Io.Clock.awake.now(init.io)); const duration = start.durationTo(std.Io.Clock.awake.now(init.io));
const ns = duration.toNanoseconds(); const ns = duration.toNanoseconds();
const ms = @as(f64, @floatFromInt(ns)) / 1_000_000.0; const ms = duration.toMilliseconds();
const mb = @as(f64, @floatFromInt(size * @sizeOf(f32))) / (1024.0 * 1024.0); const mb = @as(f64, @floatFromInt(size * @sizeOf(f32))) / (1024.0 * 1024.0);
// Print table row // Print table row

View File

@ -1,16 +1,6 @@
pub const SHADER_ADD = const std = @import("std");
\\@group(0) @binding(0) var<storage, read> a : array<f32>;
\\@group(0) @binding(1) var<storage, read> b : array<f32>; pub const SHADER_ADD = @embedFile("shaders/add.wgsl");
\\@group(0) @binding(2) var<storage, read_write> out : array<f32>;
\\
\\@compute @workgroup_size(256)
\\fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
\\ let i = gid.x;
\\ if (i < arrayLength(&out)) {
\\ out[i] = a[i] + b[i];
\\ }
\\}
;
pub const SHADER_SCALE = pub const SHADER_SCALE =
\\struct Uniforms { scalar : f32 } \\struct Uniforms { scalar : f32 }

26
src/shaders/add.wgsl Normal file
View File

@ -0,0 +1,26 @@
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
struct TensorInfo {
size: u32,
};
@group(0) @binding(3) var<uniform> info: TensorInfo;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>
) {
// 1. Calculate the total number of threads across the entire grid
let total_threads = num_workgroups.x * 256u;
// 2. Start at this thread's unique global ID
var index = global_id.x;
// 3. Stride through the tensor elements
while (index < info.size) {
C[index] = A[index] + B[index];
index += total_threads; // Jump forward by the total thread count
}
}