Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22ffd4fc64 | ||
|
|
9ac3d4d699 | ||
|
|
2215a5d86d | ||
|
|
9deb25b825 | ||
|
|
ceff8ff1bd | ||
|
|
ff21f0ac8b | ||
|
|
5ac9968021 | ||
|
|
8028cf41a5 | ||
|
|
0ef19e18de | ||
|
|
7494595db4 | ||
|
|
91c5c41fc5 | ||
|
|
ba671ee486 | ||
|
|
09d6ca1ff5 | ||
|
|
bcd888d59e | ||
|
|
63e9b6b63d | ||
|
|
957f75243f | ||
|
|
5f833a5e58 | ||
|
|
00e0f5ab73 | ||
|
|
f67e9d709d | ||
|
|
e6d0f62929 | ||
|
|
f702c1e09a | ||
|
|
b959f5f28a | ||
|
|
6ba1e664c1 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
zig-out
|
||||
.zig-cache
|
||||
mkdocs.yaml
|
||||
zig-pkg
|
||||
|
||||
16
build.zig
16
build.zig
@ -4,20 +4,27 @@ pub fn build(b: *std.Build) void {
|
||||
const target = b.standardTargetOptions(.{});
|
||||
const optimize = b.standardOptimizeOption(.{});
|
||||
|
||||
// 1. Define the module so other projects can import it
|
||||
_ = b.addModule("dimal", .{
|
||||
.root_source_file = b.path("src/main.zig"),
|
||||
const zig_wgpu = b.dependency("zig_wgpu", .{
|
||||
.target = target,
|
||||
.optimize = optimize,
|
||||
});
|
||||
|
||||
// 1. Define the module so other projects can import it
|
||||
const mod = b.addModule("dimal", .{
|
||||
.root_source_file = b.path("src/lib.zig"),
|
||||
});
|
||||
mod.addImport("gpu", zig_wgpu.module("zig-wgpu"));
|
||||
|
||||
const exe_tests = b.addTest(.{
|
||||
.root_module = b.createModule(.{
|
||||
.root_source_file = b.path("src/main.zig"),
|
||||
.root_source_file = b.path("src/test.zig"),
|
||||
.target = target,
|
||||
.optimize = optimize,
|
||||
}),
|
||||
.test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple },
|
||||
});
|
||||
|
||||
exe_tests.root_module.addImport("gpu", zig_wgpu.module("zig-wgpu"));
|
||||
const run_exe_tests = b.addRunArtifact(exe_tests);
|
||||
const test_step = b.step("test", "Run tests");
|
||||
test_step.dependOn(&run_exe_tests.step);
|
||||
@ -30,6 +37,7 @@ pub fn build(b: *std.Build) void {
|
||||
.imports = &.{},
|
||||
}),
|
||||
});
|
||||
bench_exe.root_module.addImport("gpu", zig_wgpu.module("zig-wgpu"));
|
||||
|
||||
b.installArtifact(bench_exe);
|
||||
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
.{
|
||||
.name = .dimal,
|
||||
.version = "0.2.2",
|
||||
.version = "0.3.0",
|
||||
.fingerprint = 0x9453b1ff1e52d858,
|
||||
.minimum_zig_version = "0.16.0",
|
||||
.dependencies = .{},
|
||||
.dependencies = .{
|
||||
.zig_wgpu = .{
|
||||
.url = "git+https://git.bouvais.lu/adrien/zig-wgpu?ref=0.2.2#5f8da0940d77c40eacd39c268d09acbeaea0b2a5",
|
||||
.hash = "zig_wgpu-0.2.0-xsLAy2-s0QPNwR2QNd8ZX2kWiVfV5oB92N3ga1V1Uwpu",
|
||||
},
|
||||
},
|
||||
.paths = .{
|
||||
"build.zig",
|
||||
"build.zig.zon",
|
||||
|
||||
@ -3,7 +3,7 @@ const std = @import("std");
|
||||
// Adjust these imports to match your actual file names
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Scales = @import("Scales.zig");
|
||||
const Tensor = @import("Tensor.zig").TensorStatic;
|
||||
const Tensor = @import("TensorStatic.zig").Tensor;
|
||||
|
||||
fn PhysicalConstant(comptime d: Dimensions.ArgOpts, comptime val: f64, comptime s: Scales.ArgOpts) type {
|
||||
return struct {
|
||||
|
||||
1745
src/TensorAlloc.zig
Normal file
1745
src/TensorAlloc.zig
Normal file
File diff suppressed because it is too large
Load Diff
1775
src/TensorGpu.zig
Normal file
1775
src/TensorGpu.zig
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
const std = @import("std");
|
||||
const Io = std.Io;
|
||||
const Tensor = @import("Tensor.zig").TensorStatic;
|
||||
const Tensor = @import("Tensor.zig").Tensor;
|
||||
|
||||
var io: Io = undefined;
|
||||
pub fn main(init: std.process.Init) !void {
|
||||
@ -10,23 +10,8 @@ pub fn main(init: std.process.Init) !void {
|
||||
|
||||
io = init.io;
|
||||
|
||||
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
|
||||
try bench_Scalar(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_vsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_Vector(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_HighDimTensor(&stdout_writer.interface);
|
||||
@ -169,245 +154,6 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
|
||||
try writer.print("└──────────────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n", .{});
|
||||
}
|
||||
|
||||
fn bench_vsNative(writer: *std.Io.Writer) !void {
|
||||
const ITERS: usize = 100_000;
|
||||
const SAMPLES: usize = 100;
|
||||
|
||||
const getValT = struct {
|
||||
fn f(comptime TT: type, i: usize) TT {
|
||||
const v = (i % 100) + 1;
|
||||
return if (comptime @typeInfo(TT) == .float) @floatFromInt(v) else @intCast(v);
|
||||
}
|
||||
}.f;
|
||||
|
||||
const Types = .{ i32, i64, i128, f32, f64 };
|
||||
const TNames = .{ "i32", "i64", "i128", "f32", "f64" };
|
||||
// Expanded Ops to match bench_Scalar
|
||||
const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" };
|
||||
|
||||
try writer.print(
|
||||
\\
|
||||
\\ Scalar vs Native Overhead Analysis
|
||||
\\
|
||||
\\┌───────────┬──────┬───────────┬───────────┬───────────┬───────────────────────┐
|
||||
\\│ Operation │ Type │ Native │ @Vector │ Tensor{{1}} │ Slowdown Nat | Vec │
|
||||
\\├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤
|
||||
\\
|
||||
, .{});
|
||||
|
||||
inline for (Ops, 0..) |op_name, j| {
|
||||
inline for (Types, 0..) |T, tidx| {
|
||||
var native_total_ns: f64 = 0;
|
||||
var vector_total_ns: f64 = 0;
|
||||
var tensor_total_ns: f64 = 0;
|
||||
|
||||
const M = Tensor(T, .{}, .{}, &.{1});
|
||||
|
||||
for (0..SAMPLES) |_| {
|
||||
// --- 1. Benchmark Native ---
|
||||
const n_start = getTime();
|
||||
const a = getValT(T, 10);
|
||||
const b = getValT(T, 2);
|
||||
for (0..ITERS) |_| {
|
||||
// Native logic branch
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
if (comptime @typeInfo(T) == .int) a +| b else a + b
|
||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||
if (comptime @typeInfo(T) == .int) a -| b else a - b
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
if (comptime @typeInfo(T) == .int) a *| b else a * b
|
||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
|
||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||
if (comptime @typeInfo(T) == .int) @abs(a) else @as(T, @abs(a))
|
||||
else if (comptime std.mem.eql(u8, op_name, "eq"))
|
||||
a == b
|
||||
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
||||
a > b
|
||||
else
|
||||
unreachable;
|
||||
}
|
||||
const n_end = getTime();
|
||||
native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds()));
|
||||
|
||||
const v_start = getTime();
|
||||
const va = @Vector(1, T){getValT(T, 10)};
|
||||
const vb = @Vector(1, T){getValT(T, 2)};
|
||||
for (0..ITERS) |_| {
|
||||
// Native logic branch
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
if (comptime @typeInfo(T) == .int) va +| vb else va + vb
|
||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||
if (comptime @typeInfo(T) == .int) va -| vb else va - vb
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
if (comptime @typeInfo(T) == .int) va *| vb else va * vb
|
||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||
if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb
|
||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||
if (comptime @typeInfo(T) == .int) @as(T, @intCast(@abs(va[0]))) else @abs(va)
|
||||
else if (comptime std.mem.eql(u8, op_name, "eq"))
|
||||
va == vb
|
||||
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
||||
va > vb
|
||||
else
|
||||
unreachable;
|
||||
}
|
||||
const v_end = getTime();
|
||||
vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds()));
|
||||
|
||||
// --- 2. Benchmark Scalar ---
|
||||
const q_start = getTime();
|
||||
const qa = M.splat(getValT(T, 10));
|
||||
const qb = M.splat(getValT(T, 2));
|
||||
for (0..ITERS) |_| {
|
||||
// Scalar logic branch
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
qa.add(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||
qa.sub(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
qa.mul(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||
qa.div(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||
qa.abs()
|
||||
else if (comptime std.mem.eql(u8, op_name, "eq"))
|
||||
qa.eq(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
||||
qa.gt(qb)
|
||||
else
|
||||
unreachable;
|
||||
}
|
||||
const q_end = getTime();
|
||||
tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds()));
|
||||
}
|
||||
|
||||
const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
|
||||
const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
|
||||
const avg_t = (tensor_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
|
||||
const slowdown_nt = avg_t / avg_n;
|
||||
const slowdown_vt = avg_t / avg_v;
|
||||
|
||||
try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x {d:>8.2}x │\n", .{
|
||||
op_name, TNames[tidx], avg_n, avg_v, avg_t, slowdown_nt, slowdown_vt,
|
||||
});
|
||||
}
|
||||
if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤\n", .{});
|
||||
}
|
||||
|
||||
try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┴───────────────────────┘\n", .{});
|
||||
}
|
||||
|
||||
fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {
|
||||
const ITERS: usize = 100_000;
|
||||
const SAMPLES: usize = 5;
|
||||
|
||||
const getValT = struct {
|
||||
fn f(comptime TT: type, i: usize) TT {
|
||||
// Keep values safe and non-zero to avoid division by zero or overflows during cross-casting
|
||||
const v = (i % 50) + 1;
|
||||
return if (comptime @typeInfo(TT) == .float) @floatFromInt(v) else @intCast(v);
|
||||
}
|
||||
}.f;
|
||||
|
||||
// Helper for the Native baseline: explicitly casting T2 to T1 before the operation
|
||||
const castTo = struct {
|
||||
fn f(comptime DestT: type, comptime SrcT: type, val: SrcT) DestT {
|
||||
if (comptime DestT == SrcT) return val;
|
||||
const src_info = @typeInfo(SrcT);
|
||||
const dest_info = @typeInfo(DestT);
|
||||
|
||||
if (dest_info == .int and src_info == .int) return @intCast(val);
|
||||
if (dest_info == .float and src_info == .int) return @floatFromInt(val);
|
||||
if (dest_info == .int and src_info == .float) return @intFromFloat(val);
|
||||
if (dest_info == .float and src_info == .float) return @floatCast(val);
|
||||
unreachable;
|
||||
}
|
||||
}.f;
|
||||
|
||||
const Types = .{ i16, i64, i128, f32, f64 };
|
||||
const TNames = .{ "i16", "i64", "i128", "f32", "f64" };
|
||||
const Ops = .{ "add", "mul", "div" };
|
||||
|
||||
try writer.print(
|
||||
\\
|
||||
\\ Cross-Type Overhead Analysis: Scalar vs Native
|
||||
\\
|
||||
\\┌─────────┬──────┬──────┬───────────┬───────────┬───────────┐
|
||||
\\│ Op │ T1 │ T2 │ Native │ Scalar │ Slowdown │
|
||||
\\├─────────┼──────┼──────┼───────────┼───────────┼───────────┤
|
||||
\\
|
||||
, .{});
|
||||
|
||||
inline for (Ops, 0..) |op_name, j| {
|
||||
inline for (Types, 0..) |T1, t1_idx| {
|
||||
inline for (Types, 0..) |T2, t2_idx| {
|
||||
var native_total_ns: f64 = 0;
|
||||
var quantity_total_ns: f64 = 0;
|
||||
|
||||
const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1});
|
||||
const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1});
|
||||
const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1});
|
||||
|
||||
std.mem.doNotOptimizeAway({
|
||||
for (0..SAMPLES) |_| {
|
||||
// --- 1. Benchmark Native (Cast T2 to T1, then math) ---
|
||||
const n_start = getTime();
|
||||
for (0..ITERS) |i| {
|
||||
const a = getValT(T1, i);
|
||||
const b_raw = getValT(T2, 2);
|
||||
const b = castTo(T1, T2, b_raw);
|
||||
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
a + b
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
a * b
|
||||
else if (comptime @typeInfo(T1) == .int)
|
||||
@divTrunc(a, b)
|
||||
else
|
||||
a / b;
|
||||
}
|
||||
const n_end = getTime();
|
||||
native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds()));
|
||||
|
||||
// --- 2. Benchmark Scalar ---
|
||||
const q_start = getTime();
|
||||
for (0..ITERS) |i| {
|
||||
const qa = M1.splat(getValT(T1, i));
|
||||
const qb = if (comptime std.mem.eql(u8, op_name, "div"))
|
||||
S2.splat(getValT(T2, 2))
|
||||
else
|
||||
M2.splat(getValT(T2, 2));
|
||||
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
qa.add(qb)
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
qa.mul(qb)
|
||||
else
|
||||
qa.div(qb);
|
||||
}
|
||||
const q_end = getTime();
|
||||
quantity_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds()));
|
||||
}
|
||||
|
||||
const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
|
||||
const avg_q = (quantity_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));
|
||||
const slowdown = avg_q / avg_n;
|
||||
|
||||
try writer.print("│ {s:<7} │ {s:<4} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x │\n", .{
|
||||
op_name, TNames[t1_idx], TNames[t2_idx], avg_n, avg_q, slowdown,
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
if (j != Ops.len - 1) {
|
||||
try writer.print("├─────────┼──────┼──────┼───────────┼───────────┼───────────┤\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
try writer.print("└─────────┴──────┴──────┴───────────┴───────────┴───────────┘\n", .{});
|
||||
}
|
||||
|
||||
fn bench_Vector(writer: *std.Io.Writer) !void {
|
||||
const ITERS: usize = 10_000;
|
||||
const SAMPLES: usize = 10;
|
||||
@ -446,7 +192,7 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
||||
const TNames = .{ "i32", "i64", "i128", "f32", "f64" };
|
||||
const Lengths = .{ 1, 3, 4, 16, 100 };
|
||||
// "cross" is only valid for len=3; other cells will show " --- "
|
||||
const Ops = .{ "add", "div", "mulScalar", "dot", "cross", "product", "pow", "length" };
|
||||
const Ops = .{ "add", "div", "mulScalar", "dot", "product", "pow", "length" };
|
||||
|
||||
inline for (Ops, 0..) |op_name, o_idx| {
|
||||
inline for (Types, TNames) |T, tname| {
|
||||
@ -482,10 +228,6 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
||||
} else if (comptime std.mem.eql(u8, op_name, "dot")) {
|
||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||
_ = v1.contract(v2, 0, 0);
|
||||
} else if (comptime std.mem.eql(u8, op_name, "cross")) {
|
||||
// len == 3 guaranteed by the guard above
|
||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||
_ = v1.cross(v2);
|
||||
} else if (comptime std.mem.eql(u8, op_name, "product")) {
|
||||
_ = v1.product();
|
||||
} else if (comptime std.mem.eql(u8, op_name, "pow")) {
|
||||
@ -608,62 +350,3 @@ fn bench_HighDimTensor(writer: *std.Io.Writer) !void {
|
||||
}
|
||||
try writer.print("└─────────────────┴──────┴──────────────┴──────────────┴──────────────┴──────────────┘\n", .{});
|
||||
}
|
||||
|
||||
fn vectorSIMDvsNative(comptime T: type, writer: *std.Io.Writer) !void {
|
||||
const iterations: u64 = 10_000;
|
||||
const lens = [_]u32{ 1, 2, 3, 4, 5, 10, 100, 1_000, 10_000 };
|
||||
|
||||
try writer.print("\nSIMD Speedup Analysis: {s}\n", .{@typeName(T)});
|
||||
try writer.print("┌────────────┬────────────┬────────────┬────────────┐\n", .{});
|
||||
try writer.print("│ Vector Len │ Scalar (us)│ Vector (us)│ Speedup │\n", .{});
|
||||
try writer.print("├────────────┼────────────┼────────────┼────────────┤\n", .{});
|
||||
|
||||
inline for (lens) |vector_len| {
|
||||
// --- Scalar Test ---
|
||||
var scalar_val: T = 10;
|
||||
const start_scalar = getTime();
|
||||
|
||||
var i: u64 = 0;
|
||||
while (i < iterations * vector_len) : (i += 1) {
|
||||
if (comptime @typeInfo(T) == .int)
|
||||
scalar_val = scalar_val +% 1
|
||||
else
|
||||
scalar_val = scalar_val + 1;
|
||||
}
|
||||
const scalar_time = start_scalar.durationTo(getTime()).toMicroseconds();
|
||||
|
||||
// --- Vector Test ---
|
||||
var vector_val: @Vector(vector_len, T) = @splat(20);
|
||||
const start_vector = getTime();
|
||||
|
||||
i = 0;
|
||||
const increment: @Vector(vector_len, T) = @splat(1);
|
||||
while (i < iterations) : (i += 1) {
|
||||
if (comptime @typeInfo(T) == .int)
|
||||
vector_val = vector_val +% increment
|
||||
else
|
||||
vector_val = vector_val + increment;
|
||||
}
|
||||
const vector_time = start_vector.durationTo(getTime()).toMicroseconds();
|
||||
|
||||
// --- Results ---
|
||||
const s_float = @as(f64, @floatFromInt(scalar_time));
|
||||
const v_float = @as(f64, @floatFromInt(vector_time));
|
||||
|
||||
// Speedup = ScalarTime / VectorTime.
|
||||
// > 1.0 means SIMD is faster.
|
||||
const speedup = if (vector_time > 0) s_float / v_float else 0;
|
||||
|
||||
try writer.print("│ {d:<10} │ {d:>10} │ {d:>10} │ {d:>9.2}x │\n", .{
|
||||
vector_len,
|
||||
scalar_time,
|
||||
vector_time,
|
||||
speedup,
|
||||
});
|
||||
try writer.flush();
|
||||
|
||||
std.mem.doNotOptimizeAway(scalar_val);
|
||||
std.mem.doNotOptimizeAway(vector_val);
|
||||
}
|
||||
try writer.print("└────────────┴────────────┴────────────┴────────────┘\n", .{});
|
||||
}
|
||||
|
||||
9
src/lib.zig
Normal file
9
src/lib.zig
Normal file
@ -0,0 +1,9 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const TensorStatic = @import("TensorStatic.zig").Tensor;
|
||||
pub const TensorAlloc = @import("TensorAlloc.zig").Tensor;
|
||||
pub const TensorGpu = @import("TensorGpu.zig").Tensor;
|
||||
pub const Dimensions = @import("Dimensions.zig");
|
||||
pub const Scales = @import("Scales.zig");
|
||||
pub const Base = @import("Base.zig");
|
||||
pub const UnitParser = @import("UnitParser.zig");
|
||||
15
src/main.zig
15
src/main.zig
@ -1,15 +0,0 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const Tensor = @import("Tensor.zig").TensorStatic;
|
||||
pub const Dimensions = @import("Dimensions.zig");
|
||||
pub const Scales = @import("Scales.zig");
|
||||
pub const Base = @import("Base.zig");
|
||||
pub const UnitParser = @import("UnitParser.zig");
|
||||
|
||||
test {
|
||||
_ = @import("Tensor.zig");
|
||||
_ = @import("Dimensions.zig");
|
||||
_ = @import("Scales.zig");
|
||||
_ = @import("Base.zig");
|
||||
_ = @import("UnitParser.zig");
|
||||
}
|
||||
@ -4,6 +4,12 @@ const UnitScale = Scales.UnitScale;
|
||||
const Dimensions = @import("Dimensions.zig");
|
||||
const Dimension = Dimensions.Dimension;
|
||||
|
||||
pub const TensorKind = enum { static, alloc, gpu };
|
||||
|
||||
pub fn isTensor(comptime T: type) bool {
|
||||
return comptime @typeInfo(T) == .@"struct" and @hasDecl(T, "ISTENSOR");
|
||||
}
|
||||
|
||||
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||
var t: comptime_int = 1;
|
||||
for (shape) |s| t *= s;
|
||||
|
||||
9
src/test.zig
Normal file
9
src/test.zig
Normal file
@ -0,0 +1,9 @@
|
||||
test {
|
||||
_ = @import("TensorStatic.zig");
|
||||
_ = @import("TensorAlloc.zig");
|
||||
_ = @import("TensorGpu.zig");
|
||||
_ = @import("Dimensions.zig");
|
||||
_ = @import("Scales.zig");
|
||||
_ = @import("Base.zig");
|
||||
_ = @import("UnitParser.zig");
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user