From 5b5243b5b7f5d01267afa1241c5847b3507694c1 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 26 Jun 2025 16:58:58 -0700 Subject: [PATCH] std: fix drain bugs in Writer.Allocating and net.Stream --- lib/std/io/Writer.zig | 11 ++++--- lib/std/net.zig | 73 ++++++++++++++++++++++++------------------- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 80cef5eb85..7a7fe2de7d 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -2172,6 +2172,7 @@ pub const Allocating = struct { } fn drain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + if (data.len == 0) return 0; // flush const a: *Allocating = @fieldParentPtr("interface", w); const gpa = a.allocator; const pattern = data[data.len - 1]; @@ -2179,14 +2180,16 @@ pub const Allocating = struct { var list = a.toArrayList(); defer setArrayList(a, list); const start_len = list.items.len; - for (data[0 .. data.len - 1]) |bytes| { + for (data) |bytes| { list.ensureUnusedCapacity(gpa, bytes.len + splat_len) catch return error.WriteFailed; list.appendSliceAssumeCapacity(bytes); } - switch (pattern.len) { + if (splat == 0) { + list.items.len -= pattern.len; + } else switch (pattern.len) { 0 => {}, - 1 => list.appendNTimesAssumeCapacity(pattern[0], splat), - else => for (0..splat) |_| list.appendSliceAssumeCapacity(pattern), + 1 => list.appendNTimesAssumeCapacity(pattern[0], splat - 1), + else => for (0..splat - 1) |_| list.appendSliceAssumeCapacity(pattern), } return list.items.len - start_len; } diff --git a/lib/std/net.zig b/lib/std/net.zig index b9f69bcaff..ee2f8f1d4a 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -2104,7 +2104,6 @@ pub const Stream = struct { fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize { const w: *Writer = @fieldParentPtr("interface", io_w); const buffered = io_w.buffered(); - var splat_buffer: [splat_buffer_len]u8 = undefined; var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; var msg: posix.msghdr_const = msg: { var i: usize = 0; @@ -2115,7 +2114,7 @@ pub const Stream = struct { }; i += 1; } - for (data[0..data.len]) |bytes| { + for (data) |bytes| { // OS checks ptr addr before length so zero length vectors must be omitted. if (bytes.len == 0) continue; iovecs[i] = .{ @@ -2135,38 +2134,48 @@ pub const Stream = struct { .flags = 0, }; }; - const pattern = data[data.len - 1]; - switch (splat) { - 0 => msg.iovlen -= 1, - 1 => {}, - else => switch (pattern.len) { - 0 => {}, - 1 => { - // Replace the 1-byte buffer with a bigger one. - const memset_len = @min(splat_buffer.len, splat); - const buf = splat_buffer[0..memset_len]; - @memset(buf, pattern[0]); - iovecs[msg.iovlen - 1] = .{ .base = buf.ptr, .len = buf.len }; - var remaining_splat = splat - buf.len; - while (remaining_splat > splat_buffer.len and msg.iovlen < iovecs.len) { - iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = splat_buffer.len }; - remaining_splat -= splat_buffer.len; - msg.iovlen += 1; - } - if (remaining_splat > 0 and msg.iovlen < iovecs.len) { - iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = remaining_splat }; - msg.iovlen += 1; - } + if (data.len != 0) { + const pattern = data[data.len - 1]; + switch (splat) { + 0 => if (msg.iovlen != 0 and iovecs[msg.iovlen - 1].base == data[data.len - 1].ptr) { + msg.iovlen -= 1; }, - else => for (0..splat - 1) |_| { - if (iovecs.len - msg.iovlen == 0) break; - iovecs[msg.iovlen] = .{ - .base = pattern.ptr, - .len = pattern.len, - }; - msg.iovlen += 1; + 1 => {}, + else => switch (pattern.len) { + 0 => {}, + 1 => memset: { + // Replace the 1-byte buffer with a bigger one. + if (msg.iovlen != 0 and iovecs[msg.iovlen - 1].base == data[data.len - 1].ptr) + msg.iovlen -= 1; + if (iovecs.len - msg.iovlen == 0) break :memset; + const splat_buffer = io_w.buffer[io_w.end..]; + const memset_len = @min(splat_buffer.len, splat); + const buf = splat_buffer[0..memset_len]; + @memset(buf, pattern[0]); + iovecs[msg.iovlen] = .{ .base = buf.ptr, .len = buf.len }; + msg.iovlen += 1; + var remaining_splat = splat - buf.len; + while (remaining_splat > splat_buffer.len and iovecs.len - msg.iovlen != 0) { + assert(buf.len == splat_buffer.len); + iovecs[msg.iovlen] = .{ .base = splat_buffer.ptr, .len = splat_buffer.len }; + msg.iovlen += 1; + remaining_splat -= splat_buffer.len; + } + if (remaining_splat > 0 and iovecs.len - msg.iovlen != 0) { + iovecs[msg.iovlen] = .{ .base = splat_buffer.ptr, .len = remaining_splat }; + msg.iovlen += 1; + } + }, + else => for (0..splat - 1) |_| { + if (iovecs.len - msg.iovlen == 0) break; + iovecs[msg.iovlen] = .{ + .base = pattern.ptr, + .len = pattern.len, + }; + msg.iovlen += 1; + }, }, - }, + } } const flags = posix.MSG.NOSIGNAL; return io_w.consume(std.posix.sendmsg(w.file_writer.file.handle, &msg, flags) catch |err| {