stage2: add support for Nvptx target

sample command:

/home/guw/github/zig/stage2/bin/zig build-obj cuda_kernel.zig -target nvptx64-cuda -O ReleaseSafe
this will create a kernel.ptx

expose PtxKernel call convention from LLVM
kernels are `export fn f() callconv(.PtxKernel)`
This commit is contained in:
gwenzek 2022-02-05 15:33:00 +01:00 committed by GitHub
parent fbc06f9c91
commit 0e1afb4d98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 185 additions and 8 deletions

View File

@ -147,6 +147,7 @@ pub const CallingConvention = enum {
AAPCS,
AAPCSVFP,
SysV,
PtxKernel,
};
/// This data structure is used by the Zig language code generation and

View File

@ -579,6 +579,8 @@ pub const Target = struct {
raw,
/// Plan 9 from Bell Labs
plan9,
/// Nvidia PTX format
nvptx,
pub fn fileExt(of: ObjectFormat, cpu_arch: Cpu.Arch) [:0]const u8 {
return switch (of) {
@ -589,6 +591,7 @@ pub const Target = struct {
.hex => ".ihex",
.raw => ".bin",
.plan9 => plan9Ext(cpu_arch),
.nvptx => ".ptx",
};
}
};
@ -1388,6 +1391,7 @@ pub const Target = struct {
else => return switch (cpu_arch) {
.wasm32, .wasm64 => .wasm,
.spirv32, .spirv64 => .spirv,
.nvptx, .nvptx64 => .nvptx,
else => .elf,
},
};

View File

@ -181,6 +181,7 @@ pub fn binNameAlloc(allocator: std.mem.Allocator, options: BinNameOptions) error
.Obj => return std.fmt.allocPrint(allocator, "{s}{s}", .{ root_name, ofmt.fileExt(target.cpu.arch) }),
.Lib => return std.fmt.allocPrint(allocator, "{s}{s}.a", .{ target.libPrefix(), root_name }),
},
.nvptx => return std.fmt.allocPrint(allocator, "{s}", .{root_name}),
}
}

View File

@ -4242,7 +4242,7 @@ fn scanDecl(iter: *ScanDeclIter, decl_sub_index: usize, flags: u4) SemaError!voi
// in `Decl` to notice that the line number did not change.
mod.comp.work_queue.writeItemAssumeCapacity(.{ .update_line_number = decl });
},
.c, .wasm, .spirv => {},
.c, .wasm, .spirv, .nvptx => {},
}
}
}
@ -4316,6 +4316,7 @@ pub fn clearDecl(
.c => .{ .c = {} },
.wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty },
.spirv => .{ .spirv = {} },
.nvptx => .{ .nvptx = {} },
};
decl.fn_link = switch (mod.comp.bin_file.tag) {
.coff => .{ .coff = {} },
@ -4325,6 +4326,7 @@ pub fn clearDecl(
.c => .{ .c = {} },
.wasm => .{ .wasm = link.File.Wasm.FnData.empty },
.spirv => .{ .spirv = .{} },
.nvptx => .{ .nvptx = .{} },
};
}
if (decl.getInnerNamespace()) |namespace| {
@ -4652,6 +4654,7 @@ pub fn allocateNewDecl(
.c => .{ .c = {} },
.wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty },
.spirv => .{ .spirv = {} },
.nvptx => .{ .nvptx = {} },
},
.fn_link = switch (mod.comp.bin_file.tag) {
.coff => .{ .coff = {} },
@ -4661,6 +4664,7 @@ pub fn allocateNewDecl(
.c => .{ .c = {} },
.wasm => .{ .wasm = link.File.Wasm.FnData.empty },
.spirv => .{ .spirv = .{} },
.nvptx => .{ .nvptx = .{} },
},
.generation = 0,
.is_pub = false,

View File

@ -3724,6 +3724,7 @@ pub fn analyzeExport(
.c => .{ .c = {} },
.wasm => .{ .wasm = {} },
.spirv => .{ .spirv = {} },
.nvptx => .{ .nvptx = {} },
},
.owner_decl = owner_decl,
.src_decl = src_decl,

View File

@ -378,7 +378,7 @@ pub const Object = struct {
const mod = comp.bin_file.options.module.?;
const cache_dir = mod.zig_cache_artifact_directory;
const emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit|
var emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit|
try emit.basenamePath(arena, try arena.dupeZ(u8, comp.bin_file.intermediary_basename.?))
else
null;
@ -5078,6 +5078,10 @@ fn toLlvmCallConv(cc: std.builtin.CallingConvention, target: std.Target) llvm.Ca
},
.Signal => .AVR_SIGNAL,
.SysV => .X86_64_SysV,
.PtxKernel => return switch (target.cpu.arch) {
.nvptx, .nvptx64 => .PTX_Kernel,
else => unreachable,
},
};
}

