slice can now use null value like python [2:]
This commit is contained in:
parent
b959f5f28a
commit
f702c1e09a
@ -77,7 +77,7 @@ pub fn TensorStatic(
|
|||||||
return @as([*]T, @ptrCast(&self.data))[0..total];
|
return @as([*]T, @ptrCast(&self.data))[0..total];
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn add(self: *const Self, rhs: anytype) TensorStatic(
|
pub inline fn add(self: *const Self, rhs: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
@ -99,7 +99,7 @@ pub fn TensorStatic(
|
|||||||
return .{ .data = if (comptime sh.isInt(T)) l +| r else l + r };
|
return .{ .data = if (comptime sh.isInt(T)) l +| r else l + r };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
pub inline fn sub(self: *const Self, rhs: anytype) TensorStatic(
|
pub inline fn sub(self: *const Self, rhs: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
@ -121,7 +121,7 @@ pub fn TensorStatic(
|
|||||||
return .{ .data = if (comptime sh.isInt(T)) l -| r else l - r };
|
return .{ .data = if (comptime sh.isInt(T)) l -| r else l - r };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise multiply. Dimension exponents summed.
|
/// Element-wise multiply. Dimension exponents summed.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn mul(self: *const Self, rhs: anytype) TensorStatic(
|
pub inline fn mul(self: *const Self, rhs: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
@ -142,7 +142,7 @@ pub fn TensorStatic(
|
|||||||
return .{ .data = if (comptime sh.isInt(T)) l *| r else l * r };
|
return .{ .data = if (comptime sh.isInt(T)) l *| r else l * r };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise divide. Dimension exponents subtracted.
|
/// Element-wise divide. Dimension exponents subtracted.
|
||||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||||
pub inline fn div(self: *const Self, rhs: anytype) TensorStatic(
|
pub inline fn div(self: *const Self, rhs: anytype) TensorStatic(
|
||||||
T,
|
T,
|
||||||
@ -213,17 +213,26 @@ pub fn TensorStatic(
|
|||||||
|
|
||||||
/// Extract sub-tensor by half-open ranges [start, end) per axis.
|
/// Extract sub-tensor by half-open ranges [start, end) per axis.
|
||||||
/// All bounds comptime. Dims and scales preserved.
|
/// All bounds comptime. Dims and scales preserved.
|
||||||
/// TODO: Make it usable with negative index
|
/// Negative indices count from end: -1 = last element.
|
||||||
pub inline fn slice(
|
pub inline fn slice(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
comptime ranges: [rank]struct { start: usize, end: usize },
|
comptime ranges: [rank]struct { start: ?isize = null, end: ?isize = null },
|
||||||
) blk: {
|
) blk: {
|
||||||
var ns: [rank]comptime_int = undefined;
|
var ns: [rank]comptime_int = undefined;
|
||||||
for (0..rank) |i| {
|
for (0..rank) |i| {
|
||||||
const s = ranges[i].start;
|
const dim = @as(isize, @intCast(shape[i]));
|
||||||
const e = ranges[i].end;
|
const s: isize = blk2: {
|
||||||
|
const raw = ranges[i].start orelse 0;
|
||||||
|
break :blk2 if (raw < 0) raw + dim else raw;
|
||||||
|
};
|
||||||
|
const e: isize = blk2: {
|
||||||
|
const raw = ranges[i].end orelse dim;
|
||||||
|
break :blk2 if (raw < 0) raw + dim else raw;
|
||||||
|
};
|
||||||
|
if (s < 0) @compileError("slice: start out of bounds after normalization");
|
||||||
|
if (e < 0) @compileError("slice: end out of bounds after normalization");
|
||||||
if (s >= e) @compileError("slice: start must be < end");
|
if (s >= e) @compileError("slice: start must be < end");
|
||||||
if (e > shape[i]) @compileError("slice: end exceeds shape");
|
if (e > dim) @compileError("slice: end exceeds shape");
|
||||||
ns[i] = e - s;
|
ns[i] = e - s;
|
||||||
}
|
}
|
||||||
const new_shape: [rank]comptime_int = ns;
|
const new_shape: [rank]comptime_int = ns;
|
||||||
@ -231,24 +240,30 @@ pub fn TensorStatic(
|
|||||||
} {
|
} {
|
||||||
const new_shape: [rank]comptime_int = comptime blk: {
|
const new_shape: [rank]comptime_int = comptime blk: {
|
||||||
var ns: [rank]comptime_int = undefined;
|
var ns: [rank]comptime_int = undefined;
|
||||||
for (0..rank) |i| ns[i] = ranges[i].end - ranges[i].start;
|
for (0..rank) |i| {
|
||||||
|
const dim = @as(isize, @intCast(shape[i]));
|
||||||
|
const raw_s = ranges[i].start orelse 0;
|
||||||
|
const raw_e = ranges[i].end orelse dim;
|
||||||
|
const s: isize = if (raw_s < 0) raw_s + dim else raw_s;
|
||||||
|
const e: isize = if (raw_e < 0) raw_e + dim else raw_e;
|
||||||
|
ns[i] = e - s;
|
||||||
|
}
|
||||||
break :blk ns;
|
break :blk ns;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ResultType = TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape);
|
const ResultType = TensorStatic(T, dims.argsOpt(), scales.argsOpt(), &new_shape);
|
||||||
|
|
||||||
const src: [total]T = self.data;
|
const src: [total]T = self.data;
|
||||||
var dst: [ResultType.total]T = undefined;
|
var dst: [ResultType.total]T = undefined;
|
||||||
|
|
||||||
for (0..ResultType.total) |flat| {
|
for (0..ResultType.total) |flat| {
|
||||||
var src_flat: usize = 0;
|
var src_flat: usize = 0;
|
||||||
inline for (0..rank) |i| {
|
inline for (0..rank) |i| {
|
||||||
|
const dim = @as(isize, @intCast(shape[i]));
|
||||||
|
const raw_s = ranges[i].start orelse 0;
|
||||||
|
const s: isize = if (raw_s < 0) raw_s + dim else raw_s;
|
||||||
const coord = (flat / ResultType.strides_arr[i]) % new_shape[i];
|
const coord = (flat / ResultType.strides_arr[i]) % new_shape[i];
|
||||||
src_flat += (coord + ranges[i].start) * strides_arr[i];
|
src_flat += (coord + @as(usize, @intCast(s))) * strides_arr[i];
|
||||||
}
|
}
|
||||||
dst[flat] = src[src_flat];
|
dst[flat] = src[src_flat];
|
||||||
}
|
}
|
||||||
|
|
||||||
return .{ .data = dst };
|
return .{ .data = dst };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1270,3 +1285,70 @@ test "Slice then scale convert" {
|
|||||||
try std.testing.expectEqual(2000, converted.data[0]);
|
try std.testing.expectEqual(2000, converted.data[0]);
|
||||||
try std.testing.expectEqual(3000, converted.data[1]);
|
try std.testing.expectEqual(3000, converted.data[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test "Slice 1D negative start" {
|
||||||
|
const Vec = TensorStatic(i32, .{}, .{}, &.{5});
|
||||||
|
const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } };
|
||||||
|
const s = v.slice(.{.{ .start = -3, .end = 5 }}); // [2,5) → 30,40,50
|
||||||
|
try std.testing.expectEqual(3, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(30, s.data[0]);
|
||||||
|
try std.testing.expectEqual(40, s.data[1]);
|
||||||
|
try std.testing.expectEqual(50, s.data[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Slice 1D negative end" {
|
||||||
|
const Vec = TensorStatic(i32, .{}, .{}, &.{5});
|
||||||
|
const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } };
|
||||||
|
const s = v.slice(.{.{ .start = 1, .end = -1 }}); // [1,4) → 20,30,40
|
||||||
|
try std.testing.expectEqual(3, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(20, s.data[0]);
|
||||||
|
try std.testing.expectEqual(30, s.data[1]);
|
||||||
|
try std.testing.expectEqual(40, s.data[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Slice 1D both negative" {
|
||||||
|
const Vec = TensorStatic(i64, .{}, .{}, &.{6});
|
||||||
|
const v = Vec{ .data = .{ 5, 10, 15, 20, 25, 30 } };
|
||||||
|
const s = v.slice(.{.{ .start = -4, .end = -1 }}); // [2,5) → 15,20,25
|
||||||
|
try std.testing.expectEqual(3, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(15, s.data[0]);
|
||||||
|
try std.testing.expectEqual(20, s.data[1]);
|
||||||
|
try std.testing.expectEqual(25, s.data[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Slice 1D null start" {
|
||||||
|
const Vec = TensorStatic(i32, .{}, .{}, &.{5});
|
||||||
|
const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } };
|
||||||
|
const s = v.slice(.{.{ .end = -2 }}); // [:-2] → 10,20,30
|
||||||
|
try std.testing.expectEqual(3, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(10, s.data[0]);
|
||||||
|
try std.testing.expectEqual(20, s.data[1]);
|
||||||
|
try std.testing.expectEqual(30, s.data[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Slice 1D null end" {
|
||||||
|
const Vec = TensorStatic(i32, .{}, .{}, &.{5});
|
||||||
|
const v = Vec{ .data = .{ 10, 20, 30, 40, 50 } };
|
||||||
|
const s = v.slice(.{.{ .start = -3 }}); // [-3:] → 30,40,50
|
||||||
|
try std.testing.expectEqual(3, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(30, s.data[0]);
|
||||||
|
try std.testing.expectEqual(40, s.data[1]);
|
||||||
|
try std.testing.expectEqual(50, s.data[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "Slice 2D negative & null indices" {
|
||||||
|
const Mat = TensorStatic(i32, .{}, .{}, &.{ 4, 4 });
|
||||||
|
const m = Mat{ .data = .{
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
} };
|
||||||
|
// last 2 rows, last 2 cols → same as subblock test [2,4)x[2,4)
|
||||||
|
const s = m.slice(.{ .{ .start = -2, .end = 4 }, .{ .start = -2 } });
|
||||||
|
try std.testing.expectEqual(4, @TypeOf(s).total);
|
||||||
|
try std.testing.expectEqual(11, s.data[0]);
|
||||||
|
try std.testing.expectEqual(12, s.data[1]);
|
||||||
|
try std.testing.expectEqual(15, s.data[2]);
|
||||||
|
try std.testing.expectEqual(16, s.data[3]);
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user