diff --git a/src/TensorStatic.zig b/src/TensorStatic.zig index 82b9d34..d63bb47 100644 --- a/src/TensorStatic.zig +++ b/src/TensorStatic.zig @@ -77,7 +77,7 @@ pub fn TensorStatic( return @as([*]T, @ptrCast(&self.data))[0..total]; } - /// Element-wise add. Dimensions must match; scales resolve to finer. + /// Element-wise add. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn add(self: *const Self, rhs: anytype) TensorStatic( T, @@ -99,7 +99,7 @@ pub fn TensorStatic( return .{ .data = if (comptime sh.isInt(T)) l +| r else l + r }; } - /// Element-wise sub. Dimensions must match; scales resolve to finer. + /// Element-wise sub. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). pub inline fn sub(self: *const Self, rhs: anytype) TensorStatic( T, @@ -121,7 +121,7 @@ pub fn TensorStatic( return .{ .data = if (comptime sh.isInt(T)) l -| r else l - r }; } - /// Element-wise multiply. Dimension exponents summed. + /// Element-wise multiply. Dimension exponents summed. /// Shape {1} RHS is automatically broadcast across all elements. pub inline fn mul(self: *const Self, rhs: anytype) TensorStatic( T, @@ -142,7 +142,7 @@ pub fn TensorStatic( return .{ .data = if (comptime sh.isInt(T)) l *| r else l * r }; } - /// Element-wise divide. Dimension exponents subtracted. + /// Element-wise divide. Dimension exponents subtracted. /// Shape {1} RHS is automatically broadcast across all elements. pub inline fn div(self: *const Self, rhs: anytype) TensorStatic( T, @@ -213,17 +213,26 @@ pub fn TensorStatic( /// Extract sub-tensor by half-open ranges [start, end) per axis. /// All bounds comptime. Dims and scales preserved. - /// TODO: Make it usable with negative index + /// Negative indices count from end: -1 = last element. pub inline fn slice( self: *const Self, - comptime ranges: [rank]struct { start: usize, end: usize }, + comptime ranges: [rank]struct { start: ?isize = null, end: ?isize = null }, ) blk: { var ns: [rank]comptime_int = undefined; for (0..rank) |i| { - const s = ranges[i].start; - const e = ranges[i].end; + const dim = @as(isize, @intCast(shape[i])); + const s: isize = blk2: { + const raw = ranges[i].start orelse 0; + break :blk2 if (raw < 0) raw + dim else raw; + }; + const e: isize = blk2: { + const raw = ranges[i].end orelse dim; + break :blk2 if (raw < 0) raw + dim else raw; + }; + if (s < 0) @compileError("slice: start out of bounds after normalization"); + if (e < 0) @compileError("slice: end out of bounds after normalization"); if (s >= e) @compileError("slice: start must be < end"); - if (e > shape[i]) @compileError("slice: end exceeds shape"); + if (e > dim) @compileError("slice: end exceeds shape"); ns[i] = e - s; } const new_shape: [rank]comptime_int = ns; @@ -231,24 +240,30 @@ pub fn TensorStatic( } { const new_shape: [rank]comptime_int = comptime blk: { var ns: [rank]comptime_int = undefined; - for (0..rank) |i| ns[i] = ranges[i].end - ranges[i].start; + for (0..rank) |i| { + const dim = @as(isize, @intCast(shape[i])); + const raw_s = ranges[i].start orelse 0; + const raw_e = ranges[i].end orelse dim; + const s: isize = if (raw_s < 0) raw_s + dim else raw_s; + const e: isize = if (raw_e < 0) raw_e + dim else raw_e; + ns[i] = e - s; + } break :blk ns; }; - const ResultType = TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape); - const src: [total]T = self.data; var dst: [ResultType.total]T = undefined; - for (0..ResultType.total) |flat| { var src_flat: usize = 0; inline for (0..rank) |i| { + const dim = @as(isize, @intCast(shape[i])); + const raw_s = ranges[i].start orelse 0; + const s: isize = if (raw_s < 0) raw_s + dim else raw_s; const coord = (flat / ResultType.strides_arr[i]) % new_shape[i]; - src_flat += (coord + ranges[i].start) * strides_arr[i]; + src_flat += (coord + @as(usize, @intCast(s))) * strides_arr[i]; } dst[flat] = src[src_flat]; } - return .{ .data = dst }; } @@ -1270,3 +1285,70 @@ test "Slice then scale convert" { try std.testing.expectEqual(2000, converted.data[0]); try std.testing.expectEqual(3000, converted.data[1]); } + +test "Slice 1D negative start" { + const Vec = TensorStatic(i32, .{}, .{}, &.{5}); + const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } }; + const s = v.slice(.{.{ .start = -3, .end = 5 }}); // [2,5) → 30,40,50 + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(30, s.data[0]); + try std.testing.expectEqual(40, s.data[1]); + try std.testing.expectEqual(50, s.data[2]); +} + +test "Slice 1D negative end" { + const Vec = TensorStatic(i32, .{}, .{}, &.{5}); + const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } }; + const s = v.slice(.{.{ .start = 1, .end = -1 }}); // [1,4) → 20,30,40 + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(20, s.data[0]); + try std.testing.expectEqual(30, s.data[1]); + try std.testing.expectEqual(40, s.data[2]); +} + +test "Slice 1D both negative" { + const Vec = TensorStatic(i64, .{}, .{}, &.{6}); + const v = Vec{ .data = .{ 5, 10, 15, 20, 25, 30 } }; + const s = v.slice(.{.{ .start = -4, .end = -1 }}); // [2,5) → 15,20,25 + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(15, s.data[0]); + try std.testing.expectEqual(20, s.data[1]); + try std.testing.expectEqual(25, s.data[2]); +} + +test "Slice 1D null start" { + const Vec = TensorStatic(i32, .{}, .{}, &.{5}); + const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } }; + const s = v.slice(.{.{ .end = -2 }}); // [:-2] → 10,20,30 + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(10, s.data[0]); + try std.testing.expectEqual(20, s.data[1]); + try std.testing.expectEqual(30, s.data[2]); +} + +test "Slice 1D null end" { + const Vec = TensorStatic(i32, .{}, .{}, &.{5}); + const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } }; + const s = v.slice(.{.{ .start = -3 }}); // [-3:] → 30,40,50 + try std.testing.expectEqual(3, @TypeOf(s).total); + try std.testing.expectEqual(30, s.data[0]); + try std.testing.expectEqual(40, s.data[1]); + try std.testing.expectEqual(50, s.data[2]); +} + +test "Slice 2D negative & null indices" { + const Mat = TensorStatic(i32, .{}, .{}, &.{ 4, 4 }); + const m = Mat{ .data = .{ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } }; + // last 2 rows, last 2 cols → same as subblock test [2,4)x[2,4) + const s = m.slice(.{ .{ .start = -2, .end = 4 }, .{ .start = -2 } }); + try std.testing.expectEqual(4, @TypeOf(s).total); + try std.testing.expectEqual(11, s.data[0]); + try std.testing.expectEqual(12, s.data[1]); + try std.testing.expectEqual(15, s.data[2]); + try std.testing.expectEqual(16, s.data[3]); +}