mirror of
https://github.com/ziglang/zig.git
synced 2026-02-03 13:13:40 +00:00
spirv: cache for floats
This commit is contained in:
parent
b2a984cda6
commit
aade6f1195
@ -56,11 +56,40 @@ const Tag = enum {
|
||||
type_array,
|
||||
|
||||
// -- Values
|
||||
/// Value of type f16
|
||||
/// data is value
|
||||
float16,
|
||||
/// Value of type f32
|
||||
/// data is value
|
||||
float32,
|
||||
/// Value of type f64
|
||||
/// data is payload to Float16
|
||||
float64,
|
||||
|
||||
const SimpleType = enum { void, bool };
|
||||
|
||||
const VectorType = Key.VectorType;
|
||||
const ArrayType = Key.ArrayType;
|
||||
|
||||
const Float64 = struct {
|
||||
// Low-order 32 bits of the value.
|
||||
low: u32,
|
||||
// High-order 32 bits of the value.
|
||||
high: u32,
|
||||
|
||||
fn encode(value: f64) Float64 {
|
||||
const bits = @bitCast(u64, value);
|
||||
return .{
|
||||
.low = @truncate(u32, bits),
|
||||
.high = @truncate(u32, bits >> 32),
|
||||
};
|
||||
}
|
||||
|
||||
fn decode(self: Float64) f64 {
|
||||
const bits = @as(u64, self.low) | (@as(u64, self.high) << 32);
|
||||
return @bitCast(f64, bits);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
pub const Ref = enum(u32) { _ };
|
||||
@ -79,6 +108,7 @@ pub const Key = union(enum) {
|
||||
array_type: ArrayType,
|
||||
|
||||
// -- values
|
||||
float: Float,
|
||||
|
||||
pub const IntType = std.builtin.Type.Int;
|
||||
pub const FloatType = std.builtin.Type.Float;
|
||||
@ -98,9 +128,33 @@ pub const Key = union(enum) {
|
||||
stride: u32 = 0,
|
||||
};
|
||||
|
||||
/// Represents a numberic value of some type.
|
||||
pub const Float = struct {
|
||||
/// The type: 16, 32, or 64-bit float.
|
||||
ty: Ref,
|
||||
/// The actual value.
|
||||
value: Value,
|
||||
|
||||
pub const Value = union(enum) {
|
||||
float16: f16,
|
||||
float32: f32,
|
||||
float64: f64,
|
||||
};
|
||||
};
|
||||
|
||||
fn hash(self: Key) u32 {
|
||||
var hasher = std.hash.Wyhash.init(0);
|
||||
std.hash.autoHash(&hasher, self);
|
||||
switch (self) {
|
||||
.float => |float| {
|
||||
std.hash.autoHash(&hasher, float.ty);
|
||||
switch (float.value) {
|
||||
.float16 => |value| std.hash.autoHash(&hasher, @bitCast(u16, value)),
|
||||
.float32 => |value| std.hash.autoHash(&hasher, @bitCast(u32, value)),
|
||||
.float64 => |value| std.hash.autoHash(&hasher, @bitCast(u64, value)),
|
||||
}
|
||||
},
|
||||
inline else => |key| std.hash.autoHash(&hasher, key),
|
||||
}
|
||||
return @truncate(u32, hasher.final());
|
||||
}
|
||||
|
||||
@ -141,7 +195,7 @@ pub fn deinit(self: *Self, spv: *const Module) void {
|
||||
/// This function returns a spir-v section of (only) constant and type instructions.
|
||||
/// Additionally, decorations, debug names, etc, are all directly emitted into the
|
||||
/// `spv` module. The section is allocated with `spv.gpa`.
|
||||
pub fn materialize(self: *Self, spv: *Module) !Section {
|
||||
pub fn materialize(self: *const Self, spv: *Module) !Section {
|
||||
var section = Section{};
|
||||
errdefer section.deinit(spv.gpa);
|
||||
for (self.items.items(.result_id), 0..) |result_id, index| {
|
||||
@ -151,7 +205,7 @@ pub fn materialize(self: *Self, spv: *Module) !Section {
|
||||
}
|
||||
|
||||
fn emit(
|
||||
self: *Self,
|
||||
self: *const Self,
|
||||
spv: *Module,
|
||||
result_id: IdResult,
|
||||
ref: Ref,
|
||||
@ -206,9 +260,34 @@ fn emit(
|
||||
try spv.decorate(result_id, .{ .ArrayStride = .{ .array_stride = array.stride } });
|
||||
}
|
||||
},
|
||||
.float => |float| {
|
||||
const ty_id = self.resultId(float.ty);
|
||||
const lit: spec.LiteralContextDependentNumber = switch (float.value) {
|
||||
.float16 => |value| .{ .uint32 = @bitCast(u16, value) },
|
||||
.float32 => |value| .{ .float32 = value },
|
||||
.float64 => |value| .{ .float64 = value },
|
||||
};
|
||||
try section.emit(spv.gpa, .OpConstant, .{
|
||||
.id_result_type = ty_id,
|
||||
.id_result = result_id,
|
||||
.value = lit,
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the ref for a key that has already been added to the cache.
|
||||
fn get(self: *const Self, key: Key) Ref {
|
||||
const adapter: Key.Adapter = .{ .self = self };
|
||||
const index = self.map.getIndexAdapted(key, adapter).?;
|
||||
return @intToEnum(Ref, index);
|
||||
}
|
||||
|
||||
/// Get the result-id for a key that has already been added to the cache.
|
||||
fn getId(self: *const Self, key: Key) IdResult {
|
||||
return self.resultId(self.get(key));
|
||||
}
|
||||
|
||||
/// Add a key to this cache. Returns a reference to the key that
|
||||
/// was added. The corresponding result-id can be queried using
|
||||
/// self.resultId with the result.
|
||||
@ -251,6 +330,24 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
|
||||
.result_id = result_id,
|
||||
.data = try self.addExtra(spv, array),
|
||||
},
|
||||
.float => |float| switch (self.lookup(float.ty).float_type.bits) {
|
||||
16 => .{
|
||||
.tag = .float16,
|
||||
.result_id = result_id,
|
||||
.data = @bitCast(u16, float.value.float16),
|
||||
},
|
||||
32 => .{
|
||||
.tag = .float32,
|
||||
.result_id = result_id,
|
||||
.data = @bitCast(u32, float.value.float32),
|
||||
},
|
||||
64 => .{
|
||||
.tag = .float64,
|
||||
.result_id = result_id,
|
||||
.data = try self.addExtra(spv, Tag.Float64.encode(float.value.float64)),
|
||||
},
|
||||
else => unreachable,
|
||||
},
|
||||
};
|
||||
try self.items.append(spv.gpa, item);
|
||||
|
||||
@ -285,6 +382,18 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
|
||||
} },
|
||||
.type_vector => .{ .vector_type = self.extraData(Tag.VectorType, data) },
|
||||
.type_array => .{ .array_type = self.extraData(Tag.ArrayType, data) },
|
||||
.float16 => .{ .float = .{
|
||||
.ty = self.get(.{ .float_type = .{ .bits = 16 } }),
|
||||
.value = .{ .float16 = @bitCast(f16, @intCast(u16, data)) },
|
||||
} },
|
||||
.float32 => .{ .float = .{
|
||||
.ty = self.get(.{ .float_type = .{ .bits = 32 } }),
|
||||
.value = .{ .float32 = @bitCast(f32, data) },
|
||||
} },
|
||||
.float64 => .{ .float = .{
|
||||
.ty = self.get(.{ .float_type = .{ .bits = 32 } }),
|
||||
.value = .{ .float64 = self.extraData(Tag.Float64, data).decode() },
|
||||
} },
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user