std.sort.block: add safety check for lessThan return value

This commit is contained in:
Ali Chraghi 2023-06-24 17:31:50 +03:30 committed by Andrew Kelley
parent 88284c124a
commit 6bd5479306
6 changed files with 45 additions and 21 deletions

View File

@ -1,5 +1,4 @@
const std = @import("std");
const mem = std.mem;
pub const Token = struct {
id: Id,

View File

@ -9,18 +9,12 @@ const mem = std.mem;
/// You can pass `struct { []const u8 }` (only keys) tuples if `V` is `void`.
pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
const precomputed = comptime blk: {
@setEvalBranchQuota(2000);
@setEvalBranchQuota(1500);
const KV = struct {
key: []const u8,
value: V,
};
var sorted_kvs: [kvs_list.len]KV = undefined;
const lenAsc = (struct {
fn lenAsc(context: void, a: KV, b: KV) bool {
_ = context;
return a.key.len < b.key.len;
}
}).lenAsc;
for (kvs_list, 0..) |kv, i| {
if (V != void) {
sorted_kvs[i] = .{ .key = kv.@"0", .value = kv.@"1" };
@ -28,7 +22,20 @@ pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
sorted_kvs[i] = .{ .key = kv.@"0", .value = {} };
}
}
mem.sort(KV, &sorted_kvs, {}, lenAsc);
const SortContext = struct {
kvs: []KV,
pub fn lessThan(ctx: @This(), a: usize, b: usize) bool {
return ctx.kvs[a].key.len < ctx.kvs[b].key.len;
}
pub fn swap(ctx: @This(), a: usize, b: usize) void {
return std.mem.swap(KV, &ctx.kvs[a], &ctx.kvs[b]);
}
};
mem.sortUnstableContext(0, sorted_kvs.len, SortContext{ .kvs = &sorted_kvs });
const min_len = sorted_kvs[0].key.len;
const max_len = sorted_kvs[sorted_kvs.len - 1].key.len;
var len_indexes: [max_len + 1]usize = undefined;

View File

@ -1289,10 +1289,6 @@ test "std.enums.ensureIndexer" {
});
}
fn ascByValue(ctx: void, comptime a: EnumField, comptime b: EnumField) bool {
_ = ctx;
return a.value < b.value;
}
pub fn EnumIndexer(comptime E: type) type {
if (!@typeInfo(E).Enum.is_exhaustive) {
@compileError("Cannot create an enum indexer for a non-exhaustive enum.");
@ -1300,7 +1296,10 @@ pub fn EnumIndexer(comptime E: type) type {
const const_fields = std.meta.fields(E);
var fields = const_fields[0..const_fields.len].*;
if (fields.len == 0) {
const min = fields[0].value;
const max = fields[fields.len - 1].value;
const fields_len = fields.len;
if (fields_len == 0) {
return struct {
pub const Key = E;
pub const count: usize = 0;
@ -1314,10 +1313,20 @@ pub fn EnumIndexer(comptime E: type) type {
}
};
}
std.mem.sort(EnumField, &fields, {}, ascByValue);
const min = fields[0].value;
const max = fields[fields.len - 1].value;
const fields_len = fields.len;
const SortContext = struct {
fields: []EnumField,
pub fn lessThan(comptime ctx: @This(), comptime a: usize, comptime b: usize) bool {
return ctx.fields[a].value < ctx.fields[b].value;
}
pub fn swap(comptime ctx: @This(), comptime a: usize, comptime b: usize) void {
return std.mem.swap(EnumField, &ctx.fields[a], &ctx.fields[b]);
}
};
std.sort.insertionContext(0, fields_len, SortContext{ .fields = &fields });
if (max - min == fields.len - 1) {
return struct {
pub const Key = E;

View File

@ -366,7 +366,7 @@ test "sort with context in the middle of a slice" {
const slice = buf[0..case[0].len];
@memcpy(slice, case[0]);
sortFn(range.start, range.end, Context{ .items = slice });
try testing.expectEqualSlices(i32, slice[range.start..range.end], case[1][range.start..range.end]);
try testing.expectEqualSlices(i32, case[1][range.start..range.end], slice[range.start..range.end]);
}
}
}

View File

@ -1,3 +1,4 @@
const builtin = @import("builtin");
const std = @import("../std.zig");
const sort = std.sort;
const math = std.math;
@ -100,8 +101,16 @@ pub fn block(
comptime T: type,
items: []T,
context: anytype,
comptime lessThan: fn (@TypeOf(context), lhs: T, rhs: T) bool,
comptime lessThanFn: fn (@TypeOf(context), lhs: T, rhs: T) bool,
) void {
const lessThan = if (builtin.mode == .Debug) struct {
fn lessThan(ctx: @TypeOf(context), lhs: T, rhs: T) bool {
const lt = lessThanFn(ctx, lhs, rhs);
const gt = lessThanFn(ctx, rhs, lhs);
std.debug.assert(!(lt and gt));
return lt;
}
}.lessThan else lessThanFn;
// Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c
var cache: [512]T = undefined;

View File

@ -767,7 +767,7 @@ fn estimateInstructionLength(prefix: Prefix, encoding: Encoding, ops: []const Op
}
const mnemonic_to_encodings_map = init: {
@setEvalBranchQuota(30_000);
@setEvalBranchQuota(50_000);
const encodings = @import("encodings.zig");
var entries = encodings.table;
std.mem.sort(encodings.Entry, &entries, {}, struct {