From fcb422585c1a9e91933ff998417eb8682a4ffbcc Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 29 May 2023 20:45:54 +0200 Subject: [PATCH] spirv: translate remaining types --- src/codegen/spirv.zig | 215 ++++++++++++++++++++++-- src/codegen/spirv/Module.zig | 19 +++ src/codegen/spirv/TypeConstantCache.zig | 16 +- 3 files changed, 227 insertions(+), 23 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index fa429c024b..b0e4c8e950 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -22,7 +22,8 @@ const IdResultType = spec.IdResultType; const StorageClass = spec.StorageClass; const SpvModule = @import("spirv/Module.zig"); -const SpvRef = SpvModule.TypeConstantCache.Ref; +const SpvCacheRef = SpvModule.TypeConstantCache.Ref; +const SpvCacheString = SpvModule.TypeConstantCache.String; const SpvSection = @import("spirv/Section.zig"); const SpvType = @import("spirv/type.zig").Type; @@ -1160,7 +1161,7 @@ pub const DeclGen = struct { return try self.spv.resolveType(try SpvType.int(self.spv.arena, signedness, backing_bits)); } - fn intType2(self: *DeclGen, signedness: std.builtin.Signedness, bits: u16) !SpvRef { + fn intType2(self: *DeclGen, signedness: std.builtin.Signedness, bits: u16) !SpvCacheRef { const backing_bits = self.backingIntBits(bits) orelse { // TODO: Integers too big for any native type are represented as "composite integers": // An array of largestSupportedIntBits. @@ -1177,7 +1178,7 @@ pub const DeclGen = struct { return try self.intType(.unsigned, self.getTarget().ptrBitWidth()); } - fn sizeType2(self: *DeclGen) !SpvRef { + fn sizeType2(self: *DeclGen) !SpvCacheRef { return try self.intType2(.unsigned, self.getTarget().ptrBitWidth()); } @@ -1256,7 +1257,91 @@ pub const DeclGen = struct { return try self.spv.simpleStructType(members.slice()); } - fn resolveType2(self: *DeclGen, ty: Type, repr: Repr) !SpvRef { + /// Generate a union type, optionally with a known field. If the tag alignment is greater + /// than that of the payload, a regular union (non-packed, with both tag and payload), will + /// be generated as follows: + /// If the active field is known: + /// struct { + /// tag: TagType, + /// payload: ActivePayloadType, + /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// padding: [padding_size]u8, + /// } + /// If the payload alignment is greater than that of the tag: + /// struct { + /// payload: ActivePayloadType, + /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// tag: TagType, + /// padding: [padding_size]u8, + /// } + /// If the active payload is unknown, it will default back to the most aligned field. This is + /// to make sure that the overal struct has the correct alignment in spir-v. + /// If any of the fields' size is 0, it will be omitted. + /// NOTE: When the active field is set to something other than the most aligned field, the + /// resulting struct will be *underaligned*. + fn resolveUnionType2(self: *DeclGen, ty: Type, maybe_active_field: ?usize) !SpvCacheRef { + const target = self.getTarget(); + const layout = ty.unionGetLayout(target); + const union_ty = ty.cast(Type.Payload.Union).?.data; + + if (union_ty.layout == .Packed) { + return self.todo("packed union types", .{}); + } + + if (layout.payload_size == 0) { + // No payload, so represent this as just the tag type. + return try self.resolveType2(union_ty.tag_ty, .indirect); + } + + var member_types = std.BoundedArray(SpvCacheRef, 4){}; + var member_names = std.BoundedArray(SpvCacheString, 4){}; + + const has_tag = layout.tag_size != 0; + const tag_first = layout.tag_align >= layout.payload_align; + const u8_ty_ref = try self.intType2(.unsigned, 8); // TODO: What if Int8Type is not enabled? + + if (has_tag and tag_first) { + const tag_ty_ref = try self.resolveType2(union_ty.tag_ty, .indirect); + member_types.appendAssumeCapacity(tag_ty_ref); + member_names.appendAssumeCapacity(try self.spv.resolveString("tag")); + } + + const active_field = maybe_active_field orelse layout.most_aligned_field; + const active_field_ty = union_ty.fields.values()[active_field].ty; + + const active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime()) blk: { + const active_payload_ty_ref = try self.resolveType2(active_field_ty, .indirect); + member_types.appendAssumeCapacity(active_payload_ty_ref); + member_names.appendAssumeCapacity(try self.spv.resolveString("payload")); + break :blk active_field_ty.abiSize(target); + } else 0; + + const payload_padding_len = layout.payload_size - active_field_size; + if (payload_padding_len != 0) { + const payload_padding_ty_ref = try self.spv.arrayType2(@intCast(u32, payload_padding_len), u8_ty_ref); + member_types.appendAssumeCapacity(payload_padding_ty_ref); + member_names.appendAssumeCapacity(try self.spv.resolveString("payload_padding")); + } + + if (has_tag and !tag_first) { + const tag_ty_ref = try self.resolveType2(union_ty.tag_ty, .indirect); + member_types.appendAssumeCapacity(tag_ty_ref); + member_names.appendAssumeCapacity(try self.spv.resolveString("tag")); + } + + if (layout.padding != 0) { + const padding_ty_ref = try self.spv.arrayType2(layout.padding, u8_ty_ref); + member_types.appendAssumeCapacity(padding_ty_ref); + member_names.appendAssumeCapacity(try self.spv.resolveString("padding")); + } + + return try self.spv.resolve(.{ .struct_type = .{ + .member_types = member_types.slice(), + .member_names = member_names.slice(), + } }); + } + + fn resolveType2(self: *DeclGen, ty: Type, repr: Repr) Error!SpvCacheRef { const target = self.getTarget(); switch (ty.zigTypeTag()) { .Void, .NoReturn => return try self.spv.resolve(.void_type), @@ -1297,15 +1382,7 @@ pub const DeclGen = struct { const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel()) orelse { return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel()}); }; - const len_ty_ref = try self.intType2(.unsigned, 32); - const len_ref = try self.spv.resolve(.{ .int = .{ - .ty = len_ty_ref, - .value = .{ .uint64 = total_len }, - } }); - return try self.spv.resolve(.{ .array_type = .{ - .element_type = elem_ty_ref, - .length = len_ref, - } }); + return self.spv.arrayType2(total_len, elem_ty_ref); }, .Fn => switch (repr) { .direct => { @@ -1313,7 +1390,7 @@ pub const DeclGen = struct { if (ty.fnIsVarArgs()) return self.fail("VarArgs functions are unsupported for SPIR-V", .{}); - const param_ty_refs = try self.gpa.alloc(SpvRef, ty.fnParamLen()); + const param_ty_refs = try self.gpa.alloc(SpvCacheRef, ty.fnParamLen()); defer self.gpa.free(param_ty_refs); for (param_ty_refs, 0..) |*param_type, i| { param_type.* = try self.resolveType2(ty.fnParamType(i), .direct); @@ -1360,8 +1437,116 @@ pub const DeclGen = struct { .component_count = @intCast(u32, ty.vectorLen()), } }); }, + .Struct => { + if (ty.isSimpleTupleOrAnonStruct()) { + unreachable; // TODO + } - else => unreachable, // TODO + const struct_ty = ty.castTag(.@"struct").?.data; + + if (struct_ty.layout == .Packed) { + return try self.resolveType2(struct_ty.backing_int_ty, .direct); + } + + const member_types = try self.gpa.alloc(SpvCacheRef, struct_ty.fields.count()); + defer self.gpa.free(member_types); + + const member_names = try self.gpa.alloc(SpvCacheString, struct_ty.fields.count()); + defer self.gpa.free(member_names); + + // const members = try self.spv.arena.alloc(SpvType.Payload.Struct.Member, struct_ty.fields.count()); + var member_index: usize = 0; + for (struct_ty.fields.values(), 0..) |field, i| { + if (field.is_comptime or !field.ty.hasRuntimeBits()) continue; + + member_types[member_index] = try self.resolveType2(field.ty, .indirect); + member_names[member_index] = try self.spv.resolveString(struct_ty.fields.keys()[i]); + member_index += 1; + } + + const name = try struct_ty.getFullyQualifiedName(self.module); + defer self.module.gpa.free(name); + + return try self.spv.resolve(.{ .struct_type = .{ + .name = try self.spv.resolveString(name), + .member_types = member_types[0..member_index], + .member_names = member_names[0..member_index], + } }); + }, + .Optional => { + var buf: Type.Payload.ElemType = undefined; + const payload_ty = ty.optionalChild(&buf); + if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { + // Just use a bool. + // Note: Always generate the bool with indirect format, to save on some sanity + // Perform the conversion to a direct bool when the field is extracted. + return try self.resolveType2(Type.bool, .indirect); + } + + const payload_ty_ref = try self.resolveType2(payload_ty, .indirect); + if (ty.optionalReprIsPayload()) { + // Optional is actually a pointer or a slice. + return payload_ty_ref; + } + + const bool_ty_ref = try self.resolveType2(Type.bool, .indirect); + + return try self.spv.resolve(.{ .struct_type = .{ + .member_types = &.{ payload_ty_ref, bool_ty_ref }, + .member_names = &.{ + try self.spv.resolveString("payload"), + try self.spv.resolveString("valid"), + }, + } }); + }, + .Union => return try self.resolveUnionType2(ty, null), + .ErrorSet => return try self.intType2(.unsigned, 16), + .ErrorUnion => { + const payload_ty = ty.errorUnionPayload(); + const error_ty_ref = try self.resolveType2(Type.anyerror, .indirect); + + const eu_layout = self.errorUnionLayout(payload_ty); + if (!eu_layout.payload_has_bits) { + return error_ty_ref; + } + + const payload_ty_ref = try self.resolveType2(payload_ty, .indirect); + + var member_types: [2]SpvCacheRef = undefined; + var member_names: [2]SpvCacheString = undefined; + if (eu_layout.error_first) { + // Put the error first + member_types = .{ error_ty_ref, payload_ty_ref }; + member_names = .{ + try self.spv.resolveString("error"), + try self.spv.resolveString("payload"), + }; + // TODO: ABI padding? + } else { + // Put the payload first. + member_types = .{ payload_ty_ref, error_ty_ref }; + member_names = .{ + try self.spv.resolveString("payload"), + try self.spv.resolveString("error"), + }; + // TODO: ABI padding? + } + + return try self.spv.resolve(.{ .struct_type = .{ + .member_types = &member_types, + .member_names = &member_names, + } }); + }, + + .Null, + .Undefined, + .EnumLiteral, + .ComptimeFloat, + .ComptimeInt, + .Type, + => unreachable, // Must be comptime. + + else => |tag| return self.todo("Implement zig type '{}'", .{tag}), } } diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 54b868aba2..85d0bbf78c 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -235,6 +235,10 @@ pub fn resolveId(self: *Module, key: TypeConstantCache.Key) !IdResult { return self.resultId(try self.resolve(key)); } +pub fn resolveString(self: *Module, str: []const u8) !TypeConstantCache.String { + return try self.tc_cache.addString(self, str); +} + fn orderGlobalsInto( self: *Module, decl_index: Decl.Index, @@ -769,6 +773,21 @@ pub fn simpleStructType(self: *Module, members: []const Type.Payload.Struct.Memb return try self.resolveType(Type.initPayload(&payload.base)); } +pub fn arrayType2(self: *Module, len: u32, elem_ty_ref: TypeConstantCache.Ref) !TypeConstantCache.Ref { + const len_ty_ref = try self.resolve(.{ .int_type = .{ + .signedness = .unsigned, + .bits = 32, + } }); + const len_ref = try self.resolve(.{ .int = .{ + .ty = len_ty_ref, + .value = .{ .uint64 = len }, + } }); + return try self.resolve(.{ .array_type = .{ + .element_type = elem_ty_ref, + .length = len_ref, + } }); +} + pub fn arrayType(self: *Module, len: u32, ty: Type.Ref) !Type.Ref { const payload = try self.arena.create(Type.Payload.Array); payload.* = .{ diff --git a/src/codegen/spirv/TypeConstantCache.zig b/src/codegen/spirv/TypeConstantCache.zig index b8258b8819..d5bc888276 100644 --- a/src/codegen/spirv/TypeConstantCache.zig +++ b/src/codegen/spirv/TypeConstantCache.zig @@ -263,7 +263,7 @@ pub const Key = union(enum) { pub const StructType = struct { // TODO: Decorations. /// The name of the structure. Can be `.none`. - name: String, + name: String = .none, /// The type of each member. member_types: []const Ref, /// Name for each member. May be omitted. @@ -922,14 +922,14 @@ pub const String = enum(u32) { self: *const Self, pub fn eql(ctx: @This(), a: []const u8, _: void, b_index: usize) bool { - const offset = ctx.self.string_map.values()[b_index]; + const offset = ctx.self.strings.values()[b_index]; const b = std.mem.sliceTo(ctx.self.string_bytes.items[offset..], 0); return std.mem.eql(u8, a, b); } pub fn hash(ctx: @This(), a: []const u8) u32 { _ = ctx; - const hasher = std.hash.Wyhash.init(0); + var hasher = std.hash.Wyhash.init(0); hasher.update(a); return @truncate(u32, hasher.final()); } @@ -937,16 +937,16 @@ pub const String = enum(u32) { }; /// Add a string to the cache. Must not contain any 0 values. -pub fn addString(self: *Self, spv: *Module, str: []const u8) String { +pub fn addString(self: *Self, spv: *Module, str: []const u8) !String { assert(std.mem.indexOfScalar(u8, str, 0) == null); const adapter = String.Adapter{ .self = self }; const entry = try self.strings.getOrPutAdapted(spv.gpa, str, adapter); if (!entry.found_existing) { const offset = self.string_bytes.items.len; - try self.string_bytes.ensureUnusedCapacity(1 + str.len); - self.string_bytes.appendAssumeCapacity(str); - self.string_bytes.append(0); - entry.value_ptr.* = offset; + try self.string_bytes.ensureUnusedCapacity(spv.gpa, 1 + str.len); + self.string_bytes.appendSliceAssumeCapacity(str); + self.string_bytes.appendAssumeCapacity(0); + entry.value_ptr.* = @intCast(u32, offset); } return @intToEnum(String, entry.index);