From 3eafe3033ef83e5b34e3ccbd6e803c7a046df390 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 26 Nov 2022 12:23:07 +0100 Subject: [PATCH] spirv: improve storage efficiency for integer and float types In practice there are only a few variations of these types allowed, so it kind-of makes sense to write them all out. Because the types are hashed this does not actually save all that many bytes in the long run, though. Perhaps some of these types should be pre-registered? --- src/codegen/spirv.zig | 13 +-- src/codegen/spirv/Assembler.zig | 50 +++++------ src/codegen/spirv/Module.zig | 19 ++-- src/codegen/spirv/type.zig | 153 +++++++++++++++++++++++++++++--- 4 files changed, 181 insertions(+), 54 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 93bb7b19f8..22045a282d 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -451,12 +451,7 @@ pub const DeclGen = struct { return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits }); }; - const payload = try self.spv.arena.create(SpvType.Payload.Int); - payload.* = .{ - .width = backing_bits, - .signedness = signedness, - }; - return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + return try self.spv.resolveType(try SpvType.int(self.spv.arena, signedness, backing_bits)); } /// Turn a Zig type into a SPIR-V Type, and return a reference to it. @@ -495,11 +490,7 @@ pub const DeclGen = struct { return self.fail("Floating point width of {} bits is not supported for the current SPIR-V feature set", .{bits}); } - const payload = try self.spv.arena.create(SpvType.Payload.Float); - payload.* = .{ - .width = bits, - }; - return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + return try self.spv.resolveType(SpvType.float(bits)); }, .Fn => { // TODO: Put this somewhere in Sema.zig diff --git a/src/codegen/spirv/Assembler.zig b/src/codegen/spirv/Assembler.zig index 3a35bc4db9..c4432102a0 100644 --- a/src/codegen/spirv/Assembler.zig +++ b/src/codegen/spirv/Assembler.zig @@ -266,27 +266,28 @@ fn processTypeInstruction(self: *Assembler) !AsmValue { .OpTypeVoid => SpvType.initTag(.void), .OpTypeBool => SpvType.initTag(.bool), .OpTypeInt => blk: { - const payload = try self.spv.arena.create(SpvType.Payload.Int); const signedness: std.builtin.Signedness = switch (operands[2].literal32) { 0 => .unsigned, 1 => .signed, else => { // TODO: Improve source location. - return self.fail(0, "'{}' is not a valid signedness (expected 0 or 1)", .{operands[2].literal32}); + return self.fail(0, "{} is not a valid signedness (expected 0 or 1)", .{operands[2].literal32}); }, }; - payload.* = .{ - .width = operands[1].literal32, - .signedness = signedness, + const width = std.math.cast(u16, operands[1].literal32) orelse { + return self.fail(0, "int type of {} bits is too large", .{operands[1].literal32}); }; - break :blk SpvType.initPayload(&payload.base); + break :blk try SpvType.int(self.spv.arena, signedness, width); }, .OpTypeFloat => blk: { - const payload = try self.spv.arena.create(SpvType.Payload.Float); - payload.* = .{ - .width = operands[1].literal32, - }; - break :blk SpvType.initPayload(&payload.base); + const bits = operands[1].literal32; + switch (bits) { + 16, 32, 64 => {}, + else => { + return self.fail(0, "{} is not a valid bit count for floats (expected 16, 32 or 64)", .{bits}); + }, + } + break :blk SpvType.float(@intCast(u16, bits)); }, .OpTypeVector => blk: { const payload = try self.spv.arena.create(SpvType.Payload.Vector); @@ -754,21 +755,18 @@ fn parseContextDependentNumber(self: *Assembler) !void { const tok = self.currentToken(); const result_type_ref = try self.resolveTypeRef(self.inst.operands.items[0].ref_id); const result_type = self.spv.type_cache.keys()[@enumToInt(result_type_ref)]; - switch (result_type.tag()) { - .int => { - const int = result_type.castTag(.int).?; - try self.parseContextDependentInt(int.signedness, int.width); - }, - .float => { - const width = result_type.castTag(.float).?.width; - switch (width) { - 16 => try self.parseContextDependentFloat(16), - 32 => try self.parseContextDependentFloat(32), - 64 => try self.parseContextDependentFloat(64), - else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}), - } - }, - else => return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())}), + if (result_type.isInt()) { + try self.parseContextDependentInt(result_type.intSignedness(), result_type.intFloatBits()); + } else if (result_type.isFloat()) { + const width = result_type.intFloatBits(); + switch (width) { + 16 => try self.parseContextDependentFloat(16), + 32 => try self.parseContextDependentFloat(32), + 64 => try self.parseContextDependentFloat(64), + else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}), + } + } else { + return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())}); } } diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 2b62bcaf0e..6998d13f42 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -250,21 +250,30 @@ pub fn emitType(self: *Module, ty: Type) !IdResultType { switch (ty.tag()) { .void => try types.emit(self.gpa, .OpTypeVoid, result_id_operand), .bool => try types.emit(self.gpa, .OpTypeBool, result_id_operand), - .int => { - const signedness: spec.LiteralInteger = switch (ty.payload(.int).signedness) { + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .int, + => { + const signedness: spec.LiteralInteger = switch (ty.intSignedness()) { .unsigned => 0, .signed => 1, }; try types.emit(self.gpa, .OpTypeInt, .{ .id_result = result_id, - .width = ty.payload(.int).width, + .width = ty.intFloatBits(), .signedness = signedness, }); }, - .float => try types.emit(self.gpa, .OpTypeFloat, .{ + .f16, .f32, .f64 => try types.emit(self.gpa, .OpTypeFloat, .{ .id_result = result_id, - .width = ty.payload(.float).width, + .width = ty.intFloatBits(), }), .vector => try types.emit(self.gpa, .OpTypeVector, .{ .id_result = result_id, diff --git a/src/codegen/spirv/type.zig b/src/codegen/spirv/type.zig index 5ec013cb0a..a65ddca01e 100644 --- a/src/codegen/spirv/type.zig +++ b/src/codegen/spirv/type.zig @@ -3,6 +3,8 @@ const std = @import("std"); const assert = std.debug.assert; +const Signedness = std.builtin.Signedness; +const Allocator = std.mem.Allocator; const spec = @import("spec.zig"); @@ -23,6 +25,41 @@ pub const Type = extern union { return .{ .ptr_otherwise = pl }; } + pub fn int(arena: Allocator, signedness: Signedness, bits: u16) !Type { + const bits_and_signedness = switch (signedness) { + .signed => -@as(i32, bits), + .unsigned => @as(i32, bits), + }; + + return switch (bits_and_signedness) { + 8 => initTag(.u8), + 16 => initTag(.u16), + 32 => initTag(.u32), + 64 => initTag(.u64), + -8 => initTag(.i8), + -16 => initTag(.i16), + -32 => initTag(.i32), + -64 => initTag(.i64), + else => { + const int_payload = try arena.create(Payload.Int); + int_payload.* = .{ + .width = bits, + .signedness = signedness, + }; + return initPayload(&int_payload.base); + }, + }; + } + + pub fn float(bits: u16) Type { + return switch (bits) { + 16 => initTag(.f16), + 32 => initTag(.f32), + 64 => initTag(.f64), + else => unreachable, // Enable more types if required. + }; + } + pub fn tag(self: Type) Tag { if (@enumToInt(self.tag_if_small_enough) < Tag.no_payload_count) { return self.tag_if_small_enough; @@ -80,9 +117,19 @@ pub const Type = extern union { .queue, .pipe_storage, .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, => return true, .int, - .float, .vector, .matrix, .sampled_image, @@ -132,6 +179,17 @@ pub const Type = extern union { .queue, .pipe_storage, .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, => {}, else => self.hashPayload(@field(Tag, field.name), &hasher), } @@ -185,6 +243,53 @@ pub const Type = extern union { }; } + pub fn isInt(self: Type) bool { + return switch (self.tag()) { + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .int, + => true, + else => false, + }; + } + + pub fn isFloat(self: Type) bool { + return switch (self.tag()) { + .f16, .f32, .f64 => true, + else => false, + }; + } + + /// Returns the number of bits that make up an int or float type. + /// Asserts type is either int or float. + pub fn intFloatBits(self: Type) u16 { + return switch (self.tag()) { + .u8, .i8 => 8, + .u16, .i16, .f16 => 16, + .u32, .i32, .f32 => 32, + .u64, .i64, .f64 => 64, + .int => self.payload(.int).width, + else => unreachable, + }; + } + + /// Returns the signedness of an integer type. + /// Asserts that the type is an int. + pub fn intSignedness(self: Type) Signedness { + return switch (self.tag()) { + .u8, .u16, .u32, .u64 => .unsigned, + .i8, .i16, .i32, .i64 => .signed, + .int => self.payload(.int).signedness, + else => unreachable, + }; + } + pub const Tag = enum(usize) { void, bool, @@ -195,10 +300,20 @@ pub const Type = extern union { queue, pipe_storage, named_barrier, + u8, + u16, + u32, + u64, + i8, + i16, + i32, + i64, + f16, + f32, + f64, // After this, the tag requires a payload. int, - float, vector, matrix, image, @@ -211,14 +326,33 @@ pub const Type = extern union { function, pipe, - pub const last_no_payload_tag = Tag.named_barrier; + pub const last_no_payload_tag = Tag.f64; pub const no_payload_count = @enumToInt(last_no_payload_tag) + 1; pub fn Type(comptime t: Tag) type { return switch (t) { - .void, .bool, .sampler, .event, .device_event, .reserve_id, .queue, .pipe_storage, .named_barrier => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"), + .void, + .bool, + .sampler, + .event, + .device_event, + .reserve_id, + .queue, + .pipe_storage, + .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, + => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"), .int => Payload.Int, - .float => Payload.Float, .vector => Payload.Vector, .matrix => Payload.Matrix, .image => Payload.Image, @@ -239,13 +373,8 @@ pub const Type = extern union { pub const Int = struct { base: Payload = .{ .tag = .int }, - width: u32, - signedness: std.builtin.Signedness, - }; - - pub const Float = struct { - base: Payload = .{ .tag = .float }, - width: u32, + width: u16, + signedness: Signedness, }; pub const Vector = struct {