From 19b82ca7ab7c44c3a1f4d0c6cbaf43bc0bded43a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 21 Apr 2025 14:09:47 -0700 Subject: [PATCH] std.http.Server: implement chunked request parsing --- lib/compiler/std-docs.zig | 2 +- lib/std/http/Server.zig | 440 ++++++++++++++++++++-------------- lib/std/io/BufferedReader.zig | 92 +++++-- lib/std/io/BufferedWriter.zig | 4 +- lib/std/io/Reader.zig | 8 +- 5 files changed, 339 insertions(+), 207 deletions(-) diff --git a/lib/compiler/std-docs.zig b/lib/compiler/std-docs.zig index 172da8e5e1..b9d96b9823 100644 --- a/lib/compiler/std-docs.zig +++ b/lib/compiler/std-docs.zig @@ -428,7 +428,7 @@ fn receiveWasmMessage( }, else => { // Ignore other messages. - try br.discard(header.bytes_len); + try br.discardAll(header.bytes_len); }, } } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index c0a3d26485..ae1199fd27 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -22,9 +22,6 @@ out: *std.io.BufferedWriter, state: State, head_parse_err: Request.Head.ParseError, -/// being deleted... -next_request_start: usize = 0, - pub const State = enum { /// The connection is available to be used for the first time, or reused. ready, @@ -95,6 +92,8 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { if (hp.state == .finished) return .{ .server = s, .head_end = head_end, + .trailers_len = 0, + .read_err = null, .head = Request.Head.parse(buf[0..head_end]) catch |err| { s.head_parse_err = err; return error.HttpHeadersInvalid; @@ -108,11 +107,38 @@ pub const Request = struct { server: *Server, /// Index into `Server.in` internal buffer. head_end: usize, + /// Number of bytes of HTTP trailers. These are at the end of a + /// transfer-encoding: chunked message. + trailers_len: usize, head: Head, reader_state: union { remaining_content_length: u64, - chunk_parser: http.ChunkParser, + remaining_chunk_len: RemainingChunkLen, }, + read_err: ?ReadError, + + pub const ReadError = error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; + + pub const max_chunk_header_len = 22; + + pub const RemainingChunkLen = enum(u64) { + head = 0, + n = 1, + rn = 2, + done = std.math.maxInt(u64), + _, + + pub fn init(integer: u64) RemainingChunkLen { + return @enumFromInt(integer); + } + + pub fn int(rcl: RemainingChunkLen) u64 { + return @intFromEnum(rcl); + } + }; pub const Compression = union(enum) { deflate: std.compress.zlib.Decompressor, @@ -559,177 +585,240 @@ pub const Request = struct { }; } - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, - }; - - fn contentLengthReader_read( + fn contentLengthRead( ctx: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit, - ) std.io.Reader.Error!usize { + ) std.io.Reader.RwError!usize { const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = bw; - _ = limit; - @panic("TODO"); - } - - fn contentLengthReader_readVec(ctx: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { - const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = data; - @panic("TODO"); - } - - fn contentLengthReader_discard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { - const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = limit; - @panic("TODO"); - } - - fn chunkedReader_read( - ctx: ?*anyopaque, - bw: *std.io.BufferedWriter, - limit: std.io.Reader.Limit, - ) std.io.Reader.Error!usize { - const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = bw; - _ = limit; - @panic("TODO"); - } - - fn chunkedReader_readVec(ctx: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { - const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = data; - @panic("TODO"); - } - - fn chunkedReader_discard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { - const request: *Request = @alignCast(@ptrCast(ctx)); - _ = request; - _ = limit; - @panic("TODO"); - } - - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @alignCast(@ptrCast(context)); - const s = request.server; - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; + const remaining = remaining_content_length.*; + const server = request.server; + if (remaining == 0) { + server.state = .ready; + return error.EndOfStream; } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; + const n = try server.in.read(bw, limit.min(.limited(remaining))); + const new_remaining = remaining - n; + remaining_content_length.* = new_remaining; + return n; } - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + fn contentLengthReadVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { const request: *Request = @alignCast(@ptrCast(context)); - const s = request.server; + const remaining_content_length = &request.reader_state.remaining_content_length; + const server = request.server; + const remaining = remaining_content_length.*; + if (remaining == 0) { + server.state = .ready; + return error.EndOfStream; + } + const n = try server.in.readVecLimit(data, .limited(remaining)); + const new_remaining = remaining - n; + remaining_content_length.* = new_remaining; + return n; + } - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; + fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { + const request: *Request = @alignCast(@ptrCast(ctx)); + const remaining_content_length = &request.reader_state.remaining_content_length; + const server = request.server; + const remaining = remaining_content_length.*; + if (remaining == 0) { + server.state = .ready; + return error.EndOfStream; + } + const n = try server.in.discard(limit.min(.limited(remaining))); + const new_remaining = remaining - n; + remaining_content_length.* = new_remaining; + return n; + } - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, + fn chunkedRead( + ctx: ?*anyopaque, + bw: *std.io.BufferedWriter, + limit: std.io.Reader.Limit, + ) std.io.Reader.RwError!usize { + const request: *Request = @alignCast(@ptrCast(ctx)); + const chunk_len_ptr = &request.reader_state.remaining_chunk_len; + const in = request.server.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: http.ChunkParser = .init; + const i = cp.feed(in.bufferContents()); + switch (cp.state) { + .invalid => return request.failRead(error.HttpChunkInvalid), + .data => { + if (i > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(i); + }, + else => { + try in.fill(max_chunk_header_len); + const next_i = cp.feed(in.bufferContents()[i..]); + if (cp.state != .data) return request.failRead(error.HttpChunkInvalid); + const header_len = i + next_i; + if (header_len > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(header_len); + }, + } + if (cp.chunk_len == 0) return parseTrailers(request, 0); + const n = try in.read(bw, limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return request.failRead(error.HttpChunkInvalid); + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return request.failRead(error.HttpChunkInvalid); + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.read(bw, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); + chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); + return n; + }, + .done => return error.EndOfStream, + } + } + + fn chunkedReadVec(ctx: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize { + const request: *Request = @alignCast(@ptrCast(ctx)); + const chunk_len_ptr = &request.reader_state.remaining_chunk_len; + const in = request.server.in; + var already_requested_more = false; + var amt_read: usize = 0; + data: for (data) |d| { + len: switch (chunk_len_ptr.*) { + .head => { + var cp: http.ChunkParser = .init; + const available_buffer = in.bufferContents(); + const i = cp.feed(available_buffer); + if (cp.state == .invalid) return request.failRead(error.HttpChunkInvalid); + if (i == available_buffer.len) { + if (already_requested_more) { + chunk_len_ptr.* = .head; + return amt_read; + } + already_requested_more = true; + try in.fill(max_chunk_header_len); + const next_i = cp.feed(in.bufferContents()[i..]); + if (cp.state != .data) return request.failRead(error.HttpChunkInvalid); + const header_len = i + next_i; + if (header_len > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(header_len); + } else { + if (i > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(i); } + if (cp.chunk_len == 0) return parseTrailers(request, amt_read); + continue :len .init(cp.chunk_len + 2); }, + .n => { + if (in.bufferContents().len < 1) already_requested_more = true; + if ((try in.takeByte()) != '\n') return request.failRead(error.HttpChunkInvalid); + continue :len .head; + }, + .rn => { + if (in.bufferContents().len < 2) already_requested_more = true; + const rn = try in.takeArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return request.failRead(error.HttpChunkInvalid); + continue :len .head; + }, + else => |remaining_chunk_len| { + const available_buffer = in.bufferContents(); + const copy_len = @min(available_buffer.len, d.len, remaining_chunk_len.int() - 2); + @memcpy(d[0..copy_len], available_buffer[0..copy_len]); + amt_read += copy_len; + in.toss(copy_len); + const next_chunk_len: RemainingChunkLen = .init(remaining_chunk_len.int() - copy_len); + if (copy_len == d.len) { + chunk_len_ptr.* = next_chunk_len; + continue :data; + } + if (already_requested_more) { + chunk_len_ptr.* = next_chunk_len; + return amt_read; + } + already_requested_more = true; + try in.fill(3); + continue :len next_chunk_len; + }, + .done => return error.EndOfStream, + } + } + return amt_read; + } + + fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize { + const request: *Request = @alignCast(@ptrCast(ctx)); + const chunk_len_ptr = &request.reader_state.remaining_chunk_len; + const in = request.server.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: http.ChunkParser = .init; + const i = cp.feed(in.bufferContents()); + switch (cp.state) { + .invalid => return request.failRead(error.HttpChunkInvalid), + .data => { + if (i > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(i); + }, + else => { + try in.fill(max_chunk_header_len); + const next_i = cp.feed(in.bufferContents()[i..]); + if (cp.state != .data) return request.failRead(error.HttpChunkInvalid); + const header_len = i + next_i; + if (header_len > max_chunk_header_len) return request.failRead(error.HttpChunkInvalid); + in.toss(header_len); + }, + } + if (cp.chunk_len == 0) return parseTrailers(request, 0); + const n = try in.discard(limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return request.failRead(error.HttpChunkInvalid); + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return request.failRead(error.HttpChunkInvalid); + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.discard(limit.min(.limited(remaining_chunk_len.int() - 2))); + chunk_len_ptr.* = .init(remaining_chunk_len.int() - n); + return n; + }, + .done => return error.EndOfStream, + } + } + + /// Called when next bytes in the stream are trailers, or "\r\n" to indicate + /// end of chunked body. + fn parseTrailers(request: *Request, amt_read: usize) std.io.Reader.Error!usize { + const in = request.server.in; + var hp: http.HeadParser = .{}; + var trailers_len: usize = 0; + while (true) { + if (trailers_len >= in.buffer.len) return request.failRead(error.HttpHeadersOversize); + try in.fill(trailers_len + 1); + trailers_len += hp.feed(in.bufferContents()[trailers_len..]); + if (hp.state == .finished) { + request.reader_state.remaining_chunk_len = .done; + request.server.state = .ready; + request.trailers_len = trailers_len; + return amt_read; } } - return out_end; } pub const ReaderError = error{ @@ -752,7 +841,6 @@ pub const Request = struct { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; - s.next_request_start = request.head_end; if (request.head.expect) |expect| { if (mem.eql(u8, expect, "100-continue")) { @@ -765,13 +853,13 @@ pub const Request = struct { switch (request.head.transfer_encoding) { .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + request.reader_state = .{ .remaining_chunk_len = .head }; return .{ .context = request, .vtable = &.{ - .read = &chunkedReader_read, - .readVec = &chunkedReader_readVec, - .discard = &chunkedReader_discard, + .read = &chunkedRead, + .readVec = &chunkedReadVec, + .discard = &chunkedDiscard, }, }; }, @@ -782,9 +870,9 @@ pub const Request = struct { return .{ .context = request, .vtable = &.{ - .read = &contentLengthReader_read, - .readVec = &contentLengthReader_readVec, - .discard = &contentLengthReader_discard, + .read = &contentLengthRead, + .readVec = &contentLengthReadVec, + .discard = &contentLengthDiscard, }, }; }, @@ -822,6 +910,11 @@ pub const Request = struct { } return false; } + + fn failRead(r: *Request, err: ReadError) error{ReadFailed} { + r.read_err = err; + return error.ReadFailed; + } }; pub const Response = struct { @@ -1165,14 +1258,3 @@ pub const Response = struct { }; } }; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} diff --git a/lib/std/io/BufferedReader.zig b/lib/std/io/BufferedReader.zig index 2a76af547d..de911d7f03 100644 --- a/lib/std/io/BufferedReader.zig +++ b/lib/std/io/BufferedReader.zig @@ -43,14 +43,26 @@ pub fn reader(br: *BufferedReader) Reader { .vtable = &.{ .read = passthruRead, .readVec = passthruReadVec, + .discard = passthruDiscard, }, }; } +/// Equivalent semantics to `std.io.Reader.VTable.readVec`. pub fn readVec(br: *BufferedReader, data: []const []u8) Reader.Error!usize { return passthruReadVec(br, data); } +/// Equivalent semantics to `std.io.Reader.VTable.read`. +pub fn read(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { + return passthruRead(br, bw, limit); +} + +/// Equivalent semantics to `std.io.Reader.VTable.discard`. +pub fn discard(br: *BufferedReader, limit: Reader.Limit) Reader.Error!usize { + return passthruDiscard(br, limit); +} + pub fn readVecAll(br: *BufferedReader, data: [][]u8) Reader.Error!void { var index: usize = 0; var truncate: usize = 0; @@ -68,10 +80,6 @@ pub fn readVecAll(br: *BufferedReader, data: [][]u8) Reader.Error!void { } } -pub fn read(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { - return passthruRead(br, bw, limit); -} - /// "Pump" data from the reader to the writer. pub fn readAll(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!void { var remaining = limit; @@ -81,21 +89,46 @@ pub fn readAll(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Re } } -fn passthruRead(ctx: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { - const br: *BufferedReader = @alignCast(@ptrCast(ctx)); - const buffer = br.buffer[0..br.end]; - const buffered = buffer[br.seek..]; - const limited = buffered[0..limit.minInt(buffered.len)]; - if (limited.len > 0) { - const n = try bw.write(limited); +/// Equivalent to `readVec` but reads at most `limit` bytes. +pub fn readVecLimit(br: *BufferedReader, data: []const []u8, limit: Reader.Limit) Reader.Error!usize { + _ = br; + _ = data; + _ = limit; + @panic("TODO"); +} + +fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize { + const br: *BufferedReader = @alignCast(@ptrCast(context)); + const buffer = limit.slice(br.buffer[br.end..br.seek]); + if (buffer.len > 0) { + const n = try bw.write(buffer); br.seek += n; return n; } return br.unbuffered_reader.read(bw, limit); } -fn passthruReadVec(ctx: ?*anyopaque, data: []const []u8) Reader.Error!usize { - const br: *BufferedReader = @alignCast(@ptrCast(ctx)); +fn passthruDiscard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize { + const br: *BufferedReader = @alignCast(@ptrCast(context)); + const buffered_len = br.end - br.seek; + if (limit.toInt()) |n| { + if (buffered_len >= n) { + br.seek += n; + return n; + } + br.seek = 0; + br.end = 0; + const additional = try br.unbuffered_reader.discard(.limited(n - buffered_len)); + return n + additional; + } + const n = try br.unbuffered_reader.discard(.unlimited); + br.seek = 0; + br.end = 0; + return buffered_len + n; +} + +fn passthruReadVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize { + const br: *BufferedReader = @alignCast(@ptrCast(context)); var total: usize = 0; for (data, 0..) |buf, i| { const buffered = br.buffer[br.seek..br.end]; @@ -171,14 +204,14 @@ pub fn peek(br: *BufferedReader, n: usize) Reader.Error![]u8 { } /// Returns all the next buffered bytes from `unbuffered_reader`, after filling -/// the buffer to ensure it contains at least `min_len` bytes. +/// the buffer to ensure it contains at least `n` bytes. /// /// Invalidates previously returned values from `peek` and `peekGreedy`. /// /// Asserts that the `BufferedReader` was initialized with a buffer capacity at -/// least as big as `min_len`. +/// least as big as `n`. /// -/// If there are fewer than `min_len` bytes left in the stream, `error.EndOfStream` +/// If there are fewer than `n` bytes left in the stream, `error.EndOfStream` /// is returned instead. /// /// See also: @@ -253,7 +286,8 @@ pub fn peekArray(br: *BufferedReader, comptime n: usize) Reader.Error!*[n]u8 { /// * `toss` /// * `discardRemaining` /// * `discardShort` -pub fn discard(br: *BufferedReader, n: usize) Reader.Error!void { +/// * `discard` +pub fn discardAll(br: *BufferedReader, n: usize) Reader.Error!void { if ((try br.discardShort(n)) != n) return error.EndOfStream; } @@ -265,9 +299,9 @@ pub fn discard(br: *BufferedReader, n: usize) Reader.Error!void { /// if the stream reached the end. /// /// See also: -/// * `discard` -/// * `toss` +/// * `discardAll` /// * `discardRemaining` +/// * `discard` pub fn discardShort(br: *BufferedReader, n: usize) Reader.ShortError!usize { const proposed_seek = br.seek + n; if (proposed_seek <= br.end) { @@ -609,18 +643,30 @@ pub fn fill(br: *BufferedReader, n: usize) Reader.Error!void { } } -/// Reads 1 byte from the stream or returns `error.EndOfStream`. -pub fn takeByte(br: *BufferedReader) Reader.Error!u8 { +/// Returns the next byte from the stream or returns `error.EndOfStream`. +/// +/// Does not advance the seek position. +/// +/// Asserts the buffer capacity is nonzero. +pub fn peekByte(br: *BufferedReader) Reader.Error!u8 { const buffer = br.buffer[0..br.end]; const seek = br.seek; if (seek >= buffer.len) { @branchHint(.unlikely); try fill(br, 1); } - br.seek = seek + 1; return buffer[seek]; } +/// Reads 1 byte from the stream or returns `error.EndOfStream`. +/// +/// Asserts the buffer capacity is nonzero. +pub fn takeByte(br: *BufferedReader) Reader.Error!u8 { + const result = try peekByte(br); + br.seek += 1; + return result; +} + /// Same as `takeByte` except the returned byte is signed. pub fn takeByteSigned(br: *BufferedReader) Reader.Error!i8 { return @bitCast(try br.takeByte()); @@ -813,7 +859,7 @@ test peekArray { return error.Unimplemented; } -test discard { +test discardAll { var br: BufferedReader = undefined; br.initFixed("foobar"); try br.discard(3); diff --git a/lib/std/io/BufferedWriter.zig b/lib/std/io/BufferedWriter.zig index d9b97bab24..bea74b743d 100644 --- a/lib/std/io/BufferedWriter.zig +++ b/lib/std/io/BufferedWriter.zig @@ -558,8 +558,8 @@ fn passthruWriteFile( const remaining_buffers = buffers[1..]; const send_trailers_len: usize = @min(trailers.len, remaining_buffers.len); @memcpy(remaining_buffers[0..send_trailers_len], trailers[0..send_trailers_len]); - const send_headers_len = 1; - const send_buffers = buffers[0 .. send_headers_len + send_trailers_len]; + const send_headers_len = @intFromBool(end != 0); + const send_buffers = buffers[1 - send_headers_len .. 1 + send_trailers_len]; const n = try bw.unbuffered_writer.writeFile(file, offset, limit, send_buffers, send_headers_len); if (n < end) { @branchHint(.unlikely); diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index b0e65980fc..1d9ecb561b 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -126,7 +126,9 @@ pub const Limit = enum(usize) { }; pub fn read(r: Reader, bw: *BufferedWriter, limit: Limit) RwError!usize { - return r.vtable.read(r.context, bw, limit); + const n = try r.vtable.read(r.context, bw, limit); + assert(n <= @intFromEnum(limit)); + return n; } pub fn readVec(r: Reader, data: []const []u8) Error!usize { @@ -134,7 +136,9 @@ pub fn readVec(r: Reader, data: []const []u8) Error!usize { } pub fn discard(r: Reader, limit: Limit) Error!usize { - return r.vtable.discard(r.context, limit); + const n = try r.vtable.discard(r.context, limit); + assert(n <= @intFromEnum(limit)); + return n; } /// Returns total number of bytes written to `bw`.