# Tensor GPU: Memory & Pipeline Strategy

**Best approach:** Lazy graph + ping-pong buffers + single command buffer.

---

## Architecture

**Problem with eager pipelines:**
```
m1.add(m2)     → dispatch + sync point (slow)
.mul(5)        → dispatch + sync point (slow)
.sub(m3)       → dispatch + sync point (slow)
Result: 3× GPU kernel submission overhead. Many intermediate buffers.
```

**Better: Build graph, execute once:**
```
m1.add(m2).mul(5).sub(m3)   // build operation list
                .compute()   // ONE command buffer, all ops
```

---

## Implementation

```zig
const std = @import("std");
const c = @cImport(@cInclude("wgpu.h"));

pub const Operation = union(enum) {
    add: struct { other: *TensorGPU },
    mul: struct { scalar: f32 },
    sub: struct { other: *TensorGPU },
    div: struct { scalar: f32 },
};

pub const TensorGPU = struct {
    gpu: *AllocatorGPU,
    buffer: c.WGPUBuffer,
    shape: [2]u32,  // rows, cols
    element_count: u32,
    buf_bytes: u32,
    
    operations: std.ArrayList(Operation),
    is_computed: bool,
    allocator: std.mem.Allocator,

    pub fn init(gpu: *AllocatorGPU, shape: [2]u32, allocator: std.mem.Allocator) !TensorGPU {
        const rows = shape[0];
        const cols = shape[1];
        const element_count = rows * cols;
        const buf_bytes = element_count * @sizeOf(f32);

        const buffer = c.wgpuDeviceCreateBuffer(gpu.device, &.{
            .usage = c.WGPUBufferUsage_Storage 
                   | c.WGPUBufferUsage_CopySrc 
                   | c.WGPUBufferUsage_CopyDst,
            .size = buf_bytes,
        }) orelse return error.BufferCreate;

        var self: TensorGPU = .{
            .gpu = gpu,
            .buffer = buffer,
            .shape = shape,
            .element_count = element_count,
            .buf_bytes = buf_bytes,
            .operations = try std.ArrayList(Operation).initCapacity(allocator, 8),
            .is_computed = true,
            .allocator = allocator,
        };

        return self;
    }

    pub fn deinit(self: *TensorGPU) void {
        c.wgpuBufferRelease(self.buffer);
        self.operations.deinit();
    }

    pub fn add(self: *TensorGPU, other: *TensorGPU) *TensorGPU {
        self.operations.append(.{ .add = .{ .other = other } }) catch unreachable;
        self.is_computed = false;
        return self;
    }

    pub fn mul(self: *TensorGPU, scalar: f32) *TensorGPU {
        self.operations.append(.{ .mul = .{ .scalar = scalar } }) catch unreachable;
        self.is_computed = false;
        return self;
    }

    pub fn sub(self: *TensorGPU, other: *TensorGPU) *TensorGPU {
        self.operations.append(.{ .sub = .{ .other = other } }) catch unreachable;
        self.is_computed = false;
        return self;
    }

    pub fn compute(self: *TensorGPU) !void {
        if (self.is_computed or self.operations.items.len == 0) return;

        // Allocate ping-pong temp buffer (freed after compute)
        const buf_temp = c.wgpuDeviceCreateBuffer(self.gpu.device, &.{
            .usage = c.WGPUBufferUsage_Storage 
                   | c.WGPUBufferUsage_CopySrc 
                   | c.WGPUBufferUsage_CopyDst,
            .size = self.buf_bytes,
        }) orelse return error.TempBuffer;
        defer c.wgpuBufferRelease(buf_temp);

        // Build ONE command encoder for all operations
        const encoder = c.wgpuDeviceCreateCommandEncoder(self.gpu.device, null) 
            orelse return error.Encoder;
        defer c.wgpuCommandEncoderRelease(encoder);

        var buf_read = self.buffer;   // input
        var buf_write = buf_temp;     // output (swap after each op)

        for (self.operations.items) |op| {
            try self.encodeOp(encoder, op, buf_read, buf_write);
            
            // Swap: output becomes input for next op
            const tmp = buf_read;
            buf_read = buf_write;
            buf_write = tmp;
        }

        // Final result in buf_read; copy back to self.buffer if needed
        if (buf_read != self.buffer) {
            c.wgpuCommandEncoderCopyBufferToBuffer(
                encoder, buf_read, 0, self.buffer, 0, self.buf_bytes,
            );
        }

        const cmdbuf = c.wgpuCommandEncoderFinish(encoder, null) 
            orelse return error.CommandBuffer;
        defer c.wgpuCommandBufferRelease(cmdbuf);

        c.wgpuQueueSubmit(self.gpu.queue, 1, &cmdbuf);

        self.operations.clearAndFree();
        self.is_computed = true;
    }

    fn encodeOp(
        self: TensorGPU,
        encoder: c.WGPUCommandEncoder,
        op: Operation,
        buf_in: c.WGPUBuffer,
        buf_out: c.WGPUBuffer,
    ) !void {
        const shader_code = switch (op) {
            .add => SHADER_ADD,
            .mul => SHADER_MUL,
            .sub => SHADER_SUB,
            .div => SHADER_DIV,
        };

        var wgsl_src = c.WGPUShaderSourceWGSL{
            .chain = .{ .sType = c.WGPUSType_ShaderSourceWGSL },
            .code = sv(shader_code),
        };

        const shader = c.wgpuDeviceCreateShaderModule(self.gpu.device, &.{
            .nextInChain = @ptrCast(&wgsl_src),
        }) orelse return error.Shader;
        defer c.wgpuShaderModuleRelease(shader);

        const pipeline = c.wgpuDeviceCreateComputePipeline(self.gpu.device, &.{
            .compute = .{ .module = shader, .entryPoint = sv("main") },
        }) orelse return error.Pipeline;
        defer c.wgpuComputePipelineRelease(pipeline);

        // Bind groups depend on operation (binary vs unary)
        const bgl = c.wgpuComputePipelineGetBindGroupLayout(pipeline, 0);
        defer c.wgpuBindGroupLayoutRelease(bgl);

        var entries: [3]c.WGPUBindGroupEntry = undefined;
        var entry_count: u32 = 2;

        entries[0] = .{ .binding = 0, .buffer = buf_in, .size = self.buf_bytes };
        entries[1] = .{ .binding = 1, .buffer = buf_out, .size = self.buf_bytes };

        if (op == .add or op == .sub) {
            entries[2] = .{ 
                .binding = 2, 
                .buffer = op.add.other.buffer,  // or op.sub.other
                .size = self.buf_bytes,
            };
            entry_count = 3;
        }

        const bind_group = c.wgpuDeviceCreateBindGroup(self.gpu.device, &.{
            .layout = bgl,
            .entries = entries[0..entry_count],
            .entryCount = entry_count,
        }) orelse return error.BindGroup;
        defer c.wgpuBindGroupRelease(bind_group);

        const pass = c.wgpuCommandEncoderBeginComputePass(encoder, null);
        c.wgpuComputePassEncoderSetPipeline(pass, pipeline);
        c.wgpuComputePassEncoderSetBindGroup(pass, 0, bind_group, 0, null);
        
        const workgroups_x = (self.shape[1] + 3) / 4;
        const workgroups_y = (self.shape[0] + 3) / 4;
        c.wgpuComputePassEncoderDispatchWorkgroups(pass, workgroups_x, workgroups_y, 1);
        
        c.wgpuComputePassEncoderEnd(pass);
        c.wgpuComputePassEncoderRelease(pass);
    }
};

// ── Shaders ──────────────────────────────────────────────────────────────────

const SHADER_ADD =
    \\@group(0) @binding(0) var<storage, read>       mat_a : array<f32>;
    \\@group(0) @binding(1) var<storage, read_write> mat_c : array<f32>;
    \\@group(0) @binding(2) var<storage, read>       mat_b : array<f32>;
    \\@compute @workgroup_size(4, 4)
    \\fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
    \\    let idx = gid.y * 4u + gid.x;
    \\    mat_c[idx] = mat_a[idx] + mat_b[idx];
    \\}
;

const SHADER_MUL =
    \\@group(0) @binding(0) var<storage, read>       mat_a : array<f32>;
    \\@group(0) @binding(1) var<storage, read_write> mat_c : array<f32>;
    \\fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
    \\    let idx = gid.y * 4u + gid.x;
    \\    mat_c[idx] = mat_a[idx] * 5.0;  // hardcoded for demo
    \\}
;

// ... SUB, DIV similar
```

