diff --git a/src/shaders/add.wgsl b/src/shaders/add.wgsl index 333e56e..4288742 100644 --- a/src/shaders/add.wgsl +++ b/src/shaders/add.wgsl @@ -3,11 +3,7 @@ enable f16; @group(0) @binding(0) var A: array; @group(0) @binding(1) var B: array; @group(0) @binding(2) var C: array; - -struct TensorInfo { - size: u32, -}; -@group(0) @binding(3) var info: TensorInfo; +@group(0) @binding(3) var size: u32; @compute @workgroup_size(256) fn main( @@ -21,7 +17,7 @@ fn main( var index = global_id.x; // 3. Stride through the tensor elements - while (index < info.size) { + while (index < size) { C[index] = A[index] + B[index]; index += total_threads; // Jump forward by the total thread count }