std.io: start removing context from Reader/Writer

rely on the field parent pointer pattern
This commit is contained in:
Andrew Kelley 2025-06-25 13:08:18 -07:00
parent 7d6b5ed510
commit d4545f216a
7 changed files with 220 additions and 188 deletions

View File

@ -21,11 +21,15 @@ const array = tls.array;
/// here via `reader`.
///
/// The buffer is asserted to have capacity at least `min_buffer_len`.
input: *std.io.Reader,
input: *Reader,
/// Decrypted stream from the server to the client.
reader: Reader,
/// The encrypted stream from the client to the server. Bytes are pushed here
/// via `writer`.
output: *Writer,
/// The plaintext stream from the client to the server.
writer: Writer,
/// Populated when `error.TlsAlert` is returned.
alert: ?tls.Alert = null,
@ -36,14 +40,6 @@ write_seq: u64,
/// When this is true, the stream may still not be at the end because there
/// may be data in the input buffer.
received_close_notify: bool,
/// By default, reaching the end-of-stream when reading from the server will
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
/// message has been received. By setting this flag to `true`, instead, the
/// end-of-stream will be forwarded to the application layer above TLS.
///
/// This makes the application vulnerable to truncation attacks unless the
/// application layer itself verifies that the amount of data received equals
/// the amount of data expected, such as HTTP with the Content-Length header.
allow_truncation_attacks: bool,
application_cipher: tls.ApplicationCipher,
@ -85,7 +81,7 @@ pub const SslKeyLog = struct {
}
};
/// The `std.io.Reader` supplied to `init` requires a buffer capacity
/// The `Reader` supplied to `init` requires a buffer capacity
/// at least this amount.
pub const min_buffer_len = tls.max_ciphertext_record_len;
@ -116,6 +112,20 @@ pub const Options = struct {
/// Only the `writer` field is observed during the handshake (`init`).
/// After that, the other fields are populated.
ssl_key_log: ?*SslKeyLog = null,
/// By default, reaching the end-of-stream when reading from the server will
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
/// message has been received. By setting this flag to `true`, instead, the
/// end-of-stream will be forwarded to the application layer above TLS.
///
/// This makes the application vulnerable to truncation attacks unless the
/// application layer itself verifies that the amount of data received equals
/// the amount of data expected, such as HTTP with the Content-Length header.
allow_truncation_attacks: bool = false,
write_buffer: []u8,
/// Asserted to have capacity at least `min_buffer_len`.
read_buffer: []u8,
/// Populated when `error.TlsAlert` is returned from `init`.
alert: ?*tls.Alert = null,
};
const InitError = error{
@ -173,14 +183,8 @@ const InitError = error{
/// `host` is only borrowed during this function call.
///
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
pub fn init(
client: *Client,
input: *std.io.Reader,
output: *Writer,
options: Options,
) InitError!void {
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
assert(input.buffer.len >= min_buffer_len);
client.alert = null;
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@ -411,7 +415,7 @@ pub fn init(
switch (ct) {
.alert => {
ctd.ensure(2) catch continue :fragment;
client.alert = .{
if (options.alert) |a| a.* = .{
.level = ctd.decode(tls.Alert.Level),
.description = ctd.decode(tls.Alert.Description),
};
@ -852,9 +856,28 @@ pub fn init(
else => unreachable,
},
};
client.* = .{
if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
.client_key_seq = key_seq,
.server_key_seq = key_seq,
.client_random = client_hello_rand,
.writer = ssl_key_log.writer,
};
return .{
.input = input,
.reader = .{
.buffer = options.read_buffer,
.vtable = &.{ .stream = stream },
.seek = 0,
.end = 0,
},
.output = output,
.writer = .{
.buffer = options.write_buffer,
.vtable = &.{
.drain = drain,
.sendFile = Writer.unimplementedSendFile,
},
},
.tls_version = tls_version,
.read_seq = switch (tls_version) {
.tls_1_3 => 0,
@ -867,17 +890,10 @@ pub fn init(
else => unreachable,
},
.received_close_notify = false,
.allow_truncation_attacks = false,
.allow_truncation_attacks = options.allow_truncation_attacks,
.application_cipher = app_cipher,
.ssl_key_log = options.ssl_key_log,
};
if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
.client_key_seq = key_seq,
.server_key_seq = key_seq,
.client_random = client_hello_rand,
.writer = ssl_key_log.writer,
};
return;
},
else => return error.TlsUnexpectedMessage,
}
@ -891,25 +907,9 @@ pub fn init(
}
}
pub fn reader(c: *Client) Reader {
return .{
.context = c,
.vtable = &.{ .read = read },
};
}
pub fn writer(c: *Client) Writer {
return .{
.context = c,
.vtable = &.{
.writeSplat = writeSplat,
.writeFile = Writer.unimplementedWriteFile,
},
};
}
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Error!usize {
const c: *Client = @alignCast(@ptrCast(context));
fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
const c: *Client = @fieldParentPtr("writer", w);
if (true) @panic("update to use the buffer and flush");
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
const output = c.output;
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
@ -1043,8 +1043,8 @@ pub fn eof(c: Client) bool {
return c.received_close_notify;
}
fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
const c: *Client = @ptrCast(@alignCast(context));
fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
const c: *Client = @fieldParentPtr("reader", r);
if (c.eof()) return error.EndOfStream;
const input = c.input;
// If at least one full encrypted record is not buffered, read once.
@ -1214,7 +1214,7 @@ fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamErr
},
.application_data => {
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
try bw.writeAll(cleartext);
try w.writeAll(cleartext);
return cleartext.len;
},
else => return failRead(c, error.TlsUnexpectedMessage),
@ -1226,8 +1226,8 @@ fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
return error.ReadFailed;
}
fn logSecrets(bw: *Writer, context: anytype, secrets: anytype) void {
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| bw.print("{s}" ++
fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void {
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
context.client_random,

View File

@ -943,7 +943,6 @@ pub const Reader = struct {
pub fn initInterface(buffer: []u8) std.io.Reader {
return .{
.context = undefined,
.vtable = &.{
.stream = Reader.stream,
.discard = Reader.discard,
@ -1291,7 +1290,6 @@ pub const Writer = struct {
pub fn initInterface(buffer: []u8) std.io.Writer {
return .{
.context = undefined,
.vtable = &.{
.drain = drain,
.sendFile = sendFile,

View File

@ -328,6 +328,7 @@ pub const Header = struct {
pub const Reader = struct {
in: *std.io.Reader,
interface: std.io.Reader,
/// Keeps track of whether the stream is ready to accept a new request,
/// making invalid API usage cause assertion failures rather than HTTP
/// protocol violations.
@ -438,37 +439,41 @@ pub const Reader = struct {
buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
) std.io.Reader {
) *std.io.Reader {
assert(reader.state == .received_head);
return switch (transfer_encoding) {
switch (transfer_encoding) {
.chunked => {
reader.state = .{ .body_remaining_chunk_len = .head };
return .{
reader.interface = .{
.buffer = buffer,
.context = reader,
.seek = 0,
.end = 0,
.vtable = &.{
.read = chunkedRead,
.stream = chunkedStream,
.discard = chunkedDiscard,
},
};
return &reader.interface;
},
.none => {
if (content_length) |len| {
reader.state = .{ .body_remaining_content_length = len };
return .{
reader.interface = .{
.buffer = buffer,
.context = reader,
.seek = 0,
.end = 0,
.vtable = &.{
.read = contentLengthRead,
.stream = contentLengthStream,
.discard = contentLengthDiscard,
},
};
return &reader.interface;
} else {
reader.state = .body_none;
return reader.in.reader();
return reader.in;
}
},
};
}
}
/// If compressed body has been negotiated this will return decompressed bytes.
@ -511,25 +516,25 @@ pub const Reader = struct {
return decompressor.reader(transfer_reader, decompression_buffer, content_encoding);
}
fn contentLengthRead(
ctx: ?*anyopaque,
bw: *Writer,
fn contentLengthStream(
io_r: *std.io.Reader,
w: *Writer,
limit: std.io.Limit,
) std.io.Reader.StreamError!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
const reader: *Reader = @fieldParentPtr("interface", io_r);
const remaining_content_length = &reader.state.body_remaining_content_length;
const remaining = remaining_content_length.*;
if (remaining == 0) {
reader.state = .ready;
return error.EndOfStream;
}
const n = try reader.in.read(bw, limit.min(.limited(remaining)));
const n = try reader.in.stream(w, limit.min(.limited(remaining)));
remaining_content_length.* = remaining - n;
return n;
}
fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
fn contentLengthDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @fieldParentPtr("interface", io_r);
const remaining_content_length = &reader.state.body_remaining_content_length;
const remaining = remaining_content_length.*;
if (remaining == 0) {
@ -541,18 +546,14 @@ pub const Reader = struct {
return n;
}
fn chunkedRead(
ctx: ?*anyopaque,
bw: *Writer,
limit: std.io.Limit,
) std.io.Reader.StreamError!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
fn chunkedStream(io_r: *std.io.Reader, w: *Writer, limit: std.io.Limit) std.io.Reader.StreamError!usize {
const reader: *Reader = @fieldParentPtr("interface", io_r);
const chunk_len_ptr = switch (reader.state) {
.ready => return error.EndOfStream,
.body_remaining_chunk_len => |*x| x,
else => unreachable,
};
return chunkedReadEndless(reader, bw, limit, chunk_len_ptr) catch |err| switch (err) {
return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) {
error.ReadFailed => return error.ReadFailed,
error.WriteFailed => return error.WriteFailed,
error.EndOfStream => {
@ -568,7 +569,7 @@ pub const Reader = struct {
fn chunkedReadEndless(
reader: *Reader,
bw: *Writer,
w: *Writer,
limit: std.io.Limit,
chunk_len_ptr: *RemainingChunkLen,
) (BodyError || std.io.Reader.StreamError)!usize {
@ -592,7 +593,7 @@ pub const Reader = struct {
}
}
if (cp.chunk_len == 0) return parseTrailers(reader, 0);
const n = try in.read(bw, limit.min(.limited(cp.chunk_len)));
const n = try in.stream(w, limit.min(.limited(cp.chunk_len)));
chunk_len_ptr.* = .init(cp.chunk_len + 2 - n);
return n;
},
@ -608,15 +609,15 @@ pub const Reader = struct {
continue :len .head;
},
else => |remaining_chunk_len| {
const n = try in.read(bw, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2)));
const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2)));
chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n);
return n;
},
}
}
fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
fn chunkedDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @fieldParentPtr("interface", io_r);
const chunk_len_ptr = switch (reader.state) {
.ready => return error.EndOfStream,
.body_remaining_chunk_len => |*x| x,
@ -758,7 +759,6 @@ pub const BodyWriter = struct {
/// state of this other than via methods of `BodyWriter`.
http_protocol_output: *Writer,
state: State,
elide: bool,
interface: Writer,
pub const Error = Writer.Error;
@ -796,6 +796,10 @@ pub const BodyWriter = struct {
};
};
pub fn isEliding(w: *const BodyWriter) bool {
return w.interface.vtable.drain == Writer.discardingDrain;
}
/// Sends all buffered data across `BodyWriter.http_protocol_output`.
pub fn flush(w: *BodyWriter) Error!void {
const out = w.http_protocol_output;
@ -825,7 +829,7 @@ pub const BodyWriter = struct {
/// with empty trailers, then flushes the stream to the system. Asserts any
/// started chunk has been completely finished.
///
/// Respects the value of `elide` to omit all data after the headers.
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// See also:
/// * `endUnflushed`
@ -841,7 +845,7 @@ pub const BodyWriter = struct {
/// Otherwise, transfer-encoding: chunked is being used, and it writes the
/// end-of-stream message with empty trailers.
///
/// Respects the value of `elide` to omit all data after the headers.
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// See also:
/// * `end`
@ -867,7 +871,7 @@ pub const BodyWriter = struct {
///
/// Asserts that the BodyWriter is using transfer-encoding: chunked.
///
/// Respects the value of `elide` to omit all data after the headers.
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// See also:
/// * `endChunkedUnflushed`
@ -883,7 +887,7 @@ pub const BodyWriter = struct {
///
/// Asserts that the BodyWriter is using transfer-encoding: chunked.
///
/// Respects the value of `elide` to omit all data after the headers.
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// See also:
/// * `endChunked`
@ -891,7 +895,7 @@ pub const BodyWriter = struct {
/// * `end`
pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void {
const chunked = &w.state.chunked;
if (w.elide) {
if (w.isEliding()) {
w.state = .end;
return;
}
@ -922,7 +926,7 @@ pub const BodyWriter = struct {
fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
const out = w.http_protocol_output;
const n = try w.drainTo(out, data, splat);
w.state.content_length -= n;
@ -931,7 +935,7 @@ pub const BodyWriter = struct {
fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
const out = w.http_protocol_output;
return try w.drainTo(out, data, splat);
}
@ -939,13 +943,13 @@ pub const BodyWriter = struct {
/// Returns `null` if size cannot be computed without making any syscalls.
fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
return w.sendFileTo(bw.http_protocol_output, file_reader, limit);
}
fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
const n = try w.sendFileTo(bw.http_protocol_output, file_reader, limit);
bw.state.content_length -= n;
return n;
@ -953,7 +957,7 @@ pub const BodyWriter = struct {
fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
const data_len = w.countSendFileUpperBound(file_reader, limit) orelse {
// If the file size is unknown, we cannot lower to a `writeFile` since we would
// have to flush the chunk header before knowing the chunk length.
@ -1001,7 +1005,7 @@ pub const BodyWriter = struct {
fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
const bw: *BodyWriter = @fieldParentPtr("interface", w);
assert(!bw.elide);
assert(!bw.isEliding());
const out = w.http_protocol_output;
const data_len = Writer.countSplat(w.end, data, splat);
const chunked = &bw.state.chunked;
@ -1059,27 +1063,6 @@ pub const BodyWriter = struct {
a /= base;
}
}
pub fn writer(w: *BodyWriter) Writer {
return if (w.elide) .discarding else .{
.context = w,
.vtable = switch (w.state) {
.none => &.{
.drain = noneDrain,
.sendFile = noneSendFile,
},
.content_length => &.{
.drain = contentLengthDrain,
.sendFile = contentLengthSendFile,
},
.chunked => &.{
.drain = chunkedDrain,
.sendFile = chunkedSendFile,
},
.end => unreachable,
},
};
}
};
test {

View File

@ -14,6 +14,7 @@ const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;
const Writer = std.io.Writer;
const Reader = std.io.Reader;
const Client = @This();
@ -228,12 +229,6 @@ pub const Connection = struct {
client: *Client,
stream_writer: net.Stream.Writer,
stream_reader: net.Stream.Reader,
/// HTTP protocol from client to server.
/// This either goes directly to `stream_writer`, or to a TLS client.
writer: Writer,
/// HTTP protocol from server to client.
/// This either comes directly from `stream_reader`, or from a TLS client.
reader: std.io.Reader,
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
pool_node: std.DoublyLinkedList.Node,
port: u16,
@ -264,10 +259,8 @@ pub const Connection = struct {
plain.* = .{
.connection = .{
.client = client,
.stream_writer = stream.writer(),
.stream_reader = stream.reader(),
.writer = plain.connection.stream_writer.interface().buffered(socket_write_buffer),
.reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer),
.stream_writer = stream.writer(socket_write_buffer),
.stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
@ -297,10 +290,6 @@ pub const Connection = struct {
};
const Tls = struct {
/// Data from `client` to `Connection.stream`.
writer: Writer,
/// Data from `Connection.stream` to `client`.
reader: std.io.Reader,
client: std.crypto.tls.Client,
connection: Connection,
@ -324,10 +313,8 @@ pub const Connection = struct {
tls.* = .{
.connection = .{
.client = client,
.stream_writer = stream.writer(),
.stream_reader = stream.reader(),
.writer = tls.client.writer().buffered(socket_write_buffer),
.reader = tls.client.reader().unbuffered(),
.stream_writer = stream.writer(socket_write_buffer),
.stream_reader = stream.reader(&.{}),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
@ -335,20 +322,22 @@ pub const Connection = struct {
.closing = false,
.protocol = .tls,
},
.writer = tls.connection.stream_writer.interface().buffered(tls_write_buffer),
.reader = tls.connection.stream_reader.interface().buffered(tls_read_buffer),
.client = undefined,
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
.client = std.crypto.tls.Client.init(
tls.connection.stream_reader.interface(),
&tls.connection.stream_writer.interface,
.{
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log = client.ssl_key_log,
.read_buffer = tls_read_buffer,
.write_buffer = tls_write_buffer,
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
.allow_truncation_attacks = true,
},
) catch return error.TlsInitializationFailed,
};
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
tls.client.init(&tls.reader, &tls.writer, .{
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log = client.ssl_key_log,
}) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
tls.client.allow_truncation_attacks = true;
return tls;
}
@ -404,26 +393,52 @@ pub const Connection = struct {
}
}
/// HTTP protocol from client to server.
/// This either goes directly to `stream_writer`, or to a TLS client.
pub fn writer(c: *Connection) *Writer {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
return &tls.client.writer;
},
.plain => &c.stream_writer.interface,
};
}
/// HTTP protocol from server to client.
/// This either comes directly from `stream_reader`, or from a TLS client.
pub fn reader(c: *const Connection) *Reader {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
return &tls.client.reader;
},
.plain => c.stream_reader.interface(),
};
}
pub fn flush(c: *Connection) Writer.Error!void {
try c.writer.flush();
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
try tls.writer.flush();
try tls.client.writer.flush();
}
try c.stream_writer.interface.flush();
}
/// If the connection is a TLS connection, sends the close_notify alert.
///
/// Flushes all buffers.
pub fn end(c: *Connection) Writer.Error!void {
try c.writer.flush();
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
try tls.client.end();
try tls.writer.flush();
try tls.client.writer.flush();
}
try c.stream_writer.interface.flush();
}
};
@ -660,14 +675,14 @@ pub const Response = struct {
/// If compressed body has been negotiated this will return compressed bytes.
///
/// If the returned `std.io.Reader` returns `error.ReadFailed` the error is
/// If the returned `Reader` returns `error.ReadFailed` the error is
/// available via `bodyErr`.
///
/// Asserts that this function is only called once.
///
/// See also:
/// * `readerDecompressing`
pub fn reader(response: *Response, buffer: []u8) std.io.Reader {
pub fn reader(response: *Response, buffer: []u8) Reader {
const req = response.request;
if (!req.method.responseHasBody()) return .ending;
const head = &response.head;
@ -676,7 +691,7 @@ pub const Response = struct {
/// If compressed body has been negotiated this will return decompressed bytes.
///
/// If the returned `std.io.Reader` returns `error.ReadFailed` the error is
/// If the returned `Reader` returns `error.ReadFailed` the error is
/// available via `bodyErr`.
///
/// Asserts that this function is only called once.
@ -687,7 +702,7 @@ pub const Response = struct {
response: *Response,
decompressor: *http.Decompressor,
decompression_buffer: []u8,
) std.io.Reader {
) Reader {
const head = &response.head;
return response.request.reader.bodyReaderDecompressing(
head.transfer_encoding,
@ -698,7 +713,7 @@ pub const Response = struct {
);
}
/// After receiving `error.ReadFailed` from the `std.io.Reader` returned by
/// After receiving `error.ReadFailed` from the `Reader` returned by
/// `reader` or `readerDecompressing`, this function accesses the
/// more specific error code.
pub fn bodyErr(response: *const Response) ?http.Reader.BodyError {
@ -835,8 +850,8 @@ pub const Request = struct {
///
/// See also:
/// * `sendBodyUnflushed`
pub fn sendBody(r: *Request) Writer.Error!http.BodyWriter {
const result = try sendBodyUnflushed(r);
pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
const result = try sendBodyUnflushed(r, buffer);
try r.connection.?.flush();
return result;
}
@ -846,17 +861,44 @@ pub const Request = struct {
///
/// See also:
/// * `sendBody`
pub fn sendBodyUnflushed(r: *Request) Writer.Error!http.BodyWriter {
pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
assert(r.method.requestHasBody());
try sendHead(r);
return .{
.http_protocol_output = &r.connection.?.writer,
.state = switch (r.transfer_encoding) {
.chunked => .{ .chunked = .init },
.content_length => |len| .{ .content_length = len },
.none => .none,
const http_protocol_output = &r.connection.?.writer;
return switch (r.transfer_encoding) {
.chunked => .{
.http_protocol_output = http_protocol_output,
.state = .{ .chunked = .init },
.interface = .{
.buffer = buffer,
.interface = &.{
.drain = http.BodyWriter.chunkedDrain,
.sendFile = http.BodyWriter.chunkedSendFile,
},
},
},
.content_length => |len| .{
.http_protocol_output = http_protocol_output,
.state = .{ .content_length = len },
.interface = .{
.buffer = buffer,
.interface = &.{
.drain = http.BodyWriter.contentLengthDrain,
.sendFile = http.BodyWriter.contentLengthSendFile,
},
},
},
.none => .{
.http_protocol_output = http_protocol_output,
.state = .none,
.interface = .{
.buffer = buffer,
.interface = &.{
.drain = http.BodyWriter.noneDrain,
.sendFile = http.BodyWriter.noneSendFile,
},
},
},
.elide = false,
};
}

View File

@ -381,6 +381,8 @@ pub const Request = struct {
content_length: ?u64 = null,
/// Options that are shared with the `respond` method.
respond_options: RespondOptions = .{},
/// Used by `http.BodyWriter`.
buffer: []u8,
};
/// The header is not guaranteed to be sent until `BodyWriter.flush` or
@ -436,16 +438,38 @@ pub const Request = struct {
try out.writeAll("\r\n");
const elide_body = request.head.method == .HEAD;
const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) {
.chunked => .{ .chunked = .init },
.none => .none,
} else if (options.content_length) |len| .{
.content_length = len,
} else .{ .chunked = .init };
return .{
return if (elide_body) .{
.http_protocol_output = request.server.out,
.state = if (o.transfer_encoding) |te| switch (te) {
.chunked => .{ .chunked = .init },
.none => .none,
} else if (options.content_length) |len| .{
.content_length = len,
} else .{ .chunked = .init },
.elide = elide_body,
.state = state,
.interface = .discarding(options.buffer),
} else .{
.http_protocol_output = request.server.out,
.state = state,
.interface = .{
.buffer = options.buffer,
.vtable = switch (state) {
.none => &.{
.drain = http.BodyWriter.noneDrain,
.sendFile = http.BodyWriter.noneSendFile,
},
.content_length => &.{
.drain = http.BodyWriter.contentLengthDrain,
.sendFile = http.BodyWriter.contentLengthSendFile,
},
.chunked => &.{
.drain = http.BodyWriter.chunkedDrain,
.sendFile = http.BodyWriter.chunkedSendFile,
},
.end => unreachable,
},
},
};
}

View File

@ -13,7 +13,6 @@ const Limit = std.io.Limit;
pub const Limited = @import("Reader/Limited.zig");
context: ?*anyopaque = null,
vtable: *const VTable,
buffer: []u8,
/// Number of bytes which have been consumed from `buffer`.
@ -88,7 +87,6 @@ pub const ShortError = error{
};
pub const failing: Reader = .{
.context = undefined,
.vtable = &.{
.read = failingStream,
.discard = failingDiscard,
@ -107,7 +105,6 @@ pub fn limited(r: *Reader, limit: Limit, buffer: []u8) Limited {
/// Constructs a `Reader` such that it will read from `buffer` and then end.
pub fn fixed(buffer: []const u8) Reader {
return .{
.context = undefined,
.vtable = &.{
.stream = endingStream,
.discard = endingDiscard,
@ -1402,12 +1399,6 @@ test "readAlloc when the backing reader provides one byte at a time" {
self.curr += 1;
return 1;
}
fn reader(self: *@This()) std.io.Reader {
return .{
.context = self,
};
}
};
const str = "This is a test";

View File

@ -9,12 +9,6 @@ const File = std.fs.File;
const testing = std.testing;
const Allocator = std.mem.Allocator;
/// There are two strategies for obtaining context; one can use this field, or
/// embed the `Writer` and use `@fieldParentPtr`. This field must be either set
/// to a valid pointer or left as `null` because the interface will sometimes
/// check if this pointer value is a known special value, for example to make
/// `writableVector` work.
context: ?*anyopaque = null,
vtable: *const VTable,
/// If this has length zero, the writer is unbuffered, and `flush` is a no-op.
buffer: []u8,