From b425d887375132a915a5cd2baf7958f273732ee1 Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek Date: Tue, 4 Oct 2022 07:31:36 +0200 Subject: [PATCH] re-enable nvptx tests --- test/cases.zig | 3 +-- test/stage2/nvptx.zig | 41 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/test/cases.zig b/test/cases.zig index 65eec90f1b..412b4cb5e2 100644 --- a/test/cases.zig +++ b/test/cases.zig @@ -4,6 +4,5 @@ const TestContext = @import("../src/test.zig").TestContext; pub fn addCases(ctx: *TestContext) !void { try @import("compile_errors.zig").addCases(ctx); try @import("stage2/cbe.zig").addCases(ctx); - // https://github.com/ziglang/zig/issues/10968 - //try @import("stage2/nvptx.zig").addCases(ctx); + try @import("stage2/nvptx.zig").addCases(ctx); } diff --git a/test/stage2/nvptx.zig b/test/stage2/nvptx.zig index 7182092be7..b41a21ed6f 100644 --- a/test/stage2/nvptx.zig +++ b/test/stage2/nvptx.zig @@ -23,11 +23,10 @@ pub fn addCases(ctx: *TestContext) !void { var case = addPtx(ctx, "nvptx: read special registers"); case.compiles( - \\fn threadIdX() usize { - \\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;" - \\ : [ret] "=r" (-> u32), - \\ ); - \\ return @as(usize, tid); + \\fn threadIdX() u32 { + \\ return asm ("mov.u32 \t%[r], %tid.x;" + \\ : [r] "=r" (-> utid), + \\ ); \\} \\ \\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void { @@ -49,6 +48,38 @@ pub fn addCases(ctx: *TestContext) !void { \\} ); } + + { + var case = addPtx(ctx, "nvptx: reduce in shared mem"); + case.compiles( + \\fn threadIdX() u32 { + \\ return asm ("mov.u32 \t%[r], %tid.x;" + \\ : [r] "=r" (-> utid), + \\ ); + \\} + \\ + \\ var _sdata: [1024]f32 addrspace(.shared) = undefined; + \\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(ptx.Kernel) void { + \\ var sdata = @addrSpaceCast(.generic, &_sdata); + \\ const tid: u32 = threadIdX(); + \\ var sum = d_x[tid]; + \\ sdata[tid] = sum; + \\ asm volatile ("bar.sync \t0;"); + \\ var s: u32 = 512; + \\ while (s > 0) : (s = s >> 1) { + \\ if (tid < s) { + \\ sum += sdata[tid + s]; + \\ sdata[tid] = sum; + \\ } + \\ asm volatile ("bar.sync \t0;"); + \\ } + \\ + \\ if (tid == 0) { + \\ out.* = sum; + \\ } + \\ } + ); + } } const nvptx_target = std.zig.CrossTarget{