From 77ef78a0ef00392c4e157ebc170d6c4d98f586fb Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sun, 21 Jan 2024 01:39:20 +0100 Subject: [PATCH] spirv: clean up arithmeticTypeInfo a bit - No longer returns an error - Returns more useful vector info --- src/codegen/spirv.zig | 69 +++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 56f1832b5d..cb3d1be8f0 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -373,8 +373,9 @@ const DeclGen = struct { /// For `composite_integer` this is 0 (TODO) backing_bits: u16, - /// Whether the type is a vector. - is_vector: bool, + /// Null if this type is a scalar, or the length + /// of the vector otherwise. + vector_len: ?u32, /// Whether the inner type is signed. Only relevant for integers. signedness: std.builtin.Signedness, @@ -597,32 +598,37 @@ const DeclGen = struct { return self.backingIntBits(ty) == null; } - fn arithmeticTypeInfo(self: *DeclGen, ty: Type) !ArithmeticTypeInfo { + fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo { const mod = self.module; const target = self.getTarget(); - return switch (ty.zigTypeTag(mod)) { + var scalar_ty = ty.scalarType(mod); + if (scalar_ty.zigTypeTag(mod) == .Enum) { + scalar_ty = scalar_ty.intTagType(mod); + } + const vector_len = if (ty.isVector(mod)) ty.vectorLen(mod) else null; + return switch (scalar_ty.zigTypeTag(mod)) { .Bool => ArithmeticTypeInfo{ .bits = 1, // Doesn't matter for this class. .backing_bits = self.backingIntBits(1).?, - .is_vector = false, + .vector_len = vector_len, .signedness = .unsigned, // Technically, but doesn't matter for this class. .class = .bool, }, .Float => ArithmeticTypeInfo{ - .bits = ty.floatBits(target), - .backing_bits = ty.floatBits(target), // TODO: F80? - .is_vector = false, + .bits = scalar_ty.floatBits(target), + .backing_bits = scalar_ty.floatBits(target), // TODO: F80? + .vector_len = vector_len, .signedness = .signed, // Technically, but doesn't matter for this class. .class = .float, }, .Int => blk: { - const int_info = ty.intInfo(mod); + const int_info = scalar_ty.intInfo(mod); // TODO: Maybe it's useful to also return this value. const maybe_backing_bits = self.backingIntBits(int_info.bits); break :blk ArithmeticTypeInfo{ .bits = int_info.bits, .backing_bits = maybe_backing_bits orelse 0, - .is_vector = false, + .vector_len = vector_len, .signedness = int_info.signedness, .class = if (maybe_backing_bits) |backing_bits| if (backing_bits == int_info.bits) @@ -633,22 +639,9 @@ const DeclGen = struct { .composite_integer, }; }, - .Enum => return self.arithmeticTypeInfo(ty.intTagType(mod)), - // As of yet, there is no vector support in the self-hosted compiler. - .Vector => blk: { - const child_type = ty.childType(mod); - const child_ty_info = try self.arithmeticTypeInfo(child_type); - break :blk ArithmeticTypeInfo{ - .bits = child_ty_info.bits, - .backing_bits = child_ty_info.backing_bits, - .is_vector = true, - .signedness = child_ty_info.signedness, - .class = child_ty_info.class, - }; - }, - // TODO: For which types is this the case? - // else => self.todo("implement arithmeticTypeInfo for {}", .{ty.fmt(self.module)}), - else => unreachable, + .Enum => unreachable, + .Vector => unreachable, + else => unreachable, // Unhandled arithmetic type }; } @@ -2336,7 +2329,7 @@ const DeclGen = struct { const shift_ty = self.typeOf(bin_op.rhs); const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct); - const info = try self.arithmeticTypeInfo(result_ty); + const info = self.arithmeticTypeInfo(result_ty); switch (info.class) { .composite_integer => return self.todo("shift ops for composite integers", .{}), .integer, .strange_integer => {}, @@ -2393,7 +2386,7 @@ const DeclGen = struct { fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef { const result_ty_ref = try self.resolveType(result_ty, .direct); - const info = try self.arithmeticTypeInfo(result_ty); + const info = self.arithmeticTypeInfo(result_ty); // TODO: Use fmin for OpenCL const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id); @@ -2516,7 +2509,7 @@ const DeclGen = struct { ) !IdRef { // Binary operations are generally applicable to both scalar and vector operations // in SPIR-V, but int and float versions of operations require different opcodes. - const info = try self.arithmeticTypeInfo(ty); + const info = self.arithmeticTypeInfo(ty); const opcode_index: usize = switch (info.class) { .composite_integer => { @@ -2579,7 +2572,7 @@ const DeclGen = struct { const bool_ty_ref = try self.resolveType(Type.bool, .direct); - const info = try self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(operand_ty); switch (info.class) { .composite_integer => return self.todo("overflow ops for composite integers", .{}), .strange_integer, .integer => {}, @@ -2693,7 +2686,7 @@ const DeclGen = struct { const bool_ty_ref = try self.resolveType(Type.bool, .direct); - const info = try self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(operand_ty); switch (info.class) { .composite_integer => return self.todo("overflow shift for composite integers", .{}), .integer, .strange_integer => {}, @@ -2777,7 +2770,7 @@ const DeclGen = struct { const scalar_ty_ref = try self.resolveType(scalar_ty, .direct); const scalar_ty_id = self.typeId(scalar_ty_ref); - const info = try self.arithmeticTypeInfo(operand_ty); + const info = self.arithmeticTypeInfo(operand_ty); var result_id = try self.extractField(scalar_ty, operand, 0); const len = operand_ty.vectorLen(mod); @@ -3093,7 +3086,7 @@ const DeclGen = struct { }; const opcode: Opcode = opcode: { - const info = try self.arithmeticTypeInfo(op_ty); + const info = self.arithmeticTypeInfo(op_ty); const signedness = switch (info.class) { .composite_integer => { return self.todo("binary operations for composite integers", .{}); @@ -3245,8 +3238,8 @@ const DeclGen = struct { const dst_ty = self.typeOfIndex(inst); const dst_ty_ref = try self.resolveType(dst_ty, .direct); - const src_info = try self.arithmeticTypeInfo(src_ty); - const dst_info = try self.arithmeticTypeInfo(dst_ty); + const src_info = self.arithmeticTypeInfo(src_ty); + const dst_info = self.arithmeticTypeInfo(dst_ty); if (src_info.backing_bits == dst_info.backing_bits) { return operand_id; @@ -3302,7 +3295,7 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_ty = self.typeOf(ty_op.operand); const operand_id = try self.resolve(ty_op.operand); - const operand_info = try self.arithmeticTypeInfo(operand_ty); + const operand_info = self.arithmeticTypeInfo(operand_ty); const dest_ty = self.typeOfIndex(inst); const dest_ty_id = try self.resolveTypeId(dest_ty); @@ -3328,7 +3321,7 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); const dest_ty = self.typeOfIndex(inst); - const dest_info = try self.arithmeticTypeInfo(dest_ty); + const dest_info = self.arithmeticTypeInfo(dest_ty); const dest_ty_id = try self.resolveTypeId(dest_ty); const result_id = self.spv.allocId(); @@ -3369,7 +3362,7 @@ const DeclGen = struct { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); const result_ty = self.typeOfIndex(inst); - const info = try self.arithmeticTypeInfo(result_ty); + const info = self.arithmeticTypeInfo(result_ty); var wip = try self.elementWise(result_ty); defer wip.deinit();