diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 3700b5bce5..3b155f42c3 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -272,7 +272,7 @@ pub const DeclGen = struct { if (!entry.found_existing) { if (decl.val.castTag(.function)) |_| { - entry.value_ptr.* = .{.func = .{ .result_id = result_id }}; + entry.value_ptr.* = .{ .func = .{ .result_id = result_id } }; } else { entry.value_ptr.* = .{ .global = try self.spv.allocGlobal() }; } @@ -418,11 +418,7 @@ pub const DeclGen = struct { fn genUndef(self: *DeclGen, ty_ref: SpvType.Ref) Error!IdRef { const result_id = self.spv.allocId(); - try self.spv.sections.types_globals_constants.emit( - self.spv.gpa, - .OpUndef, - .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id } - ); + try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id }); return result_id; } @@ -899,11 +895,12 @@ pub const DeclGen = struct { .initializer = constant_struct_id, }); // TODO: Set alignment of OpVariable. - // TODO: We may be able to eliminate this cast. + // TODO: We may be able to eliminate these casts. + const const_ptr_id = try self.makePointerConstant(section, ptr_constant_struct_ty_ref, var_id); try section.emitSpecConstantOp(self.spv.gpa, .OpBitcast, .{ .id_result_type = self.typeId(ptr_ty_ref), .id_result = result_id, - .operand = var_id, + .operand = const_ptr_id, }); } @@ -1267,13 +1264,13 @@ pub const DeclGen = struct { // Similar to unions, we're going to put the most aligned member first. if (error_align > payload_align) { // Put the error first - members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" }); - members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" }); + members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" }); + members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" }); // TODO: ABI padding? } else { // Put the payload first. - members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" }); - members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" }); + members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" }); + members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" }); // TODO: ABI padding? } @@ -1302,12 +1299,81 @@ pub const DeclGen = struct { }; } + /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure. + /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry- + /// points. The test executor will then be able to invoke these to run the tests. + /// Note that tests are lowered according to std.builtin.TestFn, which is `fn () anyerror!void`. + /// (anyerror!void has the same layout as anyerror). + /// Each test declaration generates a function like. + /// %anyerror = OpTypeInt 0 16 + /// %p_anyerror = OpTypePointer CrossWorkgroup %anyerror + /// %K = OpTypeFunction %void %p_anyerror + /// + /// %test = OpFunction %void %K + /// %p_err = OpFunctionParameter %p_anyerror + /// %lbl = OpLabel + /// %result = OpFunctionCall %anyerror %func + /// OpStore %p_err %result + /// OpFunctionEnd + /// TODO is to also write out the error as a function call parameter, and to somehow fetch + /// the name of an error in the text executor. + fn generateTestEntryPoint(self: *DeclGen, name: []const u8, func: IdResult) !void { + const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct); + const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup, null); + const void_ty_ref = try self.resolveType(Type.void, .direct); + + const kernel_proto_ty_ref = blk: { + const proto_payload = try self.spv.arena.create(SpvType.Payload.Function); + proto_payload.* = .{ + .return_type = void_ty_ref, + .parameters = try self.spv.arena.dupe(SpvType.Ref, &.{ptr_anyerror_ty_ref}), + }; + break :blk try self.spv.resolveType(SpvType.initPayload(&proto_payload.base)); + }; + + const kernel_id = self.spv.allocId(); + const error_id = self.spv.allocId(); + const p_error_id = self.spv.allocId(); + + const section = &self.spv.sections.functions; + try section.emit(self.spv.gpa, .OpFunction, .{ + .id_result_type = self.typeId(void_ty_ref), + .id_result = kernel_id, + .function_control = .{}, + .function_type = self.typeId(kernel_proto_ty_ref), + }); + try section.emit(self.spv.gpa, .OpFunctionParameter, .{ + .id_result_type = self.typeId(ptr_anyerror_ty_ref), + .id_result = p_error_id, + }); + try section.emit(self.spv.gpa, .OpLabel, .{ + .id_result = self.spv.allocId(), + }); + try section.emit(self.spv.gpa, .OpFunctionCall, .{ + .id_result_type = self.typeId(anyerror_ty_ref), + .id_result = error_id, + .function = func, + }); + try section.emit(self.spv.gpa, .OpStore, .{ + .pointer = p_error_id, + .object = error_id, + }); + try section.emit(self.spv.gpa, .OpReturn, {}); + try section.emit(self.spv.gpa, .OpFunctionEnd, {}); + + try self.spv.sections.entry_points.emit(self.spv.gpa, .OpEntryPoint, .{ + .execution_model = .Kernel, + .entry_point = kernel_id, + .name = name, + }); + } + fn genDecl(self: *DeclGen) !void { const decl = self.module.declPtr(self.decl_index); const link = try self.resolveDecl(self.decl_index); if (decl.val.castTag(.function)) |_| { - log.debug("genDecl function {s} = {}", .{decl.name, link.func.result_id.id}); + log.debug("genDecl function {s} = {}", .{ decl.name, link.func.result_id.id }); assert(decl.ty.zigTypeTag() == .Fn); const prototype_id = try self.resolveTypeId(decl.ty); @@ -1356,6 +1422,10 @@ pub const DeclGen = struct { .target = link.func.result_id, .name = fqn, }); + + if (self.module.test_functions.contains(self.decl_index)) { + try self.generateTestEntryPoint(fqn, link.func.result_id); + } } else { const init_val = if (decl.val.castTag(.variable)) |payload| payload.data.init @@ -1396,6 +1466,7 @@ pub const DeclGen = struct { const ty_ref = try self.resolveType(decl.ty, .indirect); const ptr_ty_ref = try self.spv.ptrType(ty_ref, storage_class, decl.@"align"); // TODO: Can we eliminate this cast? + // TODO: Const-wash pointer try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{ .id_result_type = self.typeId(ptr_ty_ref), .id_result = global_result_id, @@ -2036,6 +2107,24 @@ pub const DeclGen = struct { return try self.structFieldPtr(result_ptr_ty, struct_ptr_ty, struct_ptr, field_index); } + /// We cannot use an OpVariable directly in an OpSpecConstantOp, but we can + /// after we insert a dummy AccessChain... + /// TODO: Get rid of this + fn makePointerConstant( + self: *DeclGen, + section: *SpvSection, + ptr_ty_ref: SpvType.Ref, + ptr_id: IdRef, + ) !IdRef { + const result_id = self.spv.allocId(); + try section.emitSpecConstantOp(self.spv.gpa, .OpInBoundsAccessChain, .{ + .id_result_type = self.typeId(ptr_ty_ref), + .id_result = result_id, + .base = ptr_id, + }); + return result_id; + } + fn variable( self: *DeclGen, comptime context: enum { function, global }, @@ -2088,11 +2177,14 @@ pub const DeclGen = struct { .pointer = alloc_result_id, }), // TODO: Can we do without this cast or move it to runtime? - else => try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{ - .id_result_type = self.typeId(ptr_ty_ref), - .id_result = result_id, - .pointer = alloc_result_id, - }), + else => { + const const_ptr_id = try self.makePointerConstant(section, actual_ptr_ty_ref, alloc_result_id); + try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{ + .id_result_type = self.typeId(ptr_ty_ref), + .id_result = result_id, + .pointer = const_ptr_id, + }); + }, } } diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 7501ec5d92..2bf1efeee2 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -187,8 +187,8 @@ fn orderGlobalsInto( seen: *std.DynamicBitSetUnmanaged, ) !void { const node = self.globals.nodes.items[@enumToInt(global_index)]; - const deps = self.globals.dependencies.items[node.begin_dep .. node.end_dep]; - const insts = self.globals.section.instructions.items[node.begin_inst .. node.end_inst]; + const deps = self.globals.dependencies.items[node.begin_dep..node.end_dep]; + const insts = self.globals.section.instructions.items[node.begin_inst..node.end_inst]; seen.set(@enumToInt(global_index)); @@ -725,7 +725,7 @@ pub fn allocGlobal(self: *Module) !Global.Index { .begin_inst = undefined, .end_inst = undefined, .begin_dep = undefined, - .end_dep = undefined, + .end_dep = undefined, }); return @intToEnum(Global.Index, @intCast(u32, self.globals.nodes.items.len - 1)); }