From 764f19034d9aa74ce2220937d090c60f8f8bf919 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 26 Mar 2023 18:56:49 +0200 Subject: [PATCH] spirv: union types/constants Implements lowering union types and constants in the SPIR-V backend. --- src/codegen/spirv.zig | 177 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 165 insertions(+), 12 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index fff872d359..080f2f645b 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -394,6 +394,16 @@ pub const DeclGen = struct { return result_id; } + fn genUndef(self: *DeclGen, ty_ref: SpvType.Ref) Error!IdRef { + const result_id = self.spv.allocId(); + try self.spv.sections.types_globals_constants.emit( + self.spv.gpa, + .OpUndef, + .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id }, + ); + return result_id; + } + fn constant(self: *DeclGen, ty: Type, val: Value, repr: Repr) Error!IdRef { const result_id = self.spv.allocId(); try self.genConstant(result_id, ty, val, repr); @@ -543,7 +553,7 @@ pub const DeclGen = struct { for (tuple.types, 0..) |field_ty, i| { const field_val = tuple.values[i]; if (field_val.tag() != .unreachable_value or !field_ty.hasRuntimeBits()) continue; - constituents[member_i] = try self.constant(field_ty, field_val, repr); + constituents[member_i] = try self.constant(field_ty, field_val, .indirect); member_i += 1; } @@ -561,7 +571,7 @@ pub const DeclGen = struct { var member_i: usize = 0; for (struct_ty.fields.values(), 0..) |field, i| { if (field.is_comptime or !field.ty.hasRuntimeBits()) continue; - constituents[member_i] = try self.constant(field.ty, field_vals[i], repr); + constituents[member_i] = try self.constant(field.ty, field_vals[i], .indirect); member_i += 1; } @@ -633,6 +643,67 @@ pub const DeclGen = struct { .constituents = &constituents, }); }, + .Union => { + const tag_and_val = val.castTag(.@"union").?.data; + const layout = ty.unionGetLayout(target); + + if (layout.payload_size == 0) { + return try self.genConstant(result_id, ty.unionTagTypeSafety().?, tag_and_val.tag, .indirect); + } + + const union_ty = ty.cast(Type.Payload.Union).?.data; + if (union_ty.layout == .Packed) { + return self.todo("packed union constants", .{}); + } + + const active_field = ty.unionTagFieldIndex(tag_and_val.tag, self.module).?; + const union_ty_ref = try self.resolveUnionType(ty, active_field); + const active_field_ty = union_ty.fields.values()[active_field].ty; + + const tag_first = layout.tag_align >= layout.payload_align; + const u8_ty_ref = try self.intType(.unsigned, 8); + + const tag = if (layout.tag_size != 0) + try self.constant(ty.unionTagTypeSafety().?, tag_and_val.tag, .indirect) + else + null; + + var members = std.BoundedArray(IdRef, 4){}; + + if (tag_first) { + if (tag) |id| members.appendAssumeCapacity(id); + } + + const active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime()) blk: { + const payload = try self.constant(active_field_ty, tag_and_val.val, .indirect); + members.appendAssumeCapacity(payload); + break :blk active_field_ty.abiSize(target); + } else 0; + + const payload_padding_len = layout.payload_size - active_field_size; + if (payload_padding_len != 0) { + const payload_padding_ty_ref = try self.arrayType(@intCast(u32, payload_padding_len), u8_ty_ref); + members.appendAssumeCapacity(try self.genUndef(payload_padding_ty_ref)); + } + + if (!tag_first) { + if (tag) |id| members.appendAssumeCapacity(id); + } + + if (layout.padding != 0) { + const padding_ty_ref = try self.arrayType(layout.padding, u8_ty_ref); + members.appendAssumeCapacity(try self.genUndef(padding_ty_ref)); + } + + try section.emit(self.spv.gpa, .OpSpecConstantComposite, .{ + .id_result_type = self.typeId(union_ty_ref), + .id_result = result_id, + .constituents = members.slice(), + }); + + // TODO: Cast to general union type? Required for pointers only or something? + }, + .Fn => switch (repr) { .direct => unreachable, .indirect => return self.todo("function pointers", .{}), @@ -691,6 +762,91 @@ pub const DeclGen = struct { return self.typeId(type_ref); } + /// Construct an array type which has 'len' elements of 'type' + fn arrayType(self: *DeclGen, len: u32, ty: SpvType.Ref) !SpvType.Ref { + const payload = try self.spv.arena.create(SpvType.Payload.Array); + payload.* = .{ + .element_type = ty, + .length = len, + }; + return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + } + + /// Generate a union type, optionally with a known field. If the tag alignment is greater + /// than that of the payload, a regular union (non-packed, with both tag and payload), will + /// be generated as follows: + /// If the active field is known: + /// struct { + /// tag: TagType, + /// payload: ActivePayloadType, + /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// padding: [padding_size]u8, + /// } + /// If the payload alignment is greater than that of the tag: + /// struct { + /// payload: ActivePayloadType, + /// payload_padding: [payload_size - @sizeOf(ActivePayloadType)]u8, + /// tag: TagType, + /// padding: [padding_size]u8, + /// } + /// If the active payload is unknown, it will default back to the most aligned field. This is + /// to make sure that the overal struct has the correct alignment in spir-v. + /// If any of the fields' size is 0, it will be omitted. + /// NOTE: When the active field is set to something other than the most aligned field, the + /// resulting struct will be *underaligned*. + fn resolveUnionType(self: *DeclGen, ty: Type, maybe_active_field: ?usize) !SpvType.Ref { + const target = self.getTarget(); + const layout = ty.unionGetLayout(target); + const union_ty = ty.cast(Type.Payload.Union).?.data; + + if (union_ty.layout == .Packed) { + return self.todo("packed union types", .{}); + } + + const tag_ty_ref = try self.resolveType(union_ty.tag_ty, .indirect); + if (layout.payload_size == 0) { + // No payload, so represent this as just the tag type. + return tag_ty_ref; + } + + var members = std.BoundedArray(SpvType.Payload.Struct.Member, 4){}; + + const has_tag = layout.tag_size != 0; + const tag_first = layout.tag_align >= layout.payload_align; + const tag_member = .{ .name = "tag", .ty = tag_ty_ref }; + const u8_ty_ref = try self.intType(.unsigned, 8); // TODO: What if Int8Type is not enabled? + + if (has_tag and tag_first) { + members.appendAssumeCapacity(tag_member); + } + + const active_field = maybe_active_field orelse layout.most_aligned_field; + const active_field_ty = union_ty.fields.values()[active_field].ty; + + const active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime()) blk: { + const active_payload_ty_ref = try self.resolveType(active_field_ty, .indirect); + members.appendAssumeCapacity(.{ .name = "payload", .ty = active_payload_ty_ref }); + break :blk active_field_ty.abiSize(target); + } else 0; + + const payload_padding_len = layout.payload_size - active_field_size; + if (payload_padding_len != 0) { + const payload_padding_ty_ref = try self.arrayType(@intCast(u32, payload_padding_len), u8_ty_ref); + members.appendAssumeCapacity(.{ .name = "padding_payload", .ty = payload_padding_ty_ref }); + } + + if (has_tag and !tag_first) { + members.appendAssumeCapacity(tag_member); + } + + if (layout.padding != 0) { + const padding_ty_ref = try self.arrayType(layout.padding, u8_ty_ref); + members.appendAssumeCapacity(.{ .name = "padding", .ty = padding_ty_ref }); + } + + return try self.simpleStructType(members.slice()); + } + /// Turn a Zig type into a SPIR-V Type, and return a reference to it. fn resolveType(self: *DeclGen, ty: Type, repr: Repr) Error!SpvType.Ref { log.debug("resolveType: ty = {}", .{ty.fmtDebug()}); @@ -733,16 +889,11 @@ pub const DeclGen = struct { }, .Array => { const elem_ty = ty.childType(); + const elem_ty_ref = try self.resolveType(elem_ty, .indirect); const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel()) orelse { return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel()}); }; - - const payload = try self.spv.arena.create(SpvType.Payload.Array); - payload.* = .{ - .element_type = try self.resolveType(elem_ty, repr), - .length = total_len, - }; - return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + return try self.arrayType(total_len, elem_ty_ref); }, .Fn => { // TODO: Put this somewhere in Sema.zig @@ -809,7 +960,7 @@ pub const DeclGen = struct { const field_val = tuple.values[i]; if (field_val.tag() != .unreachable_value or !field_ty.hasRuntimeBitsIgnoreComptime()) continue; members[member_index] = .{ - .ty = try self.resolveType(field_ty, repr), + .ty = try self.resolveType(field_ty, .indirect), }; member_index += 1; } @@ -823,7 +974,7 @@ pub const DeclGen = struct { const struct_ty = ty.castTag(.@"struct").?.data; if (struct_ty.layout == .Packed) { - return try self.resolveType(struct_ty.backing_int_ty, repr); + return try self.resolveType(struct_ty.backing_int_ty, .indirect); } const members = try self.spv.arena.alloc(SpvType.Payload.Struct.Member, struct_ty.fields.count()); @@ -832,7 +983,7 @@ pub const DeclGen = struct { if (field.is_comptime or !field.ty.hasRuntimeBits()) continue; members[member_index] = .{ - .ty = try self.resolveType(field.ty, repr), + .ty = try self.resolveType(field.ty, .indirect), .name = struct_ty.fields.keys()[i], }; member_index += 1; @@ -872,6 +1023,8 @@ pub const DeclGen = struct { .{ .ty = bool_ty_ref, .name = "valid" }, }); }, + .Union => return try self.resolveUnionType(ty, null), + .Null, .Undefined, .EnumLiteral,