std.sort.block: add safety check for lessThan return value

This commit is contained in:
Ali Chraghi 2023-06-24 17:31:50 +03:30 committed by Andrew Kelley
parent 88284c124a
commit 6bd5479306
6 changed files with 45 additions and 21 deletions

View File

@ -1,5 +1,4 @@
const std = @import("std"); const std = @import("std");
const mem = std.mem;
pub const Token = struct { pub const Token = struct {
id: Id, id: Id,

View File

@ -9,18 +9,12 @@ const mem = std.mem;
/// You can pass `struct { []const u8 }` (only keys) tuples if `V` is `void`. /// You can pass `struct { []const u8 }` (only keys) tuples if `V` is `void`.
pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type { pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
const precomputed = comptime blk: { const precomputed = comptime blk: {
@setEvalBranchQuota(2000); @setEvalBranchQuota(1500);
const KV = struct { const KV = struct {
key: []const u8, key: []const u8,
value: V, value: V,
}; };
var sorted_kvs: [kvs_list.len]KV = undefined; var sorted_kvs: [kvs_list.len]KV = undefined;
const lenAsc = (struct {
fn lenAsc(context: void, a: KV, b: KV) bool {
_ = context;
return a.key.len < b.key.len;
}
}).lenAsc;
for (kvs_list, 0..) |kv, i| { for (kvs_list, 0..) |kv, i| {
if (V != void) { if (V != void) {
sorted_kvs[i] = .{ .key = kv.@"0", .value = kv.@"1" }; sorted_kvs[i] = .{ .key = kv.@"0", .value = kv.@"1" };
@ -28,7 +22,20 @@ pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
sorted_kvs[i] = .{ .key = kv.@"0", .value = {} }; sorted_kvs[i] = .{ .key = kv.@"0", .value = {} };
} }
} }
mem.sort(KV, &sorted_kvs, {}, lenAsc);
const SortContext = struct {
kvs: []KV,
pub fn lessThan(ctx: @This(), a: usize, b: usize) bool {
return ctx.kvs[a].key.len < ctx.kvs[b].key.len;
}
pub fn swap(ctx: @This(), a: usize, b: usize) void {
return std.mem.swap(KV, &ctx.kvs[a], &ctx.kvs[b]);
}
};
mem.sortUnstableContext(0, sorted_kvs.len, SortContext{ .kvs = &sorted_kvs });
const min_len = sorted_kvs[0].key.len; const min_len = sorted_kvs[0].key.len;
const max_len = sorted_kvs[sorted_kvs.len - 1].key.len; const max_len = sorted_kvs[sorted_kvs.len - 1].key.len;
var len_indexes: [max_len + 1]usize = undefined; var len_indexes: [max_len + 1]usize = undefined;

View File

@ -1289,10 +1289,6 @@ test "std.enums.ensureIndexer" {
}); });
} }
fn ascByValue(ctx: void, comptime a: EnumField, comptime b: EnumField) bool {
_ = ctx;
return a.value < b.value;
}
pub fn EnumIndexer(comptime E: type) type { pub fn EnumIndexer(comptime E: type) type {
if (!@typeInfo(E).Enum.is_exhaustive) { if (!@typeInfo(E).Enum.is_exhaustive) {
@compileError("Cannot create an enum indexer for a non-exhaustive enum."); @compileError("Cannot create an enum indexer for a non-exhaustive enum.");
@ -1300,7 +1296,10 @@ pub fn EnumIndexer(comptime E: type) type {
const const_fields = std.meta.fields(E); const const_fields = std.meta.fields(E);
var fields = const_fields[0..const_fields.len].*; var fields = const_fields[0..const_fields.len].*;
if (fields.len == 0) { const min = fields[0].value;
const max = fields[fields.len - 1].value;
const fields_len = fields.len;
if (fields_len == 0) {
return struct { return struct {
pub const Key = E; pub const Key = E;
pub const count: usize = 0; pub const count: usize = 0;
@ -1314,10 +1313,20 @@ pub fn EnumIndexer(comptime E: type) type {
} }
}; };
} }
std.mem.sort(EnumField, &fields, {}, ascByValue);
const min = fields[0].value; const SortContext = struct {
const max = fields[fields.len - 1].value; fields: []EnumField,
const fields_len = fields.len;
pub fn lessThan(comptime ctx: @This(), comptime a: usize, comptime b: usize) bool {
return ctx.fields[a].value < ctx.fields[b].value;
}
pub fn swap(comptime ctx: @This(), comptime a: usize, comptime b: usize) void {
return std.mem.swap(EnumField, &ctx.fields[a], &ctx.fields[b]);
}
};
std.sort.insertionContext(0, fields_len, SortContext{ .fields = &fields });
if (max - min == fields.len - 1) { if (max - min == fields.len - 1) {
return struct { return struct {
pub const Key = E; pub const Key = E;

View File

@ -366,7 +366,7 @@ test "sort with context in the middle of a slice" {
const slice = buf[0..case[0].len]; const slice = buf[0..case[0].len];
@memcpy(slice, case[0]); @memcpy(slice, case[0]);
sortFn(range.start, range.end, Context{ .items = slice }); sortFn(range.start, range.end, Context{ .items = slice });
try testing.expectEqualSlices(i32, slice[range.start..range.end], case[1][range.start..range.end]); try testing.expectEqualSlices(i32, case[1][range.start..range.end], slice[range.start..range.end]);
} }
} }
} }

View File

@ -1,3 +1,4 @@
const builtin = @import("builtin");
const std = @import("../std.zig"); const std = @import("../std.zig");
const sort = std.sort; const sort = std.sort;
const math = std.math; const math = std.math;
@ -100,8 +101,16 @@ pub fn block(
comptime T: type, comptime T: type,
items: []T, items: []T,
context: anytype, context: anytype,
comptime lessThan: fn (@TypeOf(context), lhs: T, rhs: T) bool, comptime lessThanFn: fn (@TypeOf(context), lhs: T, rhs: T) bool,
) void { ) void {
const lessThan = if (builtin.mode == .Debug) struct {
fn lessThan(ctx: @TypeOf(context), lhs: T, rhs: T) bool {
const lt = lessThanFn(ctx, lhs, rhs);
const gt = lessThanFn(ctx, rhs, lhs);
std.debug.assert(!(lt and gt));
return lt;
}
}.lessThan else lessThanFn;
// Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c // Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c
var cache: [512]T = undefined; var cache: [512]T = undefined;

View File

@ -767,7 +767,7 @@ fn estimateInstructionLength(prefix: Prefix, encoding: Encoding, ops: []const Op
} }
const mnemonic_to_encodings_map = init: { const mnemonic_to_encodings_map = init: {
@setEvalBranchQuota(30_000); @setEvalBranchQuota(50_000);
const encodings = @import("encodings.zig"); const encodings = @import("encodings.zig");
var entries = encodings.table; var entries = encodings.table;
std.mem.sort(encodings.Entry, &entries, {}, struct { std.mem.sort(encodings.Entry, &entries, {}, struct {