zig-wgpu/src/reference.zig
2026-05-15 09:44:53 +02:00

240 lines
10 KiB
Zig
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Minimal WebGPU compute in Zig: element-wise matrix addition
// Uses wgpu-native C bindings.
// Build: see ../build.zig
//
// Data flow:
// CPU (mat_a, mat_b) → GPU storage buffers → compute shader → GPU buf_c
// → staging buffer (mapped) → CPU read → print
const std = @import("std");
const c = @cImport(@cInclude("wgpu.h"));
// ── Config ────────────────────────────────────────────────────────────────────
const ROWS: u32 = 4;
const COLS: u32 = 4;
const N = ROWS * COLS; // 16 elements
const BUF_BYTES = N * @sizeOf(f32);
// ── WGSL Compute Shader ───────────────────────────────────────────────────────
// workgroup_size(4,4) matches one full 4×4 matrix → dispatch(1,1,1)
const SHADER =
\\@group(0) @binding(0) var<storage, read> mat_a : array<f32>;
\\@group(0) @binding(1) var<storage, read> mat_b : array<f32>;
\\@group(0) @binding(2) var<storage, read_write> mat_c : array<f32>;
\\
\\@compute @workgroup_size(4, 4)
\\fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
\\ let idx = gid.y * 4u + gid.x;
\\ if (idx < arrayLength(&mat_c)) {
\\ mat_c[idx] = mat_a[idx] + mat_b[idx];
\\ }
\\}
;
// ── Callback state ────────────────────────────────────────────────────────────
const Ctx = struct {
adapter: c.WGPUAdapter = null,
device: c.WGPUDevice = null,
};
fn onAdapter(
status: c.WGPURequestAdapterStatus,
adapter: c.WGPUAdapter,
_: c.WGPUStringView,
userdata1: ?*anyopaque,
_: ?*anyopaque,
) callconv(.c) void {
if (status != c.WGPURequestAdapterStatus_Success) {
std.log.err("Adapter request failed (status={d})", .{status});
return;
}
const ctx: *Ctx = @ptrCast(@alignCast(userdata1.?));
ctx.adapter = adapter;
}
fn onDevice(
status: c.WGPURequestDeviceStatus,
device: c.WGPUDevice,
_: c.WGPUStringView,
userdata1: ?*anyopaque,
_: ?*anyopaque,
) callconv(.c) void {
if (status != c.WGPURequestDeviceStatus_Success) {
std.log.err("Device request failed (status={d})", .{status});
return;
}
const ctx: *Ctx = @ptrCast(@alignCast(userdata1.?));
ctx.device = device;
std.debug.print("{?}", .{device});
}
fn onMapped(
status: c.WGPUMapAsyncStatus,
_: c.WGPUStringView,
userdata1: ?*anyopaque,
_: ?*anyopaque,
) callconv(.c) void {
const flag: *bool = @ptrCast(@alignCast(userdata1.?));
flag.* = (status == c.WGPUMapAsyncStatus_Success);
}
fn sv(s: []const u8) c.WGPUStringView {
return .{ .data = s.ptr, .length = s.len };
}
// ── Main ──────────────────────────────────────────────────────────────────────
pub fn main() !void {
// 1. Instance ──────────────────────────────────────────────────────────────
const instance = c.wgpuCreateInstance(&std.mem.zeroes(c.WGPUInstanceDescriptor)) orelse
return error.NoInstance;
defer c.wgpuInstanceRelease(instance);
// 2. Adapter (async → poll) ────────────────────────────────────────────────
var ctx = Ctx{};
_ = c.wgpuInstanceRequestAdapter(
instance,
&.{ .powerPreference = c.WGPUPowerPreference_HighPerformance },
.{ .callback = onAdapter, .userdata1 = &ctx },
);
c.wgpuInstanceProcessEvents(instance); // drive callbacks
const adapter = ctx.adapter orelse return error.NoAdapter;
defer c.wgpuAdapterRelease(adapter);
// 3. Device ────────────────────────────────────────────────────────────────
_ = c.wgpuAdapterRequestDevice(adapter, null, .{ .callback = onDevice, .userdata1 = &ctx });
c.wgpuInstanceProcessEvents(instance);
const device = ctx.device orelse return error.NoDevice;
defer c.wgpuDeviceRelease(device);
const queue = c.wgpuDeviceGetQueue(device);
defer c.wgpuQueueRelease(queue);
// 4. Input data ────────────────────────────────────────────────────────────
// mat_a[i] = i (0 … 15)
// mat_b[i] = 15 i → every element of mat_c should equal 15
var mat_a: [N]f32 = undefined;
var mat_b: [N]f32 = undefined;
for (0..N) |i| {
mat_a[i] = @floatFromInt(i);
mat_b[i] = @floatFromInt(N - 1 - i);
}
// 5. GPU Buffers ───────────────────────────────────────────────────────────
const buf_a = c.wgpuDeviceCreateBuffer(device, &.{
.usage = c.WGPUBufferUsage_Storage | c.WGPUBufferUsage_CopyDst,
.size = BUF_BYTES,
}) orelse return error.BufferA;
const buf_b = c.wgpuDeviceCreateBuffer(device, &.{
.usage = c.WGPUBufferUsage_Storage | c.WGPUBufferUsage_CopyDst,
.size = BUF_BYTES,
}) orelse return error.BufferB;
// buf_c: GPU-only result; staging: CPU-readable copy
const buf_c = c.wgpuDeviceCreateBuffer(device, &.{
.usage = c.WGPUBufferUsage_Storage | c.WGPUBufferUsage_CopySrc,
.size = BUF_BYTES,
}) orelse return error.BufferC;
const buf_staging = c.wgpuDeviceCreateBuffer(device, &.{
.usage = c.WGPUBufferUsage_MapRead | c.WGPUBufferUsage_CopyDst,
.size = BUF_BYTES,
}) orelse return error.BufferStaging;
defer c.wgpuBufferRelease(buf_a);
defer c.wgpuBufferRelease(buf_b);
defer c.wgpuBufferRelease(buf_c);
defer c.wgpuBufferRelease(buf_staging);
// Upload inputs
c.wgpuQueueWriteBuffer(queue, buf_a, 0, &mat_a, BUF_BYTES);
c.wgpuQueueWriteBuffer(queue, buf_b, 0, &mat_b, BUF_BYTES);
// 6. Shader module ─────────────────────────────────────────────────────────
// ✅ New API (0.20+)
var wgsl_src = c.WGPUShaderSourceWGSL{
.chain = .{ .sType = c.WGPUSType_ShaderSourceWGSL },
.code = sv(SHADER),
};
const shader = c.wgpuDeviceCreateShaderModule(device, &.{
.nextInChain = @ptrCast(&wgsl_src),
}) orelse return error.Shader;
// 7. Compute pipeline (layout auto-inferred from shader) ───────────────────
// ✅
const pipeline = c.wgpuDeviceCreateComputePipeline(device, &.{
.compute = .{
.module = shader,
.entryPoint = sv("main"),
},
}) orelse return error.Pipeline;
defer c.wgpuComputePipelineRelease(pipeline);
// 8. Bind group ────────────────────────────────────────────────────────────
const bgl = c.wgpuComputePipelineGetBindGroupLayout(pipeline, 0);
defer c.wgpuBindGroupLayoutRelease(bgl);
const entries = [_]c.WGPUBindGroupEntry{
.{ .binding = 0, .buffer = buf_a, .offset = 0, .size = BUF_BYTES },
.{ .binding = 1, .buffer = buf_b, .offset = 0, .size = BUF_BYTES },
.{ .binding = 2, .buffer = buf_c, .offset = 0, .size = BUF_BYTES },
};
const bind_group = c.wgpuDeviceCreateBindGroup(device, &.{
.layout = bgl,
.entries = &entries,
.entryCount = entries.len,
}) orelse return error.BindGroup;
defer c.wgpuBindGroupRelease(bind_group);
// 9. Encode compute pass + buffer copy ────────────────────────────────────
const encoder = c.wgpuDeviceCreateCommandEncoder(device, null) orelse
return error.Encoder;
const pass = c.wgpuCommandEncoderBeginComputePass(encoder, null);
c.wgpuComputePassEncoderSetPipeline(pass, pipeline);
c.wgpuComputePassEncoderSetBindGroup(pass, 0, bind_group, 0, null);
// dispatch(1,1,1): one workgroup of size (4,4) covers the whole 4×4 matrix
c.wgpuComputePassEncoderDispatchWorkgroups(pass, 1, 1, 1);
c.wgpuComputePassEncoderEnd(pass);
c.wgpuComputePassEncoderRelease(pass);
// Copy result buffer → CPU-readable staging buffer
c.wgpuCommandEncoderCopyBufferToBuffer(encoder, buf_c, 0, buf_staging, 0, BUF_BYTES);
const cmdbuf = c.wgpuCommandEncoderFinish(encoder, null);
defer c.wgpuCommandEncoderRelease(encoder);
defer c.wgpuCommandBufferRelease(cmdbuf);
// 10. Submit ───────────────────────────────────────────────────────────────
c.wgpuQueueSubmit(queue, 1, &cmdbuf);
// 11. Map staging buffer back to CPU ──────────────────────────────────────
var mapped = false;
_ = c.wgpuBufferMapAsync(
buf_staging,
c.WGPUMapMode_Read,
0,
BUF_BYTES,
.{ .callback = onMapped, .userdata1 = &mapped },
);
// Poll the device until the async map completes
while (!mapped) _ = c.wgpuDevicePoll(device, 1, null);
const ptr: [*]const f32 = @ptrCast(@alignCast(
c.wgpuBufferGetConstMappedRange(buf_staging, 0, BUF_BYTES),
));
const result = ptr[0..N];
// 12. Print ────────────────────────────────────────────────────────────────
std.debug.print("\nmat_a + mat_b ({d}×{d}):\n", .{ ROWS, COLS });
for (0..ROWS) |r| {
for (0..COLS) |col|
std.debug.print("{d:6.0}", .{result[r * COLS + col]});
std.debug.print("\n", .{});
}
// Expected output: every cell = 15.0
c.wgpuBufferUnmap(buf_staging);
}