Merge pull request #12878 from gwenzek/ptx

Update Nvptx backend for Zig 0.10
This commit is contained in:
Andrew Kelley 2022-10-15 13:53:04 -04:00 committed by GitHub
commit feab1ebe1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 31 deletions

View File

@ -833,6 +833,7 @@ pub fn default_panic(msg: []const u8, error_return_trace: ?*StackTrace, ret_addr
// Didn't have boot_services, just fallback to whatever.
std.os.abort();
},
.cuda => std.os.abort(),
else => {
const first_trace_addr = ret_addr orelse @returnAddress();
std.debug.panicImpl(error_return_trace, first_trace_addr, msg);

View File

@ -500,10 +500,16 @@ pub fn abort() noreturn {
@breakpoint();
exit(1);
}
if (builtin.os.tag == .cuda) {
// TODO: introduce `@trap` instead of abusing https://github.com/ziglang/zig/issues/2291
@"llvm.trap"();
}
system.abort();
}
extern fn @"llvm.trap"() noreturn;
pub const RaiseError = UnexpectedError;
pub fn raise(sig: u8) RaiseError!void {

View File

@ -951,6 +951,13 @@ pub const Target = struct {
};
}
pub fn isNvptx(arch: Arch) bool {
return switch (arch) {
.nvptx, .nvptx64 => true,
else => false,
};
}
pub fn parseCpuModel(arch: Arch, cpu_name: []const u8) !*const Cpu.Model {
for (arch.allCpuModels()) |cpu| {
if (mem.eql(u8, cpu_name, cpu.name)) {

View File

@ -720,6 +720,15 @@ pub const Decl = struct {
var buffer = std.ArrayList(u8).init(mod.gpa);
defer buffer.deinit();
try decl.renderFullyQualifiedName(mod, buffer.writer());
// Sanitize the name for nvptx which is more restrictive.
if (mod.comp.bin_file.options.target.cpu.arch.isNvptx()) {
for (buffer.items) |*byte| switch (byte.*) {
'{', '}', '*', '[', ']', '(', ')', ',', ' ', '\'' => byte.* = '_',
else => {},
};
}
return buffer.toOwnedSliceSentinel(0);
}

View File

@ -18202,12 +18202,6 @@ fn zirAddrSpaceCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.Inst
else
dest_ptr_ty;
if (try sema.resolveMaybeUndefVal(block, ptr_src, ptr)) |val| {
// Pointer value should compatible with both address spaces.
// TODO: Figure out why this generates an invalid bitcast.
return sema.addConstant(dest_ty, val);
}
try sema.requireRuntimeBlock(block, src, ptr_src);
// TODO: Address space cast safety?
@ -21397,7 +21391,12 @@ fn validateExternType(
},
.Fn => {
if (position != .other) return false;
return !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention());
return switch (ty.fnCallingConvention()) {
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
// The goal is to experiment with more integrated CPU/GPU code.
.PtxKernel => true,
else => !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention()),
};
},
.Enum => {
var buf: Type.Payload.Bits = undefined;

View File

@ -28,10 +28,7 @@ pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx {
if (!build_options.have_llvm) return error.PtxArchNotSupported;
if (!options.use_llvm) return error.PtxArchNotSupported;
switch (options.target.cpu.arch) {
.nvptx, .nvptx64 => {},
else => return error.PtxArchNotSupported,
}
if (!options.target.cpu.arch.isNvptx()) return error.PtxArchNotSupported;
switch (options.target.os.tag) {
// TODO: does it also work with nvcl ?
@ -59,9 +56,8 @@ pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Option
if (!options.use_llvm) return error.PtxArchNotSupported;
assert(options.target.ofmt == .nvptx);
const nvptx = try createEmpty(allocator, options);
log.info("Opening .ptx target file {s}", .{sub_path});
return nvptx;
log.debug("Opening .ptx target file {s}", .{sub_path});
return createEmpty(allocator, options);
}
pub fn deinit(self: *NvPtx) void {
@ -109,13 +105,19 @@ pub fn flushModule(self: *NvPtx, comp: *Compilation, prog_node: *std.Progress.No
const tracy = trace(@src());
defer tracy.end();
var hack_comp = comp;
if (comp.bin_file.options.emit) |emit| {
hack_comp.emit_asm = .{
.directory = emit.directory,
.basename = comp.bin_file.intermediary_basename.?,
};
hack_comp.bin_file.options.emit = null;
const outfile = comp.bin_file.options.emit.?;
// We modify 'comp' before passing it to LLVM, but restore value afterwards.
// We tell LLVM to not try to build a .o, only an "assembly" file.
// This is required by the LLVM PTX backend.
comp.bin_file.options.emit = null;
comp.emit_asm = .{
.directory = outfile.directory,
.basename = comp.bin_file.intermediary_basename.?,
};
defer {
comp.bin_file.options.emit = outfile;
comp.emit_asm = null;
}
return try self.llvm_object.flushModule(hack_comp, prog_node);
try self.llvm_object.flushModule(comp, prog_node);
}

View File

@ -411,7 +411,11 @@ pub fn classifyCompilerRtLibName(target: std.Target, name: []const u8) CompilerR
}
pub fn hasDebugInfo(target: std.Target) bool {
_ = target;
if (target.cpu.arch.isNvptx()) {
// TODO: not sure how to test "ptx >= 7.5" with featureset
return std.Target.nvptx.featureSetHas(target.cpu.features, .ptx75);
}
return true;
}
@ -651,7 +655,7 @@ pub fn addrSpaceCastIsValid(
const arch = target.cpu.arch;
switch (arch) {
.x86_64, .i386 => return arch.supportsAddressSpace(from) and arch.supportsAddressSpace(to),
.amdgcn => {
.nvptx64, .nvptx, .amdgcn => {
const to_generic = arch.supportsAddressSpace(from) and to == .generic;
const from_generic = arch.supportsAddressSpace(to) and from == .generic;
return to_generic or from_generic;

View File

@ -4,6 +4,5 @@ const TestContext = @import("../src/test.zig").TestContext;
pub fn addCases(ctx: *TestContext) !void {
try @import("compile_errors.zig").addCases(ctx);
try @import("stage2/cbe.zig").addCases(ctx);
// https://github.com/ziglang/zig/issues/10968
//try @import("stage2/nvptx.zig").addCases(ctx);
try @import("stage2/nvptx.zig").addCases(ctx);
}

View File

@ -23,11 +23,10 @@ pub fn addCases(ctx: *TestContext) !void {
var case = addPtx(ctx, "nvptx: read special registers");
case.compiles(
\\fn threadIdX() usize {
\\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
\\ : [ret] "=r" (-> u32),
\\ );
\\ return @as(usize, tid);
\\fn threadIdX() u32 {
\\ return asm ("mov.u32 \t%[r], %tid.x;"
\\ : [r] "=r" (-> u32),
\\ );
\\}
\\
\\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
@ -49,6 +48,38 @@ pub fn addCases(ctx: *TestContext) !void {
\\}
);
}
{
var case = addPtx(ctx, "nvptx: reduce in shared mem");
case.compiles(
\\fn threadIdX() u32 {
\\ return asm ("mov.u32 \t%[r], %tid.x;"
\\ : [r] "=r" (-> u32),
\\ );
\\}
\\
\\ var _sdata: [1024]f32 addrspace(.shared) = undefined;
\\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(.PtxKernel) void {
\\ var sdata = @addrSpaceCast(.generic, &_sdata);
\\ const tid: u32 = threadIdX();
\\ var sum = d_x[tid];
\\ sdata[tid] = sum;
\\ asm volatile ("bar.sync \t0;");
\\ var s: u32 = 512;
\\ while (s > 0) : (s = s >> 1) {
\\ if (tid < s) {
\\ sum += sdata[tid + s];
\\ sdata[tid] = sum;
\\ }
\\ asm volatile ("bar.sync \t0;");
\\ }
\\
\\ if (tid == 0) {
\\ out.* = sum;
\\ }
\\ }
);
}
}
const nvptx_target = std.zig.CrossTarget{
@ -68,6 +99,8 @@ pub fn addPtx(
.files = std.ArrayList(TestContext.File).init(ctx.cases.allocator),
.link_libc = false,
.backend = .llvm,
// Bug in Debug mode
.optimize_mode = .ReleaseSafe,
}) catch @panic("out of memory");
return &ctx.cases.items[ctx.cases.items.len - 1];
}