View File

@ -215,6 +215,7 @@ pub const File = struct {
c: void,
wasm: Wasm.DeclBlock,
spirv: void,
nvptx: void,
};
pub const LinkFn = union {
@ -225,6 +226,7 @@ pub const File = struct {
c: void,
wasm: Wasm.FnData,
spirv: SpirV.FnData,
nvptx: void,
};
pub const Export = union {
@ -235,6 +237,7 @@ pub const File = struct {
c: void,
wasm: void,
spirv: void,
nvptx: void,
};
/// For DWARF .debug_info.
@ -274,6 +277,7 @@ pub const File = struct {
.plan9 => return &(try Plan9.createEmpty(allocator, options)).base,
.c => unreachable, // Reported error earlier.
.spirv => &(try SpirV.createEmpty(allocator, options)).base,
.nvptx => &(try NvPtx.createEmpty(allocator, options)).base,
.hex => return error.HexObjectFormatUnimplemented,
.raw => return error.RawObjectFormatUnimplemented,
};
@ -292,6 +296,7 @@ pub const File = struct {
.wasm => &(try Wasm.createEmpty(allocator, options)).base,
.c => unreachable, // Reported error earlier.
.spirv => &(try SpirV.createEmpty(allocator, options)).base,
.nvptx => &(try NvPtx.createEmpty(allocator, options)).base,
.hex => return error.HexObjectFormatUnimplemented,
.raw => return error.RawObjectFormatUnimplemented,
};
@ -312,6 +317,7 @@ pub const File = struct {
.wasm => &(try Wasm.openPath(allocator, sub_path, options)).base,
.c => &(try C.openPath(allocator, sub_path, options)).base,
.spirv => &(try SpirV.openPath(allocator, sub_path, options)).base,
.nvptx => &(try NvPtx.openPath(allocator, sub_path, options)).base,
.hex => return error.HexObjectFormatUnimplemented,
.raw => return error.RawObjectFormatUnimplemented,
};
@ -344,7 +350,7 @@ pub const File = struct {
.mode = determineMode(base.options),
});
},
.c, .wasm, .spirv => {},
.c, .wasm, .spirv, .nvptx => {},
}
}
@ -389,7 +395,7 @@ pub const File = struct {
f.close();
base.file = null;
},
.c, .wasm, .spirv => {},
.c, .wasm, .spirv, .nvptx => {},
}
}
@ -437,6 +443,7 @@ pub const File = struct {
.wasm => return @fieldParentPtr(Wasm, "base", base).updateDecl(module, decl),
.spirv => return @fieldParentPtr(SpirV, "base", base).updateDecl(module, decl),
.plan9 => return @fieldParentPtr(Plan9, "base", base).updateDecl(module, decl),
.nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDecl(module, decl),
// zig fmt: on
}
}
@ -456,6 +463,7 @@ pub const File = struct {
.wasm => return @fieldParentPtr(Wasm, "base", base).updateFunc(module, func, air, liveness),
.spirv => return @fieldParentPtr(SpirV, "base", base).updateFunc(module, func, air, liveness),
.plan9 => return @fieldParentPtr(Plan9, "base", base).updateFunc(module, func, air, liveness),
.nvptx => return @fieldParentPtr(NvPtx, "base", base).updateFunc(module, func, air, liveness),
// zig fmt: on
}
}
@ -471,7 +479,7 @@ pub const File = struct {
.macho => return @fieldParentPtr(MachO, "base", base).updateDeclLineNumber(module, decl),
.c => return @fieldParentPtr(C, "base", base).updateDeclLineNumber(module, decl),
.plan9 => @panic("TODO: implement updateDeclLineNumber for plan9"),
.wasm, .spirv => {},
.wasm, .spirv, .nvptx => {},
}
}
@ -493,7 +501,7 @@ pub const File = struct {
},
.wasm => return @fieldParentPtr(Wasm, "base", base).allocateDeclIndexes(decl),
.plan9 => return @fieldParentPtr(Plan9, "base", base).allocateDeclIndexes(decl),
.c, .spirv => {},
.c, .spirv, .nvptx => {},
}
}
@ -551,6 +559,11 @@ pub const File = struct {
parent.deinit();
base.allocator.destroy(parent);
},
.nvptx => {
const parent = @fieldParentPtr(NvPtx, "base", base);
parent.deinit();
base.allocator.destroy(parent);
},
}
}
@ -584,6 +597,7 @@ pub const File = struct {
.wasm => return @fieldParentPtr(Wasm, "base", base).flush(comp),
.spirv => return @fieldParentPtr(SpirV, "base", base).flush(comp),
.plan9 => return @fieldParentPtr(Plan9, "base", base).flush(comp),
.nvptx => return @fieldParentPtr(NvPtx, "base", base).flush(comp),
}
}
@ -598,6 +612,7 @@ pub const File = struct {
.wasm => return @fieldParentPtr(Wasm, "base", base).flushModule(comp),
.spirv => return @fieldParentPtr(SpirV, "base", base).flushModule(comp),
.plan9 => return @fieldParentPtr(Plan9, "base", base).flushModule(comp),
.nvptx => return @fieldParentPtr(NvPtx, "base", base).flushModule(comp),
}
}
@ -612,6 +627,7 @@ pub const File = struct {
.wasm => @fieldParentPtr(Wasm, "base", base).freeDecl(decl),
.spirv => @fieldParentPtr(SpirV, "base", base).freeDecl(decl),
.plan9 => @fieldParentPtr(Plan9, "base", base).freeDecl(decl),
.nvptx => @fieldParentPtr(NvPtx, "base", base).freeDecl(decl),
}
}
@ -622,7 +638,7 @@ pub const File = struct {
.macho => return @fieldParentPtr(MachO, "base", base).error_flags,
.plan9 => return @fieldParentPtr(Plan9, "base", base).error_flags,
.c => return .{ .no_entry_point_found = false },
.wasm, .spirv => return ErrorFlags{},
.wasm, .spirv, .nvptx => return ErrorFlags{},
}
}
@ -644,6 +660,7 @@ pub const File = struct {
.wasm => return @fieldParentPtr(Wasm, "base", base).updateDeclExports(module, decl, exports),
.spirv => return @fieldParentPtr(SpirV, "base", base).updateDeclExports(module, decl, exports),
.plan9 => return @fieldParentPtr(Plan9, "base", base).updateDeclExports(module, decl, exports),
.nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDeclExports(module, decl, exports),
}
}
@ -656,6 +673,7 @@ pub const File = struct {
.c => unreachable,
.wasm => unreachable,
.spirv => unreachable,
.nvptx => unreachable,
}
}
@ -851,6 +869,7 @@ pub const File = struct {
wasm,
spirv,
plan9,
nvptx,
};
pub const ErrorFlags = struct {
@ -864,6 +883,7 @@ pub const File = struct {
pub const MachO = @import("link/MachO.zig");
pub const SpirV = @import("link/SpirV.zig");
pub const Wasm = @import("link/Wasm.zig");
pub const NvPtx = @import("link/NvPtx.zig");
};
pub fn determineMode(options: Options) fs.File.Mode {

122
src/link/NvPtx.zig Normal file
View File

@ -0,0 +1,122 @@
//! NVidia PTX (Paralle Thread Execution)
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
//! For this we rely on the nvptx backend of LLVM
//! Kernel functions need to be marked both as "export" and "callconv(.PtxKernel)"
const NvPtx = @This();
const std = @import("std");
const builtin = @import("builtin");
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const log = std.log.scoped(.link);
const Module = @import("../Module.zig");
const Compilation = @import("../Compilation.zig");
const link = @import("../link.zig");
const trace = @import("../tracy.zig").trace;
const build_options = @import("build_options");
const Air = @import("../Air.zig");
const Liveness = @import("../Liveness.zig");
const LlvmObject = @import("../codegen/llvm.zig").Object;
base: link.File,
llvm_object: *LlvmObject,
pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx {
if (!build_options.have_llvm) return error.TODOArchNotSupported;
const nvptx = try gpa.create(NvPtx);
nvptx.* = .{
.base = .{
.tag = .nvptx,
.options = options,
.file = null,
.allocator = gpa,
},
.llvm_object = undefined,
};
switch (options.target.cpu.arch) {
.nvptx, .nvptx64 => {},
else => return error.TODOArchNotSupported,
}
switch (options.target.os.tag) {
// TODO: does it also work with nvcl ?
.cuda => {},
else => return error.TODOOsNotSupported,
}
return nvptx;
}
pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Options) !*NvPtx {
if (!build_options.have_llvm) @panic("nvptx target requires a zig compiler with llvm enabled.");
if (!options.use_llvm) return error.TODOArchNotSupported;
assert(options.object_format == .nvptx);
const nvptx = try createEmpty(allocator, options);
errdefer nvptx.base.destroy();
log.info("Opening .ptx target file {s}", .{sub_path});
nvptx.llvm_object = try LlvmObject.create(allocator, options);
return nvptx;
}
pub fn deinit(self: *NvPtx) void {
if (!build_options.have_llvm) return;
self.llvm_object.destroy(self.base.allocator);
}
pub fn updateFunc(self: *NvPtx, module: *Module, func: *Module.Fn, air: Air, liveness: Liveness) !void {
if (!build_options.have_llvm) return;
try self.llvm_object.updateFunc(module, func, air, liveness);
}
pub fn updateDecl(self: *NvPtx, module: *Module, decl: *Module.Decl) !void {
if (!build_options.have_llvm) return;
return self.llvm_object.updateDecl(module, decl);
}
pub fn updateDeclExports(
self: *NvPtx,
module: *Module,
decl: *const Module.Decl,
exports: []const *Module.Export,
) !void {
if (!build_options.have_llvm) return;
if (build_options.skip_non_native and builtin.object_format != .nvptx) {
@panic("Attempted to compile for object format that was disabled by build configuration");
}
return self.llvm_object.updateDeclExports(module, decl, exports);
}
pub fn freeDecl(self: *NvPtx, decl: *Module.Decl) void {
if (!build_options.have_llvm) return;
return self.llvm_object.freeDecl(decl);
}
pub fn flush(self: *NvPtx, comp: *Compilation) !void {
return self.flushModule(comp);
}
pub fn flushModule(self: *NvPtx, comp: *Compilation) !void {
if (!build_options.have_llvm) return;
if (build_options.skip_non_native) {
@panic("Attempted to compile for architecture that was disabled by build configuration");
}
const tracy = trace(@src());
defer tracy.end();
var hack_comp = comp;
if (comp.bin_file.options.emit) |emit| {
hack_comp.emit_asm = .{
.directory = emit.directory,
.basename = comp.bin_file.intermediary_basename.?,
};
hack_comp.bin_file.options.emit = null;
}
return try self.llvm_object.flushModule(hack_comp);
}

View File

@ -83,7 +83,8 @@ enum CallingConvention {
CallingConventionAPCS,
CallingConventionAAPCS,
CallingConventionAAPCSVFP,
CallingConventionSysV
CallingConventionSysV,
CallingConventionPtxKernel
};
// Stage 1 supports only the generic address space

View File

@ -991,6 +991,7 @@ const char *calling_convention_name(CallingConvention cc) {
case CallingConventionAAPCSVFP: return "AAPCSVFP";
case CallingConventionInline: return "Inline";
case CallingConventionSysV: return "SysV";
case CallingConventionPtxKernel: return "PtxKernel";
}
zig_unreachable();
}
@ -1000,6 +1001,7 @@ bool calling_convention_allows_zig_types(CallingConvention cc) {
case CallingConventionUnspecified:
case CallingConventionAsync:
case CallingConventionInline:
case CallingConventionPtxKernel:
return true;
case CallingConventionC:
case CallingConventionNaked:
@ -2006,6 +2008,15 @@ Error emit_error_unless_callconv_allowed_for_target(CodeGen *g, AstNode *source_
case CallingConventionSysV:
if (g->zig_target->arch != ZigLLVM_x86_64)
allowed_platforms = "x86_64";
break;
case CallingConventionPtxKernel:
if (g->zig_target->arch != ZigLLVM_nvptx
&& g->zig_target->arch != ZigLLVM_nvptx64)
{
allowed_platforms = "nvptx and nvptx64";
}
break;
}
if (allowed_platforms != nullptr) {
add_node_error(g, source_node, buf_sprintf(
@ -3827,6 +3838,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
case CallingConventionAAPCS:
case CallingConventionAAPCSVFP:
case CallingConventionSysV:
case CallingConventionPtxKernel:
add_fn_export(g, fn_table_entry, buf_ptr(&fn_table_entry->symbol_name),
GlobalLinkageIdStrong, fn_cc);
break;

View File

@ -209,6 +209,11 @@ static ZigLLVM_CallingConv get_llvm_cc(CodeGen *g, CallingConvention cc) {
case CallingConventionSysV:
assert(g->zig_target->arch == ZigLLVM_x86_64);
return ZigLLVM_X86_64_SysV;
case CallingConventionPtxKernel:
assert(g->zig_target->arch == ZigLLVM_nvptx ||
g->zig_target->arch == ZigLLVM_nvptx64);
return ZigLLVM_PTX_Kernel;
}
zig_unreachable();
}
@ -354,6 +359,7 @@ static bool cc_want_sret_attr(CallingConvention cc) {
case CallingConventionAAPCS:
case CallingConventionAAPCSVFP:
case CallingConventionSysV:
case CallingConventionPtxKernel:
return true;
case CallingConventionAsync:
case CallingConventionUnspecified:

View File

@ -11666,6 +11666,7 @@ static Stage1AirInst *ir_analyze_instruction_export(IrAnalyze *ira, Stage1ZirIns
case CallingConventionAAPCS:
case CallingConventionAAPCSVFP:
case CallingConventionSysV:
case CallingConventionPtxKernel:
add_fn_export(ira->codegen, fn_entry, buf_ptr(symbol_name), global_linkage_id, cc);
fn_entry->section_name = section_name;
break;