From 828d23956d44b71f8b2394c6d7ab08c23d22fcc3 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Thu, 14 Dec 2023 16:03:44 +0200 Subject: [PATCH] std.heap: add runtime safety for calling `stackFallback(N).get` multiple times Closes #16344 --- deps/aro/aro/Compilation.zig | 16 ++++++++++------ lib/std/heap.zig | 35 ++++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/deps/aro/aro/Compilation.zig b/deps/aro/aro/Compilation.zig index 56ba1bb680..fa4b89f410 100644 --- a/deps/aro/aro/Compilation.zig +++ b/deps/aro/aro/Compilation.zig @@ -1350,9 +1350,10 @@ pub fn hasInclude( } var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa); + const sf_allocator = stack_fallback.get(); - while (try it.nextWithFile(filename, stack_fallback.get())) |found| { - defer stack_fallback.get().free(found.path); + while (try it.nextWithFile(filename, sf_allocator)) |found| { + defer sf_allocator.free(found.path); if (!std.meta.isError(cwd.access(found.path, .{}))) return true; } return false; @@ -1411,9 +1412,10 @@ pub fn findEmbed( }; var it = IncludeDirIterator{ .comp = comp, .cwd_source_id = cwd_source_id }; var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa); + const sf_allocator = stack_fallback.get(); - while (try it.nextWithFile(filename, stack_fallback.get())) |found| { - defer stack_fallback.get().free(found.path); + while (try it.nextWithFile(filename, sf_allocator)) |found| { + defer sf_allocator.free(found.path); if (comp.getFileContents(found.path, limit)) |some| return some else |err| switch (err) { @@ -1457,8 +1459,10 @@ pub fn findInclude( } var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa); - while (try it.nextWithFile(filename, stack_fallback.get())) |found| { - defer stack_fallback.get().free(found.path); + const sf_allocator = stack_fallback.get(); + + while (try it.nextWithFile(filename, sf_allocator)) |found| { + defer sf_allocator.free(found.path); if (comp.addSourceFromPathExtra(found.path, found.kind)) |some| { if (it.tried_ms_cwd) { try comp.addDiagnostic(.{ diff --git a/lib/std/heap.zig b/lib/std/heap.zig index ec8679f6b2..3dc4f83396 100644 --- a/lib/std/heap.zig +++ b/lib/std/heap.zig @@ -521,10 +521,16 @@ pub fn StackFallbackAllocator(comptime size: usize) type { buffer: [size]u8, fallback_allocator: Allocator, fixed_buffer_allocator: FixedBufferAllocator, + get_called: if (std.debug.runtime_safety) bool else void = + if (std.debug.runtime_safety) false else {}, /// This function both fetches a `Allocator` interface to this /// allocator *and* resets the internal buffer allocator. pub fn get(self: *Self) Allocator { + if (std.debug.runtime_safety) { + assert(!self.get_called); // `get` called multiple times; instead use `const allocator = stackFallback(N).get();` + self.get_called = true; + } self.fixed_buffer_allocator = FixedBufferAllocator.init(self.buffer[0..]); return .{ .ptr = self, @@ -536,6 +542,12 @@ pub fn StackFallbackAllocator(comptime size: usize) type { }; } + /// Unlike most std allocators `StackFallbackAllocator` modifies + /// its internal state before returning an implementation of + /// the`Allocator` interface and therefore also doesn't use + /// the usual `.allocator()` method. + pub const allocator = @compileError("use 'const allocator = stackFallback(N).get();' instead"); + fn alloc( ctx: *anyopaque, len: usize, @@ -675,13 +687,22 @@ test "FixedBufferAllocator.reset" { } test "StackFallbackAllocator" { - const fallback_allocator = page_allocator; - var stack_allocator = stackFallback(4096, fallback_allocator); - - try testAllocator(stack_allocator.get()); - try testAllocatorAligned(stack_allocator.get()); - try testAllocatorLargeAlignment(stack_allocator.get()); - try testAllocatorAlignedShrink(stack_allocator.get()); + { + var stack_allocator = stackFallback(4096, std.testing.allocator); + try testAllocator(stack_allocator.get()); + } + { + var stack_allocator = stackFallback(4096, std.testing.allocator); + try testAllocatorAligned(stack_allocator.get()); + } + { + var stack_allocator = stackFallback(4096, std.testing.allocator); + try testAllocatorLargeAlignment(stack_allocator.get()); + } + { + var stack_allocator = stackFallback(4096, std.testing.allocator); + try testAllocatorAlignedShrink(stack_allocator.get()); + } } test "FixedBufferAllocator Reuse memory on realloc" {