Added slice to TensorStatic

This commit is contained in:
adrien 2026-05-14 01:28:24 +02:00
parent 6ba1e664c1
commit b959f5f28a

View File

@ -211,6 +211,47 @@ pub fn TensorStatic(
return .{ .data = -self.data }; return .{ .data = -self.data };
} }
/// Extract sub-tensor by half-open ranges [start, end) per axis.
/// All bounds comptime. Dims and scales preserved.
/// TODO: Make it usable with negative index
pub inline fn slice(
self: *const Self,
comptime ranges: [rank]struct { start: usize, end: usize },
) blk: {
var ns: [rank]comptime_int = undefined;
for (0..rank) |i| {
const s = ranges[i].start;
const e = ranges[i].end;
if (s >= e) @compileError("slice: start must be < end");
if (e > shape[i]) @compileError("slice: end exceeds shape");
ns[i] = e - s;
}
const new_shape: [rank]comptime_int = ns;
break :blk TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape);
} {
const new_shape: [rank]comptime_int = comptime blk: {
var ns: [rank]comptime_int = undefined;
for (0..rank) |i| ns[i] = ranges[i].end - ranges[i].start;
break :blk ns;
};
const ResultType = TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape);
const src: [total]T = self.data;
var dst: [ResultType.total]T = undefined;
for (0..ResultType.total) |flat| {
var src_flat: usize = 0;
inline for (0..rank) |i| {
const coord = (flat / ResultType.strides_arr[i]) % new_shape[i];
src_flat += (coord + ranges[i].start) * strides_arr[i];
}
dst[flat] = src[src_flat];
}
return .{ .data = dst };
}
/// Convert to a compatible Tensor type. /// Convert to a compatible Tensor type.
/// Dimension mismatch compile error. /// Dimension mismatch compile error.
/// Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern). /// Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
@ -1116,3 +1157,116 @@ test "Tensor strides_arr correctness" {
try std.testing.expectEqual(4, T3.strides_arr[1]); try std.testing.expectEqual(4, T3.strides_arr[1]);
try std.testing.expectEqual(1, T3.strides_arr[2]); try std.testing.expectEqual(1, T3.strides_arr[2]);
} }
test "Slice 1D basic" {
const Vec = TensorStatic(i32, .{}, .{}, &.{5});
var v = Vec{ .data = .{ 10, 20, 30, 40, 50 } };
const s = v.slice(.{.{ .start = 1, .end = 4 }});
try std.testing.expectEqual(3, @TypeOf(s).total);
try std.testing.expectEqual(20, s.data[0]);
try std.testing.expectEqual(30, s.data[1]);
try std.testing.expectEqual(40, s.data[2]);
}
test "Slice 1D full range" {
const Vec = TensorStatic(f32, .{}, .{}, &.{4});
const v = Vec{ .data = .{ 1.0, 2.0, 3.0, 4.0 } };
const s = v.slice(.{.{ .start = 0, .end = 4 }});
try std.testing.expectEqual(4, @TypeOf(s).total);
inline for (0..4) |i| try std.testing.expectEqual(v.data[i], s.data[i]);
}
test "Slice 1D single element" {
const Vec = TensorStatic(i64, .{}, .{}, &.{6});
const v = Vec{ .data = .{ 5, 10, 15, 20, 25, 30 } };
const s = v.slice(.{.{ .start = 3, .end = 4 }});
try std.testing.expectEqual(1, @TypeOf(s).total);
try std.testing.expectEqual(20, s.data[0]);
}
test "Slice 1D preserves dims and scales" {
const Meter = TensorStatic(i128, .{ .L = 1 }, .{ .L = .k }, &.{5});
const v = Meter{ .data = .{ 1, 2, 3, 4, 5 } };
const s = v.slice(.{.{ .start = 0, .end = 3 }});
const S = @TypeOf(s);
try std.testing.expectEqual(1, S.dims.get(.L));
try std.testing.expectEqual(Meter.scales.get(.L), S.scales.get(.L));
}
test "Slice 2D rows" {
const Mat = TensorStatic(i32, .{}, .{}, &.{ 4, 3 });
const m = Mat{ .data = .{
1, 2, 3,
4, 5, 6,
7, 8, 9,
10, 11, 12,
} };
// rows [1,3), all cols
const s = m.slice(.{ .{ .start = 1, .end = 3 }, .{ .start = 0, .end = 3 } });
try std.testing.expectEqual(6, @TypeOf(s).total);
try std.testing.expectEqual(4, s.data[0]);
try std.testing.expectEqual(5, s.data[1]);
try std.testing.expectEqual(6, s.data[2]);
try std.testing.expectEqual(7, s.data[3]);
try std.testing.expectEqual(8, s.data[4]);
try std.testing.expectEqual(9, s.data[5]);
}
test "Slice 2D cols" {
const Mat = TensorStatic(i32, .{}, .{}, &.{ 3, 4 });
const m = Mat{ .data = .{
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
} };
// all rows, cols [1,3)
const s = m.slice(.{ .{ .start = 0, .end = 3 }, .{ .start = 1, .end = 3 } });
const S = @TypeOf(s);
try std.testing.expectEqual(3, S.shape[0]);
try std.testing.expectEqual(2, S.shape[1]);
try std.testing.expectEqual(2, s.data[0]);
try std.testing.expectEqual(3, s.data[1]);
try std.testing.expectEqual(6, s.data[2]);
try std.testing.expectEqual(7, s.data[3]);
try std.testing.expectEqual(10, s.data[4]);
try std.testing.expectEqual(11, s.data[5]);
}
test "Slice 2D subblock" {
const Mat = TensorStatic(f64, .{}, .{}, &.{ 4, 4 });
const m = Mat{ .data = .{
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
} };
// centre 2x2
const s = m.slice(.{ .{ .start = 1, .end = 3 }, .{ .start = 1, .end = 3 } });
try std.testing.expectEqual(4, @TypeOf(s).total);
try std.testing.expectApproxEqAbs(6.0, s.data[0], 1e-9);
try std.testing.expectApproxEqAbs(7.0, s.data[1], 1e-9);
try std.testing.expectApproxEqAbs(10.0, s.data[2], 1e-9);
try std.testing.expectApproxEqAbs(11.0, s.data[3], 1e-9);
}
test "Slice then add" {
const Meter = TensorStatic(i32, .{ .L = 1 }, .{}, &.{5});
const a = Meter{ .data = .{ 1, 2, 3, 4, 5 } };
const b = Meter{ .data = .{ 10, 20, 30, 40, 50 } };
const sa = a.slice(.{.{ .start = 0, .end = 3 }});
const sb = b.slice(.{.{ .start = 2, .end = 5 }});
const r = sa.add(sb);
try std.testing.expectEqual(31, r.data[0]); // 1+30
try std.testing.expectEqual(42, r.data[1]); // 2+40
try std.testing.expectEqual(53, r.data[2]); // 3+50
}
test "Slice then scale convert" {
const KiloMeter = TensorStatic(i64, .{ .L = 1 }, .{ .L = .k }, &.{4});
const Meter = TensorStatic(i64, .{ .L = 1 }, .{}, &.{2});
const v = KiloMeter{ .data = .{ 1, 2, 3, 4 } };
const s = v.slice(.{.{ .start = 1, .end = 3 }}); // {2, 3} km
const converted = s.to(Meter);
try std.testing.expectEqual(2000, converted.data[0]);
try std.testing.expectEqual(3000, converted.data[1]);
}