diff --git a/lib/std/builtin.zig b/lib/std/builtin.zig index 9e7bfc99ba..39e849fb5e 100644 --- a/lib/std/builtin.zig +++ b/lib/std/builtin.zig @@ -147,6 +147,7 @@ pub const CallingConvention = enum { AAPCS, AAPCSVFP, SysV, + PtxKernel, }; /// This data structure is used by the Zig language code generation and diff --git a/lib/std/target.zig b/lib/std/target.zig index 182690484e..9a2dcfcc66 100644 --- a/lib/std/target.zig +++ b/lib/std/target.zig @@ -579,6 +579,8 @@ pub const Target = struct { raw, /// Plan 9 from Bell Labs plan9, + /// Nvidia PTX format + nvptx, pub fn fileExt(of: ObjectFormat, cpu_arch: Cpu.Arch) [:0]const u8 { return switch (of) { @@ -589,6 +591,7 @@ pub const Target = struct { .hex => ".ihex", .raw => ".bin", .plan9 => plan9Ext(cpu_arch), + .nvptx => ".ptx", }; } }; @@ -1388,6 +1391,7 @@ pub const Target = struct { else => return switch (cpu_arch) { .wasm32, .wasm64 => .wasm, .spirv32, .spirv64 => .spirv, + .nvptx, .nvptx64 => .nvptx, else => .elf, }, }; diff --git a/lib/std/zig.zig b/lib/std/zig.zig index 1420db8ec2..9b8e2294f2 100644 --- a/lib/std/zig.zig +++ b/lib/std/zig.zig @@ -181,6 +181,7 @@ pub fn binNameAlloc(allocator: std.mem.Allocator, options: BinNameOptions) error .Obj => return std.fmt.allocPrint(allocator, "{s}{s}", .{ root_name, ofmt.fileExt(target.cpu.arch) }), .Lib => return std.fmt.allocPrint(allocator, "{s}{s}.a", .{ target.libPrefix(), root_name }), }, + .nvptx => return std.fmt.allocPrint(allocator, "{s}", .{root_name}), } } diff --git a/src/Module.zig b/src/Module.zig index e2e2505927..eeed6b2dc9 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -4242,7 +4242,7 @@ fn scanDecl(iter: *ScanDeclIter, decl_sub_index: usize, flags: u4) SemaError!voi // in `Decl` to notice that the line number did not change. mod.comp.work_queue.writeItemAssumeCapacity(.{ .update_line_number = decl }); }, - .c, .wasm, .spirv => {}, + .c, .wasm, .spirv, .nvptx => {}, } } } @@ -4316,6 +4316,7 @@ pub fn clearDecl( .c => .{ .c = {} }, .wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty }, .spirv => .{ .spirv = {} }, + .nvptx => .{ .nvptx = {} }, }; decl.fn_link = switch (mod.comp.bin_file.tag) { .coff => .{ .coff = {} }, @@ -4325,6 +4326,7 @@ pub fn clearDecl( .c => .{ .c = {} }, .wasm => .{ .wasm = link.File.Wasm.FnData.empty }, .spirv => .{ .spirv = .{} }, + .nvptx => .{ .nvptx = .{} }, }; } if (decl.getInnerNamespace()) |namespace| { @@ -4652,6 +4654,7 @@ pub fn allocateNewDecl( .c => .{ .c = {} }, .wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty }, .spirv => .{ .spirv = {} }, + .nvptx => .{ .nvptx = {} }, }, .fn_link = switch (mod.comp.bin_file.tag) { .coff => .{ .coff = {} }, @@ -4661,6 +4664,7 @@ pub fn allocateNewDecl( .c => .{ .c = {} }, .wasm => .{ .wasm = link.File.Wasm.FnData.empty }, .spirv => .{ .spirv = .{} }, + .nvptx => .{ .nvptx = .{} }, }, .generation = 0, .is_pub = false, diff --git a/src/Sema.zig b/src/Sema.zig index c4b3ad8c33..934fa4064b 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3724,6 +3724,7 @@ pub fn analyzeExport( .c => .{ .c = {} }, .wasm => .{ .wasm = {} }, .spirv => .{ .spirv = {} }, + .nvptx => .{ .nvptx = {} }, }, .owner_decl = owner_decl, .src_decl = src_decl, diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 81742d4866..08fc3879a9 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -378,7 +378,7 @@ pub const Object = struct { const mod = comp.bin_file.options.module.?; const cache_dir = mod.zig_cache_artifact_directory; - const emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit| + var emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit| try emit.basenamePath(arena, try arena.dupeZ(u8, comp.bin_file.intermediary_basename.?)) else null; @@ -5078,6 +5078,10 @@ fn toLlvmCallConv(cc: std.builtin.CallingConvention, target: std.Target) llvm.Ca }, .Signal => .AVR_SIGNAL, .SysV => .X86_64_SysV, + .PtxKernel => return switch (target.cpu.arch) { + .nvptx, .nvptx64 => .PTX_Kernel, + else => unreachable, + }, }; } diff --git a/src/link.zig b/src/link.zig index 883d79de34..51e7082aa7 100644 --- a/src/link.zig +++ b/src/link.zig @@ -215,6 +215,7 @@ pub const File = struct { c: void, wasm: Wasm.DeclBlock, spirv: void, + nvptx: void, }; pub const LinkFn = union { @@ -225,6 +226,7 @@ pub const File = struct { c: void, wasm: Wasm.FnData, spirv: SpirV.FnData, + nvptx: void, }; pub const Export = union { @@ -235,6 +237,7 @@ pub const File = struct { c: void, wasm: void, spirv: void, + nvptx: void, }; /// For DWARF .debug_info. @@ -274,6 +277,7 @@ pub const File = struct { .plan9 => return &(try Plan9.createEmpty(allocator, options)).base, .c => unreachable, // Reported error earlier. .spirv => &(try SpirV.createEmpty(allocator, options)).base, + .nvptx => &(try NvPtx.createEmpty(allocator, options)).base, .hex => return error.HexObjectFormatUnimplemented, .raw => return error.RawObjectFormatUnimplemented, }; @@ -292,6 +296,7 @@ pub const File = struct { .wasm => &(try Wasm.createEmpty(allocator, options)).base, .c => unreachable, // Reported error earlier. .spirv => &(try SpirV.createEmpty(allocator, options)).base, + .nvptx => &(try NvPtx.createEmpty(allocator, options)).base, .hex => return error.HexObjectFormatUnimplemented, .raw => return error.RawObjectFormatUnimplemented, }; @@ -312,6 +317,7 @@ pub const File = struct { .wasm => &(try Wasm.openPath(allocator, sub_path, options)).base, .c => &(try C.openPath(allocator, sub_path, options)).base, .spirv => &(try SpirV.openPath(allocator, sub_path, options)).base, + .nvptx => &(try NvPtx.openPath(allocator, sub_path, options)).base, .hex => return error.HexObjectFormatUnimplemented, .raw => return error.RawObjectFormatUnimplemented, }; @@ -344,7 +350,7 @@ pub const File = struct { .mode = determineMode(base.options), }); }, - .c, .wasm, .spirv => {}, + .c, .wasm, .spirv, .nvptx => {}, } } @@ -389,7 +395,7 @@ pub const File = struct { f.close(); base.file = null; }, - .c, .wasm, .spirv => {}, + .c, .wasm, .spirv, .nvptx => {}, } } @@ -437,6 +443,7 @@ pub const File = struct { .wasm => return @fieldParentPtr(Wasm, "base", base).updateDecl(module, decl), .spirv => return @fieldParentPtr(SpirV, "base", base).updateDecl(module, decl), .plan9 => return @fieldParentPtr(Plan9, "base", base).updateDecl(module, decl), + .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDecl(module, decl), // zig fmt: on } } @@ -456,6 +463,7 @@ pub const File = struct { .wasm => return @fieldParentPtr(Wasm, "base", base).updateFunc(module, func, air, liveness), .spirv => return @fieldParentPtr(SpirV, "base", base).updateFunc(module, func, air, liveness), .plan9 => return @fieldParentPtr(Plan9, "base", base).updateFunc(module, func, air, liveness), + .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateFunc(module, func, air, liveness), // zig fmt: on } } @@ -471,7 +479,7 @@ pub const File = struct { .macho => return @fieldParentPtr(MachO, "base", base).updateDeclLineNumber(module, decl), .c => return @fieldParentPtr(C, "base", base).updateDeclLineNumber(module, decl), .plan9 => @panic("TODO: implement updateDeclLineNumber for plan9"), - .wasm, .spirv => {}, + .wasm, .spirv, .nvptx => {}, } } @@ -493,7 +501,7 @@ pub const File = struct { }, .wasm => return @fieldParentPtr(Wasm, "base", base).allocateDeclIndexes(decl), .plan9 => return @fieldParentPtr(Plan9, "base", base).allocateDeclIndexes(decl), - .c, .spirv => {}, + .c, .spirv, .nvptx => {}, } } @@ -551,6 +559,11 @@ pub const File = struct { parent.deinit(); base.allocator.destroy(parent); }, + .nvptx => { + const parent = @fieldParentPtr(NvPtx, "base", base); + parent.deinit(); + base.allocator.destroy(parent); + }, } } @@ -584,6 +597,7 @@ pub const File = struct { .wasm => return @fieldParentPtr(Wasm, "base", base).flush(comp), .spirv => return @fieldParentPtr(SpirV, "base", base).flush(comp), .plan9 => return @fieldParentPtr(Plan9, "base", base).flush(comp), + .nvptx => return @fieldParentPtr(NvPtx, "base", base).flush(comp), } } @@ -598,6 +612,7 @@ pub const File = struct { .wasm => return @fieldParentPtr(Wasm, "base", base).flushModule(comp), .spirv => return @fieldParentPtr(SpirV, "base", base).flushModule(comp), .plan9 => return @fieldParentPtr(Plan9, "base", base).flushModule(comp), + .nvptx => return @fieldParentPtr(NvPtx, "base", base).flushModule(comp), } } @@ -612,6 +627,7 @@ pub const File = struct { .wasm => @fieldParentPtr(Wasm, "base", base).freeDecl(decl), .spirv => @fieldParentPtr(SpirV, "base", base).freeDecl(decl), .plan9 => @fieldParentPtr(Plan9, "base", base).freeDecl(decl), + .nvptx => @fieldParentPtr(NvPtx, "base", base).freeDecl(decl), } } @@ -622,7 +638,7 @@ pub const File = struct { .macho => return @fieldParentPtr(MachO, "base", base).error_flags, .plan9 => return @fieldParentPtr(Plan9, "base", base).error_flags, .c => return .{ .no_entry_point_found = false }, - .wasm, .spirv => return ErrorFlags{}, + .wasm, .spirv, .nvptx => return ErrorFlags{}, } } @@ -644,6 +660,7 @@ pub const File = struct { .wasm => return @fieldParentPtr(Wasm, "base", base).updateDeclExports(module, decl, exports), .spirv => return @fieldParentPtr(SpirV, "base", base).updateDeclExports(module, decl, exports), .plan9 => return @fieldParentPtr(Plan9, "base", base).updateDeclExports(module, decl, exports), + .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDeclExports(module, decl, exports), } } @@ -656,6 +673,7 @@ pub const File = struct { .c => unreachable, .wasm => unreachable, .spirv => unreachable, + .nvptx => unreachable, } } @@ -851,6 +869,7 @@ pub const File = struct { wasm, spirv, plan9, + nvptx, }; pub const ErrorFlags = struct { @@ -864,6 +883,7 @@ pub const File = struct { pub const MachO = @import("link/MachO.zig"); pub const SpirV = @import("link/SpirV.zig"); pub const Wasm = @import("link/Wasm.zig"); + pub const NvPtx = @import("link/NvPtx.zig"); }; pub fn determineMode(options: Options) fs.File.Mode { diff --git a/src/link/NvPtx.zig b/src/link/NvPtx.zig new file mode 100644 index 0000000000..77613cdc1d --- /dev/null +++ b/src/link/NvPtx.zig @@ -0,0 +1,122 @@ +//! NVidia PTX (Paralle Thread Execution) +//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +//! For this we rely on the nvptx backend of LLVM +//! Kernel functions need to be marked both as "export" and "callconv(.PtxKernel)" + +const NvPtx = @This(); + +const std = @import("std"); +const builtin = @import("builtin"); + +const Allocator = std.mem.Allocator; +const assert = std.debug.assert; +const log = std.log.scoped(.link); + +const Module = @import("../Module.zig"); +const Compilation = @import("../Compilation.zig"); +const link = @import("../link.zig"); +const trace = @import("../tracy.zig").trace; +const build_options = @import("build_options"); +const Air = @import("../Air.zig"); +const Liveness = @import("../Liveness.zig"); +const LlvmObject = @import("../codegen/llvm.zig").Object; + +base: link.File, +llvm_object: *LlvmObject, + +pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx { + if (!build_options.have_llvm) return error.TODOArchNotSupported; + + const nvptx = try gpa.create(NvPtx); + nvptx.* = .{ + .base = .{ + .tag = .nvptx, + .options = options, + .file = null, + .allocator = gpa, + }, + .llvm_object = undefined, + }; + + switch (options.target.cpu.arch) { + .nvptx, .nvptx64 => {}, + else => return error.TODOArchNotSupported, + } + + switch (options.target.os.tag) { + // TODO: does it also work with nvcl ? + .cuda => {}, + else => return error.TODOOsNotSupported, + } + + return nvptx; +} + +pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Options) !*NvPtx { + if (!build_options.have_llvm) @panic("nvptx target requires a zig compiler with llvm enabled."); + if (!options.use_llvm) return error.TODOArchNotSupported; + assert(options.object_format == .nvptx); + + const nvptx = try createEmpty(allocator, options); + errdefer nvptx.base.destroy(); + log.info("Opening .ptx target file {s}", .{sub_path}); + nvptx.llvm_object = try LlvmObject.create(allocator, options); + return nvptx; +} + +pub fn deinit(self: *NvPtx) void { + if (!build_options.have_llvm) return; + self.llvm_object.destroy(self.base.allocator); +} + +pub fn updateFunc(self: *NvPtx, module: *Module, func: *Module.Fn, air: Air, liveness: Liveness) !void { + if (!build_options.have_llvm) return; + try self.llvm_object.updateFunc(module, func, air, liveness); +} + +pub fn updateDecl(self: *NvPtx, module: *Module, decl: *Module.Decl) !void { + if (!build_options.have_llvm) return; + return self.llvm_object.updateDecl(module, decl); +} + +pub fn updateDeclExports( + self: *NvPtx, + module: *Module, + decl: *const Module.Decl, + exports: []const *Module.Export, +) !void { + if (!build_options.have_llvm) return; + if (build_options.skip_non_native and builtin.object_format != .nvptx) { + @panic("Attempted to compile for object format that was disabled by build configuration"); + } + return self.llvm_object.updateDeclExports(module, decl, exports); +} + +pub fn freeDecl(self: *NvPtx, decl: *Module.Decl) void { + if (!build_options.have_llvm) return; + return self.llvm_object.freeDecl(decl); +} + +pub fn flush(self: *NvPtx, comp: *Compilation) !void { + return self.flushModule(comp); +} + +pub fn flushModule(self: *NvPtx, comp: *Compilation) !void { + if (!build_options.have_llvm) return; + if (build_options.skip_non_native) { + @panic("Attempted to compile for architecture that was disabled by build configuration"); + } + const tracy = trace(@src()); + defer tracy.end(); + + var hack_comp = comp; + if (comp.bin_file.options.emit) |emit| { + hack_comp.emit_asm = .{ + .directory = emit.directory, + .basename = comp.bin_file.intermediary_basename.?, + }; + hack_comp.bin_file.options.emit = null; + } + + return try self.llvm_object.flushModule(hack_comp); +} diff --git a/src/stage1/all_types.hpp b/src/stage1/all_types.hpp index b3c578d95a..36f136c77f 100644 --- a/src/stage1/all_types.hpp +++ b/src/stage1/all_types.hpp @@ -83,7 +83,8 @@ enum CallingConvention { CallingConventionAPCS, CallingConventionAAPCS, CallingConventionAAPCSVFP, - CallingConventionSysV + CallingConventionSysV, + CallingConventionPtxKernel }; // Stage 1 supports only the generic address space diff --git a/src/stage1/analyze.cpp b/src/stage1/analyze.cpp index dfe7452cfc..0dcf1fcc06 100644 --- a/src/stage1/analyze.cpp +++ b/src/stage1/analyze.cpp @@ -991,6 +991,7 @@ const char *calling_convention_name(CallingConvention cc) { case CallingConventionAAPCSVFP: return "AAPCSVFP"; case CallingConventionInline: return "Inline"; case CallingConventionSysV: return "SysV"; + case CallingConventionPtxKernel: return "PtxKernel"; } zig_unreachable(); } @@ -1000,6 +1001,7 @@ bool calling_convention_allows_zig_types(CallingConvention cc) { case CallingConventionUnspecified: case CallingConventionAsync: case CallingConventionInline: + case CallingConventionPtxKernel: return true; case CallingConventionC: case CallingConventionNaked: @@ -2006,6 +2008,15 @@ Error emit_error_unless_callconv_allowed_for_target(CodeGen *g, AstNode *source_ case CallingConventionSysV: if (g->zig_target->arch != ZigLLVM_x86_64) allowed_platforms = "x86_64"; + break; + case CallingConventionPtxKernel: + if (g->zig_target->arch != ZigLLVM_nvptx + && g->zig_target->arch != ZigLLVM_nvptx64) + { + allowed_platforms = "nvptx and nvptx64"; + } + break; + } if (allowed_platforms != nullptr) { add_node_error(g, source_node, buf_sprintf( @@ -3827,6 +3838,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) { case CallingConventionAAPCS: case CallingConventionAAPCSVFP: case CallingConventionSysV: + case CallingConventionPtxKernel: add_fn_export(g, fn_table_entry, buf_ptr(&fn_table_entry->symbol_name), GlobalLinkageIdStrong, fn_cc); break; diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index 154e982ff9..4e9d6313db 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -209,6 +209,11 @@ static ZigLLVM_CallingConv get_llvm_cc(CodeGen *g, CallingConvention cc) { case CallingConventionSysV: assert(g->zig_target->arch == ZigLLVM_x86_64); return ZigLLVM_X86_64_SysV; + case CallingConventionPtxKernel: + assert(g->zig_target->arch == ZigLLVM_nvptx || + g->zig_target->arch == ZigLLVM_nvptx64); + return ZigLLVM_PTX_Kernel; + } zig_unreachable(); } @@ -354,6 +359,7 @@ static bool cc_want_sret_attr(CallingConvention cc) { case CallingConventionAAPCS: case CallingConventionAAPCSVFP: case CallingConventionSysV: + case CallingConventionPtxKernel: return true; case CallingConventionAsync: case CallingConventionUnspecified: diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 5694db22ee..be6226313f 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -11666,6 +11666,7 @@ static Stage1AirInst *ir_analyze_instruction_export(IrAnalyze *ira, Stage1ZirIns case CallingConventionAAPCS: case CallingConventionAAPCSVFP: case CallingConventionSysV: + case CallingConventionPtxKernel: add_fn_export(ira->codegen, fn_entry, buf_ptr(symbol_name), global_linkage_id, cc); fn_entry->section_name = section_name; break;