Added shape comptime check for Tensor add/sub/div/mul
This commit is contained in:
parent
f37a196b15
commit
16d25e7e7e
197
src/Tensor.zig
197
src/Tensor.zig
@ -14,6 +14,15 @@ pub fn shapeTotal(comptime shape: []const usize) usize {
|
||||
return t;
|
||||
}
|
||||
|
||||
/// Check if two shapes are strictly identical.
|
||||
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
||||
if (a.len != b.len) return false;
|
||||
for (a, 0..) |v, i| {
|
||||
if (v != b[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||
/// e.g. shape {3, 4} → strides {4, 1}
|
||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||
@ -126,10 +135,6 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||
}
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// File-scope RHS normalisation helpers
|
||||
//
|
||||
// Any bare comptime_int / comptime_float / runtime T used as an arithmetic
|
||||
// or comparison RHS is wrapped into a dimensionless Tensor of shape {1}.
|
||||
// Actual Tensor types are passed through unchanged.
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn isTensor(comptime Rhs: type) bool {
|
||||
@ -193,32 +198,6 @@ pub fn printSuperscript(writer: *std.Io.Writer, n: i32) !void {
|
||||
}
|
||||
}
|
||||
|
||||
/// ─────────────────────────────────────────────────────────────────────────────
|
||||
/// Tensor — unified dimensioned ND type.
|
||||
///
|
||||
/// T : element numeric type (f32, f64, i32, i128, …)
|
||||
/// d_opt : SI dimension exponents
|
||||
/// s_opt : unit scales
|
||||
/// shape_ : compile-time shape
|
||||
/// &.{1} → scalar
|
||||
/// &.{3} → 3-vector
|
||||
/// &.{4, 4} → 4×4 matrix
|
||||
/// &.{3, 3, 3} → 3D field
|
||||
///
|
||||
/// Storage: flat @Vector(total, T) where total = product(shape_).
|
||||
/// All arithmetic operates on the flat vector directly → SIMD wherever possible.
|
||||
///
|
||||
/// Shape-related comptime constants exposed on every Tensor type:
|
||||
/// dims : Dimensions — SI exponent struct
|
||||
/// scales : Scales — unit scale struct
|
||||
/// shape : []const usize
|
||||
/// rank : usize = shape.len
|
||||
/// total : usize = product(shape)
|
||||
/// strides_arr : [rank]usize — row-major strides
|
||||
///
|
||||
/// Index helper:
|
||||
/// Tensor.idx(.{row, col}) → flat index (comptime, no runtime cost)
|
||||
/// ─────────────────────────────────────────────────────────────────────────────
|
||||
pub fn Tensor(
|
||||
comptime T: type,
|
||||
comptime d_opt: Dimensions.ArgOpts,
|
||||
@ -226,8 +205,10 @@ pub fn Tensor(
|
||||
comptime shape_: []const usize,
|
||||
) type {
|
||||
comptime {
|
||||
std.debug.assert(shape_.len >= 1);
|
||||
for (shape_) |s| std.debug.assert(s >= 1);
|
||||
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
||||
for (shape_) |s| {
|
||||
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
|
||||
}
|
||||
}
|
||||
@setEvalBranchQuota(10_000_000);
|
||||
|
||||
@ -236,7 +217,6 @@ pub fn Tensor(
|
||||
const Vec = @Vector(_total, T);
|
||||
|
||||
return struct {
|
||||
/// Flat SIMD storage. All arithmetic operates here directly.
|
||||
data: Vec,
|
||||
|
||||
const Self = @This();
|
||||
@ -250,27 +230,19 @@ pub fn Tensor(
|
||||
pub const strides_arr: [shape_.len]usize = _strides;
|
||||
pub const ISTENSOR = true;
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Index helper
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Convert N-D coords (row-major) to flat index — fully comptime.
|
||||
/// Usage: Tensor.idx(.{row, col})
|
||||
pub inline fn idx(comptime coords: [rank]usize) usize {
|
||||
comptime {
|
||||
var flat: usize = 0;
|
||||
for (0..rank) |i| {
|
||||
std.debug.assert(coords[i] < shape[i]);
|
||||
if (coords[i] >= shape[i]) @compileError("idx: Coordinate out of bounds");
|
||||
flat += coords[i] * strides_arr[i];
|
||||
}
|
||||
return flat;
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Constructors
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Broadcast a single value across all elements.
|
||||
pub inline fn splat(v: T) Self {
|
||||
return .{ .data = @splat(v) };
|
||||
@ -279,19 +251,11 @@ pub fn Tensor(
|
||||
pub const zero: Self = splat(0);
|
||||
pub const one: Self = splat(1);
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// GPU readiness
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Return a mutable slice to the flat storage — zero-copy WebGPU buffer mapping.
|
||||
pub inline fn asSlice(self: *Self) []T {
|
||||
return @as([*]T, @ptrCast(&self.data))[0..total];
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Internal: RHS normalisation
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
inline fn RhsT(comptime Rhs: type) type {
|
||||
return RhsTensorType(T, Rhs);
|
||||
}
|
||||
@ -299,10 +263,6 @@ pub fn Tensor(
|
||||
return toRhsTensor(T, r);
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Internal: scalar broadcast (shape {1} → full Vec)
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec {
|
||||
return if (comptime RhsType.total == 1 and total > 1)
|
||||
@splat(r.data[0])
|
||||
@ -310,12 +270,8 @@ pub fn Tensor(
|
||||
r.data;
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Arithmetic
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||
/// RHS must have the same element count as self, or total == 1 (broadcast).
|
||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||
pub inline fn add(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
@ -326,8 +282,8 @@ pub fn Tensor(
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in add: element counts must match or RHS must be scalar (total=1).");
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
@ -339,7 +295,8 @@ pub fn Tensor(
|
||||
return .{ .data = if (comptime isInt(T)) l +| rr else l + rr };
|
||||
}
|
||||
|
||||
/// Element-wise subtract. Dimensions must match; scales resolve to finer.
|
||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||
pub inline fn sub(self: Self, r: anytype) Tensor(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
@ -350,8 +307,8 @@ pub fn Tensor(
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in sub.");
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
@ -373,8 +330,8 @@ pub fn Tensor(
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in mul.");
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
@ -394,8 +351,8 @@ pub fn Tensor(
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
if (comptime RhsType.total != total and RhsType.total != 1)
|
||||
@compileError("Shape mismatch in div.");
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
@ -411,10 +368,6 @@ pub fn Tensor(
|
||||
}
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Unary
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Absolute value of every element.
|
||||
pub inline fn abs(self: Self) Self {
|
||||
return .{ .data = @bitCast(@abs(self.data)) };
|
||||
@ -469,13 +422,9 @@ pub fn Tensor(
|
||||
return .{ .data = -self.data };
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Conversion
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Convert to a compatible Tensor type.
|
||||
/// • Dimension mismatch → compile error.
|
||||
/// • Dest.total must equal self.total, or Dest.total == 1 (scalar pattern).
|
||||
/// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern).
|
||||
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
||||
pub inline fn to(
|
||||
self: Self,
|
||||
@ -485,20 +434,19 @@ pub fn Tensor(
|
||||
|
||||
if (comptime !dims.eql(ActualDest.dims))
|
||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||
if (comptime Self == ActualDest) return self;
|
||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||
|
||||
comptime std.debug.assert(Dest.total == total or Dest.total == 1);
|
||||
if (comptime Self == ActualDest) return self;
|
||||
|
||||
const DestT = ActualDest.ValueType;
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
// ── Same numeric type ──────────────────────────────────────
|
||||
if (comptime T == DestT) {
|
||||
if (comptime @typeInfo(T) == .float)
|
||||
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
||||
|
||||
// Integer — branch prevents division-by-zero
|
||||
if (comptime ratio >= 1.0) {
|
||||
const mult: T = comptime @intFromFloat(@round(ratio));
|
||||
return .{ .data = self.data *| @as(Vec, @splat(mult)) };
|
||||
@ -517,7 +465,6 @@ pub fn Tensor(
|
||||
}
|
||||
}
|
||||
|
||||
// ── Cross numeric type ─────────────────────────────────────
|
||||
var result: DestVec = undefined;
|
||||
inline for (0..total) |i| {
|
||||
const float_val: f64 = switch (comptime @typeInfo(T)) {
|
||||
@ -535,17 +482,6 @@ pub fn Tensor(
|
||||
return .{ .data = result };
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Comparisons
|
||||
//
|
||||
// Return type: bool when total == 1 (scalar semantics)
|
||||
// [total]bool when total > 1 (element-wise, flat-indexed)
|
||||
//
|
||||
// Whole-tensor equality check → eqAll / neAll (always returns bool).
|
||||
// A shape {1} RHS is broadcast automatically, unifying the old
|
||||
// eqScalar / gtScalar / … family into the plain eq / gt / … methods.
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
const CmpResult = if (total == 1) bool else [total]bool;
|
||||
|
||||
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
||||
@ -555,6 +491,9 @@ pub fn Tensor(
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
@ -626,26 +565,6 @@ pub fn Tensor(
|
||||
return !self.eqAll(other);
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Contraction — generalised dot product / matrix multiply / einsum
|
||||
//
|
||||
// a.contract(b, axis_a, axis_b)
|
||||
//
|
||||
// Sums over dimension `axis_a` of `a` and `axis_b` of `b`.
|
||||
// Requires a.shape[axis_a] == b.shape[axis_b] (checked at comptime).
|
||||
//
|
||||
// Result shape = a.shape \ axis_a ++ b.shape \ axis_b
|
||||
// Result dims = a.dims + b.dims (exponents summed, as in mul)
|
||||
// Result scales = finer of a, b
|
||||
//
|
||||
// Special cases:
|
||||
// rank-1 × rank-1, axis 0 × 0 → dot product (result shape {1})
|
||||
// rank-2 × rank-2, axis 1 × 0 → matrix multiply
|
||||
// rank-1 × rank-2, axis 0 × 0 → vector–matrix product
|
||||
//
|
||||
// All index arithmetic is comptime; runtime cost is the multiply-add loop only.
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub inline fn contract(
|
||||
self: Self,
|
||||
other: anytype,
|
||||
@ -653,10 +572,10 @@ pub fn Tensor(
|
||||
comptime axis_b: usize,
|
||||
) blk: {
|
||||
const OT = @TypeOf(other);
|
||||
std.debug.assert(axis_a < rank);
|
||||
std.debug.assert(axis_b < OT.rank);
|
||||
std.debug.assert(shape_[axis_a] == OT.shape[axis_b]);
|
||||
// Contracted-away free axes; empty joint → scalar shape {1}
|
||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
||||
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
||||
|
||||
const sa = shapeRemoveAxis(shape_, axis_a);
|
||||
const sb = shapeRemoveAxis(OT.shape, axis_b);
|
||||
const rs_raw = shapeCat(&sa, &sb);
|
||||
@ -683,20 +602,16 @@ pub fn Tensor(
|
||||
rs,
|
||||
);
|
||||
|
||||
// Normalise scales before accumulation
|
||||
const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_);
|
||||
const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape);
|
||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
||||
|
||||
// Precompute result strides from rs_raw (for coord decoding)
|
||||
const rs_raw_strides = comptime shapeStrides(&rs_raw);
|
||||
|
||||
var result: ResultType = .{ .data = @splat(0) };
|
||||
|
||||
inline for (0..ResultType.total) |res_flat| {
|
||||
// Decode result flat index into free coords using rs_raw layout.
|
||||
// When rs_raw.len == 0, decodeFlatCoords returns [0]usize{} — correct.
|
||||
const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides);
|
||||
|
||||
const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*;
|
||||
@ -704,7 +619,6 @@ pub fn Tensor(
|
||||
|
||||
var acc: T = 0;
|
||||
inline for (0..k) |ki| {
|
||||
// Reinsert the contracted index into free coords → full coord arrays
|
||||
const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free);
|
||||
const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free);
|
||||
const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides);
|
||||
@ -720,10 +634,6 @@ pub fn Tensor(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Reduction helpers
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Sum of squared elements. Cheaper than length(); use for ordering.
|
||||
pub inline fn lengthSqr(self: Self) T {
|
||||
return @reduce(.Add, self.data * self.data);
|
||||
@ -749,10 +659,6 @@ pub fn Tensor(
|
||||
return .{ .data = .{@reduce(.Mul, self.data)} };
|
||||
}
|
||||
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
// Formatting
|
||||
// ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn formatNumber(
|
||||
self: Self,
|
||||
writer: *std.Io.Writer,
|
||||
@ -809,15 +715,6 @@ pub fn Tensor(
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
// Tests
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Naming convention used throughout:
|
||||
// Tensor(T, d, s, &.{1}) → former Scalar
|
||||
// Tensor(T, d, s, &.{N}) → former Vector of length N
|
||||
// .data[0] → former .value()
|
||||
// .mul(x) → former .mulScalar(x) (x may be scalar Tensor or bare number)
|
||||
// .div(x) → former .divScalar(x)
|
||||
// .eq(x) → former .eqScalar(x) (broadcasts when x.total==1)
|
||||
// .contract(other, 0, 0) → former .dot(other) (for rank-1 tensors)
|
||||
// ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// ─── Scalar tests ─────────────────────────────────────────────────────────
|
||||
|
||||
@ -1234,7 +1131,6 @@ test "Vector Comparisons" {
|
||||
}
|
||||
|
||||
test "Vector vs Scalar broadcast comparison" {
|
||||
// Replaces the old eqScalar / gtScalar — now just eq / gt with a shape-{1} rhs.
|
||||
const Meter3 = Tensor(f32, .{ .L = 1 }, .{}, &.{3});
|
||||
const KiloMeter1 = Tensor(f32, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||
|
||||
@ -1259,7 +1155,6 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
|
||||
const pos = Meter3{ .data = .{ 10.0, 0.0, 0.0 } };
|
||||
const force = Newton3{ .data = .{ 5.0, 5.0, 0.0 } };
|
||||
|
||||
// work = force · pos
|
||||
const work = force.contract(pos, 0, 0);
|
||||
try std.testing.expectEqual(50.0, work.data[0]);
|
||||
try std.testing.expectEqual(1, @TypeOf(work).dims.get(.M));
|
||||
@ -1268,23 +1163,12 @@ test "Vector contract — dot product (rank-1 × rank-1)" {
|
||||
}
|
||||
|
||||
test "Vector contract — matrix multiply (rank-2 × rank-2)" {
|
||||
// 2×3 matrix multiplied by 3×2 matrix → 2×2 result
|
||||
const A = Tensor(f32, .{}, .{}, &.{ 2, 3 });
|
||||
const B = Tensor(f32, .{}, .{}, &.{ 3, 2 });
|
||||
|
||||
// A = [[1, 2, 3],
|
||||
// [4, 5, 6]]
|
||||
const a = A{ .data = .{ 1, 2, 3, 4, 5, 6 } };
|
||||
// B = [[7, 8],
|
||||
// [9, 10],
|
||||
// [11, 12]]
|
||||
const b = B{ .data = .{ 7, 8, 9, 10, 11, 12 } };
|
||||
|
||||
// C = A @ B (contract over axis 1 of A × axis 0 of B)
|
||||
// C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
|
||||
// C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
|
||||
// C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
|
||||
// C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
|
||||
const c = a.contract(b, 1, 0);
|
||||
try std.testing.expectEqual(58, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 0 })]);
|
||||
try std.testing.expectEqual(64, c.data[Tensor(f32, .{}, .{}, &.{ 2, 2 }).idx(.{ 0, 1 })]);
|
||||
@ -1371,16 +1255,15 @@ test "Vector eq broadcast on dimensionless" {
|
||||
|
||||
test "Tensor idx helper and matrix access" {
|
||||
const Mat3x3 = Tensor(f32, .{}, .{}, &.{ 3, 3 });
|
||||
// Identity-like: set [0][0]=1, [1][1]=2, [2][2]=3
|
||||
var m: Mat3x3 = Mat3x3.zero;
|
||||
m.data[Mat3x3.idx(.{ 0, 0 })] = 1.0;
|
||||
m.data[Mat3x3.idx(.{ 1, 1 })] = 2.0;
|
||||
m.data[Mat3x3.idx(.{ 2, 2 })] = 3.0;
|
||||
|
||||
try std.testing.expectEqual(1.0, m.data[0]); // [0][0]
|
||||
try std.testing.expectEqual(2.0, m.data[4]); // [1][1] (stride 3 → 1*3+1=4)
|
||||
try std.testing.expectEqual(3.0, m.data[8]); // [2][2] (2*3+2=8)
|
||||
try std.testing.expectEqual(0.0, m.data[1]); // [0][1]
|
||||
try std.testing.expectEqual(1.0, m.data[0]);
|
||||
try std.testing.expectEqual(2.0, m.data[4]);
|
||||
try std.testing.expectEqual(3.0, m.data[8]);
|
||||
try std.testing.expectEqual(0.0, m.data[1]);
|
||||
}
|
||||
|
||||
test "Tensor strides_arr correctness" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user