---

## Usage

```zig
var gpu_alloc = try AllocatorGPU.init(allocator);
defer gpu_alloc.deinit();

var m1 = try TensorGPU.init(&gpu_alloc, .{4, 4}, allocator);
var m2 = try TensorGPU.init(&gpu_alloc, .{4, 4}, allocator);
defer m1.deinit();
defer m2.deinit();

// Chain: lazy, no GPU work yet
m1.add(m2).mul(5).sub(m1).compute();  // ← NOW executes all at once

// m1.buffer contains final result
```

---

## Memory Breakdown

| Buffer | Lifetime | Size | Notes |
|--------|----------|------|-------|
| `m1.buffer` | Persistent (user owns) | N×4 bytes | Input + final output |
| `m2.buffer` | Persistent (user owns) | N×4 bytes | Input (read-only) |
| `buf_temp` (ping-pong) | compute() scope | N×4 bytes | Allocated/freed per compute() |

**Max GPU RAM for 3-op chain:** 2×buffer + 1×temp = 3× data size. Not 4×.

---

## Key Points

- **One command buffer:** all ops fused, single GPU submit
- **Ping-pong:** swap buf_read ↔ buf_write after each op (no extra allocs)
- **Lazy:** operations queued until `.compute()` called
- **No intermediate tensors:** user doesn't allocate intermediate results
- **Per-compute cleanup:** temp buffer freed immediately after execution

Can now chain 100 ops with same 3-buffer peak.
