Merge pull request #19337 from Snektron/spirv-globals

spirv: rework generic global
This commit is contained in:
Robin Voetter 2024-03-19 09:34:59 +01:00 committed by GitHub
commit 7057bffc14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 15003 additions and 3363 deletions

View File

@ -30,6 +30,8 @@ const SpvAssembler = @import("spirv/Assembler.zig");
const InstMap = std.AutoHashMapUnmanaged(Air.Inst.Index, IdRef); const InstMap = std.AutoHashMapUnmanaged(Air.Inst.Index, IdRef);
pub const zig_call_abi_ver = 3;
/// We want to store some extra facts about types as mapped from Zig to SPIR-V. /// We want to store some extra facts about types as mapped from Zig to SPIR-V.
/// This structure is used to keep that extra information, as well as /// This structure is used to keep that extra information, as well as
/// the cached reference to the type. /// the cached reference to the type.
@ -252,15 +254,18 @@ pub const Object = struct {
/// Note: Function does not actually generate the decl, it just allocates an index. /// Note: Function does not actually generate the decl, it just allocates an index.
pub fn resolveDecl(self: *Object, mod: *Module, decl_index: InternPool.DeclIndex) !SpvModule.Decl.Index { pub fn resolveDecl(self: *Object, mod: *Module, decl_index: InternPool.DeclIndex) !SpvModule.Decl.Index {
const decl = mod.declPtr(decl_index); const decl = mod.declPtr(decl_index);
assert(decl.has_tv); // TODO: Do we need to handle a situation where this is false?
try mod.markDeclAlive(decl); try mod.markDeclAlive(decl);
const entry = try self.decl_link.getOrPut(self.gpa, decl_index); const entry = try self.decl_link.getOrPut(self.gpa, decl_index);
if (!entry.found_existing) { if (!entry.found_existing) {
// TODO: Extern fn? // TODO: Extern fn?
const kind: SpvModule.DeclKind = if (decl.val.isFuncBody(mod)) const kind: SpvModule.Decl.Kind = if (decl.val.isFuncBody(mod))
.func .func
else else switch (decl.@"addrspace") {
.global; .generic => .invocation_global,
else => .global,
};
entry.value_ptr.* = try self.spv.allocDecl(kind); entry.value_ptr.* = try self.spv.allocDecl(kind);
} }
@ -443,40 +448,37 @@ const DeclGen = struct {
return self.inst_results.get(index).?; // Assertion means instruction does not dominate usage. return self.inst_results.get(index).?; // Assertion means instruction does not dominate usage.
} }
fn resolveAnonDecl(self: *DeclGen, val: InternPool.Index, storage_class: StorageClass) !IdRef { fn resolveAnonDecl(self: *DeclGen, val: InternPool.Index) !IdRef {
// TODO: This cannot be a function at this point, but it should probably be handled anyway. // TODO: This cannot be a function at this point, but it should probably be handled anyway.
const mod = self.module;
const ty = Type.fromInterned(mod.intern_pool.typeOf(val));
const decl_ptr_ty_ref = try self.ptrType(ty, .Generic);
const spv_decl_index = blk: { const spv_decl_index = blk: {
const entry = try self.object.anon_decl_link.getOrPut(self.object.gpa, .{ val, storage_class }); const entry = try self.object.anon_decl_link.getOrPut(self.object.gpa, .{ val, .Function });
if (entry.found_existing) { if (entry.found_existing) {
try self.addFunctionDep(entry.value_ptr.*, storage_class); try self.addFunctionDep(entry.value_ptr.*, .Function);
return self.spv.declPtr(entry.value_ptr.*).result_id;
const result_id = self.spv.declPtr(entry.value_ptr.*).result_id;
return try self.castToGeneric(self.typeId(decl_ptr_ty_ref), result_id);
} }
const spv_decl_index = try self.spv.allocDecl(.global); const spv_decl_index = try self.spv.allocDecl(.invocation_global);
try self.addFunctionDep(spv_decl_index, storage_class); try self.addFunctionDep(spv_decl_index, .Function);
entry.value_ptr.* = spv_decl_index; entry.value_ptr.* = spv_decl_index;
break :blk spv_decl_index; break :blk spv_decl_index;
}; };
const mod = self.module;
const ty = Type.fromInterned(mod.intern_pool.typeOf(val));
const ptr_ty_ref = try self.ptrType(ty, storage_class);
const var_id = self.spv.declPtr(spv_decl_index).result_id;
const section = &self.spv.sections.types_globals_constants;
try section.emit(self.spv.gpa, .OpVariable, .{
.id_result_type = self.typeId(ptr_ty_ref),
.id_result = var_id,
.storage_class = storage_class,
});
// TODO: At some point we will be able to generate this all constant here, but then all of // TODO: At some point we will be able to generate this all constant here, but then all of
// constant() will need to be implemented such that it doesn't generate any at-runtime code. // constant() will need to be implemented such that it doesn't generate any at-runtime code.
// NOTE: Because this is a global, we really only want to initialize it once. Therefore the // NOTE: Because this is a global, we really only want to initialize it once. Therefore the
// constant lowering of this value will need to be deferred to some other function, which // constant lowering of this value will need to be deferred to an initializer similar to
// is then added to the list of initializers using endGlobal(). // other globals.
const result_id = self.spv.declPtr(spv_decl_index).result_id;
{
// Save the current state so that we can temporarily generate into a different function. // Save the current state so that we can temporarily generate into a different function.
// TODO: This should probably be made a little more robust. // TODO: This should probably be made a little more robust.
const func = self.func; const func = self.func;
@ -487,9 +489,6 @@ const DeclGen = struct {
self.func = .{}; self.func = .{};
defer self.func.deinit(self.gpa); defer self.func.deinit(self.gpa);
// TODO: Merge this with genDecl?
const begin = self.spv.beginGlobal();
const void_ty_ref = try self.resolveType(Type.void, .direct); const void_ty_ref = try self.resolveType(Type.void, .direct);
const initializer_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{ const initializer_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{
.return_type = void_ty_ref, .return_type = void_ty_ref,
@ -497,6 +496,7 @@ const DeclGen = struct {
} }); } });
const initializer_id = self.spv.allocId(); const initializer_id = self.spv.allocId();
try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{ try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{
.id_result_type = self.typeId(void_ty_ref), .id_result_type = self.typeId(void_ty_ref),
.id_result = initializer_id, .id_result = initializer_id,
@ -511,19 +511,27 @@ const DeclGen = struct {
const val_id = try self.constant(ty, Value.fromInterned(val), .indirect); const val_id = try self.constant(ty, Value.fromInterned(val), .indirect);
try self.func.body.emit(self.spv.gpa, .OpStore, .{ try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = var_id, .pointer = result_id,
.object = val_id, .object = val_id,
}); });
self.spv.endGlobal(spv_decl_index, begin, var_id, initializer_id);
try self.func.body.emit(self.spv.gpa, .OpReturn, {}); try self.func.body.emit(self.spv.gpa, .OpReturn, {});
try self.func.body.emit(self.spv.gpa, .OpFunctionEnd, {}); try self.func.body.emit(self.spv.gpa, .OpFunctionEnd, {});
try self.spv.addFunction(spv_decl_index, self.func); try self.spv.addFunction(spv_decl_index, self.func);
try self.spv.debugNameFmt(var_id, "__anon_{d}", .{@intFromEnum(val)});
try self.spv.debugNameFmt(initializer_id, "initializer of __anon_{d}", .{@intFromEnum(val)}); try self.spv.debugNameFmt(initializer_id, "initializer of __anon_{d}", .{@intFromEnum(val)});
return var_id; const fn_decl_ptr_ty_ref = try self.ptrType(ty, .Function);
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpExtInst, .{
.id_result_type = self.typeId(fn_decl_ptr_ty_ref),
.id_result = result_id,
.set = try self.spv.importInstructionSet(.zig),
.instruction = .{ .inst = 0 }, // TODO: Put this definition somewhere...
.id_ref_4 = &.{initializer_id},
});
}
return try self.castToGeneric(self.typeId(decl_ptr_ty_ref), result_id);
} }
fn addFunctionDep(self: *DeclGen, decl_index: SpvModule.Decl.Index, storage_class: StorageClass) !void { fn addFunctionDep(self: *DeclGen, decl_index: SpvModule.Decl.Index, storage_class: StorageClass) !void {
@ -768,18 +776,75 @@ const DeclGen = struct {
}; };
} }
/// Construct a composite value at runtime. If the parameters are in direct /// Construct a struct at runtime.
/// representation, then the result is also in direct representation. Otherwise, /// ty must be a struct type.
/// if the parameters are in indirect representation, then the result is too. /// Constituents should be in `indirect` representation (as the elements of a struct should be).
fn constructComposite(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef { /// Result is in `direct` representation.
const constituents_id = self.spv.allocId(); fn constructStruct(self: *DeclGen, ty: Type, types: []const Type, constituents: []const IdRef) !IdRef {
const type_id = try self.resolveType(ty, .direct); assert(types.len == constituents.len);
try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{ // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
.id_result_type = self.typeId(type_id), // operands are not constant.
.id_result = constituents_id, // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
.constituents = constituents, // For now, just initialize the struct by setting the fields manually...
// TODO: Make this OpCompositeConstruct when we can
const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
for (constituents, types, 0..) |constitent_id, member_ty, index| {
const ptr_member_ty_ref = try self.ptrType(member_ty, .Function);
const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
.object = constitent_id,
}); });
return constituents_id; }
return try self.load(ty, ptr_composite_id, .{});
}
/// Construct a vector at runtime.
/// ty must be an vector type.
/// Constituents should be in `indirect` representation (as the elements of an vector should be).
/// Result is in `direct` representation.
fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
// The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
// operands are not constant.
// See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
// For now, just initialize the struct by setting the fields manually...
// TODO: Make this OpCompositeConstruct when we can
const mod = self.module;
const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
for (constituents, 0..) |constitent_id, index| {
const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
.object = constitent_id,
});
}
return try self.load(ty, ptr_composite_id, .{});
}
/// Construct an array at runtime.
/// ty must be an array type.
/// Constituents should be in `indirect` representation (as the elements of an array should be).
/// Result is in `direct` representation.
fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
// The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
// operands are not constant.
// See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
// For now, just initialize the struct by setting the fields manually...
// TODO: Make this OpCompositeConstruct when we can
const mod = self.module;
const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
for (constituents, 0..) |constitent_id, index| {
const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = ptr_id,
.object = constitent_id,
});
}
return try self.load(ty, ptr_composite_id, .{});
} }
/// This function generates a load for a constant in direct (ie, non-memory) representation. /// This function generates a load for a constant in direct (ie, non-memory) representation.
@ -887,15 +952,18 @@ const DeclGen = struct {
}); });
var constituents: [2]IdRef = undefined; var constituents: [2]IdRef = undefined;
var types: [2]Type = undefined;
if (eu_layout.error_first) { if (eu_layout.error_first) {
constituents[0] = try self.constant(err_ty, err_val, .indirect); constituents[0] = try self.constant(err_ty, err_val, .indirect);
constituents[1] = try self.constant(payload_ty, payload_val, .indirect); constituents[1] = try self.constant(payload_ty, payload_val, .indirect);
types = .{ err_ty, payload_ty };
} else { } else {
constituents[0] = try self.constant(payload_ty, payload_val, .indirect); constituents[0] = try self.constant(payload_ty, payload_val, .indirect);
constituents[1] = try self.constant(err_ty, err_val, .indirect); constituents[1] = try self.constant(err_ty, err_val, .indirect);
types = .{ payload_ty, err_ty };
} }
return try self.constructComposite(ty, &constituents); return try self.constructStruct(ty, &types, &constituents);
}, },
.enum_tag => { .enum_tag => {
const int_val = try val.intFromEnum(ty, mod); const int_val = try val.intFromEnum(ty, mod);
@ -907,7 +975,11 @@ const DeclGen = struct {
const ptr_ty = ty.slicePtrFieldType(mod); const ptr_ty = ty.slicePtrFieldType(mod);
const ptr_id = try self.constantPtr(ptr_ty, Value.fromInterned(slice.ptr)); const ptr_id = try self.constantPtr(ptr_ty, Value.fromInterned(slice.ptr));
const len_id = try self.constant(Type.usize, Value.fromInterned(slice.len), .indirect); const len_id = try self.constant(Type.usize, Value.fromInterned(slice.len), .indirect);
return self.constructComposite(ty, &.{ ptr_id, len_id }); return self.constructStruct(
ty,
&.{ ptr_ty, Type.usize },
&.{ ptr_id, len_id },
);
}, },
.opt => { .opt => {
const payload_ty = ty.optionalChild(mod); const payload_ty = ty.optionalChild(mod);
@ -934,7 +1006,11 @@ const DeclGen = struct {
else else
try self.spv.constUndef(try self.resolveType(payload_ty, .indirect)); try self.spv.constUndef(try self.resolveType(payload_ty, .indirect));
return try self.constructComposite(ty, &.{ payload_id, has_pl_id }); return try self.constructStruct(
ty,
&.{ payload_ty, Type.bool },
&.{ payload_id, has_pl_id },
);
}, },
.aggregate => |aggregate| switch (ip.indexToKey(ty.ip_index)) { .aggregate => |aggregate| switch (ip.indexToKey(ty.ip_index)) {
inline .array_type, .vector_type => |array_type, tag| { inline .array_type, .vector_type => |array_type, tag| {
@ -971,9 +1047,9 @@ const DeclGen = struct {
const sentinel = Value.fromInterned(array_type.sentinel); const sentinel = Value.fromInterned(array_type.sentinel);
constituents[constituents.len - 1] = try self.constant(elem_ty, sentinel, .indirect); constituents[constituents.len - 1] = try self.constant(elem_ty, sentinel, .indirect);
} }
return self.constructComposite(ty, constituents); return self.constructArray(ty, constituents);
}, },
inline .vector_type => return self.constructComposite(ty, constituents), inline .vector_type => return self.constructVector(ty, constituents),
else => unreachable, else => unreachable,
} }
}, },
@ -983,6 +1059,9 @@ const DeclGen = struct {
return self.todo("packed struct constants", .{}); return self.todo("packed struct constants", .{});
} }
var types = std.ArrayList(Type).init(self.gpa);
defer types.deinit();
var constituents = std.ArrayList(IdRef).init(self.gpa); var constituents = std.ArrayList(IdRef).init(self.gpa);
defer constituents.deinit(); defer constituents.deinit();
@ -998,10 +1077,11 @@ const DeclGen = struct {
const field_val = try val.fieldValue(mod, field_index); const field_val = try val.fieldValue(mod, field_index);
const field_id = try self.constant(field_ty, field_val, .indirect); const field_id = try self.constant(field_ty, field_val, .indirect);
try types.append(field_ty);
try constituents.append(field_id); try constituents.append(field_id);
} }
return try self.constructComposite(ty, constituents.items); return try self.constructStruct(ty, types.items, constituents.items);
}, },
.anon_struct_type => unreachable, // TODO .anon_struct_type => unreachable, // TODO
else => unreachable, else => unreachable,
@ -1107,19 +1187,10 @@ const DeclGen = struct {
unreachable; // TODO unreachable; // TODO
} }
const final_storage_class = self.spvStorageClass(ty.ptrAddressSpace(mod)); // Anon decl refs are always generic.
const actual_storage_class = switch (final_storage_class) { assert(ty.ptrAddressSpace(mod) == .generic);
.Generic => .CrossWorkgroup, const decl_ptr_ty_ref = try self.ptrType(decl_ty, .Generic);
else => |other| other, const ptr_id = try self.resolveAnonDecl(decl_val);
};
const decl_id = try self.resolveAnonDecl(decl_val, actual_storage_class);
const decl_ptr_ty_ref = try self.ptrType(decl_ty, final_storage_class);
const ptr_id = switch (final_storage_class) {
.Generic => try self.castToGeneric(self.typeId(decl_ptr_ty_ref), decl_id),
else => decl_id,
};
if (decl_ptr_ty_ref != ty_ref) { if (decl_ptr_ty_ref != ty_ref) {
// Differing pointer types, insert a cast. // Differing pointer types, insert a cast.
@ -1157,8 +1228,13 @@ const DeclGen = struct {
} }
const spv_decl_index = try self.object.resolveDecl(mod, decl_index); const spv_decl_index = try self.object.resolveDecl(mod, decl_index);
const spv_decl = self.spv.declPtr(spv_decl_index);
const decl_id = switch (spv_decl.kind) {
.func => unreachable, // TODO: Is this possible?
.global, .invocation_global => spv_decl.result_id,
};
const decl_id = self.spv.declPtr(spv_decl_index).result_id;
const final_storage_class = self.spvStorageClass(decl.@"addrspace"); const final_storage_class = self.spvStorageClass(decl.@"addrspace");
try self.addFunctionDep(spv_decl_index, final_storage_class); try self.addFunctionDep(spv_decl_index, final_storage_class);
@ -1437,6 +1513,13 @@ const DeclGen = struct {
if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref; if (self.type_map.get(ty.toIntern())) |info| return info.ty_ref;
const fn_info = mod.typeToFunc(ty).?; const fn_info = mod.typeToFunc(ty).?;
comptime assert(zig_call_abi_ver == 3);
switch (fn_info.cc) {
.Unspecified, .Kernel, .Fragment, .Vertex, .C => {},
else => unreachable, // TODO
}
// TODO: Put this somewhere in Sema.zig // TODO: Put this somewhere in Sema.zig
if (fn_info.is_var_args) if (fn_info.is_var_args)
return self.fail("VarArgs functions are unsupported for SPIR-V", .{}); return self.fail("VarArgs functions are unsupported for SPIR-V", .{});
@ -1841,7 +1924,7 @@ const DeclGen = struct {
for (wip.results) |*result| { for (wip.results) |*result| {
result.* = try wip.dg.convertToIndirect(wip.ty, result.*); result.* = try wip.dg.convertToIndirect(wip.ty, result.*);
} }
return try wip.dg.constructComposite(wip.result_ty, wip.results); return try wip.dg.constructArray(wip.result_ty, wip.results);
} else { } else {
return wip.results[0]; return wip.results[0];
} }
@ -1884,13 +1967,15 @@ const DeclGen = struct {
/// (anyerror!void has the same layout as anyerror). /// (anyerror!void has the same layout as anyerror).
/// Each test declaration generates a function like. /// Each test declaration generates a function like.
/// %anyerror = OpTypeInt 0 16 /// %anyerror = OpTypeInt 0 16
/// %p_invocation_globals_struct_ty = ...
/// %p_anyerror = OpTypePointer CrossWorkgroup %anyerror /// %p_anyerror = OpTypePointer CrossWorkgroup %anyerror
/// %K = OpTypeFunction %void %p_anyerror /// %K = OpTypeFunction %void %p_invocation_globals_struct_ty %p_anyerror
/// ///
/// %test = OpFunction %void %K /// %test = OpFunction %void %K
/// %p_invocation_globals = OpFunctionParameter p_invocation_globals_struct_ty
/// %p_err = OpFunctionParameter %p_anyerror /// %p_err = OpFunctionParameter %p_anyerror
/// %lbl = OpLabel /// %lbl = OpLabel
/// %result = OpFunctionCall %anyerror %func /// %result = OpFunctionCall %anyerror %func %p_invocation_globals
/// OpStore %p_err %result /// OpStore %p_err %result
/// OpFunctionEnd /// OpFunctionEnd
/// TODO is to also write out the error as a function call parameter, and to somehow fetch /// TODO is to also write out the error as a function call parameter, and to somehow fetch
@ -1900,10 +1985,12 @@ const DeclGen = struct {
const ptr_anyerror_ty_ref = try self.ptrType(Type.anyerror, .CrossWorkgroup); const ptr_anyerror_ty_ref = try self.ptrType(Type.anyerror, .CrossWorkgroup);
const void_ty_ref = try self.resolveType(Type.void, .direct); const void_ty_ref = try self.resolveType(Type.void, .direct);
const kernel_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{ const kernel_proto_ty_ref = try self.spv.resolve(.{
.function_type = .{
.return_type = void_ty_ref, .return_type = void_ty_ref,
.parameters = &.{ptr_anyerror_ty_ref}, .parameters = &.{ptr_anyerror_ty_ref},
} }); },
});
const test_id = self.spv.declPtr(spv_test_decl_index).result_id; const test_id = self.spv.declPtr(spv_test_decl_index).result_id;
@ -1954,26 +2041,26 @@ const DeclGen = struct {
const ip = &mod.intern_pool; const ip = &mod.intern_pool;
const decl = mod.declPtr(self.decl_index); const decl = mod.declPtr(self.decl_index);
const spv_decl_index = try self.object.resolveDecl(mod, self.decl_index); const spv_decl_index = try self.object.resolveDecl(mod, self.decl_index);
const target = self.getTarget(); const result_id = self.spv.declPtr(spv_decl_index).result_id;
const decl_id = self.spv.declPtr(spv_decl_index).result_id; switch (self.spv.declPtr(spv_decl_index).kind) {
.func => {
if (decl.val.getFunction(mod)) |_| {
assert(decl.ty.zigTypeTag(mod) == .Fn); assert(decl.ty.zigTypeTag(mod) == .Fn);
const fn_info = mod.typeToFunc(decl.ty).?; const fn_info = mod.typeToFunc(decl.ty).?;
const return_ty_ref = try self.resolveFnReturnType(Type.fromInterned(fn_info.return_type)); const return_ty_ref = try self.resolveFnReturnType(Type.fromInterned(fn_info.return_type));
const prototype_id = try self.resolveTypeId(decl.ty); const prototype_ty_ref = try self.resolveType(decl.ty, .direct);
try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{ try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{
.id_result_type = self.typeId(return_ty_ref), .id_result_type = self.typeId(return_ty_ref),
.id_result = decl_id, .id_result = result_id,
.function_control = switch (fn_info.cc) { .function_control = switch (fn_info.cc) {
.Inline => .{ .Inline = true }, .Inline => .{ .Inline = true },
else => .{}, else => .{},
}, },
.function_type = prototype_id, .function_type = self.typeId(prototype_ty_ref),
}); });
comptime assert(zig_call_abi_ver == 3);
try self.args.ensureUnusedCapacity(self.gpa, fn_info.param_types.len); try self.args.ensureUnusedCapacity(self.gpa, fn_info.param_types.len);
for (fn_info.param_types.get(ip)) |param_ty_index| { for (fn_info.param_types.get(ip)) |param_ty_index| {
const param_ty = Type.fromInterned(param_ty_index); const param_ty = Type.fromInterned(param_ty_index);
@ -2015,14 +2102,40 @@ const DeclGen = struct {
try self.spv.addFunction(spv_decl_index, self.func); try self.spv.addFunction(spv_decl_index, self.func);
const fqn = ip.stringToSlice(try decl.fullyQualifiedName(self.module)); const fqn = ip.stringToSlice(try decl.fullyQualifiedName(self.module));
try self.spv.debugName(decl_id, fqn); try self.spv.debugName(result_id, fqn);
// Temporarily generate a test kernel declaration if this is a test function. // Temporarily generate a test kernel declaration if this is a test function.
if (self.module.test_functions.contains(self.decl_index)) { if (self.module.test_functions.contains(self.decl_index)) {
try self.generateTestEntryPoint(fqn, spv_decl_index); try self.generateTestEntryPoint(fqn, spv_decl_index);
} }
} else { },
const opt_init_val: ?Value = blk: { .global => {
const maybe_init_val: ?Value = blk: {
if (decl.val.getVariable(mod)) |payload| {
if (payload.is_extern) break :blk null;
break :blk Value.fromInterned(payload.init);
}
break :blk decl.val;
};
assert(maybe_init_val == null); // TODO
const final_storage_class = self.spvStorageClass(decl.@"addrspace");
assert(final_storage_class != .Generic); // These should be instance globals
const ptr_ty_ref = try self.ptrType(decl.ty, final_storage_class);
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpVariable, .{
.id_result_type = self.typeId(ptr_ty_ref),
.id_result = result_id,
.storage_class = final_storage_class,
});
const fqn = ip.stringToSlice(try decl.fullyQualifiedName(self.module));
try self.spv.debugName(result_id, fqn);
try self.spv.declareDeclDeps(spv_decl_index, &.{});
},
.invocation_global => {
const maybe_init_val: ?Value = blk: {
if (decl.val.getVariable(mod)) |payload| { if (decl.val.getVariable(mod)) |payload| {
if (payload.is_extern) break :blk null; if (payload.is_extern) break :blk null;
break :blk Value.fromInterned(payload.init); break :blk Value.fromInterned(payload.init);
@ -2030,40 +2143,18 @@ const DeclGen = struct {
break :blk decl.val; break :blk decl.val;
}; };
// Generate the actual variable for the global... try self.spv.declareDeclDeps(spv_decl_index, &.{});
const final_storage_class = self.spvStorageClass(decl.@"addrspace");
const actual_storage_class = blk: {
if (target.os.tag != .vulkan) {
break :blk switch (final_storage_class) {
.Generic => .CrossWorkgroup,
else => final_storage_class,
};
}
break :blk final_storage_class;
};
const ptr_ty_ref = try self.ptrType(decl.ty, actual_storage_class); const ptr_ty_ref = try self.ptrType(decl.ty, .Function);
const begin = self.spv.beginGlobal(); if (maybe_init_val) |init_val| {
try self.spv.globals.section.emit(self.spv.gpa, .OpVariable, .{ // TODO: Combine with resolveAnonDecl?
.id_result_type = self.typeId(ptr_ty_ref),
.id_result = decl_id,
.storage_class = actual_storage_class,
});
const fqn = ip.stringToSlice(try decl.fullyQualifiedName(self.module));
try self.spv.debugName(decl_id, fqn);
if (opt_init_val) |init_val| {
// Currently, initializers for CrossWorkgroup variables is not implemented
// in Mesa. Therefore we generate an initialization kernel instead.
const void_ty_ref = try self.resolveType(Type.void, .direct); const void_ty_ref = try self.resolveType(Type.void, .direct);
const initializer_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{ const initializer_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{
.return_type = void_ty_ref, .return_type = void_ty_ref,
.parameters = &.{}, .parameters = &.{},
} }); } });
// Now emit the instructions that initialize the variable.
const initializer_id = self.spv.allocId(); const initializer_id = self.spv.allocId();
try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{ try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{
.id_result_type = self.typeId(void_ty_ref), .id_result_type = self.typeId(void_ty_ref),
@ -2071,6 +2162,7 @@ const DeclGen = struct {
.function_control = .{}, .function_control = .{},
.function_type = self.typeId(initializer_proto_ty_ref), .function_type = self.typeId(initializer_proto_ty_ref),
}); });
const root_block_id = self.spv.allocId(); const root_block_id = self.spv.allocId();
try self.func.prologue.emit(self.spv.gpa, .OpLabel, .{ try self.func.prologue.emit(self.spv.gpa, .OpLabel, .{
.id_result = root_block_id, .id_result = root_block_id,
@ -2079,22 +2171,34 @@ const DeclGen = struct {
const val_id = try self.constant(decl.ty, init_val, .indirect); const val_id = try self.constant(decl.ty, init_val, .indirect);
try self.func.body.emit(self.spv.gpa, .OpStore, .{ try self.func.body.emit(self.spv.gpa, .OpStore, .{
.pointer = decl_id, .pointer = result_id,
.object = val_id, .object = val_id,
}); });
// TODO: We should be able to get rid of this by now...
self.spv.endGlobal(spv_decl_index, begin, decl_id, initializer_id);
try self.func.body.emit(self.spv.gpa, .OpReturn, {}); try self.func.body.emit(self.spv.gpa, .OpReturn, {});
try self.func.body.emit(self.spv.gpa, .OpFunctionEnd, {}); try self.func.body.emit(self.spv.gpa, .OpFunctionEnd, {});
try self.spv.addFunction(spv_decl_index, self.func); try self.spv.addFunction(spv_decl_index, self.func);
const fqn = ip.stringToSlice(try decl.fullyQualifiedName(self.module));
try self.spv.debugNameFmt(initializer_id, "initializer of {s}", .{fqn}); try self.spv.debugNameFmt(initializer_id, "initializer of {s}", .{fqn});
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpExtInst, .{
.id_result_type = self.typeId(ptr_ty_ref),
.id_result = result_id,
.set = try self.spv.importInstructionSet(.zig),
.instruction = .{ .inst = 0 }, // TODO: Put this definition somewhere...
.id_ref_4 = &.{initializer_id},
});
} else { } else {
self.spv.endGlobal(spv_decl_index, begin, decl_id, null); try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpExtInst, .{
try self.spv.declareDeclDeps(spv_decl_index, &.{}); .id_result_type = self.typeId(ptr_ty_ref),
.id_result = result_id,
.set = try self.spv.importInstructionSet(.zig),
.instruction = .{ .inst = 0 }, // TODO: Put this definition somewhere...
.id_ref_4 = &.{},
});
} }
},
} }
} }
@ -2487,8 +2591,8 @@ const DeclGen = struct {
else => unreachable, else => unreachable,
}; };
const set_id = switch (target.os.tag) { const set_id = switch (target.os.tag) {
.opencl => try self.spv.importInstructionSet(.opencl), .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"),
.vulkan => try self.spv.importInstructionSet(.glsl), .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"),
else => unreachable, else => unreachable,
}; };
@ -2662,8 +2766,8 @@ const DeclGen = struct {
else => unreachable, else => unreachable,
}; };
const set_id = switch (target.os.tag) { const set_id = switch (target.os.tag) {
.opencl => try self.spv.importInstructionSet(.opencl), .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"),
.vulkan => try self.spv.importInstructionSet(.glsl), .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"),
else => unreachable, else => unreachable,
}; };
@ -2792,8 +2896,9 @@ const DeclGen = struct {
ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id); ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id);
} }
return try self.constructComposite( return try self.constructStruct(
result_ty, result_ty,
&.{ operand_ty, ov_ty },
&.{ try wip_result.finalize(), try wip_ov.finalize() }, &.{ try wip_result.finalize(), try wip_ov.finalize() },
); );
} }
@ -2885,8 +2990,9 @@ const DeclGen = struct {
ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id); ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id);
} }
return try self.constructComposite( return try self.constructStruct(
result_ty, result_ty,
&.{ operand_ty, ov_ty },
&.{ try wip_result.finalize(), try wip_ov.finalize() }, &.{ try wip_result.finalize(), try wip_ov.finalize() },
); );
} }
@ -3201,35 +3307,76 @@ const DeclGen = struct {
else else
try self.convertToDirect(Type.bool, rhs_id); try self.convertToDirect(Type.bool, rhs_id);
const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) { if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
return valid_cmp_id; return try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
} }
// TODO: Should we short circuit here? It shouldn't affect correctness, but // a = lhs_valid
// perhaps it will generate more efficient code. // b = rhs_valid
// c = lhs_pl == rhs_pl
//
// For op == .eq we have:
// a == b && a -> c
// = a == b && (!a || c)
//
// For op == .neq we have
// a == b && a -> c
// = !(a == b && a -> c)
// = a != b || !(a -> c
// = a != b || !(!a || c)
// = a != b || a && !c
const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0); const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0); const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id); switch (op) {
.eq => {
// op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl const valid_eq_id = try self.cmp(.eq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
// op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl const pl_eq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
const lhs_not_valid_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = lhs_not_valid_id,
.operand = lhs_valid_id,
});
const impl_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = impl_id,
.operand_1 = lhs_not_valid_id,
.operand_2 = pl_eq_id,
});
const result_id = self.spv.allocId(); const result_id = self.spv.allocId();
const args = .{ try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
.id_result_type = self.typeId(bool_ty_ref), .id_result_type = self.typeId(bool_ty_ref),
.id_result = result_id, .id_result = result_id,
.operand_1 = valid_cmp_id, .operand_1 = valid_eq_id,
.operand_2 = pl_cmp_id, .operand_2 = impl_id,
}; });
switch (op) { return result_id;
.eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args), },
.neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args), .neq => {
const valid_neq_id = try self.cmp(.neq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
const pl_neq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
const impl_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = impl_id,
.operand_1 = lhs_valid_id,
.operand_2 = pl_neq_id,
});
const result_id = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
.id_result_type = self.typeId(bool_ty_ref),
.id_result = result_id,
.operand_1 = valid_neq_id,
.operand_2 = impl_id,
});
return result_id;
},
else => unreachable, else => unreachable,
} }
return result_id;
}, },
.Vector => { .Vector => {
var wip = try self.elementWise(result_ty, true); var wip = try self.elementWise(result_ty, true);
@ -3588,7 +3735,11 @@ const DeclGen = struct {
// Convert the pointer-to-array to a pointer to the first element. // Convert the pointer-to-array to a pointer to the first element.
try self.accessChain(elem_ptr_ty_ref, array_ptr_id, &.{0}); try self.accessChain(elem_ptr_ty_ref, array_ptr_id, &.{0});
return try self.constructComposite(slice_ty, &.{ elem_ptr_id, len_id }); return try self.constructStruct(
slice_ty,
&.{ elem_ptr_ty, Type.usize },
&.{ elem_ptr_id, len_id },
);
} }
fn airSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { fn airSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@ -3596,11 +3747,16 @@ const DeclGen = struct {
const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data; const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data;
const ptr_id = try self.resolve(bin_op.lhs); const ptr_id = try self.resolve(bin_op.lhs);
const len_id = try self.resolve(bin_op.rhs); const len_id = try self.resolve(bin_op.rhs);
const ptr_ty = self.typeOf(bin_op.lhs);
const slice_ty = self.typeOfIndex(inst); const slice_ty = self.typeOfIndex(inst);
// Note: Types should not need to be converted to direct, these types // Note: Types should not need to be converted to direct, these types
// dont need to be converted. // dont need to be converted.
return try self.constructComposite(slice_ty, &.{ ptr_id, len_id }); return try self.constructStruct(
slice_ty,
&.{ ptr_ty, Type.usize },
&.{ ptr_id, len_id },
);
} }
fn airAggregateInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { fn airAggregateInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@ -3618,6 +3774,8 @@ const DeclGen = struct {
unreachable; // TODO unreachable; // TODO
} }
const types = try self.gpa.alloc(Type, elements.len);
defer self.gpa.free(types);
const constituents = try self.gpa.alloc(IdRef, elements.len); const constituents = try self.gpa.alloc(IdRef, elements.len);
defer self.gpa.free(constituents); defer self.gpa.free(constituents);
var index: usize = 0; var index: usize = 0;
@ -3629,6 +3787,7 @@ const DeclGen = struct {
assert(Type.fromInterned(field_ty).hasRuntimeBits(mod)); assert(Type.fromInterned(field_ty).hasRuntimeBits(mod));
const id = try self.resolve(element); const id = try self.resolve(element);
types[index] = Type.fromInterned(field_ty);
constituents[index] = try self.convertToIndirect(Type.fromInterned(field_ty), id); constituents[index] = try self.convertToIndirect(Type.fromInterned(field_ty), id);
index += 1; index += 1;
} }
@ -3643,6 +3802,7 @@ const DeclGen = struct {
assert(field_ty.hasRuntimeBitsIgnoreComptime(mod)); assert(field_ty.hasRuntimeBitsIgnoreComptime(mod));
const id = try self.resolve(element); const id = try self.resolve(element);
types[index] = field_ty;
constituents[index] = try self.convertToIndirect(field_ty, id); constituents[index] = try self.convertToIndirect(field_ty, id);
index += 1; index += 1;
} }
@ -3650,7 +3810,11 @@ const DeclGen = struct {
else => unreachable, else => unreachable,
} }
return try self.constructComposite(result_ty, constituents[0..index]); return try self.constructStruct(
result_ty,
types[0..index],
constituents[0..index],
);
}, },
.Vector => { .Vector => {
const n_elems = result_ty.vectorLen(mod); const n_elems = result_ty.vectorLen(mod);
@ -3662,7 +3826,7 @@ const DeclGen = struct {
elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id); elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
} }
return try self.constructComposite(result_ty, elem_ids); return try self.constructVector(result_ty, elem_ids);
}, },
.Array => { .Array => {
const array_info = result_ty.arrayInfo(mod); const array_info = result_ty.arrayInfo(mod);
@ -3679,7 +3843,7 @@ const DeclGen = struct {
elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect); elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect);
} }
return try self.constructComposite(result_ty, elem_ids); return try self.constructArray(result_ty, elem_ids);
}, },
else => unreachable, else => unreachable,
} }
@ -4792,7 +4956,11 @@ const DeclGen = struct {
members[eu_layout.errorFieldIndex()] = operand_id; members[eu_layout.errorFieldIndex()] = operand_id;
members[eu_layout.payloadFieldIndex()] = try self.spv.constUndef(payload_ty_ref); members[eu_layout.payloadFieldIndex()] = try self.spv.constUndef(payload_ty_ref);
return try self.constructComposite(err_union_ty, &members); var types: [2]Type = undefined;
types[eu_layout.errorFieldIndex()] = Type.anyerror;
types[eu_layout.payloadFieldIndex()] = payload_ty;
return try self.constructStruct(err_union_ty, &types, &members);
} }
fn airWrapErrUnionPayload(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { fn airWrapErrUnionPayload(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@ -4811,7 +4979,11 @@ const DeclGen = struct {
members[eu_layout.errorFieldIndex()] = try self.constInt(err_ty_ref, 0); members[eu_layout.errorFieldIndex()] = try self.constInt(err_ty_ref, 0);
members[eu_layout.payloadFieldIndex()] = try self.convertToIndirect(payload_ty, operand_id); members[eu_layout.payloadFieldIndex()] = try self.convertToIndirect(payload_ty, operand_id);
return try self.constructComposite(err_union_ty, &members); var types: [2]Type = undefined;
types[eu_layout.errorFieldIndex()] = Type.anyerror;
types[eu_layout.payloadFieldIndex()] = payload_ty;
return try self.constructStruct(err_union_ty, &types, &members);
} }
fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef { fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef {
@ -4978,7 +5150,8 @@ const DeclGen = struct {
const payload_id = try self.convertToIndirect(payload_ty, operand_id); const payload_id = try self.convertToIndirect(payload_ty, operand_id);
const members = [_]IdRef{ payload_id, try self.constBool(true, .indirect) }; const members = [_]IdRef{ payload_id, try self.constBool(true, .indirect) };
return try self.constructComposite(optional_ty, &members); const types = [_]Type{ payload_ty, Type.bool };
return try self.constructStruct(optional_ty, &types, &members);
} }
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void { fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
@ -5058,7 +5231,7 @@ const DeclGen = struct {
const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len; extra_index = case.end + case.data.items_len + case_body.len;
const label = IdRef{ .id = @intCast(first_case_label.id + case_i) }; const label: IdRef = @enumFromInt(@intFromEnum(first_case_label) + case_i);
for (items) |item| { for (items) |item| {
const value = (try self.air.value(item, mod)) orelse unreachable; const value = (try self.air.value(item, mod)) orelse unreachable;
@ -5072,7 +5245,7 @@ const DeclGen = struct {
else => unreachable, else => unreachable,
}; };
const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) { const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) {
1 => .{ .uint32 = @as(u32, @intCast(int_val)) }, 1 => .{ .uint32 = @intCast(int_val) },
2 => .{ .uint64 = int_val }, 2 => .{ .uint64 = int_val },
else => unreachable, else => unreachable,
}; };
@ -5097,7 +5270,7 @@ const DeclGen = struct {
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + case.data.items_len + case_body.len; extra_index = case.end + case.data.items_len + case_body.len;
const label = IdResult{ .id = @intCast(first_case_label.id + case_i) }; const label: IdResult = @enumFromInt(@intFromEnum(first_case_label) + case_i);
try self.beginSpvBlock(label); try self.beginSpvBlock(label);
@ -5327,9 +5500,9 @@ const DeclGen = struct {
const result_id = self.spv.allocId(); const result_id = self.spv.allocId();
const callee_id = try self.resolve(pl_op.operand); const callee_id = try self.resolve(pl_op.operand);
comptime assert(zig_call_abi_ver == 3);
const params = try self.gpa.alloc(spec.IdRef, args.len); const params = try self.gpa.alloc(spec.IdRef, args.len);
defer self.gpa.free(params); defer self.gpa.free(params);
var n_params: usize = 0; var n_params: usize = 0;
for (args) |arg| { for (args) |arg| {
// Note: resolve() might emit instructions, so we need to call it // Note: resolve() might emit instructions, so we need to call it

View File

@ -194,6 +194,11 @@ inst: struct {
/// This map maps results to their tracked values. /// This map maps results to their tracked values.
value_map: AsmValueMap = .{}, value_map: AsmValueMap = .{},
/// This set is used to quickly transform from an opcode name to the
/// index in its instruction set. The index of the key is the
/// index in `spec.InstructionSet.core.instructions()`.
instruction_map: std.StringArrayHashMapUnmanaged(void) = .{},
/// Free the resources owned by this assembler. /// Free the resources owned by this assembler.
pub fn deinit(self: *Assembler) void { pub fn deinit(self: *Assembler) void {
for (self.errors.items) |err| { for (self.errors.items) |err| {
@ -204,9 +209,20 @@ pub fn deinit(self: *Assembler) void {
self.inst.operands.deinit(self.gpa); self.inst.operands.deinit(self.gpa);
self.inst.string_bytes.deinit(self.gpa); self.inst.string_bytes.deinit(self.gpa);
self.value_map.deinit(self.gpa); self.value_map.deinit(self.gpa);
self.instruction_map.deinit(self.gpa);
} }
pub fn assemble(self: *Assembler) Error!void { pub fn assemble(self: *Assembler) Error!void {
// Populate the opcode map if it isn't already
if (self.instruction_map.count() == 0) {
const instructions = spec.InstructionSet.core.instructions();
try self.instruction_map.ensureUnusedCapacity(self.gpa, @intCast(instructions.len));
for (spec.InstructionSet.core.instructions(), 0..) |inst, i| {
const entry = try self.instruction_map.getOrPut(self.gpa, inst.name);
assert(entry.index == i);
}
}
try self.tokenize(); try self.tokenize();
while (!self.testToken(.eof)) { while (!self.testToken(.eof)) {
try self.parseInstruction(); try self.parseInstruction();
@ -475,12 +491,14 @@ fn parseInstruction(self: *Assembler) !void {
} }
const opcode_text = self.tokenText(opcode_tok); const opcode_text = self.tokenText(opcode_tok);
@setEvalBranchQuota(10000); const index = self.instruction_map.getIndex(opcode_text) orelse {
self.inst.opcode = std.meta.stringToEnum(Opcode, opcode_text) orelse {
return self.fail(opcode_tok.start, "invalid opcode '{s}'", .{opcode_text}); return self.fail(opcode_tok.start, "invalid opcode '{s}'", .{opcode_text});
}; };
const expected_operands = self.inst.opcode.operands(); const inst = spec.InstructionSet.core.instructions()[index];
self.inst.opcode = @enumFromInt(inst.opcode);
const expected_operands = inst.operands;
// This is a loop because the result-id is not always the first operand. // This is a loop because the result-id is not always the first operand.
const requires_lhs_result = for (expected_operands) |op| { const requires_lhs_result = for (expected_operands) |op| {
if (op.kind == .IdResult) break true; if (op.kind == .IdResult) break true;

View File

@ -134,7 +134,10 @@ const Tag = enum {
/// data is (bool) type /// data is (bool) type
bool_false, bool_false,
const SimpleType = enum { void, bool }; const SimpleType = enum {
void,
bool,
};
const VectorType = Key.VectorType; const VectorType = Key.VectorType;
const ArrayType = Key.ArrayType; const ArrayType = Key.ArrayType;
@ -287,11 +290,12 @@ pub const Key = union(enum) {
pub const PointerType = struct { pub const PointerType = struct {
storage_class: StorageClass, storage_class: StorageClass,
child_type: Ref, child_type: Ref,
/// Ref to a .fwd_ptr_type.
fwd: Ref, fwd: Ref,
// TODO: Decorations: // TODO: Decorations:
// - Alignment // - Alignment
// - ArrayStride, // - ArrayStride
// - MaxByteOffset, // - MaxByteOffset
}; };
pub const ForwardPointerType = struct { pub const ForwardPointerType = struct {
@ -728,6 +732,9 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
// }, // },
.ptr_type => |ptr| Item{ .ptr_type => |ptr| Item{
.tag = .type_ptr_simple, .tag = .type_ptr_simple,
// For this variant we need to steal the ID of the forward-declaration, instead
// of allocating one manually. This will make sure that we get a single result-id
// any possibly forward declared pointer type.
.result_id = self.resultId(ptr.fwd), .result_id = self.resultId(ptr.fwd),
.data = try self.addExtra(spv, Tag.SimplePointerType{ .data = try self.addExtra(spv, Tag.SimplePointerType{
.storage_class = ptr.storage_class, .storage_class = ptr.storage_class,
@ -896,24 +903,6 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
}, },
}; };
}, },
// .type_ptr_generic => .{
// .ptr_type = .{
// .storage_class = .Generic,
// .child_type = @enumFromInt(data),
// },
// },
// .type_ptr_crosswgp => .{
// .ptr_type = .{
// .storage_class = .CrossWorkgroup,
// .child_type = @enumFromInt(data),
// },
// },
// .type_ptr_function => .{
// .ptr_type = .{
// .storage_class = .Function,
// .child_type = @enumFromInt(data),
// },
// },
.type_ptr_simple => { .type_ptr_simple => {
const payload = self.extraData(Tag.SimplePointerType, data); const payload = self.extraData(Tag.SimplePointerType, data);
return .{ return .{

View File

@ -72,9 +72,20 @@ pub const Decl = struct {
/// Index to refer to a Decl by. /// Index to refer to a Decl by.
pub const Index = enum(u32) { _ }; pub const Index = enum(u32) { _ };
/// The result-id to be used for this declaration. This is the final result-id /// Useful to tell what kind of decl this is, and hold the result-id or field index
/// of the decl, which may be an OpFunction, OpVariable, or the result of a sequence /// to be used for this decl.
/// of OpSpecConstantOp operations. pub const Kind = enum {
func,
global,
invocation_global,
};
/// See comment on Kind
kind: Kind,
/// The result-id associated to this decl. The specific meaning of this depends on `kind`:
/// - For `func`, this is the result-id of the associated OpFunction instruction.
/// - For `global`, this is the result-id of the associated OpVariable instruction.
/// - For `invocation_global`, this is the result-id of the associated InvocationGlobal instruction.
result_id: IdRef, result_id: IdRef,
/// The offset of the first dependency of this decl in the `decl_deps` array. /// The offset of the first dependency of this decl in the `decl_deps` array.
begin_dep: u32, begin_dep: u32,
@ -82,20 +93,6 @@ pub const Decl = struct {
end_dep: u32, end_dep: u32,
}; };
/// Globals must be kept in order: operations involving globals must be ordered
/// so that the global declaration precedes any usage.
pub const Global = struct {
/// This is the result-id of the OpVariable instruction that declares the global.
result_id: IdRef,
/// The offset into `self.globals.section` of the first instruction of this global
/// declaration.
begin_inst: u32,
/// The past-end offset into `self.flobals.section`.
end_inst: u32,
/// The result-id of the function that initializes this value.
initializer_id: ?IdRef,
};
/// This models a kernel entry point. /// This models a kernel entry point.
pub const EntryPoint = struct { pub const EntryPoint = struct {
/// The declaration that should be exported. /// The declaration that should be exported.
@ -165,18 +162,8 @@ decl_deps: std.ArrayListUnmanaged(Decl.Index) = .{},
/// The list of entry points that should be exported from this module. /// The list of entry points that should be exported from this module.
entry_points: std.ArrayListUnmanaged(EntryPoint) = .{}, entry_points: std.ArrayListUnmanaged(EntryPoint) = .{},
/// The fields in this structure help to maintain the required order for global variables.
globals: struct {
/// Set of globals, referred to by Decl.Index.
globals: std.AutoArrayHashMapUnmanaged(Decl.Index, Global) = .{},
/// This pseudo-section contains the initialization code for all the globals. Instructions from
/// here are reordered when flushing the module. Its contents should be part of the
/// `types_globals_constants` SPIR-V section when the module is emitted.
section: Section = .{},
} = .{},
/// The list of extended instruction sets that should be imported. /// The list of extended instruction sets that should be imported.
extended_instruction_set: std.AutoHashMapUnmanaged(ExtendedInstructionSet, IdRef) = .{}, extended_instruction_set: std.AutoHashMapUnmanaged(spec.InstructionSet, IdRef) = .{},
pub fn init(gpa: Allocator) Module { pub fn init(gpa: Allocator) Module {
return .{ return .{
@ -205,9 +192,6 @@ pub fn deinit(self: *Module) void {
self.entry_points.deinit(self.gpa); self.entry_points.deinit(self.gpa);
self.globals.globals.deinit(self.gpa);
self.globals.section.deinit(self.gpa);
self.extended_instruction_set.deinit(self.gpa); self.extended_instruction_set.deinit(self.gpa);
self.* = undefined; self.* = undefined;
@ -215,12 +199,12 @@ pub fn deinit(self: *Module) void {
pub fn allocId(self: *Module) spec.IdResult { pub fn allocId(self: *Module) spec.IdResult {
defer self.next_result_id += 1; defer self.next_result_id += 1;
return .{ .id = self.next_result_id }; return @enumFromInt(self.next_result_id);
} }
pub fn allocIds(self: *Module, n: u32) spec.IdResult { pub fn allocIds(self: *Module, n: u32) spec.IdResult {
defer self.next_result_id += n; defer self.next_result_id += n;
return .{ .id = self.next_result_id }; return @enumFromInt(self.next_result_id);
} }
pub fn idBound(self: Module) Word { pub fn idBound(self: Module) Word {
@ -243,46 +227,6 @@ pub fn resolveString(self: *Module, str: []const u8) !CacheString {
return try self.cache.addString(self, str); return try self.cache.addString(self, str);
} }
fn orderGlobalsInto(
self: *Module,
decl_index: Decl.Index,
section: *Section,
seen: *std.DynamicBitSetUnmanaged,
) !void {
const decl = self.declPtr(decl_index);
const deps = self.decl_deps.items[decl.begin_dep..decl.end_dep];
const global = self.globalPtr(decl_index).?;
const insts = self.globals.section.instructions.items[global.begin_inst..global.end_inst];
seen.set(@intFromEnum(decl_index));
for (deps) |dep| {
if (!seen.isSet(@intFromEnum(dep))) {
try self.orderGlobalsInto(dep, section, seen);
}
}
try section.instructions.appendSlice(self.gpa, insts);
}
fn orderGlobals(self: *Module) !Section {
const globals = self.globals.globals.keys();
var seen = try std.DynamicBitSetUnmanaged.initEmpty(self.gpa, self.decls.items.len);
defer seen.deinit(self.gpa);
var ordered_globals = Section{};
errdefer ordered_globals.deinit(self.gpa);
for (globals) |decl_index| {
if (!seen.isSet(@intFromEnum(decl_index))) {
try self.orderGlobalsInto(decl_index, &ordered_globals, &seen);
}
}
return ordered_globals;
}
fn addEntryPointDeps( fn addEntryPointDeps(
self: *Module, self: *Module,
decl_index: Decl.Index, decl_index: Decl.Index,
@ -298,8 +242,8 @@ fn addEntryPointDeps(
seen.set(@intFromEnum(decl_index)); seen.set(@intFromEnum(decl_index));
if (self.globalPtr(decl_index)) |global| { if (decl.kind == .global) {
try interface.append(global.result_id); try interface.append(decl.result_id);
} }
for (deps) |dep| { for (deps) |dep| {
@ -335,81 +279,9 @@ fn entryPoints(self: *Module) !Section {
return entry_points; return entry_points;
} }
/// Generate a function that calls all initialization functions, pub fn finalize(self: *Module, a: Allocator, target: std.Target) ![]Word {
/// in unspecified order (an order should not be required here).
/// It generated as follows:
/// %init = OpFunction %void None
/// foreach %initializer:
/// OpFunctionCall %initializer
/// OpReturn
/// OpFunctionEnd
fn initializer(self: *Module, entry_points: *Section) !Section {
var section = Section{};
errdefer section.deinit(self.gpa);
// const void_ty_ref = try self.resolveType(Type.void, .direct);
const void_ty_ref = try self.resolve(.void_type);
const void_ty_id = self.resultId(void_ty_ref);
const init_proto_ty_ref = try self.resolve(.{ .function_type = .{
.return_type = void_ty_ref,
.parameters = &.{},
} });
const init_id = self.allocId();
try section.emit(self.gpa, .OpFunction, .{
.id_result_type = void_ty_id,
.id_result = init_id,
.function_control = .{},
.function_type = self.resultId(init_proto_ty_ref),
});
try section.emit(self.gpa, .OpLabel, .{
.id_result = self.allocId(),
});
var seen = try std.DynamicBitSetUnmanaged.initEmpty(self.gpa, self.decls.items.len);
defer seen.deinit(self.gpa);
var interface = std.ArrayList(IdRef).init(self.gpa);
defer interface.deinit();
for (self.globals.globals.keys(), self.globals.globals.values()) |decl_index, global| {
try self.addEntryPointDeps(decl_index, &seen, &interface);
if (global.initializer_id) |initializer_id| {
try section.emit(self.gpa, .OpFunctionCall, .{
.id_result_type = void_ty_id,
.id_result = self.allocId(),
.function = initializer_id,
});
}
}
try section.emit(self.gpa, .OpReturn, {});
try section.emit(self.gpa, .OpFunctionEnd, {});
try entry_points.emit(self.gpa, .OpEntryPoint, .{
// TODO: Rusticl does not support this because its poorly defined.
// Do we need to generate a workaround here?
.execution_model = .Kernel,
.entry_point = init_id,
.name = "zig global initializer",
.interface = interface.items,
});
try self.sections.execution_modes.emit(self.gpa, .OpExecutionMode, .{
.entry_point = init_id,
.mode = .Initializer,
});
return section;
}
/// Emit this module as a spir-v binary.
pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
// See SPIR-V Spec section 2.3, "Physical Layout of a SPIR-V Module and Instruction" // See SPIR-V Spec section 2.3, "Physical Layout of a SPIR-V Module and Instruction"
// TODO: Audit calls to allocId() in this function to make it idempotent.
// TODO: Perform topological sort on the globals.
var globals = try self.orderGlobals();
defer globals.deinit(self.gpa);
var entry_points = try self.entryPoints(); var entry_points = try self.entryPoints();
defer entry_points.deinit(self.gpa); defer entry_points.deinit(self.gpa);
@ -417,13 +289,6 @@ pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
var types_constants = try self.cache.materialize(self); var types_constants = try self.cache.materialize(self);
defer types_constants.deinit(self.gpa); defer types_constants.deinit(self.gpa);
// // TODO: Pass global variables as function parameters
// var init_func = if (target.os.tag != .vulkan)
// try self.initializer(&entry_points)
// else
// Section{};
// defer init_func.deinit(self.gpa);
const header = [_]Word{ const header = [_]Word{
spec.magic_number, spec.magic_number,
// TODO: From cpu features // TODO: From cpu features
@ -436,7 +301,7 @@ pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
else => 4, else => 4,
}, },
}), }),
0, // TODO: Register Zig compiler magic number. spec.zig_generator_id,
self.idBound(), self.idBound(),
0, // Schema (currently reserved for future use) 0, // Schema (currently reserved for future use)
}; };
@ -468,30 +333,23 @@ pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
self.sections.annotations.toWords(), self.sections.annotations.toWords(),
types_constants.toWords(), types_constants.toWords(),
self.sections.types_globals_constants.toWords(), self.sections.types_globals_constants.toWords(),
globals.toWords(),
self.sections.functions.toWords(), self.sections.functions.toWords(),
}; };
if (builtin.zig_backend == .stage2_x86_64) { var total_result_size: usize = 0;
for (buffers) |buf| { for (buffers) |buffer| {
try file.writeAll(std.mem.sliceAsBytes(buf)); total_result_size += buffer.len;
} }
} else { const result = try a.alloc(Word, total_result_size);
// miscompiles with x86_64 backend errdefer a.free(result);
var iovc_buffers: [buffers.len]std.os.iovec_const = undefined;
var file_size: u64 = 0; var offset: usize = 0;
for (&iovc_buffers, 0..) |*iovc, i| { for (buffers) |buffer| {
// Note, since spir-v supports both little and big endian we can ignore byte order here and @memcpy(result[offset..][0..buffer.len], buffer);
// just treat the words as a sequence of bytes. offset += buffer.len;
const bytes = std.mem.sliceAsBytes(buffers[i]);
iovc.* = .{ .iov_base = bytes.ptr, .iov_len = bytes.len };
file_size += bytes.len;
} }
try file.seekTo(0); return result;
try file.setEndPos(file_size);
try file.pwritevAll(&iovc_buffers, 0);
}
} }
/// Merge the sections making up a function declaration into this module. /// Merge the sections making up a function declaration into this module.
@ -501,23 +359,17 @@ pub fn addFunction(self: *Module, decl_index: Decl.Index, func: Fn) !void {
try self.declareDeclDeps(decl_index, func.decl_deps.keys()); try self.declareDeclDeps(decl_index, func.decl_deps.keys());
} }
pub const ExtendedInstructionSet = enum {
glsl,
opencl,
};
/// Imports or returns the existing id of an extended instruction set /// Imports or returns the existing id of an extended instruction set
pub fn importInstructionSet(self: *Module, set: ExtendedInstructionSet) !IdRef { pub fn importInstructionSet(self: *Module, set: spec.InstructionSet) !IdRef {
assert(set != .core);
const gop = try self.extended_instruction_set.getOrPut(self.gpa, set); const gop = try self.extended_instruction_set.getOrPut(self.gpa, set);
if (gop.found_existing) return gop.value_ptr.*; if (gop.found_existing) return gop.value_ptr.*;
const result_id = self.allocId(); const result_id = self.allocId();
try self.sections.extended_instruction_set.emit(self.gpa, .OpExtInstImport, .{ try self.sections.extended_instruction_set.emit(self.gpa, .OpExtInstImport, .{
.id_result = result_id, .id_result = result_id,
.name = switch (set) { .name = @tagName(set),
.glsl => "GLSL.std.450",
.opencl => "OpenCL.std",
},
}); });
gop.value_ptr.* = result_id; gop.value_ptr.* = result_id;
@ -631,40 +483,21 @@ pub fn decorateMember(
}); });
} }
pub const DeclKind = enum { pub fn allocDecl(self: *Module, kind: Decl.Kind) !Decl.Index {
func,
global,
};
pub fn allocDecl(self: *Module, kind: DeclKind) !Decl.Index {
try self.decls.append(self.gpa, .{ try self.decls.append(self.gpa, .{
.kind = kind,
.result_id = self.allocId(), .result_id = self.allocId(),
.begin_dep = undefined, .begin_dep = undefined,
.end_dep = undefined, .end_dep = undefined,
}); });
const index = @as(Decl.Index, @enumFromInt(@as(u32, @intCast(self.decls.items.len - 1))));
switch (kind) {
.func => {},
// If the decl represents a global, also allocate a global node.
.global => try self.globals.globals.putNoClobber(self.gpa, index, .{
.result_id = undefined,
.begin_inst = undefined,
.end_inst = undefined,
.initializer_id = undefined,
}),
}
return index; return @as(Decl.Index, @enumFromInt(@as(u32, @intCast(self.decls.items.len - 1))));
} }
pub fn declPtr(self: *Module, index: Decl.Index) *Decl { pub fn declPtr(self: *Module, index: Decl.Index) *Decl {
return &self.decls.items[@intFromEnum(index)]; return &self.decls.items[@intFromEnum(index)];
} }
pub fn globalPtr(self: *Module, index: Decl.Index) ?*Global {
return self.globals.globals.getPtr(index);
}
/// Declare ALL dependencies for a decl. /// Declare ALL dependencies for a decl.
pub fn declareDeclDeps(self: *Module, decl_index: Decl.Index, deps: []const Decl.Index) !void { pub fn declareDeclDeps(self: *Module, decl_index: Decl.Index, deps: []const Decl.Index) !void {
const begin_dep = @as(u32, @intCast(self.decl_deps.items.len)); const begin_dep = @as(u32, @intCast(self.decl_deps.items.len));
@ -676,26 +509,9 @@ pub fn declareDeclDeps(self: *Module, decl_index: Decl.Index, deps: []const Decl
decl.end_dep = end_dep; decl.end_dep = end_dep;
} }
pub fn beginGlobal(self: *Module) u32 { /// Declare a SPIR-V function as an entry point. This causes an extra wrapper
return @as(u32, @intCast(self.globals.section.instructions.items.len)); /// function to be generated, which is then exported as the real entry point. The purpose of this
} /// wrapper is to allocate and initialize the structure holding the instance globals.
pub fn endGlobal(
self: *Module,
global_index: Decl.Index,
begin_inst: u32,
result_id: IdRef,
initializer_id: ?IdRef,
) void {
const global = self.globalPtr(global_index).?;
global.* = .{
.result_id = result_id,
.begin_inst = begin_inst,
.end_inst = @intCast(self.globals.section.instructions.items.len),
.initializer_id = initializer_id,
};
}
pub fn declareEntryPoint( pub fn declareEntryPoint(
self: *Module, self: *Module,
decl_index: Decl.Index, decl_index: Decl.Index,

View File

@ -53,6 +53,17 @@ pub fn emitRaw(
section.writeWord((@as(Word, @intCast(word_count << 16))) | @intFromEnum(opcode)); section.writeWord((@as(Word, @intCast(word_count << 16))) | @intFromEnum(opcode));
} }
/// Write an entire instruction, including all operands
pub fn emitRawInstruction(
section: *Section,
allocator: Allocator,
opcode: Opcode,
operands: []const Word,
) !void {
try section.emitRaw(allocator, opcode, operands.len);
section.writeWords(operands);
}
pub fn emit( pub fn emit(
section: *Section, section: *Section,
allocator: Allocator, allocator: Allocator,
@ -123,7 +134,7 @@ fn writeOperands(section: *Section, comptime Operands: type, operands: Operands)
pub fn writeOperand(section: *Section, comptime Operand: type, operand: Operand) void { pub fn writeOperand(section: *Section, comptime Operand: type, operand: Operand) void {
switch (Operand) { switch (Operand) {
spec.IdResult => section.writeWord(operand.id), spec.IdResult => section.writeWord(@intFromEnum(operand)),
spec.LiteralInteger => section.writeWord(operand), spec.LiteralInteger => section.writeWord(operand),
@ -138,9 +149,9 @@ pub fn writeOperand(section: *Section, comptime Operand: type, operand: Operand)
// instruction in which it is used. // instruction in which it is used.
spec.LiteralSpecConstantOpInteger => section.writeWord(@intFromEnum(operand.opcode)), spec.LiteralSpecConstantOpInteger => section.writeWord(@intFromEnum(operand.opcode)),
spec.PairLiteralIntegerIdRef => section.writeWords(&.{ operand.value, operand.label.id }), spec.PairLiteralIntegerIdRef => section.writeWords(&.{ operand.value, @enumFromInt(operand.label) }),
spec.PairIdRefLiteralInteger => section.writeWords(&.{ operand.target.id, operand.member }), spec.PairIdRefLiteralInteger => section.writeWords(&.{ @intFromEnum(operand.target), operand.member }),
spec.PairIdRefIdRef => section.writeWords(&.{ operand[0].id, operand[1].id }), spec.PairIdRefIdRef => section.writeWords(&.{ @intFromEnum(operand[0]), @intFromEnum(operand[1]) }),
else => switch (@typeInfo(Operand)) { else => switch (@typeInfo(Operand)) {
.Enum => section.writeWord(@intFromEnum(operand)), .Enum => section.writeWord(@intFromEnum(operand)),
@ -338,8 +349,8 @@ test "SPIR-V Section emit() - simple" {
defer section.deinit(std.testing.allocator); defer section.deinit(std.testing.allocator);
try section.emit(std.testing.allocator, .OpUndef, .{ try section.emit(std.testing.allocator, .OpUndef, .{
.id_result_type = .{ .id = 0 }, .id_result_type = @enumFromInt(0),
.id_result = .{ .id = 1 }, .id_result = @enumFromInt(1),
}); });
try testing.expectEqualSlices(Word, &.{ try testing.expectEqualSlices(Word, &.{
@ -356,7 +367,7 @@ test "SPIR-V Section emit() - string" {
try section.emit(std.testing.allocator, .OpSource, .{ try section.emit(std.testing.allocator, .OpSource, .{
.source_language = .Unknown, .source_language = .Unknown,
.version = 123, .version = 123,
.file = .{ .id = 456 }, .file = @enumFromInt(256),
.source = "pub fn main() void {}", .source = "pub fn main() void {}",
}); });
@ -381,8 +392,8 @@ test "SPIR-V Section emit() - extended mask" {
defer section.deinit(std.testing.allocator); defer section.deinit(std.testing.allocator);
try section.emit(std.testing.allocator, .OpLoopMerge, .{ try section.emit(std.testing.allocator, .OpLoopMerge, .{
.merge_block = .{ .id = 10 }, .merge_block = @enumFromInt(10),
.continue_target = .{ .id = 20 }, .continue_target = @enumFromInt(20),
.loop_control = .{ .loop_control = .{
.Unroll = true, .Unroll = true,
.DependencyLength = .{ .DependencyLength = .{
@ -405,7 +416,7 @@ test "SPIR-V Section emit() - extended union" {
defer section.deinit(std.testing.allocator); defer section.deinit(std.testing.allocator);
try section.emit(std.testing.allocator, .OpExecutionMode, .{ try section.emit(std.testing.allocator, .OpExecutionMode, .{
.entry_point = .{ .id = 888 }, .entry_point = @enumFromInt(888),
.mode = .{ .mode = .{
.LocalSize = .{ .x_size = 4, .y_size = 8, .z_size = 16 }, .LocalSize = .{ .x_size = 4, .y_size = 8, .z_size = 16 },
}, },

View File

@ -0,0 +1,13 @@
{
"version": 0,
"revision": 0,
"instructions": [
{
"opname": "InvocationGlobal",
"opcode": 0,
"operands": [
{ "kind": "IdRef", "name": "initializer function" }
]
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -39,8 +39,12 @@ const Liveness = @import("../Liveness.zig");
const Value = @import("../Value.zig"); const Value = @import("../Value.zig");
const SpvModule = @import("../codegen/spirv/Module.zig"); const SpvModule = @import("../codegen/spirv/Module.zig");
const Section = @import("../codegen/spirv/Section.zig");
const spec = @import("../codegen/spirv/spec.zig"); const spec = @import("../codegen/spirv/spec.zig");
const IdResult = spec.IdResult; const IdResult = spec.IdResult;
const Word = spec.Word;
const BinaryModule = @import("SpirV/BinaryModule.zig");
base: link.File, base: link.File,
@ -163,7 +167,8 @@ pub fn updateExports(
.Vertex => spec.ExecutionModel.Vertex, .Vertex => spec.ExecutionModel.Vertex,
.Fragment => spec.ExecutionModel.Fragment, .Fragment => spec.ExecutionModel.Fragment,
.Kernel => spec.ExecutionModel.Kernel, .Kernel => spec.ExecutionModel.Kernel,
else => return, .C => return, // TODO: What to do here?
else => unreachable,
}; };
const is_vulkan = target.os.tag == .vulkan; const is_vulkan = target.os.tag == .vulkan;
@ -197,8 +202,6 @@ pub fn flushModule(self: *SpirV, arena: Allocator, prog_node: *std.Progress.Node
@panic("Attempted to compile for architecture that was disabled by build configuration"); @panic("Attempted to compile for architecture that was disabled by build configuration");
} }
_ = arena; // Has the same lifetime as the call to Compilation.update.
const tracy = trace(@src()); const tracy = trace(@src());
defer tracy.end(); defer tracy.end();
@ -223,9 +226,9 @@ pub fn flushModule(self: *SpirV, arena: Allocator, prog_node: *std.Progress.Node
defer error_info.deinit(); defer error_info.deinit();
try error_info.appendSlice("zig_errors"); try error_info.appendSlice("zig_errors");
const module = self.base.comp.module.?; const mod = self.base.comp.module.?;
for (module.global_error_set.keys()) |name_nts| { for (mod.global_error_set.keys()) |name_nts| {
const name = module.intern_pool.stringToSlice(name_nts); const name = mod.intern_pool.stringToSlice(name_nts);
// Errors can contain pretty much any character - to encode them in a string we must escape // Errors can contain pretty much any character - to encode them in a string we must escape
// them somehow. Easiest here is to use some established scheme, one which also preseves the // them somehow. Easiest here is to use some established scheme, one which also preseves the
// name if it contains no strange characters is nice for debugging. URI encoding fits the bill. // name if it contains no strange characters is nice for debugging. URI encoding fits the bill.
@ -239,7 +242,34 @@ pub fn flushModule(self: *SpirV, arena: Allocator, prog_node: *std.Progress.Node
.extension = error_info.items, .extension = error_info.items,
}); });
try spv.flush(self.base.file.?, target); const module = try spv.finalize(arena, target);
errdefer arena.free(module);
const linked_module = self.linkModule(arena, module) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
else => |other| {
log.err("error while linking: {s}\n", .{@errorName(other)});
return error.FlushFailure;
},
};
try self.base.file.?.writeAll(std.mem.sliceAsBytes(linked_module));
}
fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word {
_ = self;
const lower_invocation_globals = @import("SpirV/lower_invocation_globals.zig");
const prune_unused = @import("SpirV/prune_unused.zig");
var parser = try BinaryModule.Parser.init(a);
defer parser.deinit();
var binary = try parser.parse(module);
try lower_invocation_globals.run(&parser, &binary);
try prune_unused.run(&parser, &binary);
return binary.finalize(a);
} }
fn writeCapabilities(spv: *SpvModule, target: std.Target) !void { fn writeCapabilities(spv: *SpvModule, target: std.Target) !void {

View File

@ -0,0 +1,461 @@
const std = @import("std");
const assert = std.debug.assert;
const Allocator = std.mem.Allocator;
const log = std.log.scoped(.spirv_parse);
const spec = @import("../../codegen/spirv/spec.zig");
const Opcode = spec.Opcode;
const Word = spec.Word;
const InstructionSet = spec.InstructionSet;
const ResultId = spec.IdResult;
const BinaryModule = @This();
pub const header_words = 5;
/// The module SPIR-V version.
version: spec.Version,
/// The generator magic number.
generator_magic: u32,
/// The result-id bound of this SPIR-V module.
id_bound: u32,
/// The instructions of this module. This does not contain the header.
instructions: []const Word,
/// Maps OpExtInstImport result-ids to their InstructionSet.
ext_inst_map: std.AutoHashMapUnmanaged(ResultId, InstructionSet),
/// This map contains the width of arithmetic types (OpTypeInt and
/// OpTypeFloat). We need this information to correctly parse the operands
/// of Op(Spec)Constant and OpSwitch.
arith_type_width: std.AutoHashMapUnmanaged(ResultId, u16),
/// The starting offsets of some sections
sections: struct {
functions: usize,
},
pub fn deinit(self: *BinaryModule, a: Allocator) void {
self.ext_inst_map.deinit(a);
self.arith_type_width.deinit(a);
self.* = undefined;
}
pub fn iterateInstructions(self: BinaryModule) Instruction.Iterator {
return Instruction.Iterator.init(self.instructions, 0);
}
pub fn iterateInstructionsFrom(self: BinaryModule, offset: usize) Instruction.Iterator {
return Instruction.Iterator.init(self.instructions, offset);
}
pub fn instructionAt(self: BinaryModule, offset: usize) Instruction {
var it = self.iterateInstructionsFrom(offset);
return it.next().?;
}
pub fn finalize(self: BinaryModule, a: Allocator) ![]Word {
const result = try a.alloc(Word, 5 + self.instructions.len);
errdefer a.free(result);
result[0] = spec.magic_number;
result[1] = @bitCast(self.version);
result[2] = spec.zig_generator_id;
result[3] = self.id_bound;
result[4] = 0; // Schema
@memcpy(result[5..], self.instructions);
return result;
}
/// Errors that can be raised when the module is not correct.
/// Note that the parser doesn't validate SPIR-V modules by a
/// long shot. It only yields errors that critically prevent
/// further analysis of the module.
pub const ParseError = error{
/// Raised when the module doesn't start with the SPIR-V magic.
/// This usually means that the module isn't actually SPIR-V.
InvalidMagic,
/// Raised when the module has an invalid "physical" format:
/// For example when the header is incomplete, or an instruction
/// has an illegal format.
InvalidPhysicalFormat,
/// OpExtInstImport was used with an unknown extension string.
InvalidExtInstImport,
/// The module had an instruction with an invalid (unknown) opcode.
InvalidOpcode,
/// An instruction's operands did not conform to the SPIR-V specification
/// for that instruction.
InvalidOperands,
/// A result-id was declared more than once.
DuplicateId,
/// Some ID did not resolve.
InvalidId,
/// Parser ran out of memory.
OutOfMemory,
};
pub const Instruction = struct {
pub const Iterator = struct {
words: []const Word,
index: usize = 0,
offset: usize = 0,
pub fn init(words: []const Word, start_offset: usize) Iterator {
return .{ .words = words, .offset = start_offset };
}
pub fn next(self: *Iterator) ?Instruction {
if (self.offset >= self.words.len) return null;
const instruction_len = self.words[self.offset] >> 16;
defer self.offset += instruction_len;
defer self.index += 1;
assert(instruction_len != 0 and self.offset < self.words.len); // Verified in BinaryModule.parse.
return Instruction{
.opcode = @enumFromInt(self.words[self.offset] & 0xFFFF),
.index = self.index,
.offset = self.offset,
.operands = self.words[self.offset..][1..instruction_len],
};
}
};
/// The opcode for this instruction.
opcode: Opcode,
/// The instruction's index.
index: usize,
/// The instruction's word offset in the module.
offset: usize,
/// The raw (unparsed) operands for this instruction.
operands: []const Word,
};
/// This parser contains information (acceleration tables)
/// that can be persisted across different modules. This is
/// used to initialize the module, and is also used when
/// further analyzing it.
pub const Parser = struct {
/// The allocator used to allocate this parser's structures,
/// and also the structures of any parsed module.
a: Allocator,
/// Maps (instruction set, opcode) => instruction index (for instruction set)
opcode_table: std.AutoHashMapUnmanaged(u32, u16) = .{},
pub fn init(a: Allocator) !Parser {
var self = Parser{
.a = a,
};
errdefer self.deinit();
inline for (std.meta.tags(InstructionSet)) |set| {
const instructions = set.instructions();
try self.opcode_table.ensureUnusedCapacity(a, @intCast(instructions.len));
for (instructions, 0..) |inst, i| {
// Note: Some instructions may alias another. In this case we don't really care
// which one is first: they all (should) have the same operands anyway. Just pick
// the first, which is usually the core, KHR or EXT variant.
const entry = self.opcode_table.getOrPutAssumeCapacity(mapSetAndOpcode(set, @intCast(inst.opcode)));
if (!entry.found_existing) {
entry.value_ptr.* = @intCast(i);
}
}
}
return self;
}
pub fn deinit(self: *Parser) void {
self.opcode_table.deinit(self.a);
}
fn mapSetAndOpcode(set: InstructionSet, opcode: u16) u32 {
return (@as(u32, @intFromEnum(set)) << 16) | opcode;
}
pub fn getInstSpec(self: Parser, opcode: Opcode) ?spec.Instruction {
const index = self.opcode_table.get(mapSetAndOpcode(.core, @intFromEnum(opcode))) orelse return null;
return InstructionSet.core.instructions()[index];
}
pub fn parse(self: *Parser, module: []const u32) ParseError!BinaryModule {
if (module[0] != spec.magic_number) {
return error.InvalidMagic;
} else if (module.len < header_words) {
log.err("module only has {}/{} header words", .{ module.len, header_words });
return error.InvalidPhysicalFormat;
}
var binary = BinaryModule{
.version = @bitCast(module[1]),
.generator_magic = module[2],
.id_bound = module[3],
.instructions = module[header_words..],
.ext_inst_map = .{},
.arith_type_width = .{},
.sections = undefined,
};
var maybe_function_section: ?usize = null;
// First pass through the module to verify basic structure and
// to gather some initial stuff for more detailed analysis.
// We want to check some stuff that Instruction.Iterator is no good for,
// so just iterate manually.
var offset: usize = 0;
while (offset < binary.instructions.len) {
const len = binary.instructions[offset] >> 16;
if (len == 0 or len + offset > binary.instructions.len) {
log.err("invalid instruction format: len={}, end={}, module len={}", .{ len, len + offset, binary.instructions.len });
return error.InvalidPhysicalFormat;
}
defer offset += len;
// We can't really efficiently use non-exhaustive enums here, because we would
// need to manually write out all valid cases. Since we have this map anyway, just
// use that.
const opcode: Opcode = @enumFromInt(@as(u16, @truncate(binary.instructions[offset])));
const inst_spec = self.getInstSpec(opcode) orelse {
log.err("invalid opcode for core set: {}", .{@intFromEnum(opcode)});
return error.InvalidOpcode;
};
const operands = binary.instructions[offset..][1..len];
switch (opcode) {
.OpExtInstImport => {
const set_name = std.mem.sliceTo(std.mem.sliceAsBytes(operands[1..]), 0);
const set = std.meta.stringToEnum(InstructionSet, set_name) orelse {
log.err("invalid instruction set '{s}'", .{set_name});
return error.InvalidExtInstImport;
};
if (set == .core) return error.InvalidExtInstImport;
try binary.ext_inst_map.put(self.a, @enumFromInt(operands[0]), set);
},
.OpTypeInt, .OpTypeFloat => {
const entry = try binary.arith_type_width.getOrPut(self.a, @enumFromInt(operands[0]));
if (entry.found_existing) return error.DuplicateId;
entry.value_ptr.* = std.math.cast(u16, operands[1]) orelse return error.InvalidOperands;
},
.OpFunction => if (maybe_function_section == null) {
maybe_function_section = offset;
},
else => {},
}
// OpSwitch takes a value as argument, not an OpType... hence we need to populate arith_type_width
// with ALL operations that return an int or float.
const spec_operands = inst_spec.operands;
if (spec_operands.len >= 2 and
spec_operands[0].kind == .IdResultType and
spec_operands[1].kind == .IdResult)
{
if (operands.len < 2) return error.InvalidOperands;
if (binary.arith_type_width.get(@enumFromInt(operands[0]))) |width| {
const entry = try binary.arith_type_width.getOrPut(self.a, @enumFromInt(operands[1]));
if (entry.found_existing) return error.DuplicateId;
entry.value_ptr.* = width;
}
}
}
binary.sections = .{
.functions = maybe_function_section orelse binary.instructions.len,
};
return binary;
}
/// Parse offsets in the instruction that contain result-ids.
/// Returned offsets are relative to inst.operands.
/// Returns in an arraylist to armortize allocations.
pub fn parseInstructionResultIds(
self: *Parser,
binary: BinaryModule,
inst: Instruction,
offsets: *std.ArrayList(u16),
) !void {
const index = self.opcode_table.get(mapSetAndOpcode(.core, @intFromEnum(inst.opcode))).?;
const operands = InstructionSet.core.instructions()[index].operands;
var offset: usize = 0;
switch (inst.opcode) {
.OpSpecConstantOp => {
assert(operands[0].kind == .IdResultType);
assert(operands[1].kind == .IdResult);
offset = try self.parseOperandsResultIds(binary, inst, operands[0..2], offset, offsets);
if (offset >= inst.operands.len) return error.InvalidPhysicalFormat;
const spec_opcode = std.math.cast(u16, inst.operands[offset]) orelse return error.InvalidPhysicalFormat;
const spec_index = self.opcode_table.get(mapSetAndOpcode(.core, spec_opcode)) orelse
return error.InvalidPhysicalFormat;
const spec_operands = InstructionSet.core.instructions()[spec_index].operands;
assert(spec_operands[0].kind == .IdResultType);
assert(spec_operands[1].kind == .IdResult);
offset = try self.parseOperandsResultIds(binary, inst, spec_operands[2..], offset + 1, offsets);
},
.OpExtInst => {
assert(operands[0].kind == .IdResultType);
assert(operands[1].kind == .IdResult);
offset = try self.parseOperandsResultIds(binary, inst, operands[0..2], offset, offsets);
if (offset + 1 >= inst.operands.len) return error.InvalidPhysicalFormat;
const set_id: ResultId = @enumFromInt(inst.operands[offset]);
try offsets.append(@intCast(offset));
const set = binary.ext_inst_map.get(set_id) orelse {
log.err("invalid instruction set {}", .{@intFromEnum(set_id)});
return error.InvalidId;
};
const ext_opcode = std.math.cast(u16, inst.operands[offset + 1]) orelse return error.InvalidPhysicalFormat;
const ext_index = self.opcode_table.get(mapSetAndOpcode(set, ext_opcode)) orelse
return error.InvalidPhysicalFormat;
const ext_operands = set.instructions()[ext_index].operands;
offset = try self.parseOperandsResultIds(binary, inst, ext_operands, offset + 2, offsets);
},
else => {
offset = try self.parseOperandsResultIds(binary, inst, operands, offset, offsets);
},
}
if (offset != inst.operands.len) return error.InvalidPhysicalFormat;
}
fn parseOperandsResultIds(
self: *Parser,
binary: BinaryModule,
inst: Instruction,
operands: []const spec.Operand,
start_offset: usize,
offsets: *std.ArrayList(u16),
) !usize {
var offset = start_offset;
for (operands) |operand| {
offset = try self.parseOperandResultIds(binary, inst, operand, offset, offsets);
}
return offset;
}
fn parseOperandResultIds(
self: *Parser,
binary: BinaryModule,
inst: Instruction,
operand: spec.Operand,
start_offset: usize,
offsets: *std.ArrayList(u16),
) !usize {
var offset = start_offset;
switch (operand.quantifier) {
.variadic => while (offset < inst.operands.len) {
offset = try self.parseOperandKindResultIds(binary, inst, operand.kind, offset, offsets);
},
.optional => if (offset < inst.operands.len) {
offset = try self.parseOperandKindResultIds(binary, inst, operand.kind, offset, offsets);
},
.required => {
offset = try self.parseOperandKindResultIds(binary, inst, operand.kind, offset, offsets);
},
}
return offset;
}
fn parseOperandKindResultIds(
self: *Parser,
binary: BinaryModule,
inst: Instruction,
kind: spec.OperandKind,
start_offset: usize,
offsets: *std.ArrayList(u16),
) !usize {
var offset = start_offset;
if (offset >= inst.operands.len) return error.InvalidPhysicalFormat;
switch (kind.category()) {
.bit_enum => {
const mask = inst.operands[offset];
offset += 1;
for (kind.enumerants()) |enumerant| {
if ((mask & enumerant.value) != 0) {
for (enumerant.parameters) |param_kind| {
offset = try self.parseOperandKindResultIds(binary, inst, param_kind, offset, offsets);
}
}
}
},
.value_enum => {
const value = inst.operands[offset];
offset += 1;
for (kind.enumerants()) |enumerant| {
if (value == enumerant.value) {
for (enumerant.parameters) |param_kind| {
offset = try self.parseOperandKindResultIds(binary, inst, param_kind, offset, offsets);
}
break;
}
}
},
.id => {
try offsets.append(@intCast(offset));
offset += 1;
},
else => switch (kind) {
.LiteralInteger, .LiteralFloat => offset += 1,
.LiteralString => while (true) {
if (offset >= inst.operands.len) return error.InvalidPhysicalFormat;
const word = inst.operands[offset];
offset += 1;
if (word & 0xFF000000 == 0 or
word & 0x00FF0000 == 0 or
word & 0x0000FF00 == 0 or
word & 0x000000FF == 0)
{
break;
}
},
.LiteralContextDependentNumber => {
assert(inst.opcode == .OpConstant or inst.opcode == .OpSpecConstantOp);
const bit_width = binary.arith_type_width.get(@enumFromInt(inst.operands[0])) orelse {
log.err("invalid LiteralContextDependentNumber type {}", .{inst.operands[0]});
return error.InvalidId;
};
offset += switch (bit_width) {
1...32 => 1,
33...64 => 2,
else => unreachable,
};
},
.LiteralExtInstInteger => unreachable,
.LiteralSpecConstantOpInteger => unreachable,
.PairLiteralIntegerIdRef => { // Switch case
assert(inst.opcode == .OpSwitch);
const bit_width = binary.arith_type_width.get(@enumFromInt(inst.operands[0])) orelse {
log.err("invalid OpSwitch type {}", .{inst.operands[0]});
return error.InvalidId;
};
offset += switch (bit_width) {
1...32 => 1,
33...64 => 2,
else => unreachable,
};
try offsets.append(@intCast(offset));
offset += 1;
},
.PairIdRefLiteralInteger => {
try offsets.append(@intCast(offset));
offset += 2;
},
.PairIdRefIdRef => {
try offsets.append(@intCast(offset));
try offsets.append(@intCast(offset + 1));
offset += 2;
},
else => unreachable,
},
}
return offset;
}
};

View File

@ -0,0 +1,700 @@
const std = @import("std");
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const log = std.log.scoped(.spirv_link);
const BinaryModule = @import("BinaryModule.zig");
const Section = @import("../../codegen/spirv/Section.zig");
const spec = @import("../../codegen/spirv/spec.zig");
const ResultId = spec.IdResult;
const Word = spec.Word;
/// This structure contains all the stuff that we need to parse from the module in
/// order to run this pass, as well as some functions to ease its use.
const ModuleInfo = struct {
/// Information about a particular function.
const Fn = struct {
/// The index of the first callee in `callee_store`.
first_callee: usize,
/// The return type id of this function
return_type: ResultId,
/// The parameter types of this function
param_types: []const ResultId,
/// The set of (result-id's of) invocation globals that are accessed
/// in this function, or after resolution, that are accessed in this
/// function or any of it's callees.
invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, void),
};
/// Information about a particular invocation global
const InvocationGlobal = struct {
/// The list of invocation globals that this invocation global
/// depends on.
dependencies: std.AutoArrayHashMapUnmanaged(ResultId, void),
/// The invocation global's type
ty: ResultId,
/// Initializer function. May be `none`.
/// Note that if the initializer is `none`, then `dependencies` is empty.
initializer: ResultId,
};
/// Maps function result-id -> Fn information structure.
functions: std.AutoArrayHashMapUnmanaged(ResultId, Fn),
/// Set of OpFunction result-ids in this module.
entry_points: std.AutoArrayHashMapUnmanaged(ResultId, void),
/// For each function, a list of function result-ids that it calls.
callee_store: []const ResultId,
/// Maps each invocation global result-id to a type-id.
invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, InvocationGlobal),
/// Fetch the list of callees per function. Guaranteed to contain only unique IDs.
fn callees(self: ModuleInfo, fn_id: ResultId) []const ResultId {
const fn_index = self.functions.getIndex(fn_id).?;
const values = self.functions.values();
const first_callee = values[fn_index].first_callee;
if (fn_index == values.len - 1) {
return self.callee_store[first_callee..];
} else {
const next_first_callee = values[fn_index + 1].first_callee;
return self.callee_store[first_callee..next_first_callee];
}
}
/// Extract most of the required information from the binary. The remaining info is
/// constructed by `resolve()`.
fn parse(
arena: Allocator,
parser: *BinaryModule.Parser,
binary: BinaryModule,
) BinaryModule.ParseError!ModuleInfo {
var entry_points = std.AutoArrayHashMap(ResultId, void).init(arena);
var functions = std.AutoArrayHashMap(ResultId, Fn).init(arena);
var fn_types = std.AutoHashMap(ResultId, struct {
return_type: ResultId,
param_types: []const ResultId,
}).init(arena);
var calls = std.AutoArrayHashMap(ResultId, void).init(arena);
var callee_store = std.ArrayList(ResultId).init(arena);
var function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(arena);
var result_id_offsets = std.ArrayList(u16).init(arena);
var invocation_globals = std.AutoArrayHashMap(ResultId, InvocationGlobal).init(arena);
var maybe_current_function: ?ResultId = null;
var fn_ty_id: ResultId = undefined;
var it = binary.iterateInstructions();
while (it.next()) |inst| {
result_id_offsets.items.len = 0;
try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
switch (inst.opcode) {
.OpEntryPoint => {
const entry_point: ResultId = @enumFromInt(inst.operands[1]);
const entry = try entry_points.getOrPut(entry_point);
if (entry.found_existing) {
log.err("Entry point type {} has duplicate definition", .{entry_point});
return error.DuplicateId;
}
},
.OpTypeFunction => {
const fn_type: ResultId = @enumFromInt(inst.operands[0]);
const return_type: ResultId = @enumFromInt(inst.operands[1]);
const param_types: []const ResultId = @ptrCast(inst.operands[2..]);
const entry = try fn_types.getOrPut(fn_type);
if (entry.found_existing) {
log.err("Function type {} has duplicate definition", .{fn_type});
return error.DuplicateId;
}
entry.value_ptr.* = .{
.return_type = return_type,
.param_types = param_types,
};
},
.OpExtInst => {
// Note: format and set are already verified by parseInstructionResultIds().
const global_type: ResultId = @enumFromInt(inst.operands[0]);
const result_id: ResultId = @enumFromInt(inst.operands[1]);
const set_id: ResultId = @enumFromInt(inst.operands[2]);
const set_inst = inst.operands[3];
const set = binary.ext_inst_map.get(set_id).?;
if (set == .zig and set_inst == 0) {
const initializer: ResultId = if (inst.operands.len >= 5)
@enumFromInt(inst.operands[4])
else
.none;
try invocation_globals.put(result_id, .{
.dependencies = .{},
.ty = global_type,
.initializer = initializer,
});
}
},
.OpFunction => {
if (maybe_current_function) |current_function| {
log.err("OpFunction {} does not have an OpFunctionEnd", .{current_function});
return error.InvalidPhysicalFormat;
}
maybe_current_function = @enumFromInt(inst.operands[1]);
fn_ty_id = @enumFromInt(inst.operands[3]);
function_invocation_globals.clearRetainingCapacity();
},
.OpFunctionCall => {
const callee: ResultId = @enumFromInt(inst.operands[2]);
try calls.put(callee, {});
},
.OpFunctionEnd => {
const current_function = maybe_current_function orelse {
log.err("encountered OpFunctionEnd without corresponding OpFunction", .{});
return error.InvalidPhysicalFormat;
};
const entry = try functions.getOrPut(current_function);
if (entry.found_existing) {
log.err("Function {} has duplicate definition", .{current_function});
return error.DuplicateId;
}
const first_callee = callee_store.items.len;
try callee_store.appendSlice(calls.keys());
const fn_type = fn_types.get(fn_ty_id) orelse {
log.err("Function {} has invalid OpFunction type", .{current_function});
return error.InvalidId;
};
entry.value_ptr.* = .{
.first_callee = first_callee,
.return_type = fn_type.return_type,
.param_types = fn_type.param_types,
.invocation_globals = try function_invocation_globals.unmanaged.clone(arena),
};
maybe_current_function = null;
calls.clearRetainingCapacity();
},
else => {},
}
for (result_id_offsets.items) |off| {
const result_id: ResultId = @enumFromInt(inst.operands[off]);
if (invocation_globals.contains(result_id)) {
try function_invocation_globals.put(result_id, {});
}
}
}
if (maybe_current_function) |current_function| {
log.err("OpFunction {} does not have an OpFunctionEnd", .{current_function});
return error.InvalidPhysicalFormat;
}
return ModuleInfo{
.functions = functions.unmanaged,
.entry_points = entry_points.unmanaged,
.callee_store = callee_store.items,
.invocation_globals = invocation_globals.unmanaged,
};
}
/// Derive the remaining info from the structures filled in by parsing.
fn resolve(self: *ModuleInfo, arena: Allocator) !void {
try self.resolveInvocationGlobalUsage(arena);
try self.resolveInvocationGlobalDependencies(arena);
}
/// For each function, extend the list of `invocation_globals` with the
/// invocation globals that ALL of its dependencies use.
fn resolveInvocationGlobalUsage(self: *ModuleInfo, arena: Allocator) !void {
var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.functions.count());
for (self.functions.keys()) |id| {
try self.resolveInvocationGlobalUsageStep(arena, id, &seen);
}
}
fn resolveInvocationGlobalUsageStep(
self: *ModuleInfo,
arena: Allocator,
id: ResultId,
seen: *std.DynamicBitSetUnmanaged,
) !void {
const index = self.functions.getIndex(id) orelse {
log.err("function calls invalid function {}", .{id});
return error.InvalidId;
};
if (seen.isSet(index)) {
return;
}
seen.set(index);
const info = &self.functions.values()[index];
for (self.callees(id)) |callee| {
try self.resolveInvocationGlobalUsageStep(arena, callee, seen);
const callee_info = self.functions.get(callee).?;
for (callee_info.invocation_globals.keys()) |global| {
try info.invocation_globals.put(arena, global, {});
}
}
}
/// For each invocation global, populate and fully resolve the `dependencies` set.
/// This requires `resolveInvocationGlobalUsage()` to be already done.
fn resolveInvocationGlobalDependencies(
self: *ModuleInfo,
arena: Allocator,
) !void {
var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.invocation_globals.count());
for (self.invocation_globals.keys()) |id| {
try self.resolveInvocationGlobalDependenciesStep(arena, id, &seen);
}
}
fn resolveInvocationGlobalDependenciesStep(
self: *ModuleInfo,
arena: Allocator,
id: ResultId,
seen: *std.DynamicBitSetUnmanaged,
) !void {
const index = self.invocation_globals.getIndex(id) orelse {
log.err("invalid invocation global {}", .{id});
return error.InvalidId;
};
if (seen.isSet(index)) {
return;
}
seen.set(index);
const info = &self.invocation_globals.values()[index];
if (info.initializer == .none) {
return;
}
const initializer = self.functions.get(info.initializer) orelse {
log.err("invocation global {} has invalid initializer {}", .{ id, info.initializer });
return error.InvalidId;
};
for (initializer.invocation_globals.keys()) |dependency| {
if (dependency == id) {
// The set of invocation global dependencies includes the dependency itself,
// so we need to skip that case.
continue;
}
try info.dependencies.put(arena, dependency, {});
try self.resolveInvocationGlobalDependenciesStep(arena, dependency, seen);
const dep_info = self.invocation_globals.getPtr(dependency).?;
for (dep_info.dependencies.keys()) |global| {
try info.dependencies.put(arena, global, {});
}
}
}
};
const ModuleBuilder = struct {
const FunctionType = struct {
return_type: ResultId,
param_types: []const ResultId,
const Context = struct {
pub fn hash(_: @This(), ty: FunctionType) u32 {
var hasher = std.hash.Wyhash.init(0);
hasher.update(std.mem.asBytes(&ty.return_type));
hasher.update(std.mem.sliceAsBytes(ty.param_types));
return @truncate(hasher.final());
}
pub fn eql(_: @This(), a: FunctionType, b: FunctionType, _: usize) bool {
if (a.return_type != b.return_type) return false;
return std.mem.eql(ResultId, a.param_types, b.param_types);
}
};
};
const FunctionNewInfo = struct {
/// This is here just so that we don't need to allocate the new
/// param_types multiple times.
new_function_type: ResultId,
/// The first ID of the parameters for the invocation globals.
/// Each global is allocate here according to the index in
/// `ModuleInfo.Fn.invocation_globals`.
global_id_base: u32,
fn invocationGlobalId(self: FunctionNewInfo, index: usize) ResultId {
return @enumFromInt(self.global_id_base + @as(u32, @intCast(index)));
}
};
arena: Allocator,
section: Section,
/// The ID bound of the new module.
id_bound: u32,
/// The first ID of the new entry points. Entry points are allocated from
/// here according to their index in `info.entry_points`.
entry_point_new_id_base: u32,
/// A set of all function types in the new program. SPIR-V mandates that these are unique,
/// and until a general type deduplication pass is programmed, we just handle it here via this.
function_types: std.ArrayHashMapUnmanaged(FunctionType, ResultId, FunctionType.Context, true) = .{},
/// Maps functions to new information required for creating the module
function_new_info: std.AutoArrayHashMapUnmanaged(ResultId, FunctionNewInfo) = .{},
/// Offset of the functions section in the new binary.
new_functions_section: ?usize,
fn init(arena: Allocator, binary: BinaryModule, info: ModuleInfo) !ModuleBuilder {
var self = ModuleBuilder{
.arena = arena,
.section = .{},
.id_bound = binary.id_bound,
.entry_point_new_id_base = undefined,
.new_functions_section = null,
};
self.entry_point_new_id_base = @intFromEnum(self.allocIds(@intCast(info.entry_points.count())));
return self;
}
fn allocId(self: *ModuleBuilder) ResultId {
return self.allocIds(1);
}
fn allocIds(self: *ModuleBuilder, n: u32) ResultId {
defer self.id_bound += n;
return @enumFromInt(self.id_bound);
}
fn finalize(self: *ModuleBuilder, a: Allocator, binary: *BinaryModule) !void {
binary.id_bound = self.id_bound;
binary.instructions = try a.dupe(Word, self.section.instructions.items);
// Nothing is removed in this pass so we don't need to change any of the maps,
// just make sure the section is updated.
binary.sections.functions = self.new_functions_section orelse binary.instructions.len;
}
/// Process everything from `binary` up to the first function and emit it into the builder.
fn processPreamble(self: *ModuleBuilder, binary: BinaryModule, info: ModuleInfo) !void {
var it = binary.iterateInstructions();
while (it.next()) |inst| {
switch (inst.opcode) {
.OpExtInst => {
const set_id: ResultId = @enumFromInt(inst.operands[2]);
const set_inst = inst.operands[3];
const set = binary.ext_inst_map.get(set_id).?;
if (set == .zig and set_inst == 0) {
continue;
}
},
.OpEntryPoint => {
const original_id: ResultId = @enumFromInt(inst.operands[1]);
const new_id_index = info.entry_points.getIndex(original_id).?;
const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
try self.section.emitRaw(self.arena, .OpEntryPoint, inst.operands.len);
self.section.writeWord(inst.operands[0]);
self.section.writeOperand(ResultId, new_id);
self.section.writeWords(inst.operands[2..]);
continue;
},
.OpTypeFunction => {
// Re-emitted in `emitFunctionTypes()`. We can do this because
// OpTypeFunction's may not currently be used anywhere that is not
// directly with an OpFunction. For now we igore Intels function
// pointers extension, that is not a problem with a generalized
// pass anyway.
continue;
},
.OpFunction => break,
else => {},
}
try self.section.emitRawInstruction(self.arena, inst.opcode, inst.operands);
}
}
/// Derive new information required for further emitting this module,
fn deriveNewFnInfo(self: *ModuleBuilder, info: ModuleInfo) !void {
for (info.functions.keys(), info.functions.values()) |func, fn_info| {
const invocation_global_count = fn_info.invocation_globals.count();
const new_param_types = try self.arena.alloc(ResultId, fn_info.param_types.len + invocation_global_count);
for (fn_info.invocation_globals.keys(), 0..) |global, i| {
new_param_types[i] = info.invocation_globals.get(global).?.ty;
}
@memcpy(new_param_types[invocation_global_count..], fn_info.param_types);
const new_type = try self.internFunctionType(fn_info.return_type, new_param_types);
try self.function_new_info.put(self.arena, func, .{
.new_function_type = new_type,
.global_id_base = @intFromEnum(self.allocIds(@intCast(invocation_global_count))),
});
}
}
/// Emit the new function types, which include the parameters for the invocation globals.
/// Currently, this function re-emits ALL function types to ensure that there are
/// no duplicates in the final program.
/// TODO: The above should be resolved by a generalized deduplication pass, and then
/// we only need to emit the new function pointers type here.
fn emitFunctionTypes(self: *ModuleBuilder, info: ModuleInfo) !void {
// TODO: Handle decorators. Function types usually don't have those
// though, but stuff like OpName could be a possibility.
// Entry points retain their old function type, so make sure to emit
// those in the `function_types` set.
for (info.entry_points.keys()) |func| {
const fn_info = info.functions.get(func).?;
_ = try self.internFunctionType(fn_info.return_type, fn_info.param_types);
}
for (self.function_types.keys(), self.function_types.values()) |fn_type, result_id| {
try self.section.emit(self.arena, .OpTypeFunction, .{
.id_result = result_id,
.return_type = fn_type.return_type,
.id_ref_2 = fn_type.param_types,
});
}
}
fn internFunctionType(self: *ModuleBuilder, return_type: ResultId, param_types: []const ResultId) !ResultId {
const entry = try self.function_types.getOrPut(self.arena, .{
.return_type = return_type,
.param_types = param_types,
});
if (!entry.found_existing) {
const new_id = self.allocId();
entry.value_ptr.* = new_id;
}
return entry.value_ptr.*;
}
/// Rewrite the modules' functions and emit them with the new parameter types.
fn rewriteFunctions(
self: *ModuleBuilder,
parser: *BinaryModule.Parser,
binary: BinaryModule,
info: ModuleInfo,
) !void {
var result_id_offsets = std.ArrayList(u16).init(self.arena);
var operands = std.ArrayList(u32).init(self.arena);
var maybe_current_function: ?ResultId = null;
var it = binary.iterateInstructionsFrom(binary.sections.functions);
self.new_functions_section = self.section.instructions.items.len;
while (it.next()) |inst| {
result_id_offsets.items.len = 0;
try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
operands.items.len = 0;
try operands.appendSlice(inst.operands);
// Replace the result-ids with the global's new result-id if required.
for (result_id_offsets.items) |off| {
const result_id: ResultId = @enumFromInt(operands.items[off]);
if (info.invocation_globals.contains(result_id)) {
const func = maybe_current_function.?;
const new_info = self.function_new_info.get(func).?;
const fn_info = info.functions.get(func).?;
const index = fn_info.invocation_globals.getIndex(result_id).?;
operands.items[off] = @intFromEnum(new_info.invocationGlobalId(index));
}
}
switch (inst.opcode) {
.OpFunction => {
// Re-declare the function with the new parameters.
const func: ResultId = @enumFromInt(operands.items[1]);
const fn_info = info.functions.get(func).?;
const new_info = self.function_new_info.get(func).?;
try self.section.emitRaw(self.arena, .OpFunction, 4);
self.section.writeOperand(ResultId, fn_info.return_type);
self.section.writeOperand(ResultId, func);
self.section.writeWord(operands.items[2]);
self.section.writeOperand(ResultId, new_info.new_function_type);
// Emit the OpFunctionParameters for the invocation globals. The functions
// actual parameters are emitted unchanged from their original form, so
// we don't need to handle those here.
for (fn_info.invocation_globals.keys(), 0..) |global, index| {
const ty = info.invocation_globals.get(global).?.ty;
const id = new_info.invocationGlobalId(index);
try self.section.emit(self.arena, .OpFunctionParameter, .{
.id_result_type = ty,
.id_result = id,
});
}
maybe_current_function = func;
},
.OpFunctionCall => {
// Add the required invocation globals to the function's new parameter list.
const caller = maybe_current_function.?;
const callee: ResultId = @enumFromInt(operands.items[2]);
const caller_info = info.functions.get(caller).?;
const callee_info = info.functions.get(callee).?;
const caller_new_info = self.function_new_info.get(caller).?;
const total_params = callee_info.invocation_globals.count() + callee_info.param_types.len;
try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_params);
self.section.writeWord(operands.items[0]); // Copy result type-id
self.section.writeWord(operands.items[1]); // Copy result-id
self.section.writeOperand(ResultId, callee);
// Add the new arguments
for (callee_info.invocation_globals.keys()) |global| {
const caller_global_index = caller_info.invocation_globals.getIndex(global).?;
const id = caller_new_info.invocationGlobalId(caller_global_index);
self.section.writeOperand(ResultId, id);
}
// Add the original arguments
self.section.writeWords(operands.items[3..]);
},
else => {
try self.section.emitRawInstruction(self.arena, inst.opcode, operands.items);
},
}
}
}
fn emitNewEntryPoints(self: *ModuleBuilder, info: ModuleInfo) !void {
var all_function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(self.arena);
for (info.entry_points.keys(), 0..) |func, entry_point_index| {
const fn_info = info.functions.get(func).?;
const ep_id: ResultId = @enumFromInt(self.entry_point_new_id_base + @as(u32, @intCast(entry_point_index)));
const fn_type = self.function_types.get(.{
.return_type = fn_info.return_type,
.param_types = fn_info.param_types,
}).?;
try self.section.emit(self.arena, .OpFunction, .{
.id_result_type = fn_info.return_type,
.id_result = ep_id,
.function_control = .{}, // TODO: Copy the attributes from the original function maybe?
.function_type = fn_type,
});
// Emit OpFunctionParameter instructions for the original kernel's parameters.
const params_id_base: u32 = @intFromEnum(self.allocIds(@intCast(fn_info.param_types.len)));
for (fn_info.param_types, 0..) |param_type, i| {
const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(i)));
try self.section.emit(self.arena, .OpFunctionParameter, .{
.id_result_type = param_type,
.id_result = id,
});
}
try self.section.emit(self.arena, .OpLabel, .{
.id_result = self.allocId(),
});
// Besides the IDs of the main kernel, we also need the
// dependencies of the globals.
// Just quickly construct that set here.
all_function_invocation_globals.clearRetainingCapacity();
for (fn_info.invocation_globals.keys()) |global| {
try all_function_invocation_globals.put(global, {});
const global_info = info.invocation_globals.get(global).?;
for (global_info.dependencies.keys()) |dependency| {
try all_function_invocation_globals.put(dependency, {});
}
}
// Declare the IDs of the invocation globals.
const global_id_base: u32 = @intFromEnum(self.allocIds(@intCast(all_function_invocation_globals.count())));
for (all_function_invocation_globals.keys(), 0..) |global, i| {
const global_info = info.invocation_globals.get(global).?;
const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(i)));
try self.section.emit(self.arena, .OpVariable, .{
.id_result_type = global_info.ty,
.id_result = id,
.storage_class = .Function,
.initializer = null,
});
}
// Call initializers for invocation globals that need it
for (all_function_invocation_globals.keys()) |global| {
const global_info = info.invocation_globals.get(global).?;
if (global_info.initializer == .none) continue;
const initializer_info = info.functions.get(global_info.initializer).?;
assert(initializer_info.param_types.len == 0);
try self.callWithGlobalsAndLinearParams(
all_function_invocation_globals,
global_info.initializer,
initializer_info,
global_id_base,
undefined,
);
}
// Call the main kernel entry
try self.callWithGlobalsAndLinearParams(
all_function_invocation_globals,
func,
fn_info,
global_id_base,
params_id_base,
);
try self.section.emit(self.arena, .OpReturn, {});
try self.section.emit(self.arena, .OpFunctionEnd, {});
}
}
fn callWithGlobalsAndLinearParams(
self: *ModuleBuilder,
all_globals: std.AutoArrayHashMap(ResultId, void),
func: ResultId,
callee_info: ModuleInfo.Fn,
global_id_base: u32,
params_id_base: u32,
) !void {
const total_arguments = callee_info.invocation_globals.count() + callee_info.param_types.len;
try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_arguments);
self.section.writeOperand(ResultId, callee_info.return_type);
self.section.writeOperand(ResultId, self.allocId());
self.section.writeOperand(ResultId, func);
// Add the invocation globals
for (callee_info.invocation_globals.keys()) |global| {
const index = all_globals.getIndex(global).?;
const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(index)));
self.section.writeOperand(ResultId, id);
}
// Add the arguments
for (0..callee_info.param_types.len) |index| {
const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(index)));
self.section.writeOperand(ResultId, id);
}
}
};
pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
var arena = std.heap.ArenaAllocator.init(parser.a);
defer arena.deinit();
const a = arena.allocator();
var info = try ModuleInfo.parse(a, parser, binary.*);
try info.resolve(a);
var builder = try ModuleBuilder.init(a, binary.*, info);
try builder.deriveNewFnInfo(info);
try builder.processPreamble(binary.*, info);
try builder.emitFunctionTypes(info);
try builder.rewriteFunctions(parser, binary.*, info);
try builder.emitNewEntryPoints(info);
try builder.finalize(parser.a, binary);
}

View File

@ -0,0 +1,354 @@
//! This pass is used to simple pruning of unused things:
//! - Instructions at global scope
//! - Functions
//! Debug info and nonsemantic instructions are not handled;
//! this pass is mainly intended for cleaning up left over
//! stuff from codegen and other passes that is generated
//! but not actually used.
const std = @import("std");
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const log = std.log.scoped(.spirv_link);
const BinaryModule = @import("BinaryModule.zig");
const Section = @import("../../codegen/spirv/Section.zig");
const spec = @import("../../codegen/spirv/spec.zig");
const Opcode = spec.Opcode;
const ResultId = spec.IdResult;
const Word = spec.Word;
/// Return whether a particular opcode's instruction can be pruned.
/// These are idempotent instructions at globals scope and instructions
/// within functions that do not have any side effects.
/// The opcodes that return true here do not necessarily need to
/// have an .IdResult. If they don't, then they are regarded
/// as 'decoration'-style instructions that don't keep their
/// operands alive, but will be emitted if they are.
fn canPrune(op: Opcode) bool {
// This list should be as worked out as possible, but just
// getting common instructions is a good effort/effect ratio.
// When adding items to this list, also check whether the
// instruction requires any special control flow rules (like
// with labels and control flow and stuff) and whether the
// instruction has any non-trivial side effects (like OpLoad
// with the Volatile memory semantics).
return switch (op.class()) {
.TypeDeclaration,
.Conversion,
.Arithmetic,
.RelationalAndLogical,
.Bit,
=> true,
else => switch (op) {
.OpFunction,
.OpUndef,
.OpString,
.OpName,
.OpMemberName,
// Prune OpConstant* instructions but
// retain OpSpecConstant declaration instructions
.OpConstantTrue,
.OpConstantFalse,
.OpConstant,
.OpConstantComposite,
.OpConstantSampler,
.OpConstantNull,
.OpSpecConstantOp,
// Prune ext inst import instructions, but not
// ext inst instructions themselves, because
// we don't know if they might have side effects.
.OpExtInstImport,
=> true,
else => false,
},
};
}
const ModuleInfo = struct {
const Fn = struct {
/// The index of the first callee in `callee_store`.
first_callee: usize,
};
/// Maps function result-id -> Fn information structure.
functions: std.AutoArrayHashMapUnmanaged(ResultId, Fn),
/// For each function, a list of function result-ids that it calls.
callee_store: []const ResultId,
/// For each instruction, the offset at which it appears in the source module.
result_id_to_code_offset: std.AutoArrayHashMapUnmanaged(ResultId, usize),
/// Fetch the list of callees per function. Guaranteed to contain only unique IDs.
fn callees(self: ModuleInfo, fn_id: ResultId) []const ResultId {
const fn_index = self.functions.getIndex(fn_id).?;
const values = self.functions.values();
const first_callee = values[fn_index].first_callee;
if (fn_index == values.len - 1) {
return self.callee_store[first_callee..];
} else {
const next_first_callee = values[fn_index + 1].first_callee;
return self.callee_store[first_callee..next_first_callee];
}
}
/// Extract the information required to run this pass from the binary.
// TODO: Should the contents of this function be merged with that of lower_invocation_globals.zig?
// Many of the contents are the same...
fn parse(
arena: Allocator,
parser: *BinaryModule.Parser,
binary: BinaryModule,
) !ModuleInfo {
var functions = std.AutoArrayHashMap(ResultId, Fn).init(arena);
var calls = std.AutoArrayHashMap(ResultId, void).init(arena);
var callee_store = std.ArrayList(ResultId).init(arena);
var result_id_to_code_offset = std.AutoArrayHashMap(ResultId, usize).init(arena);
var maybe_current_function: ?ResultId = null;
var it = binary.iterateInstructions();
while (it.next()) |inst| {
const inst_spec = parser.getInstSpec(inst.opcode).?;
// Result-id can only be the first or second operand
const maybe_result_id: ?ResultId = for (0..2) |i| {
if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) {
break @enumFromInt(inst.operands[i]);
}
} else null;
// Only add result-ids of functions and anything outside a function.
// Result-ids declared inside functions cannot be reached outside anyway,
// and we don't care about the internals of functions anyway.
// Note that in the case of OpFunction, `maybe_current_function` is
// also `null`, because it is set below.
if (maybe_result_id) |result_id| {
try result_id_to_code_offset.put(result_id, inst.offset);
}
switch (inst.opcode) {
.OpFunction => {
if (maybe_current_function) |current_function| {
log.err("OpFunction {} does not have an OpFunctionEnd", .{current_function});
return error.InvalidPhysicalFormat;
}
maybe_current_function = @enumFromInt(inst.operands[1]);
},
.OpFunctionCall => {
const callee: ResultId = @enumFromInt(inst.operands[2]);
try calls.put(callee, {});
},
.OpFunctionEnd => {
const current_function = maybe_current_function orelse {
log.err("encountered OpFunctionEnd without corresponding OpFunction", .{});
return error.InvalidPhysicalFormat;
};
const entry = try functions.getOrPut(current_function);
if (entry.found_existing) {
log.err("Function {} has duplicate definition", .{current_function});
return error.DuplicateId;
}
const first_callee = callee_store.items.len;
try callee_store.appendSlice(calls.keys());
entry.value_ptr.* = .{
.first_callee = first_callee,
};
maybe_current_function = null;
calls.clearRetainingCapacity();
},
else => {},
}
}
if (maybe_current_function) |current_function| {
log.err("OpFunction {} does not have an OpFunctionEnd", .{current_function});
return error.InvalidPhysicalFormat;
}
return ModuleInfo{
.functions = functions.unmanaged,
.callee_store = callee_store.items,
.result_id_to_code_offset = result_id_to_code_offset.unmanaged,
};
}
};
const AliveMarker = struct {
parser: *BinaryModule.Parser,
binary: BinaryModule,
info: ModuleInfo,
result_id_offsets: std.ArrayList(u16),
alive: std.DynamicBitSetUnmanaged,
fn markAlive(self: *AliveMarker, result_id: ResultId) BinaryModule.ParseError!void {
const index = self.info.result_id_to_code_offset.getIndex(result_id) orelse {
log.err("undefined result-id {}", .{result_id});
return error.InvalidId;
};
if (self.alive.isSet(index)) {
return;
}
self.alive.set(index);
const offset = self.info.result_id_to_code_offset.values()[index];
const inst = self.binary.instructionAt(offset);
if (inst.opcode == .OpFunction) {
try self.markFunctionAlive(inst);
} else {
try self.markInstructionAlive(inst);
}
}
fn markFunctionAlive(
self: *AliveMarker,
func_inst: BinaryModule.Instruction,
) !void {
// Go through the instruction and mark the
// operands of each instruction alive.
var it = self.binary.iterateInstructionsFrom(func_inst.offset);
try self.markInstructionAlive(it.next().?);
while (it.next()) |inst| {
if (inst.opcode == .OpFunctionEnd) {
break;
}
if (!canPrune(inst.opcode)) {
try self.markInstructionAlive(inst);
}
}
}
fn markInstructionAlive(
self: *AliveMarker,
inst: BinaryModule.Instruction,
) !void {
const start_offset = self.result_id_offsets.items.len;
try self.parser.parseInstructionResultIds(self.binary, inst, &self.result_id_offsets);
const end_offset = self.result_id_offsets.items.len;
// Recursive calls to markInstructionAlive() might change the pointer in self.result_id_offsets,
// so we need to iterate it manually.
var i = start_offset;
while (i < end_offset) : (i += 1) {
const offset = self.result_id_offsets.items[i];
try self.markAlive(@enumFromInt(inst.operands[offset]));
}
}
};
fn removeIdsFromMap(a: Allocator, map: anytype, info: ModuleInfo, alive_marker: AliveMarker) !void {
var to_remove = std.ArrayList(ResultId).init(a);
var it = map.iterator();
while (it.next()) |entry| {
const id = entry.key_ptr.*;
const index = info.result_id_to_code_offset.getIndex(id).?;
if (!alive_marker.alive.isSet(index)) {
try to_remove.append(id);
}
}
for (to_remove.items) |id| {
assert(map.remove(id));
}
}
pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
var arena = std.heap.ArenaAllocator.init(parser.a);
defer arena.deinit();
const a = arena.allocator();
const info = try ModuleInfo.parse(a, parser, binary.*);
var alive_marker = AliveMarker{
.parser = parser,
.binary = binary.*,
.info = info,
.result_id_offsets = std.ArrayList(u16).init(a),
.alive = try std.DynamicBitSetUnmanaged.initEmpty(a, info.result_id_to_code_offset.count()),
};
// Mark initial stuff as slive
{
var it = binary.iterateInstructions();
while (it.next()) |inst| {
if (inst.opcode == .OpFunction) {
// No need to process further.
break;
} else if (!canPrune(inst.opcode)) {
try alive_marker.markInstructionAlive(inst);
}
}
}
var section = Section{};
var new_functions_section: ?usize = null;
var it = binary.iterateInstructions();
skip: while (it.next()) |inst| {
const inst_spec = parser.getInstSpec(inst.opcode).?;
reemit: {
if (!canPrune(inst.opcode)) {
break :reemit;
}
// Result-id can only be the first or second operand
const result_id: ResultId = for (0..2) |i| {
if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) {
break @enumFromInt(inst.operands[i]);
}
} else {
// Instruction can be pruned but doesn't have a result id.
// Check all operands to see if they are alive, and emit it only if so.
alive_marker.result_id_offsets.items.len = 0;
try parser.parseInstructionResultIds(binary.*, inst, &alive_marker.result_id_offsets);
for (alive_marker.result_id_offsets.items) |offset| {
const id: ResultId = @enumFromInt(inst.operands[offset]);
const index = info.result_id_to_code_offset.getIndex(id).?;
if (!alive_marker.alive.isSet(index)) {
continue :skip;
}
}
break :reemit;
};
const index = info.result_id_to_code_offset.getIndex(result_id).?;
if (alive_marker.alive.isSet(index)) {
break :reemit;
}
if (inst.opcode != .OpFunction) {
// Instruction can be pruned and its not alive, so skip it.
continue :skip;
}
// We're at the start of a function that can be pruned, so skip everything until
// we encounter an OpFunctionEnd.
while (it.next()) |body_inst| {
if (body_inst.opcode == .OpFunctionEnd)
break;
}
continue :skip;
}
if (inst.opcode == .OpFunction and new_functions_section == null) {
new_functions_section = section.instructions.items.len;
}
try section.emitRawInstruction(a, inst.opcode, inst.operands);
}
// This pass might have pruned ext inst imports or arith types, update
// those maps to main consistency.
try removeIdsFromMap(a, &binary.ext_inst_map, info, alive_marker);
try removeIdsFromMap(a, &binary.arith_type_width, info, alive_marker);
binary.instructions = try parser.a.dupe(Word, section.toWords());
binary.sections.functions = new_functions_section orelse binary.instructions.len;
}

View File

@ -18,6 +18,7 @@ test "global variable alignment" {
test "large alignment of local constant" { test "large alignment of local constant" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // flaky
const x: f32 align(128) = 12.34; const x: f32 align(128) = 12.34;
try std.testing.expect(@intFromPtr(&x) % 128 == 0); try std.testing.expect(@intFromPtr(&x) % 128 == 0);

View File

@ -757,6 +757,7 @@ test "extern variable with non-pointer opaque type" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
@export(var_to_export, .{ .name = "opaque_extern_var" }); @export(var_to_export, .{ .name = "opaque_extern_var" });
try expect(@as(*align(1) u32, @ptrCast(&opaque_extern_var)).* == 42); try expect(@as(*align(1) u32, @ptrCast(&opaque_extern_var)).* == 42);

View File

@ -5,7 +5,6 @@ var result: []const u8 = "wrong";
test "pass string literal byvalue to a generic var param" { test "pass string literal byvalue to a generic var param" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
start(); start();
blowUpStack(10); blowUpStack(10);

View File

@ -1252,7 +1252,6 @@ test "implicit cast from *T to ?*anyopaque" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var a: u8 = 1; var a: u8 = 1;
incrementVoidPtrValue(&a); incrementVoidPtrValue(&a);
@ -2035,6 +2034,8 @@ test "peer type resolution: tuple pointer and optional slice" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
// Miscompilation on Intel's OpenCL CPU runtime.
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // flaky
var a: ?[:0]const u8 = null; var a: ?[:0]const u8 = null;
var b = &.{ @as(u8, 'x'), @as(u8, 'y'), @as(u8, 'z') }; var b = &.{ @as(u8, 'x'), @as(u8, 'y'), @as(u8, 'z') };

View File

@ -124,7 +124,6 @@ test "debug info for optional error set" {
test "implicit cast to optional to error union to return result loc" { test "implicit cast to optional to error union to return result loc" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
const S = struct { const S = struct {
fn entry() !void { fn entry() !void {
@ -951,7 +950,6 @@ test "returning an error union containing a type with no runtime bits" {
test "try used in recursive function with inferred error set" { test "try used in recursive function with inferred error set" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
const Value = union(enum) { const Value = union(enum) {
values: []const @This(), values: []const @This(),

View File

@ -127,7 +127,6 @@ test "cmp f16" {
test "cmp f32/f64" { test "cmp f32/f64" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest; if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try testCmp(f32); try testCmp(f32);
try comptime testCmp(f32); try comptime testCmp(f32);
@ -979,7 +978,6 @@ test "@abs f32/f64" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try testFabs(f32); try testFabs(f32);
try comptime testFabs(f32); try comptime testFabs(f32);

View File

@ -28,7 +28,6 @@ pub const EmptyStruct = struct {};
test "optional pointer to size zero struct" { test "optional pointer to size zero struct" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var e = EmptyStruct{}; var e = EmptyStruct{};
const o: ?*EmptyStruct = &e; const o: ?*EmptyStruct = &e;
@ -36,8 +35,6 @@ test "optional pointer to size zero struct" {
} }
test "equality compare optional pointers" { test "equality compare optional pointers" {
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
try testNullPtrsEql(); try testNullPtrsEql();
try comptime testNullPtrsEql(); try comptime testNullPtrsEql();
} }

View File

@ -216,7 +216,6 @@ test "assign null directly to C pointer and test null equality" {
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
var x: [*c]i32 = null; var x: [*c]i32 = null;
_ = &x; _ = &x;

View File

@ -372,7 +372,6 @@ test "load vector elements via comptime index" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
const S = struct { const S = struct {
fn doTheTest() !void { fn doTheTest() !void {
@ -394,7 +393,6 @@ test "store vector elements via comptime index" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
const S = struct { const S = struct {
fn doTheTest() !void { fn doTheTest() !void {

View File

@ -38,8 +38,6 @@ fn staticWhileLoop2() i32 {
} }
test "while with continue expression" { test "while with continue expression" {
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
var sum: i32 = 0; var sum: i32 = 0;
{ {
var i: i32 = 0; var i: i32 = 0;

View File

@ -1,45 +1,138 @@
const std = @import("std"); const std = @import("std");
const g = @import("spirv/grammar.zig");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const g = @import("spirv/grammar.zig");
const CoreRegistry = g.CoreRegistry;
const ExtensionRegistry = g.ExtensionRegistry;
const Instruction = g.Instruction;
const OperandKind = g.OperandKind;
const Enumerant = g.Enumerant;
const Operand = g.Operand;
const ExtendedStructSet = std.StringHashMap(void); const ExtendedStructSet = std.StringHashMap(void);
const Extension = struct {
name: []const u8,
spec: ExtensionRegistry,
};
const CmpInst = struct {
fn lt(_: CmpInst, a: Instruction, b: Instruction) bool {
return a.opcode < b.opcode;
}
};
const StringPair = struct { []const u8, []const u8 };
const StringPairContext = struct {
pub fn hash(_: @This(), a: StringPair) u32 {
var hasher = std.hash.Wyhash.init(0);
const x, const y = a;
hasher.update(x);
hasher.update(y);
return @truncate(hasher.final());
}
pub fn eql(_: @This(), a: StringPair, b: StringPair, b_index: usize) bool {
_ = b_index;
const a_x, const a_y = a;
const b_x, const b_y = b;
return std.mem.eql(u8, a_x, b_x) and std.mem.eql(u8, a_y, b_y);
}
};
const OperandKindMap = std.ArrayHashMap(StringPair, OperandKind, StringPairContext, true);
/// Khronos made it so that these names are not defined explicitly, so
/// we need to hardcode it (like they did).
/// See https://github.com/KhronosGroup/SPIRV-Registry/
const set_names = std.ComptimeStringMap([]const u8, .{
.{ "opencl.std.100", "OpenCL.std" },
.{ "glsl.std.450", "GLSL.std.450" },
.{ "opencl.debuginfo.100", "OpenCL.DebugInfo.100" },
.{ "spv-amd-shader-ballot", "SPV_AMD_shader_ballot" },
.{ "nonsemantic.shader.debuginfo.100", "NonSemantic.Shader.DebugInfo.100" },
.{ "nonsemantic.vkspreflection", "NonSemantic.VkspReflection" },
.{ "nonsemantic.clspvreflection", "NonSemantic.ClspvReflection.6" }, // This version needs to be handled manually
.{ "spv-amd-gcn-shader", "SPV_AMD_gcn_shader" },
.{ "spv-amd-shader-trinary-minmax", "SPV_AMD_shader_trinary_minmax" },
.{ "debuginfo", "DebugInfo" },
.{ "nonsemantic.debugprintf", "NonSemantic.DebugPrintf" },
.{ "spv-amd-shader-explicit-vertex-parameter", "SPV_AMD_shader_explicit_vertex_parameter" },
.{ "nonsemantic.debugbreak", "NonSemantic.DebugBreak" },
.{ "zig", "zig" },
});
pub fn main() !void { pub fn main() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit(); defer arena.deinit();
const allocator = arena.allocator(); const a = arena.allocator();
const args = try std.process.argsAlloc(allocator); const args = try std.process.argsAlloc(a);
if (args.len != 2) { if (args.len != 3) {
usageAndExit(std.io.getStdErr(), args[0], 1); usageAndExit(args[0], 1);
} }
const spec_path = args[1]; const json_path = try std.fs.path.join(a, &.{ args[1], "include/spirv/unified1/" });
const spec = try std.fs.cwd().readFileAlloc(allocator, spec_path, std.math.maxInt(usize)); const dir = try std.fs.cwd().openDir(json_path, .{ .iterate = true });
const core_spec = try readRegistry(CoreRegistry, a, dir, "spirv.core.grammar.json");
std.sort.block(Instruction, core_spec.instructions, CmpInst{}, CmpInst.lt);
var exts = std.ArrayList(Extension).init(a);
var it = dir.iterate();
while (try it.next()) |entry| {
if (entry.kind != .file) {
continue;
}
try readExtRegistry(&exts, a, dir, entry.name);
}
try readExtRegistry(&exts, a, std.fs.cwd(), args[2]);
var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
try render(bw.writer(), a, core_spec, exts.items);
try bw.flush();
}
fn readExtRegistry(exts: *std.ArrayList(Extension), a: Allocator, dir: std.fs.Dir, sub_path: []const u8) !void {
const filename = std.fs.path.basename(sub_path);
if (!std.mem.startsWith(u8, filename, "extinst.")) {
return;
}
std.debug.assert(std.mem.endsWith(u8, filename, ".grammar.json"));
const name = filename["extinst.".len .. filename.len - ".grammar.json".len];
const spec = try readRegistry(ExtensionRegistry, a, dir, sub_path);
std.sort.block(Instruction, spec.instructions, CmpInst{}, CmpInst.lt);
try exts.append(.{ .name = set_names.get(name).?, .spec = spec });
}
fn readRegistry(comptime RegistryType: type, a: Allocator, dir: std.fs.Dir, path: []const u8) !RegistryType {
const spec = try dir.readFileAlloc(a, path, std.math.maxInt(usize));
// Required for json parsing. // Required for json parsing.
@setEvalBranchQuota(10000); @setEvalBranchQuota(10000);
var scanner = std.json.Scanner.initCompleteInput(allocator, spec); var scanner = std.json.Scanner.initCompleteInput(a, spec);
var diagnostics = std.json.Diagnostics{}; var diagnostics = std.json.Diagnostics{};
scanner.enableDiagnostics(&diagnostics); scanner.enableDiagnostics(&diagnostics);
const parsed = std.json.parseFromTokenSource(g.CoreRegistry, allocator, &scanner, .{}) catch |err| { const parsed = std.json.parseFromTokenSource(RegistryType, a, &scanner, .{}) catch |err| {
std.debug.print("line,col: {},{}\n", .{ diagnostics.getLine(), diagnostics.getColumn() }); std.debug.print("{s}:{}:{}:\n", .{ path, diagnostics.getLine(), diagnostics.getColumn() });
return err; return err;
}; };
return parsed.value;
var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
try render(bw.writer(), allocator, parsed.value);
try bw.flush();
} }
/// Returns a set with types that require an extra struct for the `Instruction` interface /// Returns a set with types that require an extra struct for the `Instruction` interface
/// to the spir-v spec, or whether the original type can be used. /// to the spir-v spec, or whether the original type can be used.
fn extendedStructs( fn extendedStructs(
arena: Allocator, a: Allocator,
kinds: []const g.OperandKind, kinds: []const OperandKind,
) !ExtendedStructSet { ) !ExtendedStructSet {
var map = ExtendedStructSet.init(arena); var map = ExtendedStructSet.init(a);
try map.ensureTotalCapacity(@as(u32, @intCast(kinds.len))); try map.ensureTotalCapacity(@as(u32, @intCast(kinds.len)));
for (kinds) |kind| { for (kinds) |kind| {
@ -73,10 +166,12 @@ fn tagPriorityScore(tag: []const u8) usize {
} }
} }
fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void { fn render(writer: anytype, a: Allocator, registry: CoreRegistry, extensions: []const Extension) !void {
try writer.writeAll( try writer.writeAll(
\\//! This file is auto-generated by tools/gen_spirv_spec.zig. \\//! This file is auto-generated by tools/gen_spirv_spec.zig.
\\ \\
\\const std = @import("std");
\\
\\pub const Version = packed struct(Word) { \\pub const Version = packed struct(Word) {
\\ padding: u8 = 0, \\ padding: u8 = 0,
\\ minor: u8, \\ minor: u8,
@ -89,8 +184,21 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
\\}; \\};
\\ \\
\\pub const Word = u32; \\pub const Word = u32;
\\pub const IdResult = struct{ \\pub const IdResult = enum(Word) {
\\ id: Word, \\ none,
\\ _,
\\
\\ pub fn format(
\\ self: IdResult,
\\ comptime _: []const u8,
\\ _: std.fmt.FormatOptions,
\\ writer: anytype,
\\ ) @TypeOf(writer).Error!void {
\\ switch (self) {
\\ .none => try writer.writeAll("(none)"),
\\ else => try writer.print("%{}", .{@intFromEnum(self)}),
\\ }
\\ }
\\}; \\};
\\pub const IdResultType = IdResult; \\pub const IdResultType = IdResult;
\\pub const IdRef = IdResult; \\pub const IdRef = IdResult;
@ -99,6 +207,7 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
\\pub const IdScope = IdRef; \\pub const IdScope = IdRef;
\\ \\
\\pub const LiteralInteger = Word; \\pub const LiteralInteger = Word;
\\pub const LiteralFloat = Word;
\\pub const LiteralString = []const u8; \\pub const LiteralString = []const u8;
\\pub const LiteralContextDependentNumber = union(enum) { \\pub const LiteralContextDependentNumber = union(enum) {
\\ int32: i32, \\ int32: i32,
@ -139,6 +248,13 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
\\ parameters: []const OperandKind, \\ parameters: []const OperandKind,
\\}; \\};
\\ \\
\\pub const Instruction = struct {
\\ name: []const u8,
\\ opcode: Word,
\\ operands: []const Operand,
\\};
\\
\\pub const zig_generator_id: Word = 41;
\\ \\
); );
@ -151,15 +267,123 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
.{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number }, .{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number },
); );
const extended_structs = try extendedStructs(allocator, registry.operand_kinds); // Merge the operand kinds from all extensions together.
try renderClass(writer, allocator, registry.instructions); // var all_operand_kinds = std.ArrayList(OperandKind).init(a);
try renderOperandKind(writer, registry.operand_kinds); // try all_operand_kinds.appendSlice(registry.operand_kinds);
try renderOpcodes(writer, allocator, registry.instructions, extended_structs); var all_operand_kinds = OperandKindMap.init(a);
try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs); for (registry.operand_kinds) |kind| {
try all_operand_kinds.putNoClobber(.{ "core", kind.kind }, kind);
}
for (extensions) |ext| {
// Note: extensions may define the same operand kind, with different
// parameters. Instead of trying to merge them, just discriminate them
// using the name of the extension. This is similar to what
// the official headers do.
try all_operand_kinds.ensureUnusedCapacity(ext.spec.operand_kinds.len);
for (ext.spec.operand_kinds) |kind| {
var new_kind = kind;
new_kind.kind = try std.mem.join(a, ".", &.{ ext.name, kind.kind });
try all_operand_kinds.putNoClobber(.{ ext.name, kind.kind }, new_kind);
}
}
const extended_structs = try extendedStructs(a, all_operand_kinds.values());
// Note: extensions don't seem to have class.
try renderClass(writer, a, registry.instructions);
try renderOperandKind(writer, all_operand_kinds.values());
try renderOpcodes(writer, a, registry.instructions, extended_structs);
try renderOperandKinds(writer, a, all_operand_kinds.values(), extended_structs);
try renderInstructionSet(writer, a, registry, extensions, all_operand_kinds);
} }
fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.Instruction) !void { fn renderInstructionSet(
var class_map = std.StringArrayHashMap(void).init(allocator); writer: anytype,
a: Allocator,
core: CoreRegistry,
extensions: []const Extension,
all_operand_kinds: OperandKindMap,
) !void {
_ = a;
try writer.writeAll(
\\pub const InstructionSet = enum {
\\ core,
);
for (extensions) |ext| {
try writer.print("{},\n", .{std.zig.fmtId(ext.name)});
}
try writer.writeAll(
\\
\\ pub fn instructions(self: InstructionSet) []const Instruction {
\\ return switch (self) {
\\
);
try renderInstructionsCase(writer, "core", core.instructions, all_operand_kinds);
for (extensions) |ext| {
try renderInstructionsCase(writer, ext.name, ext.spec.instructions, all_operand_kinds);
}
try writer.writeAll(
\\ };
\\ }
\\};
\\
);
}
fn renderInstructionsCase(
writer: anytype,
set_name: []const u8,
instructions: []const Instruction,
all_operand_kinds: OperandKindMap,
) !void {
// Note: theoretically we could dedup from tags and give every instruction a list of aliases,
// but there aren't so many total aliases and that would add more overhead in total. We will
// just filter those out when needed.
try writer.print(".{} => &[_]Instruction{{\n", .{std.zig.fmtId(set_name)});
for (instructions) |inst| {
try writer.print(
\\.{{
\\ .name = "{s}",
\\ .opcode = {},
\\ .operands = &[_]Operand{{
\\
, .{ inst.opname, inst.opcode });
for (inst.operands) |operand| {
const quantifier = if (operand.quantifier) |q|
switch (q) {
.@"?" => "optional",
.@"*" => "variadic",
}
else
"required";
const kind = all_operand_kinds.get(.{ set_name, operand.kind }) orelse
all_operand_kinds.get(.{ "core", operand.kind }).?;
try writer.print(".{{.kind = .{}, .quantifier = .{s}}},\n", .{ std.zig.fmtId(kind.kind), quantifier });
}
try writer.writeAll(
\\ },
\\},
\\
);
}
try writer.writeAll(
\\},
\\
);
}
fn renderClass(writer: anytype, a: Allocator, instructions: []const Instruction) !void {
var class_map = std.StringArrayHashMap(void).init(a);
for (instructions) |inst| { for (instructions) |inst| {
if (std.mem.eql(u8, inst.class.?, "@exclude")) { if (std.mem.eql(u8, inst.class.?, "@exclude")) {
@ -173,7 +397,7 @@ fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.In
try renderInstructionClass(writer, class); try renderInstructionClass(writer, class);
try writer.writeAll(",\n"); try writer.writeAll(",\n");
} }
try writer.writeAll("};\n"); try writer.writeAll("};\n\n");
} }
fn renderInstructionClass(writer: anytype, class: []const u8) !void { fn renderInstructionClass(writer: anytype, class: []const u8) !void {
@ -192,15 +416,20 @@ fn renderInstructionClass(writer: anytype, class: []const u8) !void {
} }
} }
fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void { fn renderOperandKind(writer: anytype, operands: []const OperandKind) !void {
try writer.writeAll("pub const OperandKind = enum {\n"); try writer.writeAll(
\\pub const OperandKind = enum {
\\ Opcode,
\\
);
for (operands) |operand| { for (operands) |operand| {
try writer.print("{},\n", .{std.zig.fmtId(operand.kind)}); try writer.print("{},\n", .{std.zig.fmtId(operand.kind)});
} }
try writer.writeAll( try writer.writeAll(
\\ \\
\\pub fn category(self: OperandKind) OperandCategory { \\pub fn category(self: OperandKind) OperandCategory {
\\return switch (self) { \\ return switch (self) {
\\ .Opcode => .literal,
\\ \\
); );
for (operands) |operand| { for (operands) |operand| {
@ -214,10 +443,11 @@ fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
try writer.print(".{} => .{s},\n", .{ std.zig.fmtId(operand.kind), cat }); try writer.print(".{} => .{s},\n", .{ std.zig.fmtId(operand.kind), cat });
} }
try writer.writeAll( try writer.writeAll(
\\}; \\ };
\\} \\}
\\pub fn enumerants(self: OperandKind) []const Enumerant { \\pub fn enumerants(self: OperandKind) []const Enumerant {
\\return switch (self) { \\ return switch (self) {
\\ .Opcode => unreachable,
\\ \\
); );
for (operands) |operand| { for (operands) |operand| {
@ -242,7 +472,7 @@ fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
try writer.writeAll("};\n}\n};\n"); try writer.writeAll("};\n}\n};\n");
} }
fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void { fn renderEnumerant(writer: anytype, enumerant: Enumerant) !void {
try writer.print(".{{.name = \"{s}\", .value = ", .{enumerant.enumerant}); try writer.print(".{{.name = \"{s}\", .value = ", .{enumerant.enumerant});
switch (enumerant.value) { switch (enumerant.value) {
.bitflag => |flag| try writer.writeAll(flag), .bitflag => |flag| try writer.writeAll(flag),
@ -260,14 +490,14 @@ fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void {
fn renderOpcodes( fn renderOpcodes(
writer: anytype, writer: anytype,
allocator: Allocator, a: Allocator,
instructions: []const g.Instruction, instructions: []const Instruction,
extended_structs: ExtendedStructSet, extended_structs: ExtendedStructSet,
) !void { ) !void {
var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator); var inst_map = std.AutoArrayHashMap(u32, usize).init(a);
try inst_map.ensureTotalCapacity(instructions.len); try inst_map.ensureTotalCapacity(instructions.len);
var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator); var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(a);
try aliases.ensureTotalCapacity(instructions.len); try aliases.ensureTotalCapacity(instructions.len);
for (instructions, 0..) |inst, i| { for (instructions, 0..) |inst, i| {
@ -302,7 +532,9 @@ fn renderOpcodes(
try writer.print("{} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode }); try writer.print("{} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode });
} }
try writer.writeByte('\n'); try writer.writeAll(
\\
);
for (aliases.items) |alias| { for (aliases.items) |alias| {
try writer.print("pub const {} = Opcode.{};\n", .{ try writer.print("pub const {} = Opcode.{};\n", .{
@ -314,7 +546,7 @@ fn renderOpcodes(
try writer.writeAll( try writer.writeAll(
\\ \\
\\pub fn Operands(comptime self: Opcode) type { \\pub fn Operands(comptime self: Opcode) type {
\\return switch (self) { \\ return switch (self) {
\\ \\
); );
@ -324,35 +556,10 @@ fn renderOpcodes(
} }
try writer.writeAll( try writer.writeAll(
\\}; \\ };
\\}
\\pub fn operands(self: Opcode) []const Operand {
\\return switch (self) {
\\
);
for (instructions_indices) |i| {
const inst = instructions[i];
try writer.print(".{} => &[_]Operand{{", .{std.zig.fmtId(inst.opname)});
for (inst.operands) |operand| {
const quantifier = if (operand.quantifier) |q|
switch (q) {
.@"?" => "optional",
.@"*" => "variadic",
}
else
"required";
try writer.print(".{{.kind = .{s}, .quantifier = .{s}}},", .{ operand.kind, quantifier });
}
try writer.writeAll("},\n");
}
try writer.writeAll(
\\};
\\} \\}
\\pub fn class(self: Opcode) Class { \\pub fn class(self: Opcode) Class {
\\return switch (self) { \\ return switch (self) {
\\ \\
); );
@ -363,19 +570,24 @@ fn renderOpcodes(
try writer.writeAll(",\n"); try writer.writeAll(",\n");
} }
try writer.writeAll("};\n}\n};\n"); try writer.writeAll(
\\ };
\\}
\\};
\\
);
} }
fn renderOperandKinds( fn renderOperandKinds(
writer: anytype, writer: anytype,
allocator: Allocator, a: Allocator,
kinds: []const g.OperandKind, kinds: []const OperandKind,
extended_structs: ExtendedStructSet, extended_structs: ExtendedStructSet,
) !void { ) !void {
for (kinds) |kind| { for (kinds) |kind| {
switch (kind.category) { switch (kind.category) {
.ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs), .ValueEnum => try renderValueEnum(writer, a, kind, extended_structs),
.BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs), .BitEnum => try renderBitEnum(writer, a, kind, extended_structs),
else => {}, else => {},
} }
} }
@ -383,20 +595,26 @@ fn renderOperandKinds(
fn renderValueEnum( fn renderValueEnum(
writer: anytype, writer: anytype,
allocator: Allocator, a: Allocator,
enumeration: g.OperandKind, enumeration: OperandKind,
extended_structs: ExtendedStructSet, extended_structs: ExtendedStructSet,
) !void { ) !void {
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator); var enum_map = std.AutoArrayHashMap(u32, usize).init(a);
try enum_map.ensureTotalCapacity(enumerants.len); try enum_map.ensureTotalCapacity(enumerants.len);
var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator); var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(a);
try aliases.ensureTotalCapacity(enumerants.len); try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants, 0..) |enumerant, i| { for (enumerants, 0..) |enumerant, i| {
const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int); try writer.context.flush();
const value: u31 = switch (enumerant.value) {
.int => |value| value,
// Some extensions declare ints as string
.bitflag => |value| try std.fmt.parseInt(u31, value, 10),
};
const result = enum_map.getOrPutAssumeCapacity(value);
if (!result.found_existing) { if (!result.found_existing) {
result.value_ptr.* = i; result.value_ptr.* = i;
continue; continue;
@ -422,9 +640,12 @@ fn renderValueEnum(
for (enum_indices) |i| { for (enum_indices) |i| {
const enumerant = enumerants[i]; const enumerant = enumerants[i];
if (enumerant.value != .int) return error.InvalidRegistry; // if (enumerant.value != .int) return error.InvalidRegistry;
try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int }); switch (enumerant.value) {
.int => |value| try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
.bitflag => |value| try writer.print("{} = {s},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
}
} }
try writer.writeByte('\n'); try writer.writeByte('\n');
@ -454,8 +675,8 @@ fn renderValueEnum(
fn renderBitEnum( fn renderBitEnum(
writer: anytype, writer: anytype,
allocator: Allocator, a: Allocator,
enumeration: g.OperandKind, enumeration: OperandKind,
extended_structs: ExtendedStructSet, extended_structs: ExtendedStructSet,
) !void { ) !void {
try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)}); try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)});
@ -463,7 +684,7 @@ fn renderBitEnum(
var flags_by_bitpos = [_]?usize{null} ** 32; var flags_by_bitpos = [_]?usize{null} ** 32;
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator); var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(a);
try aliases.ensureTotalCapacity(enumerants.len); try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants, 0..) |enumerant, i| { for (enumerants, 0..) |enumerant, i| {
@ -471,6 +692,10 @@ fn renderBitEnum(
const value = try parseHexInt(enumerant.value.bitflag); const value = try parseHexInt(enumerant.value.bitflag);
if (value == 0) { if (value == 0) {
continue; // Skip 'none' items continue; // Skip 'none' items
} else if (std.mem.eql(u8, enumerant.enumerant, "FlagIsPublic")) {
// This flag is special and poorly defined in the json files.
// Just skip it for now
continue;
} }
std.debug.assert(@popCount(value) == 1); std.debug.assert(@popCount(value) == 1);
@ -540,7 +765,7 @@ fn renderOperand(
mask, mask,
}, },
field_name: []const u8, field_name: []const u8,
parameters: []const g.Operand, parameters: []const Operand,
extended_structs: ExtendedStructSet, extended_structs: ExtendedStructSet,
) !void { ) !void {
if (kind == .instruction) { if (kind == .instruction) {
@ -606,7 +831,7 @@ fn renderOperand(
try writer.writeAll(",\n"); try writer.writeAll(",\n");
} }
fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void { fn renderFieldName(writer: anytype, operands: []const Operand, field_index: usize) !void {
const operand = operands[field_index]; const operand = operands[field_index];
// Should be enough for all names - adjust as needed. // Should be enough for all names - adjust as needed.
@ -673,16 +898,16 @@ fn parseHexInt(text: []const u8) !u31 {
return try std.fmt.parseInt(u31, text[prefix.len..], 16); return try std.fmt.parseInt(u31, text[prefix.len..], 16);
} }
fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn { fn usageAndExit(arg0: []const u8, code: u8) noreturn {
file.writer().print( std.io.getStdErr().writer().print(
\\Usage: {s} <spirv json spec> \\Usage: {s} <SPIRV-Headers repository path> <path/to/zig/src/codegen/spirv/extinst.zig.grammar.json>
\\ \\
\\Generates Zig bindings for a SPIR-V specification .json (either core or \\Generates Zig bindings for SPIR-V specifications found in the SPIRV-Headers
\\extinst versions). The result, printed to stdout, should be used to update \\repository. The result, printed to stdout, should be used to update
\\files in src/codegen/spirv. Don't forget to format the output. \\files in src/codegen/spirv. Don't forget to format the output.
\\ \\
\\The relevant specifications can be obtained from the SPIR-V registry: \\<SPIRV-Headers repository path> should point to a clone of
\\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/ \\https://github.com/KhronosGroup/SPIRV-Headers/
\\ \\
, .{arg0}) catch std.process.exit(1); , .{arg0}) catch std.process.exit(1);
std.process.exit(code); std.process.exit(code);

View File

@ -22,8 +22,8 @@ pub const CoreRegistry = struct {
}; };
pub const ExtensionRegistry = struct { pub const ExtensionRegistry = struct {
copyright: [][]const u8, copyright: ?[][]const u8 = null,
version: u32, version: ?u32 = null,
revision: u32, revision: u32,
instructions: []Instruction, instructions: []Instruction,
operand_kinds: []OperandKind = &[_]OperandKind{}, operand_kinds: []OperandKind = &[_]OperandKind{},
@ -40,6 +40,8 @@ pub const Instruction = struct {
opcode: u32, opcode: u32,
operands: []Operand = &[_]Operand{}, operands: []Operand = &[_]Operand{},
capabilities: [][]const u8 = &[_][]const u8{}, capabilities: [][]const u8 = &[_][]const u8{},
// DebugModuleINTEL has this...
capability: ?[]const u8 = null,
extensions: [][]const u8 = &[_][]const u8{}, extensions: [][]const u8 = &[_][]const u8{},
version: ?[]const u8 = null, version: ?[]const u8 = null,