25 lines
795 B
WebGPU Shading Language
25 lines
795 B
WebGPU Shading Language
enable f16;
|
|
|
|
@group(0) @binding(0) var<storage, read> A: array<f16>;
|
|
@group(0) @binding(1) var<storage, read> B: array<f16>;
|
|
@group(0) @binding(2) var<storage, read_write> C: array<f16>;
|
|
@group(0) @binding(3) var<uniform> size: u32;
|
|
|
|
@compute @workgroup_size(256)
|
|
fn main(
|
|
@builtin(global_invocation_id) global_id : vec3<u32>,
|
|
@builtin(num_workgroups) num_workgroups: vec3<u32>
|
|
) {
|
|
// 1. Calculate the total number of threads across the entire grid
|
|
let total_threads = num_workgroups.x * 256u;
|
|
|
|
// 2. Start at this thread's unique global ID
|
|
var index = global_id.x;
|
|
|
|
// 3. Stride through the tensor elements
|
|
while (index < size) {
|
|
C[index] = A[index] + B[index];
|
|
index += total_threads; // Jump forward by the total thread count
|
|
}
|
|
}
|