diff --git a/lib/std/Io/Reader/Limited.zig b/lib/std/Io/Reader/Limited.zig index e0d04bb531..8373161387 100644 --- a/lib/std/Io/Reader/Limited.zig +++ b/lib/std/Io/Reader/Limited.zig @@ -28,9 +28,11 @@ pub fn init(reader: *Reader, limit: Limit, buffer: []u8) Limited { fn stream(r: *Reader, w: *Writer, limit: Limit) Reader.StreamError!usize { const l: *Limited = @fieldParentPtr("interface", r); const combined_limit = limit.min(l.remaining); - const n = try l.unlimited.stream(w, combined_limit); - l.remaining = l.remaining.subtract(n).?; - return n; + if (combined_limit.nonzero()) { + const n = try l.unlimited.stream(w, combined_limit); + l.remaining = l.remaining.subtract(n).?; + return n; + } else return error.EndOfStream; } test stream { @@ -49,6 +51,64 @@ test stream { try std.testing.expectEqualStrings("test", result_buf[0..streamed]); } +test "readSliceAll from infinite source" { + const InfSource = struct { + reader: std.Io.Reader, + + pub fn init(buffer: []u8) @This() { + return @This(){ + .reader = .{ + .vtable = &.{ + .stream = streamA, + }, + .buffer = buffer, + .seek = 0, + .end = 0, + }, + }; + } + + fn streamA(io_reader: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + _ = io_reader; + + std.debug.assert(limit.nonzero()); + + const n_bytes_remaining = limit.minInt(2); + for (0..n_bytes_remaining) |_| { + try w.writeByte('A'); + } + return n_bytes_remaining; + } + }; + + // Exact size + { + var inf_buf: [10]u8 = undefined; + var inf_stream = InfSource.init(&inf_buf); + + var limit_buf: [2]u8 = undefined; + var limited: std.Io.Reader.Limited = .init(&inf_stream.reader, .limited(2), &limit_buf); + const limited_reader = &limited.interface; + + var out_buffer: [2]u8 = undefined; + try std.testing.expectEqual({}, limited_reader.readSliceAll(&out_buffer)); + try std.testing.expectEqualStrings("AA", &out_buffer); + } + + // Too large + { + var inf_buf: [10]u8 = undefined; + var inf_stream = InfSource.init(&inf_buf); + + var limit_buf: [2]u8 = undefined; + var limited: std.Io.Reader.Limited = .init(&inf_stream.reader, .limited(2), &limit_buf); + const limited_reader = &limited.interface; + + var out_buffer: [8]u8 = undefined; + try std.testing.expectError(error.EndOfStream, limited_reader.readSliceAll(&out_buffer)); + } +} + fn discard(r: *Reader, limit: Limit) Reader.Error!usize { const l: *Limited = @fieldParentPtr("interface", r); const combined_limit = limit.min(l.remaining);