diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 93bb7b19f8..22045a282d 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -451,12 +451,7 @@ pub const DeclGen = struct { return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits }); }; - const payload = try self.spv.arena.create(SpvType.Payload.Int); - payload.* = .{ - .width = backing_bits, - .signedness = signedness, - }; - return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + return try self.spv.resolveType(try SpvType.int(self.spv.arena, signedness, backing_bits)); } /// Turn a Zig type into a SPIR-V Type, and return a reference to it. @@ -495,11 +490,7 @@ pub const DeclGen = struct { return self.fail("Floating point width of {} bits is not supported for the current SPIR-V feature set", .{bits}); } - const payload = try self.spv.arena.create(SpvType.Payload.Float); - payload.* = .{ - .width = bits, - }; - return try self.spv.resolveType(SpvType.initPayload(&payload.base)); + return try self.spv.resolveType(SpvType.float(bits)); }, .Fn => { // TODO: Put this somewhere in Sema.zig diff --git a/src/codegen/spirv/Assembler.zig b/src/codegen/spirv/Assembler.zig index 3a35bc4db9..c4432102a0 100644 --- a/src/codegen/spirv/Assembler.zig +++ b/src/codegen/spirv/Assembler.zig @@ -266,27 +266,28 @@ fn processTypeInstruction(self: *Assembler) !AsmValue { .OpTypeVoid => SpvType.initTag(.void), .OpTypeBool => SpvType.initTag(.bool), .OpTypeInt => blk: { - const payload = try self.spv.arena.create(SpvType.Payload.Int); const signedness: std.builtin.Signedness = switch (operands[2].literal32) { 0 => .unsigned, 1 => .signed, else => { // TODO: Improve source location. - return self.fail(0, "'{}' is not a valid signedness (expected 0 or 1)", .{operands[2].literal32}); + return self.fail(0, "{} is not a valid signedness (expected 0 or 1)", .{operands[2].literal32}); }, }; - payload.* = .{ - .width = operands[1].literal32, - .signedness = signedness, + const width = std.math.cast(u16, operands[1].literal32) orelse { + return self.fail(0, "int type of {} bits is too large", .{operands[1].literal32}); }; - break :blk SpvType.initPayload(&payload.base); + break :blk try SpvType.int(self.spv.arena, signedness, width); }, .OpTypeFloat => blk: { - const payload = try self.spv.arena.create(SpvType.Payload.Float); - payload.* = .{ - .width = operands[1].literal32, - }; - break :blk SpvType.initPayload(&payload.base); + const bits = operands[1].literal32; + switch (bits) { + 16, 32, 64 => {}, + else => { + return self.fail(0, "{} is not a valid bit count for floats (expected 16, 32 or 64)", .{bits}); + }, + } + break :blk SpvType.float(@intCast(u16, bits)); }, .OpTypeVector => blk: { const payload = try self.spv.arena.create(SpvType.Payload.Vector); @@ -754,21 +755,18 @@ fn parseContextDependentNumber(self: *Assembler) !void { const tok = self.currentToken(); const result_type_ref = try self.resolveTypeRef(self.inst.operands.items[0].ref_id); const result_type = self.spv.type_cache.keys()[@enumToInt(result_type_ref)]; - switch (result_type.tag()) { - .int => { - const int = result_type.castTag(.int).?; - try self.parseContextDependentInt(int.signedness, int.width); - }, - .float => { - const width = result_type.castTag(.float).?.width; - switch (width) { - 16 => try self.parseContextDependentFloat(16), - 32 => try self.parseContextDependentFloat(32), - 64 => try self.parseContextDependentFloat(64), - else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}), - } - }, - else => return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())}), + if (result_type.isInt()) { + try self.parseContextDependentInt(result_type.intSignedness(), result_type.intFloatBits()); + } else if (result_type.isFloat()) { + const width = result_type.intFloatBits(); + switch (width) { + 16 => try self.parseContextDependentFloat(16), + 32 => try self.parseContextDependentFloat(32), + 64 => try self.parseContextDependentFloat(64), + else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}), + } + } else { + return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())}); } } diff --git a/src/codegen/spirv/Module.zig b/src/codegen/spirv/Module.zig index 2b62bcaf0e..6998d13f42 100644 --- a/src/codegen/spirv/Module.zig +++ b/src/codegen/spirv/Module.zig @@ -250,21 +250,30 @@ pub fn emitType(self: *Module, ty: Type) !IdResultType { switch (ty.tag()) { .void => try types.emit(self.gpa, .OpTypeVoid, result_id_operand), .bool => try types.emit(self.gpa, .OpTypeBool, result_id_operand), - .int => { - const signedness: spec.LiteralInteger = switch (ty.payload(.int).signedness) { + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .int, + => { + const signedness: spec.LiteralInteger = switch (ty.intSignedness()) { .unsigned => 0, .signed => 1, }; try types.emit(self.gpa, .OpTypeInt, .{ .id_result = result_id, - .width = ty.payload(.int).width, + .width = ty.intFloatBits(), .signedness = signedness, }); }, - .float => try types.emit(self.gpa, .OpTypeFloat, .{ + .f16, .f32, .f64 => try types.emit(self.gpa, .OpTypeFloat, .{ .id_result = result_id, - .width = ty.payload(.float).width, + .width = ty.intFloatBits(), }), .vector => try types.emit(self.gpa, .OpTypeVector, .{ .id_result = result_id, diff --git a/src/codegen/spirv/type.zig b/src/codegen/spirv/type.zig index 5ec013cb0a..a65ddca01e 100644 --- a/src/codegen/spirv/type.zig +++ b/src/codegen/spirv/type.zig @@ -3,6 +3,8 @@ const std = @import("std"); const assert = std.debug.assert; +const Signedness = std.builtin.Signedness; +const Allocator = std.mem.Allocator; const spec = @import("spec.zig"); @@ -23,6 +25,41 @@ pub const Type = extern union { return .{ .ptr_otherwise = pl }; } + pub fn int(arena: Allocator, signedness: Signedness, bits: u16) !Type { + const bits_and_signedness = switch (signedness) { + .signed => -@as(i32, bits), + .unsigned => @as(i32, bits), + }; + + return switch (bits_and_signedness) { + 8 => initTag(.u8), + 16 => initTag(.u16), + 32 => initTag(.u32), + 64 => initTag(.u64), + -8 => initTag(.i8), + -16 => initTag(.i16), + -32 => initTag(.i32), + -64 => initTag(.i64), + else => { + const int_payload = try arena.create(Payload.Int); + int_payload.* = .{ + .width = bits, + .signedness = signedness, + }; + return initPayload(&int_payload.base); + }, + }; + } + + pub fn float(bits: u16) Type { + return switch (bits) { + 16 => initTag(.f16), + 32 => initTag(.f32), + 64 => initTag(.f64), + else => unreachable, // Enable more types if required. + }; + } + pub fn tag(self: Type) Tag { if (@enumToInt(self.tag_if_small_enough) < Tag.no_payload_count) { return self.tag_if_small_enough; @@ -80,9 +117,19 @@ pub const Type = extern union { .queue, .pipe_storage, .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, => return true, .int, - .float, .vector, .matrix, .sampled_image, @@ -132,6 +179,17 @@ pub const Type = extern union { .queue, .pipe_storage, .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, => {}, else => self.hashPayload(@field(Tag, field.name), &hasher), } @@ -185,6 +243,53 @@ pub const Type = extern union { }; } + pub fn isInt(self: Type) bool { + return switch (self.tag()) { + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .int, + => true, + else => false, + }; + } + + pub fn isFloat(self: Type) bool { + return switch (self.tag()) { + .f16, .f32, .f64 => true, + else => false, + }; + } + + /// Returns the number of bits that make up an int or float type. + /// Asserts type is either int or float. + pub fn intFloatBits(self: Type) u16 { + return switch (self.tag()) { + .u8, .i8 => 8, + .u16, .i16, .f16 => 16, + .u32, .i32, .f32 => 32, + .u64, .i64, .f64 => 64, + .int => self.payload(.int).width, + else => unreachable, + }; + } + + /// Returns the signedness of an integer type. + /// Asserts that the type is an int. + pub fn intSignedness(self: Type) Signedness { + return switch (self.tag()) { + .u8, .u16, .u32, .u64 => .unsigned, + .i8, .i16, .i32, .i64 => .signed, + .int => self.payload(.int).signedness, + else => unreachable, + }; + } + pub const Tag = enum(usize) { void, bool, @@ -195,10 +300,20 @@ pub const Type = extern union { queue, pipe_storage, named_barrier, + u8, + u16, + u32, + u64, + i8, + i16, + i32, + i64, + f16, + f32, + f64, // After this, the tag requires a payload. int, - float, vector, matrix, image, @@ -211,14 +326,33 @@ pub const Type = extern union { function, pipe, - pub const last_no_payload_tag = Tag.named_barrier; + pub const last_no_payload_tag = Tag.f64; pub const no_payload_count = @enumToInt(last_no_payload_tag) + 1; pub fn Type(comptime t: Tag) type { return switch (t) { - .void, .bool, .sampler, .event, .device_event, .reserve_id, .queue, .pipe_storage, .named_barrier => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"), + .void, + .bool, + .sampler, + .event, + .device_event, + .reserve_id, + .queue, + .pipe_storage, + .named_barrier, + .u8, + .u16, + .u32, + .u64, + .i8, + .i16, + .i32, + .i64, + .f16, + .f32, + .f64, + => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"), .int => Payload.Int, - .float => Payload.Float, .vector => Payload.Vector, .matrix => Payload.Matrix, .image => Payload.Image, @@ -239,13 +373,8 @@ pub const Type = extern union { pub const Int = struct { base: Payload = .{ .tag = .int }, - width: u32, - signedness: std.builtin.Signedness, - }; - - pub const Float = struct { - base: Payload = .{ .tag = .float }, - width: u32, + width: u16, + signedness: Signedness, }; pub const Vector = struct {