diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 22d7560806..2761a967de 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -12,6 +12,7 @@ const LazySrcLoc = Module.LazySrcLoc; const Air = @import("../Air.zig"); const Zir = @import("../Zir.zig"); const Liveness = @import("../Liveness.zig"); +const InternPool = @import("../InternPool.zig"); const spec = @import("spirv/spec.zig"); const Opcode = spec.Opcode; @@ -30,6 +31,15 @@ const SpvAssembler = @import("spirv/Assembler.zig"); const InstMap = std.AutoHashMapUnmanaged(Air.Inst.Index, IdRef); +/// We want to store some extra facts about types as mapped from Zig to SPIR-V. +/// This structure is used to keep that extra information, as well as +/// the cached reference to the type. +const SpvTypeInfo = struct { + ty_ref: CacheRef, +}; + +const TypeMap = std.AutoHashMapUnmanaged(InternPool.Index, SpvTypeInfo); + const IncomingBlock = struct { src_label_id: IdRef, break_value_id: IdRef, @@ -78,6 +88,15 @@ pub const DeclGen = struct { /// A map keeping track of which instruction generated which result-id. inst_results: InstMap = .{}, + /// A map that maps AIR intern pool indices to SPIR-V cache references (which + /// is basically the same thing except for SPIR-V). + /// This map is typically only used for structures that are deemed heavy enough + /// that it is worth to store them here. The SPIR-V module also interns types, + /// and so the main purpose of this map is to avoid recomputation and to + /// cache extra information about the type rather than to aid in validity + /// of the SPIR-V module. + type_map: TypeMap = .{}, + /// We need to keep track of result ids for block labels, as well as the 'incoming' /// blocks for a block. blocks: BlockMap = .{}, @@ -207,6 +226,7 @@ pub const DeclGen = struct { pub fn deinit(self: *DeclGen) void { self.args.deinit(self.gpa); self.inst_results.deinit(self.gpa); + self.type_map.deinit(self.gpa); self.blocks.deinit(self.gpa); self.func.deinit(self.gpa); } @@ -1180,6 +1200,9 @@ pub const DeclGen = struct { return try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect); } + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + var member_types = std.BoundedArray(CacheRef, 4){}; var member_names = std.BoundedArray(CacheString, 4){}; @@ -1222,10 +1245,16 @@ pub const DeclGen = struct { member_names.appendAssumeCapacity(try self.spv.resolveString("padding")); } - return try self.spv.resolve(.{ .struct_type = .{ + const ty_ref = try self.spv.resolve(.{ .struct_type = .{ .member_types = member_types.slice(), .member_names = member_names.slice(), } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + + return ty_ref; } /// Turn a Zig type into a SPIR-V Type, and return a reference to it. @@ -1268,15 +1297,26 @@ pub const DeclGen = struct { return try self.spv.resolve(.{ .float_type = .{ .bits = bits } }); }, .Array => { + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + const elem_ty = ty.childType(mod); - const elem_ty_ref = try self.resolveType(elem_ty, .direct); + const elem_ty_ref = try self.resolveType(elem_ty, .indirect); const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse { return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)}); }; - return self.spv.arrayType(total_len, elem_ty_ref); + const ty_ref = try self.spv.arrayType(total_len, elem_ty_ref); + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + return ty_ref; }, .Fn => switch (repr) { .direct => { + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + + const ip = &mod.intern_pool; const fn_info = mod.typeToFunc(ty).?; // TODO: Put this somewhere in Sema.zig if (fn_info.is_var_args) @@ -1289,10 +1329,16 @@ pub const DeclGen = struct { } const return_ty_ref = try self.resolveType(fn_info.return_type.toType(), .direct); - return try self.spv.resolve(.{ .function_type = .{ + const ty_ref = try self.spv.resolve(.{ .function_type = .{ .return_type = return_ty_ref, .parameters = param_ty_refs, } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + + return ty_ref; }, .indirect => { // TODO: Represent function pointers properly. @@ -1338,6 +1384,9 @@ pub const DeclGen = struct { } }); }, .Struct => { + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + const struct_type = switch (ip.indexToKey(ty.toIntern())) { .anon_struct_type => |tuple| { const member_types = try self.gpa.alloc(CacheRef, tuple.values.len); @@ -1351,9 +1400,14 @@ pub const DeclGen = struct { member_index += 1; } - return try self.spv.resolve(.{ .struct_type = .{ + const ty_ref = try self.spv.resolve(.{ .struct_type = .{ .member_types = member_types[0..member_index], } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + return ty_ref; }, .struct_type => |struct_type| struct_type, else => unreachable, @@ -1361,7 +1415,6 @@ pub const DeclGen = struct { if (struct_type.layout == .Packed) { return try self.resolveType(struct_type.backingIntType(ip).toType(), .direct); - } var member_types = std.ArrayList(CacheRef).init(self.gpa); defer member_types.deinit(); @@ -1379,11 +1432,16 @@ pub const DeclGen = struct { const name = ip.stringToSlice(try mod.declPtr(struct_type.decl.unwrap().?).getFullyQualifiedName(mod)); - return try self.spv.resolve(.{ .struct_type = .{ + const ty_ref = try self.spv.resolve(.{ .struct_type = .{ .name = try self.spv.resolveString(name), .member_types = member_types.items, .member_names = member_names.items, } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + return ty_ref; }, .Optional => { const payload_ty = ty.optionalChild(mod); @@ -1400,15 +1458,23 @@ pub const DeclGen = struct { return payload_ty_ref; } + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + const bool_ty_ref = try self.resolveType(Type.bool, .indirect); - return try self.spv.resolve(.{ .struct_type = .{ + const ty_ref = try self.spv.resolve(.{ .struct_type = .{ .member_types = &.{ payload_ty_ref, bool_ty_ref }, .member_names = &.{ try self.spv.resolveString("payload"), try self.spv.resolveString("valid"), }, } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + return ty_ref; }, .Union => return try self.resolveUnionType(ty, null), .ErrorSet => return try self.intType(.unsigned, 16), @@ -1421,6 +1487,9 @@ pub const DeclGen = struct { return error_ty_ref; } + const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern()); + if (entry.found_existing) return entry.value_ptr.ty_ref; + const payload_ty_ref = try self.resolveType(payload_ty, .indirect); var member_types: [2]CacheRef = undefined; @@ -1443,10 +1512,15 @@ pub const DeclGen = struct { // TODO: ABI padding? } - return try self.spv.resolve(.{ .struct_type = .{ + const ty_ref = try self.spv.resolve(.{ .struct_type = .{ .member_types = &member_types, .member_names = &member_names, } }); + + entry.value_ptr.* = .{ + .ty_ref = ty_ref, + }; + return ty_ref; }, .Null,