spirv: cache for floats

This commit is contained in:
Robin Voetter 2023-05-29 14:10:02 +02:00
parent b2a984cda6
commit aade6f1195
No known key found for this signature in database
GPG Key ID: E755662F227CB468

View File

@ -56,11 +56,40 @@ const Tag = enum {
type_array,
// -- Values
/// Value of type f16
/// data is value
float16,
/// Value of type f32
/// data is value
float32,
/// Value of type f64
/// data is payload to Float16
float64,
const SimpleType = enum { void, bool };
const VectorType = Key.VectorType;
const ArrayType = Key.ArrayType;
const Float64 = struct {
// Low-order 32 bits of the value.
low: u32,
// High-order 32 bits of the value.
high: u32,
fn encode(value: f64) Float64 {
const bits = @bitCast(u64, value);
return .{
.low = @truncate(u32, bits),
.high = @truncate(u32, bits >> 32),
};
}
fn decode(self: Float64) f64 {
const bits = @as(u64, self.low) | (@as(u64, self.high) << 32);
return @bitCast(f64, bits);
}
};
};
pub const Ref = enum(u32) { _ };
@ -79,6 +108,7 @@ pub const Key = union(enum) {
array_type: ArrayType,
// -- values
float: Float,
pub const IntType = std.builtin.Type.Int;
pub const FloatType = std.builtin.Type.Float;
@ -98,9 +128,33 @@ pub const Key = union(enum) {
stride: u32 = 0,
};
/// Represents a numberic value of some type.
pub const Float = struct {
/// The type: 16, 32, or 64-bit float.
ty: Ref,
/// The actual value.
value: Value,
pub const Value = union(enum) {
float16: f16,
float32: f32,
float64: f64,
};
};
fn hash(self: Key) u32 {
var hasher = std.hash.Wyhash.init(0);
std.hash.autoHash(&hasher, self);
switch (self) {
.float => |float| {
std.hash.autoHash(&hasher, float.ty);
switch (float.value) {
.float16 => |value| std.hash.autoHash(&hasher, @bitCast(u16, value)),
.float32 => |value| std.hash.autoHash(&hasher, @bitCast(u32, value)),
.float64 => |value| std.hash.autoHash(&hasher, @bitCast(u64, value)),
}
},
inline else => |key| std.hash.autoHash(&hasher, key),
}
return @truncate(u32, hasher.final());
}
@ -141,7 +195,7 @@ pub fn deinit(self: *Self, spv: *const Module) void {
/// This function returns a spir-v section of (only) constant and type instructions.
/// Additionally, decorations, debug names, etc, are all directly emitted into the
/// `spv` module. The section is allocated with `spv.gpa`.
pub fn materialize(self: *Self, spv: *Module) !Section {
pub fn materialize(self: *const Self, spv: *Module) !Section {
var section = Section{};
errdefer section.deinit(spv.gpa);
for (self.items.items(.result_id), 0..) |result_id, index| {
@ -151,7 +205,7 @@ pub fn materialize(self: *Self, spv: *Module) !Section {
}
fn emit(
self: *Self,
self: *const Self,
spv: *Module,
result_id: IdResult,
ref: Ref,
@ -206,9 +260,34 @@ fn emit(
try spv.decorate(result_id, .{ .ArrayStride = .{ .array_stride = array.stride } });
}
},
.float => |float| {
const ty_id = self.resultId(float.ty);
const lit: spec.LiteralContextDependentNumber = switch (float.value) {
.float16 => |value| .{ .uint32 = @bitCast(u16, value) },
.float32 => |value| .{ .float32 = value },
.float64 => |value| .{ .float64 = value },
};
try section.emit(spv.gpa, .OpConstant, .{
.id_result_type = ty_id,
.id_result = result_id,
.value = lit,
});
},
}
}
/// Get the ref for a key that has already been added to the cache.
fn get(self: *const Self, key: Key) Ref {
const adapter: Key.Adapter = .{ .self = self };
const index = self.map.getIndexAdapted(key, adapter).?;
return @intToEnum(Ref, index);
}
/// Get the result-id for a key that has already been added to the cache.
fn getId(self: *const Self, key: Key) IdResult {
return self.resultId(self.get(key));
}
/// Add a key to this cache. Returns a reference to the key that
/// was added. The corresponding result-id can be queried using
/// self.resultId with the result.
@ -251,6 +330,24 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
.result_id = result_id,
.data = try self.addExtra(spv, array),
},
.float => |float| switch (self.lookup(float.ty).float_type.bits) {
16 => .{
.tag = .float16,
.result_id = result_id,
.data = @bitCast(u16, float.value.float16),
},
32 => .{
.tag = .float32,
.result_id = result_id,
.data = @bitCast(u32, float.value.float32),
},
64 => .{
.tag = .float64,
.result_id = result_id,
.data = try self.addExtra(spv, Tag.Float64.encode(float.value.float64)),
},
else => unreachable,
},
};
try self.items.append(spv.gpa, item);
@ -285,6 +382,18 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
} },
.type_vector => .{ .vector_type = self.extraData(Tag.VectorType, data) },
.type_array => .{ .array_type = self.extraData(Tag.ArrayType, data) },
.float16 => .{ .float = .{
.ty = self.get(.{ .float_type = .{ .bits = 16 } }),
.value = .{ .float16 = @bitCast(f16, @intCast(u16, data)) },
} },
.float32 => .{ .float = .{
.ty = self.get(.{ .float_type = .{ .bits = 32 } }),
.value = .{ .float32 = @bitCast(f32, data) },
} },
.float64 => .{ .float = .{
.ty = self.get(.{ .float_type = .{ .bits = 32 } }),
.value = .{ .float64 = self.extraData(Tag.Float64, data).decode() },
} },
};
}