From d4d954abd20d57ff62940993f5b95700ebfbbda7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 3 Jun 2020 18:41:56 -0400 Subject: [PATCH 1/6] std.sort: give comparator functions a context parameter --- lib/std/comptime_string_map.zig | 48 +-- lib/std/http/headers.zig | 4 +- lib/std/net.zig | 4 +- lib/std/sort.zig | 570 ++++++++++++++++++++------------ 4 files changed, 379 insertions(+), 247 deletions(-) diff --git a/lib/std/comptime_string_map.zig b/lib/std/comptime_string_map.zig index 313f1fcdda..3021f6bc1e 100644 --- a/lib/std/comptime_string_map.zig +++ b/lib/std/comptime_string_map.zig @@ -17,18 +17,18 @@ pub fn ComptimeStringMap(comptime V: type, comptime kvs: var) type { }; var sorted_kvs: [kvs.len]KV = undefined; const lenAsc = (struct { - fn lenAsc(a: KV, b: KV) bool { + fn lenAsc(context: void, a: KV, b: KV) bool { return a.key.len < b.key.len; } }).lenAsc; for (kvs) |kv, i| { if (V != void) { - sorted_kvs[i] = .{.key = kv.@"0", .value = kv.@"1"}; + sorted_kvs[i] = .{ .key = kv.@"0", .value = kv.@"1" }; } else { - sorted_kvs[i] = .{.key = kv.@"0", .value = {}}; + sorted_kvs[i] = .{ .key = kv.@"0", .value = {} }; } } - std.sort.sort(KV, &sorted_kvs, lenAsc); + std.sort.sort(KV, &sorted_kvs, {}, lenAsc); const min_len = sorted_kvs[0].key.len; const max_len = sorted_kvs[sorted_kvs.len - 1].key.len; var len_indexes: [max_len + 1]usize = undefined; @@ -83,11 +83,11 @@ const TestEnum = enum { test "ComptimeStringMap list literal of list literals" { const map = ComptimeStringMap(TestEnum, .{ - .{"these", .D}, - .{"have", .A}, - .{"nothing", .B}, - .{"incommon", .C}, - .{"samelen", .E}, + .{ "these", .D }, + .{ "have", .A }, + .{ "nothing", .B }, + .{ "incommon", .C }, + .{ "samelen", .E }, }); testMap(map); @@ -99,11 +99,11 @@ test "ComptimeStringMap array of structs" { @"1": TestEnum, }; const map = ComptimeStringMap(TestEnum, [_]KV{ - .{.@"0" = "these", .@"1" = .D}, - .{.@"0" = "have", .@"1" = .A}, - .{.@"0" = "nothing", .@"1" = .B}, - .{.@"0" = "incommon", .@"1" = .C}, - .{.@"0" = "samelen", .@"1" = .E}, + .{ .@"0" = "these", .@"1" = .D }, + .{ .@"0" = "have", .@"1" = .A }, + .{ .@"0" = "nothing", .@"1" = .B }, + .{ .@"0" = "incommon", .@"1" = .C }, + .{ .@"0" = "samelen", .@"1" = .E }, }); testMap(map); @@ -115,11 +115,11 @@ test "ComptimeStringMap slice of structs" { @"1": TestEnum, }; const slice: []const KV = &[_]KV{ - .{.@"0" = "these", .@"1" = .D}, - .{.@"0" = "have", .@"1" = .A}, - .{.@"0" = "nothing", .@"1" = .B}, - .{.@"0" = "incommon", .@"1" = .C}, - .{.@"0" = "samelen", .@"1" = .E}, + .{ .@"0" = "these", .@"1" = .D }, + .{ .@"0" = "have", .@"1" = .A }, + .{ .@"0" = "nothing", .@"1" = .B }, + .{ .@"0" = "incommon", .@"1" = .C }, + .{ .@"0" = "samelen", .@"1" = .E }, }; const map = ComptimeStringMap(TestEnum, slice); @@ -142,11 +142,11 @@ test "ComptimeStringMap void value type, slice of structs" { @"0": []const u8, }; const slice: []const KV = &[_]KV{ - .{.@"0" = "these"}, - .{.@"0" = "have"}, - .{.@"0" = "nothing"}, - .{.@"0" = "incommon"}, - .{.@"0" = "samelen"}, + .{ .@"0" = "these" }, + .{ .@"0" = "have" }, + .{ .@"0" = "nothing" }, + .{ .@"0" = "incommon" }, + .{ .@"0" = "samelen" }, }; const map = ComptimeStringMap(void, slice); diff --git a/lib/std/http/headers.zig b/lib/std/http/headers.zig index 96a1d2a68d..ba929a446c 100644 --- a/lib/std/http/headers.zig +++ b/lib/std/http/headers.zig @@ -58,7 +58,7 @@ const HeaderEntry = struct { self.never_index = never_index orelse never_index_default(self.name); } - fn compare(a: HeaderEntry, b: HeaderEntry) bool { + fn compare(context: void, a: HeaderEntry, b: HeaderEntry) bool { if (a.name.ptr != b.name.ptr and a.name.len != b.name.len) { // Things beginning with a colon *must* be before others const a_is_colon = a.name[0] == ':'; @@ -342,7 +342,7 @@ pub const Headers = struct { } pub fn sort(self: *Self) void { - std.sort.sort(HeaderEntry, self.data.items, HeaderEntry.compare); + std.sort.sort(HeaderEntry, self.data.items, {}, HeaderEntry.compare); self.rebuild_index(); } diff --git a/lib/std/net.zig b/lib/std/net.zig index 919175e41d..229731b617 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -836,7 +836,7 @@ fn linuxLookupName( key |= (MAXADDRS - @intCast(i32, i)) << DAS_ORDER_SHIFT; addr.sortkey = key; } - std.sort.sort(LookupAddr, addrs.span(), addrCmpLessThan); + std.sort.sort(LookupAddr, addrs.span(), {}, addrCmpLessThan); } const Policy = struct { @@ -953,7 +953,7 @@ fn IN6_IS_ADDR_SITELOCAL(a: [16]u8) bool { } // Parameters `b` and `a` swapped to make this descending. -fn addrCmpLessThan(b: LookupAddr, a: LookupAddr) bool { +fn addrCmpLessThan(context: void, b: LookupAddr, a: LookupAddr) bool { return a.sortkey < b.sortkey; } diff --git a/lib/std/sort.zig b/lib/std/sort.zig index 8ed8f2c1c0..cb6162e9b0 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -5,7 +5,13 @@ const mem = std.mem; const math = std.math; const builtin = @import("builtin"); -pub fn binarySearch(comptime T: type, key: T, items: []const T, comptime compareFn: fn (lhs: T, rhs: T) math.Order) ?usize { +pub fn binarySearch( + comptime T: type, + key: T, + items: []const T, + context: var, + comptime compareFn: fn (context: @TypeOf(context), lhs: T, rhs: T) math.Order, +) ?usize { var left: usize = 0; var right: usize = items.len; @@ -13,7 +19,7 @@ pub fn binarySearch(comptime T: type, key: T, items: []const T, comptime compare // Avoid overflowing in the midpoint calculation const mid = left + (right - left) / 2; // Compare the key with the midpoint element - switch (compareFn(key, items[mid])) { + switch (compareFn(context, key, items[mid])) { .eq => return mid, .gt => left = mid + 1, .lt => right = mid, @@ -23,56 +29,61 @@ pub fn binarySearch(comptime T: type, key: T, items: []const T, comptime compare return null; } -test "std.sort.binarySearch" { +test "binarySearch" { const S = struct { - fn order_u32(lhs: u32, rhs: u32) math.Order { + fn order_u32(context: void, lhs: u32, rhs: u32) math.Order { return math.order(lhs, rhs); } - fn order_i32(lhs: i32, rhs: i32) math.Order { + fn order_i32(context: void, lhs: i32, rhs: i32) math.Order { return math.order(lhs, rhs); } }; testing.expectEqual( @as(?usize, null), - binarySearch(u32, 1, &[_]u32{}, S.order_u32), + binarySearch(u32, 1, &[_]u32{}, {}, S.order_u32), ); testing.expectEqual( @as(?usize, 0), - binarySearch(u32, 1, &[_]u32{1}, S.order_u32), + binarySearch(u32, 1, &[_]u32{1}, {}, S.order_u32), ); testing.expectEqual( @as(?usize, null), - binarySearch(u32, 1, &[_]u32{0}, S.order_u32), + binarySearch(u32, 1, &[_]u32{0}, {}, S.order_u32), ); testing.expectEqual( @as(?usize, null), - binarySearch(u32, 0, &[_]u32{1}, S.order_u32), + binarySearch(u32, 0, &[_]u32{1}, {}, S.order_u32), ); testing.expectEqual( @as(?usize, 4), - binarySearch(u32, 5, &[_]u32{ 1, 2, 3, 4, 5 }, S.order_u32), + binarySearch(u32, 5, &[_]u32{ 1, 2, 3, 4, 5 }, {}, S.order_u32), ); testing.expectEqual( @as(?usize, 0), - binarySearch(u32, 2, &[_]u32{ 2, 4, 8, 16, 32, 64 }, S.order_u32), + binarySearch(u32, 2, &[_]u32{ 2, 4, 8, 16, 32, 64 }, {}, S.order_u32), ); testing.expectEqual( @as(?usize, 1), - binarySearch(i32, -4, &[_]i32{ -7, -4, 0, 9, 10 }, S.order_i32), + binarySearch(i32, -4, &[_]i32{ -7, -4, 0, 9, 10 }, {}, S.order_i32), ); testing.expectEqual( @as(?usize, 3), - binarySearch(i32, 98, &[_]i32{ -100, -25, 2, 98, 99, 100 }, S.order_i32), + binarySearch(i32, 98, &[_]i32{ -100, -25, 2, 98, 99, 100 }, {}, S.order_i32), ); } /// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case. O(1) memory (no allocator required). -pub fn insertionSort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) void { +pub fn insertionSort( + comptime T: type, + items: []T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) void { var i: usize = 1; while (i < items.len) : (i += 1) { const x = items[i]; var j: usize = i; - while (j > 0 and lessThan(x, items[j - 1])) : (j -= 1) { + while (j > 0 and lessThan(context, x, items[j - 1])) : (j -= 1) { items[j] = items[j - 1]; } items[j] = x; @@ -168,20 +179,25 @@ const Pull = struct { /// Stable in-place sort. O(n) best case, O(n*log(n)) worst case and average case. O(1) memory (no allocator required). /// Currently implemented as block sort. -pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) void { +pub fn sort( + comptime T: type, + items: []T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) void { // Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c var cache: [512]T = undefined; if (items.len < 4) { if (items.len == 3) { // hard coded insertion sort - if (lessThan(items[1], items[0])) mem.swap(T, &items[0], &items[1]); - if (lessThan(items[2], items[1])) { + if (lessThan(context, items[1], items[0])) mem.swap(T, &items[0], &items[1]); + if (lessThan(context, items[2], items[1])) { mem.swap(T, &items[1], &items[2]); - if (lessThan(items[1], items[0])) mem.swap(T, &items[0], &items[1]); + if (lessThan(context, items[1], items[0])) mem.swap(T, &items[0], &items[1]); } } else if (items.len == 2) { - if (lessThan(items[1], items[0])) mem.swap(T, &items[0], &items[1]); + if (lessThan(context, items[1], items[0])) mem.swap(T, &items[0], &items[1]); } return; } @@ -197,75 +213,75 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo const sliced_items = items[range.start..]; switch (range.length()) { 8 => { - swap(T, sliced_items, lessThan, &order, 0, 1); - swap(T, sliced_items, lessThan, &order, 2, 3); - swap(T, sliced_items, lessThan, &order, 4, 5); - swap(T, sliced_items, lessThan, &order, 6, 7); - swap(T, sliced_items, lessThan, &order, 0, 2); - swap(T, sliced_items, lessThan, &order, 1, 3); - swap(T, sliced_items, lessThan, &order, 4, 6); - swap(T, sliced_items, lessThan, &order, 5, 7); - swap(T, sliced_items, lessThan, &order, 1, 2); - swap(T, sliced_items, lessThan, &order, 5, 6); - swap(T, sliced_items, lessThan, &order, 0, 4); - swap(T, sliced_items, lessThan, &order, 3, 7); - swap(T, sliced_items, lessThan, &order, 1, 5); - swap(T, sliced_items, lessThan, &order, 2, 6); - swap(T, sliced_items, lessThan, &order, 1, 4); - swap(T, sliced_items, lessThan, &order, 3, 6); - swap(T, sliced_items, lessThan, &order, 2, 4); - swap(T, sliced_items, lessThan, &order, 3, 5); - swap(T, sliced_items, lessThan, &order, 3, 4); + swap(T, sliced_items, context, lessThan, &order, 0, 1); + swap(T, sliced_items, context, lessThan, &order, 2, 3); + swap(T, sliced_items, context, lessThan, &order, 4, 5); + swap(T, sliced_items, context, lessThan, &order, 6, 7); + swap(T, sliced_items, context, lessThan, &order, 0, 2); + swap(T, sliced_items, context, lessThan, &order, 1, 3); + swap(T, sliced_items, context, lessThan, &order, 4, 6); + swap(T, sliced_items, context, lessThan, &order, 5, 7); + swap(T, sliced_items, context, lessThan, &order, 1, 2); + swap(T, sliced_items, context, lessThan, &order, 5, 6); + swap(T, sliced_items, context, lessThan, &order, 0, 4); + swap(T, sliced_items, context, lessThan, &order, 3, 7); + swap(T, sliced_items, context, lessThan, &order, 1, 5); + swap(T, sliced_items, context, lessThan, &order, 2, 6); + swap(T, sliced_items, context, lessThan, &order, 1, 4); + swap(T, sliced_items, context, lessThan, &order, 3, 6); + swap(T, sliced_items, context, lessThan, &order, 2, 4); + swap(T, sliced_items, context, lessThan, &order, 3, 5); + swap(T, sliced_items, context, lessThan, &order, 3, 4); }, 7 => { - swap(T, sliced_items, lessThan, &order, 1, 2); - swap(T, sliced_items, lessThan, &order, 3, 4); - swap(T, sliced_items, lessThan, &order, 5, 6); - swap(T, sliced_items, lessThan, &order, 0, 2); - swap(T, sliced_items, lessThan, &order, 3, 5); - swap(T, sliced_items, lessThan, &order, 4, 6); - swap(T, sliced_items, lessThan, &order, 0, 1); - swap(T, sliced_items, lessThan, &order, 4, 5); - swap(T, sliced_items, lessThan, &order, 2, 6); - swap(T, sliced_items, lessThan, &order, 0, 4); - swap(T, sliced_items, lessThan, &order, 1, 5); - swap(T, sliced_items, lessThan, &order, 0, 3); - swap(T, sliced_items, lessThan, &order, 2, 5); - swap(T, sliced_items, lessThan, &order, 1, 3); - swap(T, sliced_items, lessThan, &order, 2, 4); - swap(T, sliced_items, lessThan, &order, 2, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 2); + swap(T, sliced_items, context, lessThan, &order, 3, 4); + swap(T, sliced_items, context, lessThan, &order, 5, 6); + swap(T, sliced_items, context, lessThan, &order, 0, 2); + swap(T, sliced_items, context, lessThan, &order, 3, 5); + swap(T, sliced_items, context, lessThan, &order, 4, 6); + swap(T, sliced_items, context, lessThan, &order, 0, 1); + swap(T, sliced_items, context, lessThan, &order, 4, 5); + swap(T, sliced_items, context, lessThan, &order, 2, 6); + swap(T, sliced_items, context, lessThan, &order, 0, 4); + swap(T, sliced_items, context, lessThan, &order, 1, 5); + swap(T, sliced_items, context, lessThan, &order, 0, 3); + swap(T, sliced_items, context, lessThan, &order, 2, 5); + swap(T, sliced_items, context, lessThan, &order, 1, 3); + swap(T, sliced_items, context, lessThan, &order, 2, 4); + swap(T, sliced_items, context, lessThan, &order, 2, 3); }, 6 => { - swap(T, sliced_items, lessThan, &order, 1, 2); - swap(T, sliced_items, lessThan, &order, 4, 5); - swap(T, sliced_items, lessThan, &order, 0, 2); - swap(T, sliced_items, lessThan, &order, 3, 5); - swap(T, sliced_items, lessThan, &order, 0, 1); - swap(T, sliced_items, lessThan, &order, 3, 4); - swap(T, sliced_items, lessThan, &order, 2, 5); - swap(T, sliced_items, lessThan, &order, 0, 3); - swap(T, sliced_items, lessThan, &order, 1, 4); - swap(T, sliced_items, lessThan, &order, 2, 4); - swap(T, sliced_items, lessThan, &order, 1, 3); - swap(T, sliced_items, lessThan, &order, 2, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 2); + swap(T, sliced_items, context, lessThan, &order, 4, 5); + swap(T, sliced_items, context, lessThan, &order, 0, 2); + swap(T, sliced_items, context, lessThan, &order, 3, 5); + swap(T, sliced_items, context, lessThan, &order, 0, 1); + swap(T, sliced_items, context, lessThan, &order, 3, 4); + swap(T, sliced_items, context, lessThan, &order, 2, 5); + swap(T, sliced_items, context, lessThan, &order, 0, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 4); + swap(T, sliced_items, context, lessThan, &order, 2, 4); + swap(T, sliced_items, context, lessThan, &order, 1, 3); + swap(T, sliced_items, context, lessThan, &order, 2, 3); }, 5 => { - swap(T, sliced_items, lessThan, &order, 0, 1); - swap(T, sliced_items, lessThan, &order, 3, 4); - swap(T, sliced_items, lessThan, &order, 2, 4); - swap(T, sliced_items, lessThan, &order, 2, 3); - swap(T, sliced_items, lessThan, &order, 1, 4); - swap(T, sliced_items, lessThan, &order, 0, 3); - swap(T, sliced_items, lessThan, &order, 0, 2); - swap(T, sliced_items, lessThan, &order, 1, 3); - swap(T, sliced_items, lessThan, &order, 1, 2); + swap(T, sliced_items, context, lessThan, &order, 0, 1); + swap(T, sliced_items, context, lessThan, &order, 3, 4); + swap(T, sliced_items, context, lessThan, &order, 2, 4); + swap(T, sliced_items, context, lessThan, &order, 2, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 4); + swap(T, sliced_items, context, lessThan, &order, 0, 3); + swap(T, sliced_items, context, lessThan, &order, 0, 2); + swap(T, sliced_items, context, lessThan, &order, 1, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 2); }, 4 => { - swap(T, sliced_items, lessThan, &order, 0, 1); - swap(T, sliced_items, lessThan, &order, 2, 3); - swap(T, sliced_items, lessThan, &order, 0, 2); - swap(T, sliced_items, lessThan, &order, 1, 3); - swap(T, sliced_items, lessThan, &order, 1, 2); + swap(T, sliced_items, context, lessThan, &order, 0, 1); + swap(T, sliced_items, context, lessThan, &order, 2, 3); + swap(T, sliced_items, context, lessThan, &order, 0, 2); + swap(T, sliced_items, context, lessThan, &order, 1, 3); + swap(T, sliced_items, context, lessThan, &order, 1, 2); }, else => {}, } @@ -288,16 +304,16 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo var A2 = iterator.nextRange(); var B2 = iterator.nextRange(); - if (lessThan(items[B1.end - 1], items[A1.start])) { + if (lessThan(context, items[B1.end - 1], items[A1.start])) { // the two ranges are in reverse order, so copy them in reverse order into the cache mem.copy(T, cache[B1.length()..], items[A1.start..A1.end]); mem.copy(T, cache[0..], items[B1.start..B1.end]); - } else if (lessThan(items[B1.start], items[A1.end - 1])) { + } else if (lessThan(context, items[B1.start], items[A1.end - 1])) { // these two ranges weren't already in order, so merge them into the cache - mergeInto(T, items, A1, B1, lessThan, cache[0..]); + mergeInto(T, items, A1, B1, context, lessThan, cache[0..]); } else { // if A1, B1, A2, and B2 are all in order, skip doing anything else - if (!lessThan(items[B2.start], items[A2.end - 1]) and !lessThan(items[A2.start], items[B1.end - 1])) continue; + if (!lessThan(context, items[B2.start], items[A2.end - 1]) and !lessThan(context, items[A2.start], items[B1.end - 1])) continue; // copy A1 and B1 into the cache in the same order mem.copy(T, cache[0..], items[A1.start..A1.end]); @@ -306,13 +322,13 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo A1 = Range.init(A1.start, B1.end); // merge A2 and B2 into the cache - if (lessThan(items[B2.end - 1], items[A2.start])) { + if (lessThan(context, items[B2.end - 1], items[A2.start])) { // the two ranges are in reverse order, so copy them in reverse order into the cache mem.copy(T, cache[A1.length() + B2.length() ..], items[A2.start..A2.end]); mem.copy(T, cache[A1.length()..], items[B2.start..B2.end]); - } else if (lessThan(items[B2.start], items[A2.end - 1])) { + } else if (lessThan(context, items[B2.start], items[A2.end - 1])) { // these two ranges weren't already in order, so merge them into the cache - mergeInto(T, items, A2, B2, lessThan, cache[A1.length()..]); + mergeInto(T, items, A2, B2, context, lessThan, cache[A1.length()..]); } else { // copy A2 and B2 into the cache in the same order mem.copy(T, cache[A1.length()..], items[A2.start..A2.end]); @@ -324,13 +340,13 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo const A3 = Range.init(0, A1.length()); const B3 = Range.init(A1.length(), A1.length() + A2.length()); - if (lessThan(cache[B3.end - 1], cache[A3.start])) { + if (lessThan(context, cache[B3.end - 1], cache[A3.start])) { // the two ranges are in reverse order, so copy them in reverse order into the items mem.copy(T, items[A1.start + A2.length() ..], cache[A3.start..A3.end]); mem.copy(T, items[A1.start..], cache[B3.start..B3.end]); - } else if (lessThan(cache[B3.start], cache[A3.end - 1])) { + } else if (lessThan(context, cache[B3.start], cache[A3.end - 1])) { // these two ranges weren't already in order, so merge them back into the items - mergeInto(T, cache[0..], A3, B3, lessThan, items[A1.start..]); + mergeInto(T, cache[0..], A3, B3, context, lessThan, items[A1.start..]); } else { // copy A3 and B3 into the items in the same order mem.copy(T, items[A1.start..], cache[A3.start..A3.end]); @@ -347,13 +363,13 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo var A = iterator.nextRange(); var B = iterator.nextRange(); - if (lessThan(items[B.end - 1], items[A.start])) { + if (lessThan(context, items[B.end - 1], items[A.start])) { // the two ranges are in reverse order, so a simple rotation should fix it mem.rotate(T, items[A.start..B.end], A.length()); - } else if (lessThan(items[B.start], items[A.end - 1])) { + } else if (lessThan(context, items[B.start], items[A.end - 1])) { // these two ranges weren't already in order, so we'll need to merge them! mem.copy(T, cache[0..], items[A.start..A.end]); - mergeExternal(T, items, A, B, lessThan, cache[0..]); + mergeExternal(T, items, A, B, context, lessThan, cache[0..]); } } } @@ -435,7 +451,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo last = index; count += 1; }) { - index = findLastForward(T, items, items[last], Range.init(last + 1, A.end), lessThan, find - count); + index = findLastForward(T, items, items[last], Range.init(last + 1, A.end), context, lessThan, find - count); if (index == A.end) break; } index = last; @@ -493,7 +509,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo last = index - 1; count += 1; }) { - index = findFirstBackward(T, items, items[last], Range.init(B.start, last), lessThan, find - count); + index = findFirstBackward(T, items, items[last], Range.init(B.start, last), context, lessThan, find - count); if (index == B.start) break; } index = last; @@ -558,7 +574,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo index = pull[pull_index].from; count = 1; while (count < length) : (count += 1) { - index = findFirstBackward(T, items, items[index - 1], Range.init(pull[pull_index].to, pull[pull_index].from - (count - 1)), lessThan, length - count); + index = findFirstBackward(T, items, items[index - 1], Range.init(pull[pull_index].to, pull[pull_index].from - (count - 1)), context, lessThan, length - count); const range = Range.init(index + 1, pull[pull_index].from + 1); mem.rotate(T, items[range.start..range.end], range.length() - count); pull[pull_index].from = index + count; @@ -568,7 +584,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo index = pull[pull_index].from + 1; count = 1; while (count < length) : (count += 1) { - index = findLastForward(T, items, items[index], Range.init(index, pull[pull_index].to), lessThan, length - count); + index = findLastForward(T, items, items[index], Range.init(index, pull[pull_index].to), context, lessThan, length - count); const range = Range.init(pull[pull_index].from, index - 1); mem.rotate(T, items[range.start..range.end], count); pull[pull_index].from = index - 1 - count; @@ -615,10 +631,10 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo } } - if (lessThan(items[B.end - 1], items[A.start])) { + if (lessThan(context, items[B.end - 1], items[A.start])) { // the two ranges are in reverse order, so a simple rotation should fix it mem.rotate(T, items[A.start..B.end], A.length()); - } else if (lessThan(items[A.end], items[A.end - 1])) { + } else if (lessThan(context, items[A.end], items[A.end - 1])) { // these two ranges weren't already in order, so we'll need to merge them! var findA: usize = undefined; @@ -656,16 +672,16 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo while (true) { // if there's a previous B block and the first value of the minimum A block is <= the last value of the previous B block, // then drop that minimum A block behind. or if there are no B blocks left then keep dropping the remaining A blocks. - if ((lastB.length() > 0 and !lessThan(items[lastB.end - 1], items[indexA])) or blockB.length() == 0) { + if ((lastB.length() > 0 and !lessThan(context, items[lastB.end - 1], items[indexA])) or blockB.length() == 0) { // figure out where to split the previous B block, and rotate it at the split - const B_split = binaryFirst(T, items, items[indexA], lastB, lessThan); + const B_split = binaryFirst(T, items, items[indexA], lastB, context, lessThan); const B_remaining = lastB.end - B_split; // swap the minimum A block to the beginning of the rolling A blocks var minA = blockA.start; findA = minA + block_size; while (findA < blockA.end) : (findA += block_size) { - if (lessThan(items[findA], items[minA])) { + if (lessThan(context, items[findA], items[minA])) { minA = findA; } } @@ -681,11 +697,11 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo // or failing that we'll use a strictly in-place merge algorithm (MergeInPlace) if (lastA.length() <= cache.len) { - mergeExternal(T, items, lastA, Range.init(lastA.end, B_split), lessThan, cache[0..]); + mergeExternal(T, items, lastA, Range.init(lastA.end, B_split), context, lessThan, cache[0..]); } else if (buffer2.length() > 0) { - mergeInternal(T, items, lastA, Range.init(lastA.end, B_split), lessThan, buffer2); + mergeInternal(T, items, lastA, Range.init(lastA.end, B_split), context, lessThan, buffer2); } else { - mergeInPlace(T, items, lastA, Range.init(lastA.end, B_split), lessThan); + mergeInPlace(T, items, lastA, Range.init(lastA.end, B_split), context, lessThan); } if (buffer2.length() > 0 or block_size <= cache.len) { @@ -741,11 +757,11 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo // merge the last A block with the remaining B values if (lastA.length() <= cache.len) { - mergeExternal(T, items, lastA, Range.init(lastA.end, B.end), lessThan, cache[0..]); + mergeExternal(T, items, lastA, Range.init(lastA.end, B.end), context, lessThan, cache[0..]); } else if (buffer2.length() > 0) { - mergeInternal(T, items, lastA, Range.init(lastA.end, B.end), lessThan, buffer2); + mergeInternal(T, items, lastA, Range.init(lastA.end, B.end), context, lessThan, buffer2); } else { - mergeInPlace(T, items, lastA, Range.init(lastA.end, B.end), lessThan); + mergeInPlace(T, items, lastA, Range.init(lastA.end, B.end), context, lessThan); } } } @@ -755,7 +771,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo // while an unstable sort like quicksort could be applied here, in benchmarks it was consistently slightly slower than a simple insertion sort, // even for tens of millions of items. this may be because insertion sort is quite fast when the data is already somewhat sorted, like it is here - insertionSort(T, items[buffer2.start..buffer2.end], lessThan); + insertionSort(T, items[buffer2.start..buffer2.end], context, lessThan); pull_index = 0; while (pull_index < 2) : (pull_index += 1) { @@ -764,7 +780,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo // the values were pulled out to the left, so redistribute them back to the right var buffer = Range.init(pull[pull_index].range.start, pull[pull_index].range.start + pull[pull_index].count); while (buffer.length() > 0) { - index = findFirstForward(T, items, items[buffer.start], Range.init(buffer.end, pull[pull_index].range.end), lessThan, unique); + index = findFirstForward(T, items, items[buffer.start], Range.init(buffer.end, pull[pull_index].range.end), context, lessThan, unique); const amount = index - buffer.end; mem.rotate(T, items[buffer.start..index], buffer.length()); buffer.start += (amount + 1); @@ -775,7 +791,7 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo // the values were pulled out to the right, so redistribute them back to the left var buffer = Range.init(pull[pull_index].range.end - pull[pull_index].count, pull[pull_index].range.end); while (buffer.length() > 0) { - index = findLastBackward(T, items, items[buffer.end - 1], Range.init(pull[pull_index].range.start, buffer.start), lessThan, unique); + index = findLastBackward(T, items, items[buffer.end - 1], Range.init(pull[pull_index].range.start, buffer.start), context, lessThan, unique); const amount = buffer.start - index; mem.rotate(T, items[index..buffer.end], amount); buffer.start -= amount; @@ -792,7 +808,14 @@ pub fn sort(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool) vo } // merge operation without a buffer -fn mergeInPlace(comptime T: type, items: []T, A_arg: Range, B_arg: Range, lessThan: fn (T, T) bool) void { +fn mergeInPlace( + comptime T: type, + items: []T, + A_arg: Range, + B_arg: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, +) void { if (A_arg.length() == 0 or B_arg.length() == 0) return; // this just repeatedly binary searches into B and rotates A into position. @@ -818,7 +841,7 @@ fn mergeInPlace(comptime T: type, items: []T, A_arg: Range, B_arg: Range, lessTh while (true) { // find the first place in B where the first item in A needs to be inserted - const mid = binaryFirst(T, items, items[A.start], B, lessThan); + const mid = binaryFirst(T, items, items[A.start], B, context, lessThan); // rotate A into place const amount = mid - A.end; @@ -828,13 +851,21 @@ fn mergeInPlace(comptime T: type, items: []T, A_arg: Range, B_arg: Range, lessTh // calculate the new A and B ranges B.start = mid; A = Range.init(A.start + amount, B.start); - A.start = binaryLast(T, items, items[A.start], A, lessThan); + A.start = binaryLast(T, items, items[A.start], A, context, lessThan); if (A.length() == 0) break; } } // merge operation using an internal buffer -fn mergeInternal(comptime T: type, items: []T, A: Range, B: Range, lessThan: fn (T, T) bool, buffer: Range) void { +fn mergeInternal( + comptime T: type, + items: []T, + A: Range, + B: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + buffer: Range, +) void { // whenever we find a value to add to the final array, swap it with the value that's already in that spot // when this algorithm is finished, 'buffer' will contain its original contents, but in a different order var A_count: usize = 0; @@ -843,7 +874,7 @@ fn mergeInternal(comptime T: type, items: []T, A: Range, B: Range, lessThan: fn if (B.length() > 0 and A.length() > 0) { while (true) { - if (!lessThan(items[B.start + B_count], items[buffer.start + A_count])) { + if (!lessThan(context, items[B.start + B_count], items[buffer.start + A_count])) { mem.swap(T, &items[A.start + insert], &items[buffer.start + A_count]); A_count += 1; insert += 1; @@ -870,63 +901,102 @@ fn blockSwap(comptime T: type, items: []T, start1: usize, start2: usize, block_s // combine a linear search with a binary search to reduce the number of comparisons in situations // where have some idea as to how many unique values there are and where the next value might be -fn findFirstForward(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool, unique: usize) usize { +fn findFirstForward( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + unique: usize, +) usize { if (range.length() == 0) return range.start; const skip = math.max(range.length() / unique, @as(usize, 1)); var index = range.start + skip; - while (lessThan(items[index - 1], value)) : (index += skip) { + while (lessThan(context, items[index - 1], value)) : (index += skip) { if (index >= range.end - skip) { - return binaryFirst(T, items, value, Range.init(index, range.end), lessThan); + return binaryFirst(T, items, value, Range.init(index, range.end), context, lessThan); } } - return binaryFirst(T, items, value, Range.init(index - skip, index), lessThan); + return binaryFirst(T, items, value, Range.init(index - skip, index), context, lessThan); } -fn findFirstBackward(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool, unique: usize) usize { +fn findFirstBackward( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + unique: usize, +) usize { if (range.length() == 0) return range.start; const skip = math.max(range.length() / unique, @as(usize, 1)); var index = range.end - skip; - while (index > range.start and !lessThan(items[index - 1], value)) : (index -= skip) { + while (index > range.start and !lessThan(context, items[index - 1], value)) : (index -= skip) { if (index < range.start + skip) { - return binaryFirst(T, items, value, Range.init(range.start, index), lessThan); + return binaryFirst(T, items, value, Range.init(range.start, index), context, lessThan); } } - return binaryFirst(T, items, value, Range.init(index, index + skip), lessThan); + return binaryFirst(T, items, value, Range.init(index, index + skip), context, lessThan); } -fn findLastForward(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool, unique: usize) usize { +fn findLastForward( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + unique: usize, +) usize { if (range.length() == 0) return range.start; const skip = math.max(range.length() / unique, @as(usize, 1)); var index = range.start + skip; - while (!lessThan(value, items[index - 1])) : (index += skip) { + while (!lessThan(context, value, items[index - 1])) : (index += skip) { if (index >= range.end - skip) { - return binaryLast(T, items, value, Range.init(index, range.end), lessThan); + return binaryLast(T, items, value, Range.init(index, range.end), context, lessThan); } } - return binaryLast(T, items, value, Range.init(index - skip, index), lessThan); + return binaryLast(T, items, value, Range.init(index - skip, index), context, lessThan); } -fn findLastBackward(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool, unique: usize) usize { +fn findLastBackward( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + unique: usize, +) usize { if (range.length() == 0) return range.start; const skip = math.max(range.length() / unique, @as(usize, 1)); var index = range.end - skip; - while (index > range.start and lessThan(value, items[index - 1])) : (index -= skip) { + while (index > range.start and lessThan(context, value, items[index - 1])) : (index -= skip) { if (index < range.start + skip) { - return binaryLast(T, items, value, Range.init(range.start, index), lessThan); + return binaryLast(T, items, value, Range.init(range.start, index), context, lessThan); } } - return binaryLast(T, items, value, Range.init(index, index + skip), lessThan); + return binaryLast(T, items, value, Range.init(index, index + skip), context, lessThan); } -fn binaryFirst(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool) usize { +fn binaryFirst( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, +) usize { var curr = range.start; var size = range.length(); if (range.start >= range.end) return range.end; @@ -935,14 +1005,21 @@ fn binaryFirst(comptime T: type, items: []T, value: T, range: Range, lessThan: f size /= 2; const mid = items[curr + size]; - if (lessThan(mid, value)) { + if (lessThan(context, mid, value)) { curr += size + offset; } } return curr; } -fn binaryLast(comptime T: type, items: []T, value: T, range: Range, lessThan: fn (T, T) bool) usize { +fn binaryLast( + comptime T: type, + items: []T, + value: T, + range: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, +) usize { var curr = range.start; var size = range.length(); if (range.start >= range.end) return range.end; @@ -951,14 +1028,22 @@ fn binaryLast(comptime T: type, items: []T, value: T, range: Range, lessThan: fn size /= 2; const mid = items[curr + size]; - if (!lessThan(value, mid)) { + if (!lessThan(context, value, mid)) { curr += size + offset; } } return curr; } -fn mergeInto(comptime T: type, from: []T, A: Range, B: Range, lessThan: fn (T, T) bool, into: []T) void { +fn mergeInto( + comptime T: type, + from: []T, + A: Range, + B: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + into: []T, +) void { var A_index: usize = A.start; var B_index: usize = B.start; const A_last = A.end; @@ -966,7 +1051,7 @@ fn mergeInto(comptime T: type, from: []T, A: Range, B: Range, lessThan: fn (T, T var insert_index: usize = 0; while (true) { - if (!lessThan(from[B_index], from[A_index])) { + if (!lessThan(context, from[B_index], from[A_index])) { into[insert_index] = from[A_index]; A_index += 1; insert_index += 1; @@ -988,7 +1073,15 @@ fn mergeInto(comptime T: type, from: []T, A: Range, B: Range, lessThan: fn (T, T } } -fn mergeExternal(comptime T: type, items: []T, A: Range, B: Range, lessThan: fn (T, T) bool, cache: []T) void { +fn mergeExternal( + comptime T: type, + items: []T, + A: Range, + B: Range, + context: var, + comptime lessThan: fn (@TypeOf(context), T, T) bool, + cache: []T, +) void { // A fits into the cache, so use that instead of the internal buffer var A_index: usize = 0; var B_index: usize = B.start; @@ -998,7 +1091,7 @@ fn mergeExternal(comptime T: type, items: []T, A: Range, B: Range, lessThan: fn if (B.length() > 0 and A.length() > 0) { while (true) { - if (!lessThan(items[B_index], cache[A_index])) { + if (!lessThan(context, items[B_index], cache[A_index])) { items[insert_index] = cache[A_index]; A_index += 1; insert_index += 1; @@ -1016,17 +1109,25 @@ fn mergeExternal(comptime T: type, items: []T, A: Range, B: Range, lessThan: fn mem.copy(T, items[insert_index..], cache[A_index..A_last]); } -fn swap(comptime T: type, items: []T, lessThan: fn (lhs: T, rhs: T) bool, order: *[8]u8, x: usize, y: usize) void { - if (lessThan(items[y], items[x]) or ((order.*)[x] > (order.*)[y] and !lessThan(items[x], items[y]))) { +fn swap( + comptime T: type, + items: []T, + context: var, + comptime lessThan: fn (@TypeOf(context), lhs: T, rhs: T) bool, + order: *[8]u8, + x: usize, + y: usize, +) void { + if (lessThan(context, items[y], items[x]) or ((order.*)[x] > (order.*)[y] and !lessThan(context, items[x], items[y]))) { mem.swap(T, &items[x], &items[y]); mem.swap(u8, &(order.*)[x], &(order.*)[y]); } } -// Use these to generate a comparator function for a given type. e.g. `sort(u8, slice, asc(u8))`. -pub fn asc(comptime T: type) fn (T, T) bool { +/// Use to generate a comparator function for a given type. e.g. `sort(u8, slice, asc(u8))`. +pub fn asc(comptime T: type) fn (void, T, T) bool { const impl = struct { - fn inner(a: T, b: T) bool { + fn inner(context: void, a: T, b: T) bool { return a < b; } }; @@ -1034,9 +1135,10 @@ pub fn asc(comptime T: type) fn (T, T) bool { return impl.inner; } -pub fn desc(comptime T: type) fn (T, T) bool { +/// Use to generate a comparator function for a given type. e.g. `sort(u8, slice, asc(u8))`. +pub fn desc(comptime T: type) fn (void, T, T) bool { const impl = struct { - fn inner(a: T, b: T) bool { + fn inner(context: void, a: T, b: T) bool { return a > b; } }; @@ -1085,7 +1187,7 @@ fn testStableSort() void { }, }; for (cases) |*case| { - insertionSort(IdAndValue, (case.*)[0..], cmpByValue); + insertionSort(IdAndValue, (case.*)[0..], {}, cmpByValue); for (case.*) |item, i| { testing.expect(item.id == expected[i].id); testing.expect(item.value == expected[i].value); @@ -1096,11 +1198,16 @@ const IdAndValue = struct { id: usize, value: i32, }; -fn cmpByValue(a: IdAndValue, b: IdAndValue) bool { - return asc(i32)(a.value, b.value); +fn cmpByValue(context: void, a: IdAndValue, b: IdAndValue) bool { + return asc_i32(context, a.value, b.value); } -test "std.sort" { +const asc_u8 = asc(u8); +const asc_i32 = asc(i32); +const desc_u8 = desc(u8); +const desc_i32 = desc(i32); + +test "sort" { const u8cases = [_][]const []const u8{ &[_][]const u8{ "", @@ -1132,7 +1239,7 @@ test "std.sort" { var buf: [8]u8 = undefined; const slice = buf[0..case[0].len]; mem.copy(u8, slice, case[0]); - sort(u8, slice, asc(u8)); + sort(u8, slice, {}, asc_u8); testing.expect(mem.eql(u8, slice, case[1])); } @@ -1167,12 +1274,12 @@ test "std.sort" { var buf: [8]i32 = undefined; const slice = buf[0..case[0].len]; mem.copy(i32, slice, case[0]); - sort(i32, slice, asc(i32)); + sort(i32, slice, {}, asc_i32); testing.expect(mem.eql(i32, slice, case[1])); } } -test "std.sort descending" { +test "sort descending" { const rev_cases = [_][]const []const i32{ &[_][]const i32{ &[_]i32{}, @@ -1204,14 +1311,14 @@ test "std.sort descending" { var buf: [8]i32 = undefined; const slice = buf[0..case[0].len]; mem.copy(i32, slice, case[0]); - sort(i32, slice, desc(i32)); + sort(i32, slice, {}, desc_i32); testing.expect(mem.eql(i32, slice, case[1])); } } test "another sort case" { var arr = [_]i32{ 5, 3, 1, 2, 4 }; - sort(i32, arr[0..], asc(i32)); + sort(i32, arr[0..], {}, asc_i32); testing.expect(mem.eql(i32, &arr, &[_]i32{ 1, 2, 3, 4, 5 })); } @@ -1236,7 +1343,7 @@ fn fuzzTest(rng: *std.rand.Random) !void { item.id = index; item.value = rng.intRangeLessThan(i32, 0, 100); } - sort(IdAndValue, array, cmpByValue); + sort(IdAndValue, array, {}, cmpByValue); var index: usize = 1; while (index < array.len) : (index += 1) { @@ -1248,7 +1355,12 @@ fn fuzzTest(rng: *std.rand.Random) !void { } } -pub fn argMin(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) bool) ?usize { +pub fn argMin( + comptime T: type, + items: []const T, + context: var, + comptime lessThan: fn (@TypeOf(context), lhs: T, rhs: T) bool, +) ?usize { if (items.len == 0) { return null; } @@ -1256,7 +1368,7 @@ pub fn argMin(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) var smallest = items[0]; var smallest_index: usize = 0; for (items[1..]) |item, i| { - if (lessThan(item, smallest)) { + if (lessThan(context, item, smallest)) { smallest = item; smallest_index = i + 1; } @@ -1265,32 +1377,42 @@ pub fn argMin(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) return smallest_index; } -test "std.sort.argMin" { - testing.expectEqual(@as(?usize, null), argMin(i32, &[_]i32{}, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{1}, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ 1, 2, 3, 4, 5 }, asc(i32))); - testing.expectEqual(@as(?usize, 3), argMin(i32, &[_]i32{ 9, 3, 8, 2, 5 }, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ 1, 1, 1, 1, 1 }, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ -10, 1, 10 }, asc(i32))); - testing.expectEqual(@as(?usize, 3), argMin(i32, &[_]i32{ 6, 3, 5, 7, 6 }, desc(i32))); +test "argMin" { + testing.expectEqual(@as(?usize, null), argMin(i32, &[_]i32{}, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{1}, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 3), argMin(i32, &[_]i32{ 9, 3, 8, 2, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMin(i32, &[_]i32{ -10, 1, 10 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 3), argMin(i32, &[_]i32{ 6, 3, 5, 7, 6 }, {}, desc_i32)); } -pub fn min(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) bool) ?T { - const i = argMin(T, items, lessThan) orelse return null; +pub fn min( + comptime T: type, + items: []const T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) ?T { + const i = argMin(T, items, context, lessThan) orelse return null; return items[i]; } -test "std.sort.min" { - testing.expectEqual(@as(?i32, null), min(i32, &[_]i32{}, asc(i32))); - testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{1}, asc(i32))); - testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{ 1, 2, 3, 4, 5 }, asc(i32))); - testing.expectEqual(@as(?i32, 2), min(i32, &[_]i32{ 9, 3, 8, 2, 5 }, asc(i32))); - testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{ 1, 1, 1, 1, 1 }, asc(i32))); - testing.expectEqual(@as(?i32, -10), min(i32, &[_]i32{ -10, 1, 10 }, asc(i32))); - testing.expectEqual(@as(?i32, 7), min(i32, &[_]i32{ 6, 3, 5, 7, 6 }, desc(i32))); +test "min" { + testing.expectEqual(@as(?i32, null), min(i32, &[_]i32{}, {}, asc_i32)); + testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{1}, {}, asc_i32)); + testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 2), min(i32, &[_]i32{ 9, 3, 8, 2, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 1), min(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, -10), min(i32, &[_]i32{ -10, 1, 10 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 7), min(i32, &[_]i32{ 6, 3, 5, 7, 6 }, {}, desc_i32)); } -pub fn argMax(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) bool) ?usize { +pub fn argMax( + comptime T: type, + items: []const T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) ?usize { if (items.len == 0) { return null; } @@ -1298,7 +1420,7 @@ pub fn argMax(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) var biggest = items[0]; var biggest_index: usize = 0; for (items[1..]) |item, i| { - if (lessThan(biggest, item)) { + if (lessThan(context, biggest, item)) { biggest = item; biggest_index = i + 1; } @@ -1307,35 +1429,45 @@ pub fn argMax(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) return biggest_index; } -test "std.sort.argMax" { - testing.expectEqual(@as(?usize, null), argMax(i32, &[_]i32{}, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{1}, asc(i32))); - testing.expectEqual(@as(?usize, 4), argMax(i32, &[_]i32{ 1, 2, 3, 4, 5 }, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{ 9, 3, 8, 2, 5 }, asc(i32))); - testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{ 1, 1, 1, 1, 1 }, asc(i32))); - testing.expectEqual(@as(?usize, 2), argMax(i32, &[_]i32{ -10, 1, 10 }, asc(i32))); - testing.expectEqual(@as(?usize, 1), argMax(i32, &[_]i32{ 6, 3, 5, 7, 6 }, desc(i32))); +test "argMax" { + testing.expectEqual(@as(?usize, null), argMax(i32, &[_]i32{}, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{1}, {}, asc_i32)); + testing.expectEqual(@as(?usize, 4), argMax(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{ 9, 3, 8, 2, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 0), argMax(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 2), argMax(i32, &[_]i32{ -10, 1, 10 }, {}, asc_i32)); + testing.expectEqual(@as(?usize, 1), argMax(i32, &[_]i32{ 6, 3, 5, 7, 6 }, {}, desc_i32)); } -pub fn max(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) bool) ?T { - const i = argMax(T, items, lessThan) orelse return null; +pub fn max( + comptime T: type, + items: []const T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) ?T { + const i = argMax(T, items, context, lessThan) orelse return null; return items[i]; } -test "std.sort.max" { - testing.expectEqual(@as(?i32, null), max(i32, &[_]i32{}, asc(i32))); - testing.expectEqual(@as(?i32, 1), max(i32, &[_]i32{1}, asc(i32))); - testing.expectEqual(@as(?i32, 5), max(i32, &[_]i32{ 1, 2, 3, 4, 5 }, asc(i32))); - testing.expectEqual(@as(?i32, 9), max(i32, &[_]i32{ 9, 3, 8, 2, 5 }, asc(i32))); - testing.expectEqual(@as(?i32, 1), max(i32, &[_]i32{ 1, 1, 1, 1, 1 }, asc(i32))); - testing.expectEqual(@as(?i32, 10), max(i32, &[_]i32{ -10, 1, 10 }, asc(i32))); - testing.expectEqual(@as(?i32, 3), max(i32, &[_]i32{ 6, 3, 5, 7, 6 }, desc(i32))); +test "max" { + testing.expectEqual(@as(?i32, null), max(i32, &[_]i32{}, {}, asc_i32)); + testing.expectEqual(@as(?i32, 1), max(i32, &[_]i32{1}, {}, asc_i32)); + testing.expectEqual(@as(?i32, 5), max(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 9), max(i32, &[_]i32{ 9, 3, 8, 2, 5 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 1), max(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 10), max(i32, &[_]i32{ -10, 1, 10 }, {}, asc_i32)); + testing.expectEqual(@as(?i32, 3), max(i32, &[_]i32{ 6, 3, 5, 7, 6 }, {}, desc_i32)); } -pub fn isSorted(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T) bool) bool { +pub fn isSorted( + comptime T: type, + items: []const T, + context: var, + comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) bool { var i: usize = 1; while (i < items.len) : (i += 1) { - if (lessThan(items[i], items[i - 1])) { + if (lessThan(context, items[i], items[i - 1])) { return false; } } @@ -1343,29 +1475,29 @@ pub fn isSorted(comptime T: type, items: []const T, lessThan: fn (lhs: T, rhs: T return true; } -test "std.sort.isSorted" { - testing.expect(isSorted(i32, &[_]i32{}, asc(i32))); - testing.expect(isSorted(i32, &[_]i32{10}, asc(i32))); - testing.expect(isSorted(i32, &[_]i32{ 1, 2, 3, 4, 5 }, asc(i32))); - testing.expect(isSorted(i32, &[_]i32{ -10, 1, 1, 1, 10 }, asc(i32))); +test "isSorted" { + testing.expect(isSorted(i32, &[_]i32{}, {}, asc_i32)); + testing.expect(isSorted(i32, &[_]i32{10}, {}, asc_i32)); + testing.expect(isSorted(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, asc_i32)); + testing.expect(isSorted(i32, &[_]i32{ -10, 1, 1, 1, 10 }, {}, asc_i32)); - testing.expect(isSorted(i32, &[_]i32{}, desc(i32))); - testing.expect(isSorted(i32, &[_]i32{-20}, desc(i32))); - testing.expect(isSorted(i32, &[_]i32{ 3, 2, 1, 0, -1 }, desc(i32))); - testing.expect(isSorted(i32, &[_]i32{ 10, -10 }, desc(i32))); + testing.expect(isSorted(i32, &[_]i32{}, {}, desc_i32)); + testing.expect(isSorted(i32, &[_]i32{-20}, {}, desc_i32)); + testing.expect(isSorted(i32, &[_]i32{ 3, 2, 1, 0, -1 }, {}, desc_i32)); + testing.expect(isSorted(i32, &[_]i32{ 10, -10 }, {}, desc_i32)); - testing.expect(isSorted(i32, &[_]i32{ 1, 1, 1, 1, 1 }, asc(i32))); - testing.expect(isSorted(i32, &[_]i32{ 1, 1, 1, 1, 1 }, desc(i32))); + testing.expect(isSorted(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, asc_i32)); + testing.expect(isSorted(i32, &[_]i32{ 1, 1, 1, 1, 1 }, {}, desc_i32)); - testing.expectEqual(false, isSorted(i32, &[_]i32{ 5, 4, 3, 2, 1 }, asc(i32))); - testing.expectEqual(false, isSorted(i32, &[_]i32{ 1, 2, 3, 4, 5 }, desc(i32))); + testing.expectEqual(false, isSorted(i32, &[_]i32{ 5, 4, 3, 2, 1 }, {}, asc_i32)); + testing.expectEqual(false, isSorted(i32, &[_]i32{ 1, 2, 3, 4, 5 }, {}, desc_i32)); - testing.expect(isSorted(u8, "abcd", asc(u8))); - testing.expect(isSorted(u8, "zyxw", desc(u8))); + testing.expect(isSorted(u8, "abcd", {}, asc_u8)); + testing.expect(isSorted(u8, "zyxw", {}, desc_u8)); - testing.expectEqual(false, isSorted(u8, "abcd", desc(u8))); - testing.expectEqual(false, isSorted(u8, "zyxw", asc(u8))); + testing.expectEqual(false, isSorted(u8, "abcd", {}, desc_u8)); + testing.expectEqual(false, isSorted(u8, "zyxw", {}, asc_u8)); - testing.expect(isSorted(u8, "ffff", asc(u8))); - testing.expect(isSorted(u8, "ffff", desc(u8))); + testing.expect(isSorted(u8, "ffff", {}, asc_u8)); + testing.expect(isSorted(u8, "ffff", {}, desc_u8)); } From cf654b52d68f20a403965e70371a9ad193370d8c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 3 Jun 2020 21:53:13 -0400 Subject: [PATCH 2/6] stage2: -femit-zir respects decl names and supports cycles --- src-self-hosted/Module.zig | 24 ++- src-self-hosted/type.zig | 22 +++ src-self-hosted/value.zig | 12 ++ src-self-hosted/zir.zig | 390 ++++++++++++++++++++++++------------- test/stage2/zir.zig | 35 ++-- 5 files changed, 330 insertions(+), 153 deletions(-) diff --git a/src-self-hosted/Module.zig b/src-self-hosted/Module.zig index 5bdd38c693..4ddd286f16 100644 --- a/src-self-hosted/Module.zig +++ b/src-self-hosted/Module.zig @@ -1097,6 +1097,9 @@ fn resolveDecl(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*De const decl = kv.value; try self.reAnalyzeDecl(decl, old_inst); return decl; + } else if (old_inst.cast(zir.Inst.DeclVal)) |decl_val| { + // This is just a named reference to another decl. + return self.analyzeDeclVal(scope, decl_val); } else { const new_decl = blk: { try self.decl_table.ensureCapacity(self.decl_table.size + 1); @@ -1443,6 +1446,7 @@ fn analyzeInst(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*In .breakpoint => return self.analyzeInstBreakpoint(scope, old_inst.cast(zir.Inst.Breakpoint).?), .call => return self.analyzeInstCall(scope, old_inst.cast(zir.Inst.Call).?), .declref => return self.analyzeInstDeclRef(scope, old_inst.cast(zir.Inst.DeclRef).?), + .declval => return self.analyzeInstDeclVal(scope, old_inst.cast(zir.Inst.DeclVal).?), .str => { const bytes = old_inst.cast(zir.Inst.Str).?.positionals.bytes; // The bytes references memory inside the ZIR module, which can get deallocated @@ -1501,6 +1505,24 @@ fn analyzeInstDeclRef(self: *Module, scope: *Scope, inst: *zir.Inst.DeclRef) Inn return self.analyzeDeclRef(scope, inst.base.src, decl); } +fn analyzeDeclVal(self: *Module, scope: *Scope, inst: *zir.Inst.DeclVal) InnerError!*Decl { + const decl_name = inst.positionals.name; + // This will need to get more fleshed out when there are proper structs & namespaces. + const zir_module = scope.namespace(); + const src_decl = zir_module.contents.module.findDecl(decl_name) orelse + return self.fail(scope, inst.base.src, "use of undeclared identifier '{}'", .{decl_name}); + + const decl = try self.resolveCompleteDecl(scope, src_decl); + + return decl; +} + +fn analyzeInstDeclVal(self: *Module, scope: *Scope, inst: *zir.Inst.DeclVal) InnerError!*Inst { + const decl = try self.analyzeDeclVal(scope, inst); + const ptr = try self.analyzeDeclRef(scope, inst.base.src, decl); + return self.analyzeDeref(scope, inst.base.src, ptr, inst.base.src); +} + fn analyzeDeclRef(self: *Module, scope: *Scope, src: usize, decl: *Decl) InnerError!*Inst { const decl_tv = try decl.typedValue(); const ty_payload = try scope.arena().create(Type.Payload.SingleConstPointer); @@ -1621,7 +1643,7 @@ fn analyzeInstFnType(self: *Module, scope: *Scope, fntype: *zir.Inst.FnType) Inn } fn analyzeInstPrimitive(self: *Module, scope: *Scope, primitive: *zir.Inst.Primitive) InnerError!*Inst { - return self.constType(scope, primitive.base.src, primitive.positionals.tag.toType()); + return self.constInst(scope, primitive.base.src, primitive.positionals.tag.toTypedValue()); } fn analyzeInstAs(self: *Module, scope: *Scope, as: *zir.Inst.As) InnerError!*Inst { diff --git a/src-self-hosted/type.zig b/src-self-hosted/type.zig index bdce3ba2d8..aa8c000095 100644 --- a/src-self-hosted/type.zig +++ b/src-self-hosted/type.zig @@ -51,6 +51,7 @@ pub const Type = extern union { .comptime_float => return .ComptimeFloat, .noreturn => return .NoReturn, .@"null" => return .Null, + .@"undefined" => return .Undefined, .fn_noreturn_no_args => return .Fn, .fn_naked_noreturn_no_args => return .Fn, @@ -201,6 +202,7 @@ pub const Type = extern union { => return out_stream.writeAll(@tagName(t)), .@"null" => return out_stream.writeAll("@TypeOf(null)"), + .@"undefined" => return out_stream.writeAll("@TypeOf(undefined)"), .const_slice_u8 => return out_stream.writeAll("[]const u8"), .fn_noreturn_no_args => return out_stream.writeAll("fn() noreturn"), @@ -265,6 +267,7 @@ pub const Type = extern union { .comptime_float => return Value.initTag(.comptime_float_type), .noreturn => return Value.initTag(.noreturn_type), .@"null" => return Value.initTag(.null_type), + .@"undefined" => return Value.initTag(.undefined_type), .fn_noreturn_no_args => return Value.initTag(.fn_noreturn_no_args_type), .fn_naked_noreturn_no_args => return Value.initTag(.fn_naked_noreturn_no_args_type), .fn_ccc_void_no_args => return Value.initTag(.fn_ccc_void_no_args_type), @@ -318,6 +321,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", => false, }; } @@ -378,6 +382,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", => unreachable, }; } @@ -410,6 +415,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .array_u8_sentinel_0, .const_slice_u8, @@ -454,6 +460,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .array_u8_sentinel_0, .single_const_pointer, @@ -498,6 +505,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .array_u8_sentinel_0, .fn_noreturn_no_args, @@ -543,6 +551,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -586,6 +595,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -630,6 +640,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -662,6 +673,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -707,6 +719,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -781,6 +794,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .single_const_pointer, .single_const_pointer_to_comptime_int, @@ -826,6 +840,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .single_const_pointer, .single_const_pointer_to_comptime_int, @@ -870,6 +885,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .single_const_pointer, .single_const_pointer_to_comptime_int, @@ -914,6 +930,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .single_const_pointer, .single_const_pointer_to_comptime_int, @@ -958,6 +975,7 @@ pub const Type = extern union { .comptime_float, .noreturn, .@"null", + .@"undefined", .array, .single_const_pointer, .single_const_pointer_to_comptime_int, @@ -1013,6 +1031,7 @@ pub const Type = extern union { .anyerror, .noreturn, .@"null", + .@"undefined", .fn_noreturn_no_args, .fn_naked_noreturn_no_args, .fn_ccc_void_no_args, @@ -1062,6 +1081,7 @@ pub const Type = extern union { .void, .noreturn, .@"null", + .@"undefined", => return true, .int_unsigned => return ty.cast(Payload.IntUnsigned).?.bits == 0, @@ -1115,6 +1135,7 @@ pub const Type = extern union { .void, .noreturn, .@"null", + .@"undefined", .int_unsigned, .int_signed, .array, @@ -1157,6 +1178,7 @@ pub const Type = extern union { comptime_float, noreturn, @"null", + @"undefined", fn_noreturn_no_args, fn_naked_noreturn_no_args, fn_ccc_void_no_args, diff --git a/src-self-hosted/value.zig b/src-self-hosted/value.zig index 2727ad26a5..5660cd760f 100644 --- a/src-self-hosted/value.zig +++ b/src-self-hosted/value.zig @@ -47,6 +47,7 @@ pub const Value = extern union { comptime_float_type, noreturn_type, null_type, + undefined_type, fn_noreturn_no_args_type, fn_naked_noreturn_no_args_type, fn_ccc_void_no_args_type, @@ -141,6 +142,7 @@ pub const Value = extern union { .comptime_float_type => return out_stream.writeAll("comptime_float"), .noreturn_type => return out_stream.writeAll("noreturn"), .null_type => return out_stream.writeAll("@TypeOf(null)"), + .undefined_type => return out_stream.writeAll("@TypeOf(undefined)"), .fn_noreturn_no_args_type => return out_stream.writeAll("fn() noreturn"), .fn_naked_noreturn_no_args_type => return out_stream.writeAll("fn() callconv(.Naked) noreturn"), .fn_ccc_void_no_args_type => return out_stream.writeAll("fn() callconv(.C) void"), @@ -225,6 +227,7 @@ pub const Value = extern union { .comptime_float_type => Type.initTag(.comptime_float), .noreturn_type => Type.initTag(.noreturn), .null_type => Type.initTag(.@"null"), + .undefined_type => Type.initTag(.@"undefined"), .fn_noreturn_no_args_type => Type.initTag(.fn_noreturn_no_args), .fn_naked_noreturn_no_args_type => Type.initTag(.fn_naked_noreturn_no_args), .fn_ccc_void_no_args_type => Type.initTag(.fn_ccc_void_no_args), @@ -281,6 +284,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -339,6 +343,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -398,6 +403,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -462,6 +468,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -555,6 +562,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -610,6 +618,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -710,6 +719,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -771,6 +781,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, @@ -849,6 +860,7 @@ pub const Value = extern union { .comptime_float_type, .noreturn_type, .null_type, + .undefined_type, .fn_noreturn_no_args_type, .fn_naked_noreturn_no_args_type, .fn_ccc_void_no_args_type, diff --git a/src-self-hosted/zir.zig b/src-self-hosted/zir.zig index b8a1421996..749d9d9c87 100644 --- a/src-self-hosted/zir.zig +++ b/src-self-hosted/zir.zig @@ -27,9 +27,11 @@ pub const Inst = struct { pub const Tag = enum { breakpoint, call, - /// Represents a reference to a global decl by name. - /// The syntax `@foo` is equivalent to `declref("foo")`. + /// Represents a pointer to a global decl by name. declref, + /// The syntax `@foo` is equivalent to `declval("foo")`. + /// declval is equivalent to declref followed by deref. + declval, str, int, ptrtoint, @@ -59,6 +61,7 @@ pub const Inst = struct { .breakpoint => Breakpoint, .call => Call, .declref => DeclRef, + .declval => DeclVal, .str => Str, .int => Int, .ptrtoint => PtrToInt, @@ -122,6 +125,16 @@ pub const Inst = struct { kw_args: struct {}, }; + pub const DeclVal = struct { + pub const base_tag = Tag.declval; + base: Inst, + + positionals: struct { + name: []const u8, + }, + kw_args: struct {}, + }; + pub const Str = struct { pub const base_tag = Tag.str; base: Inst, @@ -254,11 +267,11 @@ pub const Inst = struct { base: Inst, positionals: struct { - tag: BuiltinType, + tag: Builtin, }, kw_args: struct {}, - pub const BuiltinType = enum { + pub const Builtin = enum { isize, usize, c_short, @@ -282,32 +295,42 @@ pub const Inst = struct { anyerror, comptime_int, comptime_float, + @"true", + @"false", + @"null", + @"undefined", + void_value, - pub fn toType(self: BuiltinType) Type { + pub fn toTypedValue(self: Builtin) TypedValue { return switch (self) { - .isize => Type.initTag(.isize), - .usize => Type.initTag(.usize), - .c_short => Type.initTag(.c_short), - .c_ushort => Type.initTag(.c_ushort), - .c_int => Type.initTag(.c_int), - .c_uint => Type.initTag(.c_uint), - .c_long => Type.initTag(.c_long), - .c_ulong => Type.initTag(.c_ulong), - .c_longlong => Type.initTag(.c_longlong), - .c_ulonglong => Type.initTag(.c_ulonglong), - .c_longdouble => Type.initTag(.c_longdouble), - .c_void => Type.initTag(.c_void), - .f16 => Type.initTag(.f16), - .f32 => Type.initTag(.f32), - .f64 => Type.initTag(.f64), - .f128 => Type.initTag(.f128), - .bool => Type.initTag(.bool), - .void => Type.initTag(.void), - .noreturn => Type.initTag(.noreturn), - .type => Type.initTag(.type), - .anyerror => Type.initTag(.anyerror), - .comptime_int => Type.initTag(.comptime_int), - .comptime_float => Type.initTag(.comptime_float), + .isize => .{ .ty = Type.initTag(.type), .val = Value.initTag(.isize_type) }, + .usize => .{ .ty = Type.initTag(.type), .val = Value.initTag(.usize_type) }, + .c_short => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_short_type) }, + .c_ushort => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_ushort_type) }, + .c_int => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_int_type) }, + .c_uint => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_uint_type) }, + .c_long => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_long_type) }, + .c_ulong => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_ulong_type) }, + .c_longlong => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_longlong_type) }, + .c_ulonglong => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_ulonglong_type) }, + .c_longdouble => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_longdouble_type) }, + .c_void => .{ .ty = Type.initTag(.type), .val = Value.initTag(.c_void_type) }, + .f16 => .{ .ty = Type.initTag(.type), .val = Value.initTag(.f16_type) }, + .f32 => .{ .ty = Type.initTag(.type), .val = Value.initTag(.f32_type) }, + .f64 => .{ .ty = Type.initTag(.type), .val = Value.initTag(.f64_type) }, + .f128 => .{ .ty = Type.initTag(.type), .val = Value.initTag(.f128_type) }, + .bool => .{ .ty = Type.initTag(.type), .val = Value.initTag(.bool_type) }, + .void => .{ .ty = Type.initTag(.type), .val = Value.initTag(.void_type) }, + .noreturn => .{ .ty = Type.initTag(.type), .val = Value.initTag(.noreturn_type) }, + .type => .{ .ty = Type.initTag(.type), .val = Value.initTag(.type_type) }, + .anyerror => .{ .ty = Type.initTag(.type), .val = Value.initTag(.anyerror_type) }, + .comptime_int => .{ .ty = Type.initTag(.type), .val = Value.initTag(.comptime_int_type) }, + .comptime_float => .{ .ty = Type.initTag(.type), .val = Value.initTag(.comptime_float_type) }, + .@"true" => .{ .ty = Type.initTag(.bool), .val = Value.initTag(.bool_true) }, + .@"false" => .{ .ty = Type.initTag(.bool), .val = Value.initTag(.bool_false) }, + .@"null" => .{ .ty = Type.initTag(.@"null"), .val = Value.initTag(.null_value) }, + .@"undefined" => .{ .ty = Type.initTag(.@"undefined"), .val = Value.initTag(.undef) }, + .void_value => .{ .ty = Type.initTag(.void), .val = Value.initTag(.the_one_possible_value) }, }; } }; @@ -440,7 +463,7 @@ pub const Module = struct { self.writeToStream(std.heap.page_allocator, std.io.getStdErr().outStream()) catch {}; } - const InstPtrTable = std.AutoHashMap(*Inst, struct { index: usize, fn_body: ?*Module.Body }); + const InstPtrTable = std.AutoHashMap(*Inst, struct { inst: *Inst, index: ?usize }); /// TODO Look into making a table to speed this up. pub fn findDecl(self: Module, name: []const u8) ?*Inst { @@ -462,17 +485,17 @@ pub const Module = struct { try inst_table.ensureCapacity(self.decls.len); for (self.decls) |decl, decl_i| { - try inst_table.putNoClobber(decl, .{ .index = decl_i, .fn_body = null }); + try inst_table.putNoClobber(decl, .{ .inst = decl, .index = null }); if (decl.cast(Inst.Fn)) |fn_inst| { for (fn_inst.positionals.body.instructions) |inst, inst_i| { - try inst_table.putNoClobber(inst, .{ .index = inst_i, .fn_body = &fn_inst.positionals.body }); + try inst_table.putNoClobber(inst, .{ .inst = inst, .index = inst_i }); } } } for (self.decls) |decl, i| { - try stream.print("@{} ", .{i}); + try stream.print("@{} ", .{decl.name}); try self.writeInstToStream(stream, decl, &inst_table); try stream.writeByte('\n'); } @@ -489,6 +512,7 @@ pub const Module = struct { .breakpoint => return self.writeInstToStreamGeneric(stream, .breakpoint, decl, inst_table), .call => return self.writeInstToStreamGeneric(stream, .call, decl, inst_table), .declref => return self.writeInstToStreamGeneric(stream, .declref, decl, inst_table), + .declval => return self.writeInstToStreamGeneric(stream, .declval, decl, inst_table), .str => return self.writeInstToStreamGeneric(stream, .str, decl, inst_table), .int => return self.writeInstToStreamGeneric(stream, .int, decl, inst_table), .ptrtoint => return self.writeInstToStreamGeneric(stream, .ptrtoint, decl, inst_table), @@ -587,9 +611,18 @@ pub const Module = struct { } fn writeInstParamToStream(self: Module, stream: var, inst: *Inst, inst_table: *const InstPtrTable) !void { - const info = inst_table.getValue(inst).?; - const prefix = if (info.fn_body == null) "@" else "%"; - try stream.print("{}{}", .{ prefix, info.index }); + if (inst_table.getValue(inst)) |info| { + if (info.index) |i| { + try stream.print("%{}", .{info.index}); + } else { + try stream.print("@{}", .{info.inst.name}); + } + } else if (inst.cast(Inst.DeclVal)) |decl_val| { + try stream.print("@{}", .{decl_val.positionals.name}); + } else { + //try stream.print("?", .{}); + unreachable; + } } }; @@ -964,47 +997,17 @@ const Parser = struct { self.i = src; return self.fail("unrecognized identifier: {}", .{bad_name}); } else { - const name_array = try self.arena.allocator.create(Inst.Str); - name_array.* = .{ + const declval = try self.arena.allocator.create(Inst.DeclVal); + declval.* = .{ .base = .{ .name = try self.generateName(), .src = src, - .tag = Inst.Str.base_tag, + .tag = Inst.DeclVal.base_tag, }, - .positionals = .{ .bytes = ident }, + .positionals = .{ .name = ident }, .kw_args = .{}, }; - const name = try self.arena.allocator.create(Inst.Ref); - name.* = .{ - .base = .{ - .name = try self.generateName(), - .src = src, - .tag = Inst.Ref.base_tag, - }, - .positionals = .{ .operand = &name_array.base }, - .kw_args = .{}, - }; - const declref = try self.arena.allocator.create(Inst.DeclRef); - declref.* = .{ - .base = .{ - .name = try self.generateName(), - .src = src, - .tag = Inst.DeclRef.base_tag, - }, - .positionals = .{ .name = &name.base }, - .kw_args = .{}, - }; - const deref = try self.arena.allocator.create(Inst.Deref); - deref.* = .{ - .base = .{ - .name = try self.generateName(), - .src = src, - .tag = Inst.Deref.base_tag, - }, - .positionals = .{ .ptr = &declref.base }, - .kw_args = .{}, - }; - return &deref.base; + return &declval.base; } }; if (local_ref) { @@ -1025,12 +1028,15 @@ pub fn emit(allocator: *Allocator, old_module: IrModule) !Module { var ctx: EmitZIR = .{ .allocator = allocator, .decls = .{}, - .decl_table = std.AutoHashMap(*ir.Inst, *Inst).init(allocator), .arena = std.heap.ArenaAllocator.init(allocator), .old_module = &old_module, + .next_auto_name = 0, + .names = std.StringHashMap(void).init(allocator), + .primitive_table = std.AutoHashMap(Inst.Primitive.Builtin, *Inst).init(allocator), }; defer ctx.decls.deinit(allocator); - defer ctx.decl_table.deinit(); + defer ctx.names.deinit(); + defer ctx.primitive_table.deinit(); errdefer ctx.arena.deinit(); try ctx.emit(); @@ -1046,47 +1052,90 @@ const EmitZIR = struct { arena: std.heap.ArenaAllocator, old_module: *const IrModule, decls: std.ArrayListUnmanaged(*Inst), - decl_table: std.AutoHashMap(*ir.Inst, *Inst), + names: std.StringHashMap(void), + next_auto_name: usize, + primitive_table: std.AutoHashMap(Inst.Primitive.Builtin, *Inst), fn emit(self: *EmitZIR) !void { - var it = self.old_module.decl_exports.iterator(); - while (it.next()) |kv| { - const decl = kv.key; - const exports = kv.value; - const export_value = try self.emitTypedValue(decl.src, decl.typed_value.most_recent.typed_value); - for (exports) |module_export| { - const symbol_name = try self.emitStringLiteral(module_export.src, module_export.options.name); - const export_inst = try self.arena.allocator.create(Inst.Export); - export_inst.* = .{ - .base = .{ - .name = try self.autoName(), - .src = module_export.src, - .tag = Inst.Export.base_tag, - }, - .positionals = .{ - .symbol_name = symbol_name, - .value = export_value, - }, - .kw_args = .{}, - }; - try self.decls.append(self.allocator, &export_inst.base); + // Put all the Decls in a list and sort them by name to avoid nondeterminism introduced + // by the hash table. + var src_decls = std.ArrayList(*IrModule.Decl).init(self.allocator); + defer src_decls.deinit(); + try src_decls.ensureCapacity(self.old_module.decl_table.size); + try self.decls.ensureCapacity(self.allocator, self.old_module.decl_table.size); + try self.names.ensureCapacity(self.old_module.decl_table.size); + + var decl_it = self.old_module.decl_table.iterator(); + while (decl_it.next()) |kv| { + const decl = kv.value; + src_decls.appendAssumeCapacity(decl); + self.names.putAssumeCapacityNoClobber(mem.spanZ(decl.name), {}); + } + std.sort.sort(*IrModule.Decl, src_decls.items, {}, (struct { + fn lessThan(context: void, a: *IrModule.Decl, b: *IrModule.Decl) bool { + return a.src < b.src; + } + }).lessThan); + + // Emit all the decls. + for (src_decls.items) |ir_decl| { + if (self.old_module.export_owners.getValue(ir_decl)) |exports| { + for (exports) |module_export| { + const declval = try self.emitDeclVal(ir_decl.src, mem.spanZ(module_export.exported_decl.name)); + const symbol_name = try self.emitStringLiteral(module_export.src, module_export.options.name); + const export_inst = try self.arena.allocator.create(Inst.Export); + export_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = module_export.src, + .tag = Inst.Export.base_tag, + }, + .positionals = .{ + .symbol_name = symbol_name, + .value = declval, + }, + .kw_args = .{}, + }; + try self.decls.append(self.allocator, &export_inst.base); + } + } else { + const new_decl = try self.emitTypedValue(ir_decl.src, ir_decl.typed_value.most_recent.typed_value); + new_decl.name = try self.arena.allocator.dupe(u8, mem.spanZ(ir_decl.name)); } } } - fn resolveInst(self: *EmitZIR, inst_table: *const std.AutoHashMap(*ir.Inst, *Inst), inst: *ir.Inst) !*Inst { + fn resolveInst(self: *EmitZIR, inst_table: *std.AutoHashMap(*ir.Inst, *Inst), inst: *ir.Inst) !*Inst { if (inst.cast(ir.Inst.Constant)) |const_inst| { - if (self.decl_table.getValue(inst)) |decl| { - return decl; - } - const new_decl = try self.emitTypedValue(inst.src, .{ .ty = inst.ty, .val = const_inst.val }); - try self.decl_table.putNoClobber(inst, new_decl); + const new_decl = if (const_inst.val.cast(Value.Payload.Function)) |func_pl| blk: { + const owner_decl = func_pl.func.owner_decl; + break :blk try self.emitDeclVal(inst.src, mem.spanZ(owner_decl.name)); + } else if (const_inst.val.cast(Value.Payload.DeclRef)) |declref| blk: { + break :blk try self.emitDeclRef(inst.src, declref.decl); + } else blk: { + break :blk try self.emitTypedValue(inst.src, .{ .ty = inst.ty, .val = const_inst.val }); + }; + try inst_table.putNoClobber(inst, new_decl); return new_decl; } else { return inst_table.getValue(inst).?; } } + fn emitDeclVal(self: *EmitZIR, src: usize, decl_name: []const u8) !*Inst { + const declval = try self.arena.allocator.create(Inst.DeclVal); + declval.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.DeclVal.base_tag, + }, + .positionals = .{ .name = try self.arena.allocator.dupe(u8, decl_name) }, + .kw_args = .{}, + }; + return &declval.base; + } + fn emitComptimeIntVal(self: *EmitZIR, src: usize, val: Value) !*Inst { const big_int_space = try self.arena.allocator.create(Value.BigIntSpace); const int_inst = try self.arena.allocator.create(Inst.Int); @@ -1105,8 +1154,31 @@ const EmitZIR = struct { return &int_inst.base; } + fn emitDeclRef(self: *EmitZIR, src: usize, decl: *IrModule.Decl) !*Inst { + const declval = try self.emitDeclVal(src, mem.spanZ(decl.name)); + const ref_inst = try self.arena.allocator.create(Inst.Ref); + ref_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.Ref.base_tag, + }, + .positionals = .{ + .operand = declval, + }, + .kw_args = .{}, + }; + try self.decls.append(self.allocator, &ref_inst.base); + + return &ref_inst.base; + } + fn emitTypedValue(self: *EmitZIR, src: usize, typed_value: TypedValue) Allocator.Error!*Inst { const allocator = &self.arena.allocator; + if (typed_value.val.cast(Value.Payload.DeclRef)) |decl_ref| { + const decl = decl_ref.decl; + return self.emitDeclRef(src, decl); + } switch (typed_value.ty.zigTypeTag()) { .Pointer => { const ptr_elem_type = typed_value.ty.elemType(); @@ -1142,7 +1214,6 @@ const EmitZIR = struct { }, .kw_args = .{}, }; - try self.decls.append(self.allocator, &as_inst.base); return &as_inst.base; }, @@ -1182,6 +1253,33 @@ const EmitZIR = struct { try self.decls.append(self.allocator, &fn_inst.base); return &fn_inst.base; }, + .Array => { + // TODO more checks to make sure this can be emitted as a string literal + //const array_elem_type = ptr_elem_type.elemType(); + //if (array_elem_type.eql(Type.initTag(.u8)) and + // ptr_elem_type.hasSentinel(Value.initTag(.zero))) + //{ + //} + const bytes = typed_value.val.toAllocatedBytes(allocator) catch |err| switch (err) { + error.AnalysisFail => unreachable, + else => |e| return e, + }; + const str_inst = try self.arena.allocator.create(Inst.Str); + str_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.Str.base_tag, + }, + .positionals = .{ + .bytes = bytes, + }, + .kw_args = .{}, + }; + try self.decls.append(self.allocator, &str_inst.base); + return &str_inst.base; + }, + .Void => return self.emitPrimitive(src, .void_value), else => |t| std.debug.panic("TODO implement emitTypedValue for {}", .{@tagName(t)}), } } @@ -1395,30 +1493,30 @@ const EmitZIR = struct { fn emitType(self: *EmitZIR, src: usize, ty: Type) Allocator.Error!*Inst { switch (ty.tag()) { - .isize => return self.emitPrimitiveType(src, .isize), - .usize => return self.emitPrimitiveType(src, .usize), - .c_short => return self.emitPrimitiveType(src, .c_short), - .c_ushort => return self.emitPrimitiveType(src, .c_ushort), - .c_int => return self.emitPrimitiveType(src, .c_int), - .c_uint => return self.emitPrimitiveType(src, .c_uint), - .c_long => return self.emitPrimitiveType(src, .c_long), - .c_ulong => return self.emitPrimitiveType(src, .c_ulong), - .c_longlong => return self.emitPrimitiveType(src, .c_longlong), - .c_ulonglong => return self.emitPrimitiveType(src, .c_ulonglong), - .c_longdouble => return self.emitPrimitiveType(src, .c_longdouble), - .c_void => return self.emitPrimitiveType(src, .c_void), - .f16 => return self.emitPrimitiveType(src, .f16), - .f32 => return self.emitPrimitiveType(src, .f32), - .f64 => return self.emitPrimitiveType(src, .f64), - .f128 => return self.emitPrimitiveType(src, .f128), - .anyerror => return self.emitPrimitiveType(src, .anyerror), + .isize => return self.emitPrimitive(src, .isize), + .usize => return self.emitPrimitive(src, .usize), + .c_short => return self.emitPrimitive(src, .c_short), + .c_ushort => return self.emitPrimitive(src, .c_ushort), + .c_int => return self.emitPrimitive(src, .c_int), + .c_uint => return self.emitPrimitive(src, .c_uint), + .c_long => return self.emitPrimitive(src, .c_long), + .c_ulong => return self.emitPrimitive(src, .c_ulong), + .c_longlong => return self.emitPrimitive(src, .c_longlong), + .c_ulonglong => return self.emitPrimitive(src, .c_ulonglong), + .c_longdouble => return self.emitPrimitive(src, .c_longdouble), + .c_void => return self.emitPrimitive(src, .c_void), + .f16 => return self.emitPrimitive(src, .f16), + .f32 => return self.emitPrimitive(src, .f32), + .f64 => return self.emitPrimitive(src, .f64), + .f128 => return self.emitPrimitive(src, .f128), + .anyerror => return self.emitPrimitive(src, .anyerror), else => switch (ty.zigTypeTag()) { - .Bool => return self.emitPrimitiveType(src, .bool), - .Void => return self.emitPrimitiveType(src, .void), - .NoReturn => return self.emitPrimitiveType(src, .noreturn), - .Type => return self.emitPrimitiveType(src, .type), - .ComptimeInt => return self.emitPrimitiveType(src, .comptime_int), - .ComptimeFloat => return self.emitPrimitiveType(src, .comptime_float), + .Bool => return self.emitPrimitive(src, .bool), + .Void => return self.emitPrimitive(src, .void), + .NoReturn => return self.emitPrimitive(src, .noreturn), + .Type => return self.emitPrimitive(src, .type), + .ComptimeInt => return self.emitPrimitive(src, .comptime_int), + .ComptimeFloat => return self.emitPrimitive(src, .comptime_float), .Fn => { const param_types = try self.allocator.alloc(Type, ty.fnParamLen()); defer self.allocator.free(param_types); @@ -1453,24 +1551,36 @@ const EmitZIR = struct { } fn autoName(self: *EmitZIR) ![]u8 { - return std.fmt.allocPrint(&self.arena.allocator, "{}", .{self.decls.items.len}); + while (true) { + const proposed_name = try std.fmt.allocPrint(&self.arena.allocator, "unnamed${}", .{self.next_auto_name}); + self.next_auto_name += 1; + const gop = try self.names.getOrPut(proposed_name); + if (!gop.found_existing) { + gop.kv.value = {}; + return proposed_name; + } + } } - fn emitPrimitiveType(self: *EmitZIR, src: usize, tag: Inst.Primitive.BuiltinType) !*Inst { - const primitive_inst = try self.arena.allocator.create(Inst.Primitive); - primitive_inst.* = .{ - .base = .{ - .name = try self.autoName(), - .src = src, - .tag = Inst.Primitive.base_tag, - }, - .positionals = .{ - .tag = tag, - }, - .kw_args = .{}, - }; - try self.decls.append(self.allocator, &primitive_inst.base); - return &primitive_inst.base; + fn emitPrimitive(self: *EmitZIR, src: usize, tag: Inst.Primitive.Builtin) !*Inst { + const gop = try self.primitive_table.getOrPut(tag); + if (!gop.found_existing) { + const primitive_inst = try self.arena.allocator.create(Inst.Primitive); + primitive_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.Primitive.base_tag, + }, + .positionals = .{ + .tag = tag, + }, + .kw_args = .{}, + }; + try self.decls.append(self.allocator, &primitive_inst.base); + gop.kv.value = &primitive_inst.base; + } + return gop.kv.value; } fn emitStringLiteral(self: *EmitZIR, src: usize, str: []const u8) !*Inst { diff --git a/test/stage2/zir.zig b/test/stage2/zir.zig index f8b9d797d5..7d5e330b89 100644 --- a/test/stage2/zir.zig +++ b/test/stage2/zir.zig @@ -21,14 +21,17 @@ pub fn addCases(ctx: *TestContext) void { \\ %11 = return() \\}) , - \\@0 = primitive(void) - \\@1 = fntype([], @0, cc=C) - \\@2 = fn(@1, { + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\@9 = str("entry") + \\@10 = ref(@9) + \\@unnamed$6 = str("entry") + \\@unnamed$7 = ref(@unnamed$6) + \\@unnamed$8 = export(@unnamed$7, @entry) + \\@unnamed$10 = fntype([], @void, cc=C) + \\@entry = fn(@unnamed$10, { \\ %0 = return() \\}) - \\@3 = str("entry") - \\@4 = ref(@3) - \\@5 = export(@4, @2) \\ ); ctx.addZIRTransform("elemptr, add, cmp, condbr, return, breakpoint", linux_x64, @@ -68,14 +71,22 @@ pub fn addCases(ctx: *TestContext) void { \\@10 = ref(@9) \\@11 = export(@10, @entry) , - \\@0 = primitive(void) - \\@1 = fntype([], @0, cc=C) - \\@2 = fn(@1, { + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\@0 = int(0) + \\@1 = int(1) + \\@2 = int(2) + \\@3 = int(3) + \\@unnamed$7 = fntype([], @void, cc=C) + \\@entry = fn(@unnamed$7, { \\ %0 = return() \\}) - \\@3 = str("entry") - \\@4 = ref(@3) - \\@5 = export(@4, @2) + \\@a = str("2\x08\x01\n") + \\@9 = str("entry") + \\@10 = ref(@9) + \\@unnamed$14 = str("entry") + \\@unnamed$15 = ref(@unnamed$14) + \\@unnamed$16 = export(@unnamed$15, @entry) \\ ); From 91930a4ff08d275ec16507aed58a73a02742f831 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 5 Jun 2020 15:49:23 -0400 Subject: [PATCH 3/6] stage2: fix not re-loading source file for updates after errors --- src-self-hosted/Module.zig | 35 ++++++++++++++++++------- src-self-hosted/main.zig | 16 +++++++++++- src-self-hosted/zir.zig | 53 +++++++++++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 11 deletions(-) diff --git a/src-self-hosted/Module.zig b/src-self-hosted/Module.zig index 4ddd286f16..7a3a97a2a6 100644 --- a/src-self-hosted/Module.zig +++ b/src-self-hosted/Module.zig @@ -576,6 +576,8 @@ pub fn update(self: *Module) !void { // TODO Use the cache hash file system to detect which source files changed. // Here we simulate a full cache miss. // Analyze the root source file now. + // Source files could have been loaded for any reason; to force a refresh we unload now. + self.root_scope.unload(self.allocator); self.analyzeRoot(self.root_scope) catch |err| switch (err) { error.AnalysisFail => { assert(self.totalErrorCount() != 0); @@ -594,8 +596,11 @@ pub fn update(self: *Module) !void { try self.deleteDecl(decl); } - // Unload all the source files from memory. - self.root_scope.unload(self.allocator); + // If there are any errors, we anticipate the source files being loaded + // to report error messages. Otherwise we unload all source files to save memory. + if (self.totalErrorCount() == 0) { + self.root_scope.unload(self.allocator); + } try self.bin_file.flush(); self.link_error_flags = self.bin_file.error_flags; @@ -878,11 +883,11 @@ fn analyzeRoot(self: *Module, root_scope: *Scope.ZIRModule) !void { const decl = kv.value; deleted_decls.removeAssertDiscard(decl); const new_contents_hash = Decl.hashSimpleName(src_decl.contents); + //std.debug.warn("'{}' contents: '{}'\n", .{ src_decl.name, src_decl.contents }); if (!mem.eql(u8, &new_contents_hash, &decl.contents_hash)) { - //std.debug.warn("noticed '{}' source changed\n", .{src_decl.name}); - decl.analysis = .outdated; + //std.debug.warn("'{}' {x} => {x}\n", .{ src_decl.name, decl.contents_hash, new_contents_hash }); + try self.markOutdatedDecl(decl); decl.contents_hash = new_contents_hash; - try self.work_queue.writeItem(.{ .re_analyze_decl = decl }); } } else if (src_decl.cast(zir.Inst.Export)) |export_inst| { try exports_to_resolve.append(&export_inst.base); @@ -923,8 +928,7 @@ fn deleteDecl(self: *Module, decl: *Decl) !void { for (decl.dependants.items) |dep| { dep.removeDependency(decl); if (dep.analysis != .outdated) { - dep.analysis = .outdated; - try self.work_queue.writeItem(.{ .re_analyze_decl = dep }); + try self.markOutdatedDecl(dep); } } self.deleteDeclExports(decl); @@ -1083,14 +1087,22 @@ fn reAnalyzeDecl(self: *Module, decl: *Decl, old_inst: *zir.Inst) InnerError!voi .codegen_failure_retryable, .complete, => if (dep.generation != self.generation) { - dep.analysis = .outdated; - try self.work_queue.writeItem(.{ .re_analyze_decl = dep }); + try self.markOutdatedDecl(dep); }, } } } } +fn markOutdatedDecl(self: *Module, decl: *Decl) !void { + //std.debug.warn("mark {} outdated\n", .{decl.name}); + try self.work_queue.writeItem(.{ .re_analyze_decl = decl }); + if (self.failed_decls.remove(decl)) |entry| { + self.allocator.destroy(entry.value); + } + decl.analysis = .outdated; +} + fn resolveDecl(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*Decl { const hash = Decl.hashSimpleName(old_inst.name); if (self.decl_table.get(hash)) |kv| { @@ -1445,6 +1457,7 @@ fn analyzeInst(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*In switch (old_inst.tag) { .breakpoint => return self.analyzeInstBreakpoint(scope, old_inst.cast(zir.Inst.Breakpoint).?), .call => return self.analyzeInstCall(scope, old_inst.cast(zir.Inst.Call).?), + .compileerror => return self.analyzeInstCompileError(scope, old_inst.cast(zir.Inst.CompileError).?), .declref => return self.analyzeInstDeclRef(scope, old_inst.cast(zir.Inst.DeclRef).?), .declval => return self.analyzeInstDeclVal(scope, old_inst.cast(zir.Inst.DeclVal).?), .str => { @@ -1484,6 +1497,10 @@ fn analyzeInst(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*In } } +fn analyzeInstCompileError(self: *Module, scope: *Scope, inst: *zir.Inst.CompileError) InnerError!*Inst { + return self.fail(scope, inst.base.src, "{}", .{inst.positionals.msg}); +} + fn analyzeInstBreakpoint(self: *Module, scope: *Scope, inst: *zir.Inst.Breakpoint) InnerError!*Inst { const b = try self.requireRuntimeBlock(scope, inst.base.src); return self.addNewInstArgs(b, inst.base.src, Type.initTag(.void), Inst.Breakpoint, Inst.Args(Inst.Breakpoint){}); diff --git a/src-self-hosted/main.zig b/src-self-hosted/main.zig index 40d3068d3f..eda9dcfc31 100644 --- a/src-self-hosted/main.zig +++ b/src-self-hosted/main.zig @@ -407,7 +407,21 @@ fn buildOutputType( std.debug.warn("-fno-emit-bin not supported yet", .{}); process.exit(1); }, - .yes_default_path => try std.fmt.allocPrint(arena, "{}{}", .{ root_name, target_info.target.exeFileExt() }), + .yes_default_path => switch (output_mode) { + .Exe => try std.fmt.allocPrint(arena, "{}{}", .{ root_name, target_info.target.exeFileExt() }), + .Lib => blk: { + const suffix = switch (link_mode orelse .Static) { + .Static => target_info.target.staticLibSuffix(), + .Dynamic => target_info.target.dynamicLibSuffix(), + }; + break :blk try std.fmt.allocPrint(arena, "{}{}{}", .{ + target_info.target.libPrefix(), + root_name, + suffix, + }); + }, + .Obj => try std.fmt.allocPrint(arena, "{}{}", .{ root_name, target_info.target.oFileExt() }), + }, .yes => |p| p, }; diff --git a/src-self-hosted/zir.zig b/src-self-hosted/zir.zig index 749d9d9c87..a00782771d 100644 --- a/src-self-hosted/zir.zig +++ b/src-self-hosted/zir.zig @@ -27,6 +27,7 @@ pub const Inst = struct { pub const Tag = enum { breakpoint, call, + compileerror, /// Represents a pointer to a global decl by name. declref, /// The syntax `@foo` is equivalent to `declval("foo")`. @@ -62,6 +63,7 @@ pub const Inst = struct { .call => Call, .declref => DeclRef, .declval => DeclVal, + .compileerror => CompileError, .str => Str, .int => Int, .ptrtoint => PtrToInt, @@ -135,6 +137,16 @@ pub const Inst = struct { kw_args: struct {}, }; + pub const CompileError = struct { + pub const base_tag = Tag.compileerror; + base: Inst, + + positionals: struct { + msg: []const u8, + }, + kw_args: struct {}, + }; + pub const Str = struct { pub const base_tag = Tag.str; base: Inst, @@ -513,6 +525,7 @@ pub const Module = struct { .call => return self.writeInstToStreamGeneric(stream, .call, decl, inst_table), .declref => return self.writeInstToStreamGeneric(stream, .declref, decl, inst_table), .declval => return self.writeInstToStreamGeneric(stream, .declval, decl, inst_table), + .compileerror => return self.writeInstToStreamGeneric(stream, .compileerror, decl, inst_table), .str => return self.writeInstToStreamGeneric(stream, .str, decl, inst_table), .int => return self.writeInstToStreamGeneric(stream, .int, decl, inst_table), .ptrtoint => return self.writeInstToStreamGeneric(stream, .ptrtoint, decl, inst_table), @@ -917,6 +930,7 @@ const Parser = struct { try requireEatBytes(self, ")"); inst_specific.base.contents = self.source[contents_start..self.i]; + //std.debug.warn("parsed {} = '{}'\n", .{ inst_specific.base.name, inst_specific.base.contents }); return &inst_specific.base; } @@ -1230,7 +1244,44 @@ const EmitZIR = struct { var instructions = std.ArrayList(*Inst).init(self.allocator); defer instructions.deinit(); - try self.emitBody(module_fn.analysis.success, &inst_table, &instructions); + switch (module_fn.analysis) { + .queued => unreachable, + .in_progress => unreachable, + .success => |body| { + try self.emitBody(body, &inst_table, &instructions); + }, + .sema_failure => { + const err_msg = self.old_module.failed_decls.getValue(module_fn.owner_decl).?; + const fail_inst = try self.arena.allocator.create(Inst.CompileError); + fail_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.CompileError.base_tag, + }, + .positionals = .{ + .msg = try self.arena.allocator.dupe(u8, err_msg.msg), + }, + .kw_args = .{}, + }; + try instructions.append(&fail_inst.base); + }, + .dependency_failure => { + const fail_inst = try self.arena.allocator.create(Inst.CompileError); + fail_inst.* = .{ + .base = .{ + .name = try self.autoName(), + .src = src, + .tag = Inst.CompileError.base_tag, + }, + .positionals = .{ + .msg = try self.arena.allocator.dupe(u8, "depends on another failed Decl"), + }, + .kw_args = .{}, + }; + try instructions.append(&fail_inst.base); + }, + } const fn_type = try self.emitType(src, module_fn.fn_type); From 9ea4965ceb62ea47569d0676bb68475304aad467 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 8 Jun 2020 02:31:04 -0400 Subject: [PATCH 4/6] self-hosted: remove deleted Decls from failed_decls --- src-self-hosted/Module.zig | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src-self-hosted/Module.zig b/src-self-hosted/Module.zig index 7a3a97a2a6..570ae69a63 100644 --- a/src-self-hosted/Module.zig +++ b/src-self-hosted/Module.zig @@ -910,6 +910,8 @@ fn analyzeRoot(self: *Module, root_scope: *Scope.ZIRModule) !void { } fn deleteDecl(self: *Module, decl: *Decl) !void { + try self.deletion_set.ensureCapacity(self.allocator, self.deletion_set.items.len + decl.dependencies.items.len); + //std.debug.warn("deleting decl '{}'\n", .{decl.name}); const name_hash = decl.fullyQualifiedNameHash(); self.decl_table.removeAssertDiscard(name_hash); @@ -921,16 +923,20 @@ fn deleteDecl(self: *Module, decl: *Decl) !void { // another reference to it may turn up. assert(!dep.deletion_flag); dep.deletion_flag = true; - try self.deletion_set.append(self.allocator, dep); + self.deletion_set.appendAssumeCapacity(dep); } } // Anything that depends on this deleted decl certainly needs to be re-analyzed. for (decl.dependants.items) |dep| { dep.removeDependency(decl); if (dep.analysis != .outdated) { + // TODO Move this failure possibility to the top of the function. try self.markOutdatedDecl(dep); } } + if (self.failed_decls.remove(decl)) |entry| { + self.allocator.destroy(entry.value); + } self.deleteDeclExports(decl); self.bin_file.freeDecl(decl); decl.destroy(self.allocator); From 47090d234ecc3e50937c918b05e6f039a53d880c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 8 Jun 2020 15:15:55 -0400 Subject: [PATCH 5/6] stage2: add passing test for compile error in unreferenced cycle --- src-self-hosted/Module.zig | 6 +- src-self-hosted/test.zig | 134 ++++++++++++++++++++++++++++++------- test/stage2/zir.zig | 119 ++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+), 29 deletions(-) diff --git a/src-self-hosted/Module.zig b/src-self-hosted/Module.zig index 570ae69a63..4bcc30a65e 100644 --- a/src-self-hosted/Module.zig +++ b/src-self-hosted/Module.zig @@ -673,8 +673,8 @@ pub fn getAllErrorsAlloc(self: *Module) !AllErrors { assert(errors.items.len == self.totalErrorCount()); return AllErrors{ - .arena = arena.state, .list = try arena.allocator.dupe(AllErrors.Message, errors.items), + .arena = arena.state, }; } @@ -935,7 +935,7 @@ fn deleteDecl(self: *Module, decl: *Decl) !void { } } if (self.failed_decls.remove(decl)) |entry| { - self.allocator.destroy(entry.value); + entry.value.destroy(self.allocator); } self.deleteDeclExports(decl); self.bin_file.freeDecl(decl); @@ -1104,7 +1104,7 @@ fn markOutdatedDecl(self: *Module, decl: *Decl) !void { //std.debug.warn("mark {} outdated\n", .{decl.name}); try self.work_queue.writeItem(.{ .re_analyze_decl = decl }); if (self.failed_decls.remove(decl)) |entry| { - self.allocator.destroy(entry.value); + entry.value.destroy(self.allocator); } decl.analysis = .outdated; } diff --git a/src-self-hosted/test.zig b/src-self-hosted/test.zig index 451bba996a..605c973bb9 100644 --- a/src-self-hosted/test.zig +++ b/src-self-hosted/test.zig @@ -27,9 +27,32 @@ pub const TestContext = struct { pub const ZIRTransformCase = struct { name: []const u8, - src: [:0]const u8, - expected_zir: []const u8, cross_target: std.zig.CrossTarget, + updates: std.ArrayList(Update), + + pub const Update = struct { + expected: Expected, + src: [:0]const u8, + }; + + pub const Expected = union(enum) { + zir: []const u8, + errors: []const []const u8, + }; + + pub fn addZIR(case: *ZIRTransformCase, src: [:0]const u8, zir_text: []const u8) void { + case.updates.append(.{ + .src = src, + .expected = .{ .zir = zir_text }, + }) catch unreachable; + } + + pub fn addError(case: *ZIRTransformCase, src: [:0]const u8, errors: []const []const u8) void { + case.updates.append(.{ + .src = src, + .expected = .{ .errors = errors }, + }) catch unreachable; + } }; pub fn addZIRCompareOutput( @@ -52,14 +75,32 @@ pub const TestContext = struct { src: [:0]const u8, expected_zir: []const u8, ) void { - ctx.zir_transform_cases.append(.{ + const case = ctx.zir_transform_cases.addOne() catch unreachable; + case.* = .{ .name = name, - .src = src, - .expected_zir = expected_zir, .cross_target = cross_target, + .updates = std.ArrayList(ZIRTransformCase.Update).init(std.heap.page_allocator), + }; + case.updates.append(.{ + .src = src, + .expected = .{ .zir = expected_zir }, }) catch unreachable; } + pub fn addZIRMulti( + ctx: *TestContext, + name: []const u8, + cross_target: std.zig.CrossTarget, + ) *ZIRTransformCase { + const case = ctx.zir_transform_cases.addOne() catch unreachable; + case.* = .{ + .name = name, + .cross_target = cross_target, + .updates = std.ArrayList(ZIRTransformCase.Update).init(std.heap.page_allocator), + }; + return case; + } + fn init(self: *TestContext) !void { self.* = .{ .zir_cmp_output_cases = std.ArrayList(ZIRCompareOutputCase).init(std.heap.page_allocator), @@ -178,13 +219,11 @@ pub const TestContext = struct { var tmp = std.testing.tmpDir(.{}); defer tmp.cleanup(); - var prg_node = root_node.start(case.name, 3); - prg_node.activate(); - defer prg_node.end(); + var update_node = root_node.start(case.name, case.updates.items.len); + update_node.activate(); + defer update_node.end(); const tmp_src_path = "test-case.zir"; - try tmp.dir.writeFile(tmp_src_path, case.src); - const root_pkg = try Package.create(allocator, tmp.dir, ".", tmp_src_path); defer root_pkg.destroy(); @@ -198,25 +237,68 @@ pub const TestContext = struct { }); defer module.deinit(); - var module_node = prg_node.start("parse/analysis/codegen", null); - module_node.activate(); - try module.update(); - module_node.end(); + for (case.updates.items) |update| { + var prg_node = update_node.start("", 3); + prg_node.activate(); + defer prg_node.end(); - var emit_node = prg_node.start("emit", null); - emit_node.activate(); - var new_zir_module = try zir.emit(allocator, module); - defer new_zir_module.deinit(allocator); - emit_node.end(); + try tmp.dir.writeFile(tmp_src_path, update.src); - var write_node = prg_node.start("write", null); - write_node.activate(); - var out_zir = std.ArrayList(u8).init(allocator); - defer out_zir.deinit(); - try new_zir_module.writeToStream(allocator, out_zir.outStream()); - write_node.end(); + var module_node = prg_node.start("parse/analysis/codegen", null); + module_node.activate(); + try module.update(); + module_node.end(); - std.testing.expectEqualSlices(u8, case.expected_zir, out_zir.items); + switch (update.expected) { + .zir => |expected_zir| { + var emit_node = prg_node.start("emit", null); + emit_node.activate(); + var new_zir_module = try zir.emit(allocator, module); + defer new_zir_module.deinit(allocator); + emit_node.end(); + + var write_node = prg_node.start("write", null); + write_node.activate(); + var out_zir = std.ArrayList(u8).init(allocator); + defer out_zir.deinit(); + try new_zir_module.writeToStream(allocator, out_zir.outStream()); + write_node.end(); + + std.testing.expectEqualSlices(u8, expected_zir, out_zir.items); + }, + .errors => |expected_errors| { + var all_errors = try module.getAllErrorsAlloc(); + defer all_errors.deinit(module.allocator); + for (expected_errors) |expected_error| { + for (all_errors.list) |full_err_msg| { + const text = try std.fmt.allocPrint(allocator, ":{}:{}: error: {}", .{ + full_err_msg.line + 1, + full_err_msg.column + 1, + full_err_msg.msg, + }); + defer allocator.free(text); + if (std.mem.eql(u8, text, expected_error)) { + break; + } + } else { + std.debug.warn( + "{}\nExpected this error:\n================\n{}\n================\nBut found these errors:\n================\n", + .{ case.name, expected_error }, + ); + for (all_errors.list) |full_err_msg| { + std.debug.warn(":{}:{}: error: {}\n", .{ + full_err_msg.line + 1, + full_err_msg.column + 1, + full_err_msg.msg, + }); + } + std.debug.warn("================\nTest failed\n", .{}); + std.process.exit(1); + } + } + }, + } + } } }; diff --git a/test/stage2/zir.zig b/test/stage2/zir.zig index 7d5e330b89..bf5d4b8eae 100644 --- a/test/stage2/zir.zig +++ b/test/stage2/zir.zig @@ -90,6 +90,125 @@ pub fn addCases(ctx: *TestContext) void { \\ ); + { + var case = ctx.addZIRMulti("reference cycle with compile error in the cycle", linux_x64); + case.addZIR( + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\ + \\@9 = str("entry") + \\@10 = ref(@9) + \\@11 = export(@10, @entry) + \\ + \\@entry = fn(@fnty, { + \\ %0 = call(@a, []) + \\ %1 = return() + \\}) + \\ + \\@a = fn(@fnty, { + \\ %0 = call(@b, []) + \\ %1 = return() + \\}) + \\ + \\@b = fn(@fnty, { + \\ %0 = call(@a, []) + \\ %1 = return() + \\}) + , + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\@9 = str("entry") + \\@10 = ref(@9) + \\@unnamed$6 = str("entry") + \\@unnamed$7 = ref(@unnamed$6) + \\@unnamed$8 = export(@unnamed$7, @entry) + \\@unnamed$12 = fntype([], @void, cc=C) + \\@entry = fn(@unnamed$12, { + \\ %0 = call(@a, [], modifier=auto) + \\ %1 = return() + \\}) + \\@unnamed$17 = fntype([], @void, cc=C) + \\@a = fn(@unnamed$17, { + \\ %0 = call(@b, [], modifier=auto) + \\ %1 = return() + \\}) + \\@unnamed$22 = fntype([], @void, cc=C) + \\@b = fn(@unnamed$22, { + \\ %0 = call(@a, [], modifier=auto) + \\ %1 = return() + \\}) + \\ + ); + // Now we introduce a compile error + case.addError( + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\ + \\@9 = str("entry") + \\@10 = ref(@9) + \\@11 = export(@10, @entry) + \\ + \\@entry = fn(@fnty, { + \\ %0 = call(@a, []) + \\ %1 = return() + \\}) + \\ + \\@a = fn(@fnty, { + \\ %0 = call(@b, []) + \\ %1 = return() + \\}) + \\ + \\@b = fn(@fnty, { + \\ %9 = compileerror("message") + \\ %0 = call(@a, []) + \\ %1 = return() + \\}) + , + &[_][]const u8{ + ":19:21: error: message", + }, + ); + // Now we remove the call to `a`. `a` and `b` form a cycle, but no entry points are + // referencing either of them. This tests that the cycle is detected, and the error + // goes away. + case.addZIR( + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\ + \\@9 = str("entry") + \\@10 = ref(@9) + \\@11 = export(@10, @entry) + \\ + \\@entry = fn(@fnty, { + \\ %1 = return() + \\}) + \\ + \\@a = fn(@fnty, { + \\ %0 = call(@b, []) + \\ %1 = return() + \\}) + \\ + \\@b = fn(@fnty, { + \\ %9 = compileerror("message") + \\ %0 = call(@a, []) + \\ %1 = return() + \\}) + , + \\@void = primitive(void) + \\@fnty = fntype([], @void, cc=C) + \\@9 = str("entry") + \\@10 = ref(@9) + \\@unnamed$6 = str("entry") + \\@unnamed$7 = ref(@unnamed$6) + \\@unnamed$8 = export(@unnamed$7, @entry) + \\@unnamed$10 = fntype([], @void, cc=C) + \\@entry = fn(@unnamed$10, { + \\ %0 = return() + \\}) + \\ + ); + } + if (std.Target.current.os.tag != .linux or std.Target.current.cpu.arch != .x86_64) { From 05d284c842a5ba21cd836c2b212daa24227a9177 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 8 Jun 2020 16:33:35 -0400 Subject: [PATCH 6/6] update sort callsite to new API --- lib/std/debug.zig | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/std/debug.zig b/lib/std/debug.zig index 0e234ab419..d0cdf6f4f5 100644 --- a/lib/std/debug.zig +++ b/lib/std/debug.zig @@ -1003,7 +1003,7 @@ fn readMachODebugInfo(allocator: *mem.Allocator, macho_file: File) !ModuleDebugI // Even though lld emits symbols in ascending order, this debug code // should work for programs linked in any valid way. // This sort is so that we can binary search later. - std.sort.sort(MachoSymbol, symbols, MachoSymbol.addressLessThan); + std.sort.sort(MachoSymbol, symbols, {}, MachoSymbol.addressLessThan); return ModuleDebugInfo{ .base_address = undefined, @@ -1058,7 +1058,7 @@ const MachoSymbol = struct { return self.nlist.n_value; } - fn addressLessThan(lhs: MachoSymbol, rhs: MachoSymbol) bool { + fn addressLessThan(context: void, lhs: MachoSymbol, rhs: MachoSymbol) bool { return lhs.address() < rhs.address(); } }; @@ -1300,7 +1300,7 @@ pub const DebugInfo = struct { fs.cwd().openFile(ctx.name, .{ .intended_io_mode = .blocking }) else fs.openSelfExe(.{ .intended_io_mode = .blocking }); - + const elf_file = copy catch |err| switch (err) { error.FileNotFound => return error.MissingDebugInfo, else => return err,