From d42c521a962045f3c20ed703be8d95c0f5f70b3d Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 18 May 2026 10:12:36 +0200 Subject: [PATCH] Added f16 capability --- src/GpuDevice.zig | 23 +++++++++++++++++------ src/Vec.zig | 20 ++++++++++---------- src/bench.zig | 10 +++++----- src/example.zig | 4 ++-- src/shaders/add.wgsl | 8 +++++--- 5 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/GpuDevice.zig b/src/GpuDevice.zig index 1820b16..b1da5c4 100644 --- a/src/GpuDevice.zig +++ b/src/GpuDevice.zig @@ -36,21 +36,32 @@ pub fn init(config: GpuDeviceConfig) !@This() { const adapter = ctx.adapter orelse return error.NoAdapter; errdefer c.wgpuAdapterRelease(adapter); - // --- QUERY HARDWARE LIMITS --- + var supported_features = std.mem.zeroes(c.WGPUSupportedFeatures); + c.wgpuAdapterGetFeatures(adapter, &supported_features); + var supported_limits = std.mem.zeroes(c.WGPULimits); supported_limits.nextInChain = null; - - // Fetch what your physical graphic card can actually handle if (c.wgpuAdapterGetLimits(adapter, &supported_limits) != 1) return error.FailedToGetAdapterLimits; + var has_f16 = false; + for (0..supported_features.featureCount) |i| { + if (supported_features.features[i] == c.WGPUFeatureName_ShaderF16) { + has_f16 = true; + break; + } + } + + var feature_buf = [_]c.WGPUFeatureName{c.WGPUFeatureName_ShaderF16}; + const required_features: []const c.WGPUFeatureName = + if (has_f16) feature_buf[0..1] else &.{}; + const device_descriptor = c.WGPUDeviceDescriptor{ .nextInChain = null, .label = sv("TensorCompilerDevice"), - .requiredFeatureCount = 0, - .requiredFeatures = null, + .requiredFeatureCount = required_features.len, + .requiredFeatures = if (required_features.len > 0) required_features.ptr else null, .requiredLimits = &supported_limits, }; - _ = c.wgpuAdapterRequestDevice( adapter, &device_descriptor, diff --git a/src/Vec.zig b/src/Vec.zig index 15badb3..81ce9c1 100644 --- a/src/Vec.zig +++ b/src/Vec.zig @@ -15,7 +15,7 @@ pub fn initZero(gloc: *GpuAllocator, len: usize) !Vec { return .{ .buf = try GpuBuffer.init( gloc, - f32, + f16, len, c.WGPUBufferUsage_Storage | c.WGPUBufferUsage_CopyDst | c.WGPUBufferUsage_CopySrc, ), @@ -23,7 +23,7 @@ pub fn initZero(gloc: *GpuAllocator, len: usize) !Vec { }; } -pub fn initLoad(gloc: *GpuAllocator, data: []const f32) !Vec { +pub fn initLoad(gloc: *GpuAllocator, data: []const f16) !Vec { var self = try initZero(gloc, data.len); try self.load(gloc.device, data); return self; @@ -37,15 +37,15 @@ pub fn deinit(self: Vec) void { pub fn load( self: Vec, device: GpuDevice, - data: []const f32, + data: []const f16, ) !void { std.debug.assert(data.len == self.len); - const bytes = data.len * @sizeOf(f32); + const bytes = self.byteSize(); c.wgpuQueueWriteBuffer(device.queue, self.buf.raw, 0, data.ptr, bytes); } pub fn byteSize(self: Vec) u64 { - return @as(u64, self.len) * @sizeOf(f32); + return @as(u64, self.len) * @sizeOf(f16); } pub fn run(self: Vec, gloc: *GpuAllocator, other: Vec, pip: GpuPipeline) !Vec { @@ -60,13 +60,13 @@ pub fn run(self: Vec, gloc: *GpuAllocator, other: Vec, pip: GpuPipeline) !Vec { } /// GPU to CPU. -pub fn read(self: Vec, gloc: *GpuAllocator, alloc: std.mem.Allocator) ![]f32 { - const out = try alloc.alloc(f32, self.len); +pub fn read(self: Vec, gloc: *GpuAllocator, alloc: std.mem.Allocator) ![]f16 { + const out = try alloc.alloc(f16, self.len); const bytes = self.byteSize(); const staging = try GpuBuffer.init( gloc, - f32, + f16, self.len, c.WGPUBufferUsage_MapRead | c.WGPUBufferUsage_CopyDst, ); @@ -88,7 +88,7 @@ pub fn read(self: Vec, gloc: *GpuAllocator, alloc: std.mem.Allocator) ![]f32 { ); while (!mapped) gloc.device.poll(); - const ptr: [*]const f32 = @ptrCast(@alignCast( + const ptr: [*]const f16 = @ptrCast(@alignCast( staging.getConstMappedRange(0, bytes), )); @memcpy(out[0..self.len], ptr[0..self.len]); @@ -122,7 +122,7 @@ fn dispatch2in1out( while (offset < bytes) { // Calculate bounds for the current chunk const current_chunk_bytes = @min(max_chunk_bytes, bytes - offset); - const current_chunk_elements: u32 = @intCast(current_chunk_bytes / @sizeOf(f32)); + const current_chunk_elements: u32 = @intCast(current_chunk_bytes / @sizeOf(f16)); // Create uniform buffer for this specific chunk's size const info_buf = try GpuBuffer.init( diff --git a/src/bench.zig b/src/bench.zig index b7b424f..368e22e 100644 --- a/src/bench.zig +++ b/src/bench.zig @@ -20,8 +20,8 @@ pub fn main(init: std.process.Init) !void { // --- WARM-UP PHASE --- { - var warmup_a = [_]f32{1.0}; - var warmup_b = [_]f32{1.0}; + var warmup_a = [_]f16{1.0}; + var warmup_b = [_]f16{1.0}; const wa = try Vec.initLoad(&gloc, &warmup_a); defer wa.deinit(); const wb = try Vec.initLoad(&gloc, &warmup_b); @@ -56,9 +56,9 @@ pub fn main(init: std.process.Init) !void { for (sizes) |size| { // --- Phase 1: Host Init/Alloc (Outside the iteration loop for pure host prep) --- - const data_a = try allocator.alloc(f32, size); + const data_a = try allocator.alloc(f16, size); defer allocator.free(data_a); - const data_b = try allocator.alloc(f32, size); + const data_b = try allocator.alloc(f16, size); defer allocator.free(data_b); for (0..size) |i| { @@ -117,7 +117,7 @@ pub fn main(init: std.process.Init) !void { // --- Metrics Calculations --- const f_size = @as(f64, @floatFromInt(size)); - const element_bytes = f_size * @as(f64, @floatFromInt(@sizeOf(f32))); + const element_bytes = f_size * @as(f64, @floatFromInt(@sizeOf(f16))); const mb = element_bytes / (1024.0 * 1024.0); // Individual Phase Timings (ms) diff --git a/src/example.zig b/src/example.zig index 560b01e..a27ae7e 100644 --- a/src/example.zig +++ b/src/example.zig @@ -18,9 +18,9 @@ pub fn main(init: std.process.Init) !void { const add_pip = try GpuPipeline.init(device, @embedFile("shaders/add.wgsl")); defer add_pip.deinit(); - const data_a = try allocator.alloc(f32, 1024); + const data_a = try allocator.alloc(f16, 1024); defer allocator.free(data_a); - const data_b = try allocator.alloc(f32, 1024); + const data_b = try allocator.alloc(f16, 1024); defer allocator.free(data_b); for (0..1024) |i| { diff --git a/src/shaders/add.wgsl b/src/shaders/add.wgsl index 58407fc..333e56e 100644 --- a/src/shaders/add.wgsl +++ b/src/shaders/add.wgsl @@ -1,6 +1,8 @@ -@group(0) @binding(0) var A: array; -@group(0) @binding(1) var B: array; -@group(0) @binding(2) var C: array; +enable f16; + +@group(0) @binding(0) var A: array; +@group(0) @binding(1) var B: array; +@group(0) @binding(2) var C: array; struct TensorInfo { size: u32,