spirv: improve generator

The spirv spec generator now also generates some support information:
Opcode gains a function to query a Zig type representing the operands
of the opcode. The idea is that this will enable a richer interface for
emitting spirv instructions.
This commit is contained in:
Robin Voetter 2022-01-21 15:05:02 +01:00
parent 0e6d2184ca
commit ff042e8006

View File

@ -1,5 +1,8 @@
const std = @import("std"); const std = @import("std");
const g = @import("spirv/grammar.zig"); const g = @import("spirv/grammar.zig");
const Allocator = std.mem.Allocator;
const ExtendedStructSet = std.StringHashMap(void);
pub fn main() !void { pub fn main() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
@ -20,101 +23,308 @@ pub fn main() !void {
var tokens = std.json.TokenStream.init(spec); var tokens = std.json.TokenStream.init(spec);
var registry = try std.json.parse(g.Registry, &tokens, .{ .allocator = allocator }); var registry = try std.json.parse(g.Registry, &tokens, .{ .allocator = allocator });
const core_reg = switch (registry) {
.core => |core_reg| core_reg,
.extension => return error.TODOSpirVExtensionSpec,
};
var bw = std.io.bufferedWriter(std.io.getStdOut().writer()); var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
try render(bw.writer(), registry); try render(bw.writer(), allocator, core_reg);
try bw.flush(); try bw.flush();
} }
fn render(writer: anytype, registry: g.Registry) !void { /// Returns a set with types that require an extra struct for the `Instruction` interface
/// to the spir-v spec, or whether the original type can be used.
fn extendedStructs(
arena: Allocator,
kinds: []const g.OperandKind,
) !ExtendedStructSet {
var map = ExtendedStructSet.init(arena);
try map.ensureTotalCapacity(@intCast(u32, kinds.len));
for (kinds) |kind| {
const enumerants = kind.enumerants orelse continue;
for (enumerants) |enumerant| {
if (enumerant.parameters.len > 0) {
break;
}
} else continue;
map.putAssumeCapacity(kind.kind, {});
}
return map;
}
// Return a score for a particular priority. Duplicate instruction/operand enum values are
// removed by picking the tag with the lowest score to keep, and by making an alias for the
// other. Note that the tag does not need to be just a tag at this point, in which case it
// gets the lowest score automatically anyway.
fn tagPriorityScore(tag: []const u8) usize {
if (tag.len == 0) {
return 1;
} else if (std.mem.eql(u8, tag, "EXT")) {
return 2;
} else if (std.mem.eql(u8, tag, "KHR")) {
return 3;
} else {
return 4;
}
}
fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void {
try writer.writeAll( try writer.writeAll(
\\//! This file is auto-generated by tools/gen_spirv_spec.zig. \\//! This file is auto-generated by tools/gen_spirv_spec.zig.
\\ \\
\\const Version = @import("std").builtin.Version; \\const Version = @import("std").builtin.Version;
\\ \\
\\pub const Word = u32;
\\pub const IdResultType = struct{
\\ id: Word,
\\ pub fn toRef(self: IdResultType) IdRef {
\\ return .{.id = self.id};
\\ }
\\};
\\pub const IdResult = struct{
\\ id: Word,
\\ pub fn toRef(self: IdResult) IdRef {
\\ return .{.id = self.id};
\\ }
\\ pub fn toResultType(self: IdResult) IdResultType {
\\ return .{.id = self.id};
\\ }
\\};
\\pub const IdRef = struct{ id: Word };
\\
\\pub const IdMemorySemantics = IdRef;
\\pub const IdScope = IdRef;
\\
\\pub const LiteralInteger = Word;
\\pub const LiteralString = []const u8;
\\pub const LiteralContextDependentNumber = union(enum) {
\\ int32: i32,
\\ uint32: u32,
\\ int64: i64,
\\ uint64: u64,
\\ float32: f32,
\\ float64: f64,
\\};
\\pub const LiteralExtInstInteger = struct{ inst: Word };
\\pub const LiteralSpecConstantOpInteger = struct { opcode: Opcode };
\\pub const PairLiteralIntegerIdRef = struct { value: LiteralInteger, label: IdRef };
\\pub const PairIdRefLiteralInteger = struct { target: IdRef, member: LiteralInteger };
\\pub const PairIdRefIdRef = [2]IdRef;
\\
\\
); );
switch (registry) {
.core => |core_reg| {
try writer.print( try writer.print(
\\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }}; \\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }};
\\pub const magic_number: u32 = {s}; \\pub const magic_number: Word = {s};
\\ \\
, ,
.{ core_reg.major_version, core_reg.minor_version, core_reg.revision, core_reg.magic_number }, .{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number },
); );
try renderOpcodes(writer, core_reg.instructions); const extended_structs = try extendedStructs(allocator, registry.operand_kinds);
try renderOperandKinds(writer, core_reg.operand_kinds); try renderOpcodes(writer, allocator, registry.instructions, extended_structs);
}, try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs);
.extension => |ext_reg| { }
try writer.print(
\\pub const version = Version{{ .major = {}, .minor = 0, .patch = {} }}; fn renderOpcodes(
\\ writer: anytype,
, allocator: Allocator,
.{ ext_reg.version, ext_reg.revision }, instructions: []const g.Instruction,
); extended_structs: ExtendedStructSet,
try renderOpcodes(writer, ext_reg.instructions); ) !void {
try renderOperandKinds(writer, ext_reg.operand_kinds); var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator);
}, try inst_map.ensureTotalCapacity(instructions.len);
var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator);
try aliases.ensureTotalCapacity(instructions.len);
for (instructions) |inst, i| {
const result = inst_map.getOrPutAssumeCapacity(inst.opcode);
if (!result.found_existing) {
result.value_ptr.* = i;
continue;
}
const existing = instructions[result.value_ptr.*];
const tag_index = std.mem.indexOfDiff(u8, inst.opname, existing.opname).?;
const inst_priority = tagPriorityScore(inst.opname[tag_index..]);
const existing_priority = tagPriorityScore(existing.opname[tag_index..]);
if (inst_priority < existing_priority) {
aliases.appendAssumeCapacity(.{ .inst = result.value_ptr.*, .alias = i });
result.value_ptr.* = i;
} else {
aliases.appendAssumeCapacity(.{ .inst = i, .alias = result.value_ptr.* });
} }
} }
fn renderOpcodes(writer: anytype, instructions: []const g.Instruction) !void { const instructions_indices = inst_map.values();
try writer.writeAll("pub const Opcode = extern enum(u16) {\n");
for (instructions) |instr| { try writer.writeAll("pub const Opcode = enum(u16) {\n");
try writer.print(" {} = {},\n", .{ std.zig.fmtId(instr.opname), instr.opcode }); for (instructions_indices) |i| {
} const inst = instructions[i];
try writer.writeAll(" _,\n};\n"); try writer.print("{} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode });
} }
fn renderOperandKinds(writer: anytype, kinds: []const g.OperandKind) !void { try writer.writeByte('\n');
for (aliases.items) |alias| {
try writer.print("pub const {} = Opcode.{};\n", .{
std.zig.fmtId(instructions[alias.inst].opname),
std.zig.fmtId(instructions[alias.alias].opname),
});
}
try writer.writeAll(
\\
\\pub fn Operands(comptime self: Opcode) type {
\\return switch (self) {
\\
);
for (instructions_indices) |i| {
const inst = instructions[i];
try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs);
}
try writer.writeAll("};\n}\n};\n");
_ = extended_structs;
}
fn renderOperandKinds(
writer: anytype,
allocator: Allocator,
kinds: []const g.OperandKind,
extended_structs: ExtendedStructSet,
) !void {
for (kinds) |kind| { for (kinds) |kind| {
switch (kind.category) { switch (kind.category) {
.ValueEnum => try renderValueEnum(writer, kind), .ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs),
.BitEnum => try renderBitEnum(writer, kind), .BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs),
else => {}, else => {},
} }
} }
} }
fn renderValueEnum(writer: anytype, enumeration: g.OperandKind) !void { fn renderValueEnum(
try writer.print("pub const {s} = extern enum(u32) {{\n", .{enumeration.kind}); writer: anytype,
allocator: Allocator,
enumeration: g.OperandKind,
extended_structs: ExtendedStructSet,
) !void {
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
for (enumerants) |enumerant| {
var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator);
try enum_map.ensureTotalCapacity(enumerants.len);
var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator);
try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants) |enumerant, i| {
const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int);
if (!result.found_existing) {
result.value_ptr.* = i;
continue;
}
const existing = enumerants[result.value_ptr.*];
const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, existing.enumerant).?;
const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]);
const existing_priority = tagPriorityScore(existing.enumerant[tag_index..]);
if (enum_priority < existing_priority) {
aliases.appendAssumeCapacity(.{ .enumerant = result.value_ptr.*, .alias = i });
result.value_ptr.* = i;
} else {
aliases.appendAssumeCapacity(.{ .enumerant = i, .alias = result.value_ptr.* });
}
}
const enum_indices = enum_map.values();
try writer.print("pub const {s} = enum(u32) {{\n", .{std.zig.fmtId(enumeration.kind)});
for (enum_indices) |i| {
const enumerant = enumerants[i];
if (enumerant.value != .int) return error.InvalidRegistry; if (enumerant.value != .int) return error.InvalidRegistry;
try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int }); try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int });
} }
try writer.writeAll(" _,\n};\n"); try writer.writeByte('\n');
for (aliases.items) |alias| {
try writer.print("pub const {} = {}.{};\n", .{
std.zig.fmtId(enumerants[alias.enumerant].enumerant),
std.zig.fmtId(enumeration.kind),
std.zig.fmtId(enumerants[alias.alias].enumerant),
});
} }
fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void { if (!extended_structs.contains(enumeration.kind)) {
try writer.print("pub const {s} = packed struct {{\n", .{enumeration.kind}); try writer.writeAll("};\n");
return;
}
var flags_by_bitpos = [_]?[]const u8{null} ** 32; try writer.print("\npub const Extended = union({}) {{\n", .{std.zig.fmtId(enumeration.kind)});
for (enum_indices) |i| {
const enumerant = enumerants[i];
try renderOperand(writer, .@"union", enumerant.enumerant, enumerant.parameters, extended_structs);
}
try writer.writeAll("};\n};\n");
}
fn renderBitEnum(
writer: anytype,
allocator: Allocator,
enumeration: g.OperandKind,
extended_structs: ExtendedStructSet,
) !void {
try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)});
var flags_by_bitpos = [_]?usize{null} ** 32;
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry; const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
for (enumerants) |enumerant| {
var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator);
try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants) |enumerant, i| {
if (enumerant.value != .bitflag) return error.InvalidRegistry; if (enumerant.value != .bitflag) return error.InvalidRegistry;
const value = try parseHexInt(enumerant.value.bitflag); const value = try parseHexInt(enumerant.value.bitflag);
if (@popCount(u32, value) != 1) { if (@popCount(u32, value) == 0) {
continue; // Skip combinations and 'none' items continue; // Skip 'none' items
} }
std.debug.assert(@popCount(u32, value) == 1);
var bitpos = std.math.log2_int(u32, value); var bitpos = std.math.log2_int(u32, value);
if (flags_by_bitpos[bitpos]) |*existing| { if (flags_by_bitpos[bitpos]) |*existing| {
// Keep the shortest const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, enumerants[existing.*].enumerant).?;
if (enumerant.enumerant.len < existing.len) const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]);
existing.* = enumerant.enumerant; const existing_priority = tagPriorityScore(enumerants[existing.*].enumerant[tag_index..]);
if (enum_priority < existing_priority) {
aliases.appendAssumeCapacity(.{ .flag = existing.*, .alias = bitpos });
existing.* = i;
} else { } else {
flags_by_bitpos[bitpos] = enumerant.enumerant; aliases.appendAssumeCapacity(.{ .flag = i, .alias = bitpos });
}
} else {
flags_by_bitpos[bitpos] = i;
} }
} }
for (flags_by_bitpos) |maybe_flag_name, bitpos| { for (flags_by_bitpos) |maybe_flag_index, bitpos| {
try writer.writeAll(" "); if (maybe_flag_index) |flag_index| {
if (maybe_flag_name) |flag_name| { try writer.print("{}", .{std.zig.fmtId(enumerants[flag_index].enumerant)});
try writer.writeAll(flag_name);
} else { } else {
try writer.print("_reserved_bit_{}", .{bitpos}); try writer.print("_reserved_bit_{}", .{bitpos});
} }
@ -126,7 +336,169 @@ fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void {
try writer.writeAll("= false,\n"); try writer.writeAll("= false,\n");
} }
try writer.writeByte('\n');
for (aliases.items) |alias| {
try writer.print("pub const {}: {} = .{{.{} = true}};\n", .{
std.zig.fmtId(enumerants[alias.flag].enumerant),
std.zig.fmtId(enumeration.kind),
std.zig.fmtId(enumerants[flags_by_bitpos[alias.alias].?].enumerant),
});
}
if (!extended_structs.contains(enumeration.kind)) {
try writer.writeAll("};\n"); try writer.writeAll("};\n");
return;
}
try writer.print("\npub const Extended = struct {{\n", .{});
for (flags_by_bitpos) |maybe_flag_index, bitpos| {
const flag_index = maybe_flag_index orelse {
try writer.print("_reserved_bit_{}: bool = false,\n", .{bitpos});
continue;
};
const enumerant = enumerants[flag_index];
try renderOperand(writer, .mask, enumerant.enumerant, enumerant.parameters, extended_structs);
}
try writer.writeAll("};\n};\n");
}
fn renderOperand(
writer: anytype,
kind: enum {
@"union",
instruction,
mask,
},
field_name: []const u8,
parameters: []const g.Operand,
extended_structs: ExtendedStructSet,
) !void {
if (kind == .instruction) {
try writer.writeByte('.');
}
try writer.print("{}", .{std.zig.fmtId(field_name)});
if (parameters.len == 0) {
switch (kind) {
.@"union" => try writer.writeAll(",\n"),
.instruction => try writer.writeAll(" => void,\n"),
.mask => try writer.writeAll(": bool = false,\n"),
}
return;
}
if (kind == .instruction) {
try writer.writeAll(" => ");
} else {
try writer.writeAll(": ");
}
if (kind == .mask) {
try writer.writeByte('?');
}
try writer.writeAll("struct{");
for (parameters) |param, j| {
if (j != 0) {
try writer.writeAll(", ");
}
try renderFieldName(writer, parameters, j);
try writer.writeAll(": ");
if (param.quantifier) |q| {
switch (q) {
.@"?" => try writer.writeByte('?'),
.@"*" => try writer.writeAll("[]const "),
}
}
try writer.print("{}", .{std.zig.fmtId(param.kind)});
if (extended_structs.contains(param.kind)) {
try writer.writeAll(".Extended");
}
if (param.quantifier) |q| {
switch (q) {
.@"?" => try writer.writeAll(" = null"),
.@"*" => try writer.writeAll(" = &.{}"),
}
}
}
try writer.writeAll("}");
if (kind == .mask) {
try writer.writeAll(" = null");
}
try writer.writeAll(",\n");
}
fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void {
const operand = operands[field_index];
// Should be enough for all names - adjust as needed.
var name_buffer = std.BoundedArray(u8, 64){
.buffer = undefined,
};
derive_from_kind: {
// Operand names are often in the json encoded as "'Name'" (with two sets of quotes).
// Additionally, some operands have ~ in them at the end (D~ref~).
const name = std.mem.trim(u8, operand.name, "'~");
if (name.len == 0) {
break :derive_from_kind;
}
// Some names have weird characters in them (like newlines) - skip any such ones.
// Use the same loop to transform to snake-case.
for (name) |c| {
switch (c) {
'a'...'z', '0'...'9' => try name_buffer.append(c),
'A'...'Z' => try name_buffer.append(std.ascii.toLower(c)),
' ', '~' => try name_buffer.append('_'),
else => break :derive_from_kind,
}
}
// Assume there are no duplicate 'name' fields.
try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())});
return;
}
// Translate to snake case.
name_buffer.len = 0;
for (operand.kind) |c, i| {
switch (c) {
'a'...'z', '0'...'9' => try name_buffer.append(c),
'A'...'Z' => if (i > 0 and std.ascii.isLower(operand.kind[i - 1])) {
try name_buffer.appendSlice(&[_]u8{ '_', std.ascii.toLower(c) });
} else {
try name_buffer.append(std.ascii.toLower(c));
},
else => unreachable, // Assume that the name is valid C-syntax (and contains no underscores).
}
}
try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())});
// For fields derived from type name, there could be any amount.
// Simply check against all other fields, and if another similar one exists, add a number.
const need_extra_index = for (operands) |other_operand, i| {
if (i != field_index and std.mem.eql(u8, operand.kind, other_operand.kind)) {
break true;
}
} else false;
if (need_extra_index) {
try writer.print("_{}", .{field_index});
}
} }
fn parseHexInt(text: []const u8) !u31 { fn parseHexInt(text: []const u8) !u31 {
@ -142,7 +514,7 @@ fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn {
\\ \\
\\Generates Zig bindings for a SPIR-V specification .json (either core or \\Generates Zig bindings for a SPIR-V specification .json (either core or
\\extinst versions). The result, printed to stdout, should be used to update \\extinst versions). The result, printed to stdout, should be used to update
\\files in src/codegen/spirv. \\files in src/codegen/spirv. Don't forget to format the output.
\\ \\
\\The relevant specifications can be obtained from the SPIR-V registry: \\The relevant specifications can be obtained from the SPIR-V registry:
\\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/ \\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/