From b959f5f28a7097d33403ba68149af95b5e7138da Mon Sep 17 00:00:00 2001 From: adrien Date: Thu, 14 May 2026 01:28:24 +0200 Subject: [PATCH] Added slice to TensorStatic --- src/TensorStatic.zig | 154 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/src/TensorStatic.zig b/src/TensorStatic.zig index 397dfec..82b9d34 100644 --- a/src/TensorStatic.zig +++ b/src/TensorStatic.zig @@ -211,6 +211,47 @@ pub fn TensorStatic( return .{ .data = -self.data }; } + /// Extract sub-tensor by half-open ranges [start, end) per axis. + /// All bounds comptime. Dims and scales preserved. + /// TODO: Make it usable with negative index + pub inline fn slice( + self: *const Self, + comptime ranges: [rank]struct { start: usize, end: usize }, + ) blk: { + var ns: [rank]comptime_int = undefined; + for (0..rank) |i| { + const s = ranges[i].start; + const e = ranges[i].end; + if (s >= e) @compileError("slice: start must be < end"); + if (e > shape[i]) @compileError("slice: end exceeds shape"); + ns[i] = e - s; + } + const new_shape: [rank]comptime_int = ns; + break :blk TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape); + } { + const new_shape: [rank]comptime_int = comptime blk: { + var ns: [rank]comptime_int = undefined; + for (0..rank) |i| ns[i] = ranges[i].end - ranges[i].start; + break :blk ns; + }; + + const ResultType = TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape); + + const src: [total]T = self.data; + var dst: [ResultType.total]T = undefined; + + for (0..ResultType.total) |flat| { + var src_flat: usize = 0; + inline for (0..rank) |i| { + const coord = (flat / ResultType.strides_arr[i]) % new_shape[i]; + src_flat += (coord + ranges[i].start) * strides_arr[i]; + } + dst[flat] = src[src_flat]; + } + + return .{ .data = dst }; + } + /// Convert to a compatible Tensor type. /// • Dimension mismatch → compile error. /// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern). @@ -1116,3 +1157,116 @@ test "Tensor strides_arr correctness" { try std.testing.expectEqual(4, T3.strides_arr[1]); try std.testing.expectEqual(1, T3.strides_arr[2]); } + +test "Slice 1D basic" { + const Vec = TensorStatic(i32, .{}, .{}, &.{5}); + var v = Vec{ .data = .{ 10, 20, 30, 40, 50 } }; + const s = v.slice(.{.{ .start = 1, .end = 4 }}); + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(20, s.data[0]); + try std.testing.expectEqual(30, s.data[1]); + try std.testing.expectEqual(40, s.data[2]); +} + +test "Slice 1D full range" { + const Vec = TensorStatic(f32, .{}, .{}, &.{4}); + const v = Vec{ .data = .{ 1.0, 2.0, 3.0, 4.0 } }; + const s = v.slice(.{.{ .start = 0, .end = 4 }}); + try std.testing.expectEqual(4, @TypeOf(s).total); + inline for (0..4) |i| try std.testing.expectEqual(v.data[i], s.data[i]); +} + +test "Slice 1D single element" { + const Vec = TensorStatic(i64, .{}, .{}, &.{6}); + const v = Vec{ .data = .{ 5, 10, 15, 20, 25, 30 } }; + const s = v.slice(.{.{ .start = 3, .end = 4 }}); + try std.testing.expectEqual(1, @TypeOf(s).total); + try std.testing.expectEqual(20, s.data[0]); +} + +test "Slice 1D preserves dims and scales" { + const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{5}); + const v = Meter{ .data = .{ 1, 2, 3, 4, 5 } }; + const s = v.slice(.{.{ .start = 0, .end = 3 }}); + const S = @TypeOf(s); + try std.testing.expectEqual(1, S.dims.get(.L)); + try std.testing.expectEqual(Meter.scales.get(.L), S.scales.get(.L)); +} + +test "Slice 2D rows" { + const Mat = TensorStatic(i32, .{}, .{}, &.{ 4, 3 }); + const m = Mat{ .data = .{ + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10, 11, 12, + } }; + // rows [1,3), all cols + const s = m.slice(.{ .{ .start = 1, .end = 3 }, .{ .start = 0, .end = 3 } }); + try std.testing.expectEqual(6, @TypeOf(s).total); + try std.testing.expectEqual(4, s.data[0]); + try std.testing.expectEqual(5, s.data[1]); + try std.testing.expectEqual(6, s.data[2]); + try std.testing.expectEqual(7, s.data[3]); + try std.testing.expectEqual(8, s.data[4]); + try std.testing.expectEqual(9, s.data[5]); +} + +test "Slice 2D cols" { + const Mat = TensorStatic(i32, .{}, .{}, &.{ 3, 4 }); + const m = Mat{ .data = .{ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + } }; + // all rows, cols [1,3) + const s = m.slice(.{ .{ .start = 0, .end = 3 }, .{ .start = 1, .end = 3 } }); + const S = @TypeOf(s); + try std.testing.expectEqual(3, S.shape[0]); + try std.testing.expectEqual(2, S.shape[1]); + try std.testing.expectEqual(2, s.data[0]); + try std.testing.expectEqual(3, s.data[1]); + try std.testing.expectEqual(6, s.data[2]); + try std.testing.expectEqual(7, s.data[3]); + try std.testing.expectEqual(10, s.data[4]); + try std.testing.expectEqual(11, s.data[5]); +} + +test "Slice 2D subblock" { + const Mat = TensorStatic(f64, .{}, .{}, &.{ 4, 4 }); + const m = Mat{ .data = .{ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } }; + // centre 2x2 + const s = m.slice(.{ .{ .start = 1, .end = 3 }, .{ .start = 1, .end = 3 } }); + try std.testing.expectEqual(4, @TypeOf(s).total); + try std.testing.expectApproxEqAbs(6.0, s.data[0], 1e-9); + try std.testing.expectApproxEqAbs(7.0, s.data[1], 1e-9); + try std.testing.expectApproxEqAbs(10.0, s.data[2], 1e-9); + try std.testing.expectApproxEqAbs(11.0, s.data[3], 1e-9); +} + +test "Slice then add" { + const Meter = TensorStatic(i32, .{ .L = 1 }, .{}, &.{5}); + const a = Meter{ .data = .{ 1, 2, 3, 4, 5 } }; + const b = Meter{ .data = .{ 10, 20, 30, 40, 50 } }; + const sa = a.slice(.{.{ .start = 0, .end = 3 }}); + const sb = b.slice(.{.{ .start = 2, .end = 5 }}); + const r = sa.add(sb); + try std.testing.expectEqual(31, r.data[0]); // 1+30 + try std.testing.expectEqual(42, r.data[1]); // 2+40 + try std.testing.expectEqual(53, r.data[2]); // 3+50 +} + +test "Slice then scale convert" { + const KiloMeter = TensorStatic(i64, .{ .L = 1 }, .{ .L = .k }, &.{4}); + const Meter = TensorStatic(i64, .{ .L = 1 }, .{}, &.{2}); + const v = KiloMeter{ .data = .{ 1, 2, 3, 4 } }; + const s = v.slice(.{.{ .start = 1, .end = 3 }}); // {2, 3} km + const converted = s.to(Meter); + try std.testing.expectEqual(2000, converted.data[0]); + try std.testing.expectEqual(3000, converted.data[1]); +}