std: fix drain bugs in Writer.Allocating and net.Stream

This commit is contained in:
Andrew Kelley 2025-06-26 16:58:58 -07:00
parent 9f8486170c
commit 5b5243b5b7
2 changed files with 48 additions and 36 deletions

View File

@ -2172,6 +2172,7 @@ pub const Allocating = struct {
}
fn drain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
if (data.len == 0) return 0; // flush
const a: *Allocating = @fieldParentPtr("interface", w);
const gpa = a.allocator;
const pattern = data[data.len - 1];
@ -2179,14 +2180,16 @@ pub const Allocating = struct {
var list = a.toArrayList();
defer setArrayList(a, list);
const start_len = list.items.len;
for (data[0 .. data.len - 1]) |bytes| {
for (data) |bytes| {
list.ensureUnusedCapacity(gpa, bytes.len + splat_len) catch return error.WriteFailed;
list.appendSliceAssumeCapacity(bytes);
}
switch (pattern.len) {
if (splat == 0) {
list.items.len -= pattern.len;
} else switch (pattern.len) {
0 => {},
1 => list.appendNTimesAssumeCapacity(pattern[0], splat),
else => for (0..splat) |_| list.appendSliceAssumeCapacity(pattern),
1 => list.appendNTimesAssumeCapacity(pattern[0], splat - 1),
else => for (0..splat - 1) |_| list.appendSliceAssumeCapacity(pattern),
}
return list.items.len - start_len;
}

View File

@ -2104,7 +2104,6 @@ pub const Stream = struct {
fn drain(io_w: *io.Writer, data: []const []const u8, splat: usize) io.Writer.Error!usize {
const w: *Writer = @fieldParentPtr("interface", io_w);
const buffered = io_w.buffered();
var splat_buffer: [splat_buffer_len]u8 = undefined;
var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
var msg: posix.msghdr_const = msg: {
var i: usize = 0;
@ -2115,7 +2114,7 @@ pub const Stream = struct {
};
i += 1;
}
for (data[0..data.len]) |bytes| {
for (data) |bytes| {
// OS checks ptr addr before length so zero length vectors must be omitted.
if (bytes.len == 0) continue;
iovecs[i] = .{
@ -2135,38 +2134,48 @@ pub const Stream = struct {
.flags = 0,
};
};
const pattern = data[data.len - 1];
switch (splat) {
0 => msg.iovlen -= 1,
1 => {},
else => switch (pattern.len) {
0 => {},
1 => {
// Replace the 1-byte buffer with a bigger one.
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[msg.iovlen - 1] = .{ .base = buf.ptr, .len = buf.len };
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and msg.iovlen < iovecs.len) {
iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = splat_buffer.len };
remaining_splat -= splat_buffer.len;
msg.iovlen += 1;
}
if (remaining_splat > 0 and msg.iovlen < iovecs.len) {
iovecs[msg.iovlen] = .{ .base = &splat_buffer, .len = remaining_splat };
msg.iovlen += 1;
}
if (data.len != 0) {
const pattern = data[data.len - 1];
switch (splat) {
0 => if (msg.iovlen != 0 and iovecs[msg.iovlen - 1].base == data[data.len - 1].ptr) {
msg.iovlen -= 1;
},
else => for (0..splat - 1) |_| {
if (iovecs.len - msg.iovlen == 0) break;
iovecs[msg.iovlen] = .{
.base = pattern.ptr,
.len = pattern.len,
};
msg.iovlen += 1;
1 => {},
else => switch (pattern.len) {
0 => {},
1 => memset: {
// Replace the 1-byte buffer with a bigger one.
if (msg.iovlen != 0 and iovecs[msg.iovlen - 1].base == data[data.len - 1].ptr)
msg.iovlen -= 1;
if (iovecs.len - msg.iovlen == 0) break :memset;
const splat_buffer = io_w.buffer[io_w.end..];
const memset_len = @min(splat_buffer.len, splat);
const buf = splat_buffer[0..memset_len];
@memset(buf, pattern[0]);
iovecs[msg.iovlen] = .{ .base = buf.ptr, .len = buf.len };
msg.iovlen += 1;
var remaining_splat = splat - buf.len;
while (remaining_splat > splat_buffer.len and iovecs.len - msg.iovlen != 0) {
assert(buf.len == splat_buffer.len);
iovecs[msg.iovlen] = .{ .base = splat_buffer.ptr, .len = splat_buffer.len };
msg.iovlen += 1;
remaining_splat -= splat_buffer.len;
}
if (remaining_splat > 0 and iovecs.len - msg.iovlen != 0) {
iovecs[msg.iovlen] = .{ .base = splat_buffer.ptr, .len = remaining_splat };
msg.iovlen += 1;
}
},
else => for (0..splat - 1) |_| {
if (iovecs.len - msg.iovlen == 0) break;
iovecs[msg.iovlen] = .{
.base = pattern.ptr,
.len = pattern.len,
};
msg.iovlen += 1;
},
},
},
}
}
const flags = posix.MSG.NOSIGNAL;
return io_w.consume(std.posix.sendmsg(w.file_writer.file.handle, &msg, flags) catch |err| {