diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 29399dcd87..8eab61c2fa 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -21,11 +21,15 @@ const array = tls.array; /// here via `reader`. /// /// The buffer is asserted to have capacity at least `min_buffer_len`. -input: *std.io.Reader, +input: *Reader, +/// Decrypted stream from the server to the client. +reader: Reader, /// The encrypted stream from the client to the server. Bytes are pushed here /// via `writer`. output: *Writer, +/// The plaintext stream from the client to the server. +writer: Writer, /// Populated when `error.TlsAlert` is returned. alert: ?tls.Alert = null, @@ -36,14 +40,6 @@ write_seq: u64, /// When this is true, the stream may still not be at the end because there /// may be data in the input buffer. received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, @@ -85,7 +81,7 @@ pub const SslKeyLog = struct { } }; -/// The `std.io.Reader` supplied to `init` requires a buffer capacity +/// The `Reader` supplied to `init` requires a buffer capacity /// at least this amount. pub const min_buffer_len = tls.max_ciphertext_record_len; @@ -116,6 +112,20 @@ pub const Options = struct { /// Only the `writer` field is observed during the handshake (`init`). /// After that, the other fields are populated. ssl_key_log: ?*SslKeyLog = null, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + write_buffer: []u8, + /// Asserted to have capacity at least `min_buffer_len`. + read_buffer: []u8, + /// Populated when `error.TlsAlert` is returned from `init`. + alert: ?*tls.Alert = null, }; const InitError = error{ @@ -173,14 +183,8 @@ const InitError = error{ /// `host` is only borrowed during this function call. /// /// `input` is asserted to have buffer capacity at least `min_buffer_len`. -pub fn init( - client: *Client, - input: *std.io.Reader, - output: *Writer, - options: Options, -) InitError!void { +pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { assert(input.buffer.len >= min_buffer_len); - client.alert = null; const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -411,7 +415,7 @@ pub fn init( switch (ct) { .alert => { ctd.ensure(2) catch continue :fragment; - client.alert = .{ + if (options.alert) |a| a.* = .{ .level = ctd.decode(tls.Alert.Level), .description = ctd.decode(tls.Alert.Description), }; @@ -852,9 +856,28 @@ pub fn init( else => unreachable, }, }; - client.* = .{ + if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .writer = ssl_key_log.writer, + }; + return .{ .input = input, + .reader = .{ + .buffer = options.read_buffer, + .vtable = &.{ .stream = stream }, + .seek = 0, + .end = 0, + }, .output = output, + .writer = .{ + .buffer = options.write_buffer, + .vtable = &.{ + .drain = drain, + .sendFile = Writer.unimplementedSendFile, + }, + }, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -867,17 +890,10 @@ pub fn init( else => unreachable, }, .received_close_notify = false, - .allow_truncation_attacks = false, + .allow_truncation_attacks = options.allow_truncation_attacks, .application_cipher = app_cipher, .ssl_key_log = options.ssl_key_log, }; - if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ - .client_key_seq = key_seq, - .server_key_seq = key_seq, - .client_random = client_hello_rand, - .writer = ssl_key_log.writer, - }; - return; }, else => return error.TlsUnexpectedMessage, } @@ -891,25 +907,9 @@ pub fn init( } } -pub fn reader(c: *Client) Reader { - return .{ - .context = c, - .vtable = &.{ .read = read }, - }; -} - -pub fn writer(c: *Client) Writer { - return .{ - .context = c, - .vtable = &.{ - .writeSplat = writeSplat, - .writeFile = Writer.unimplementedWriteFile, - }, - }; -} - -fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Error!usize { - const c: *Client = @alignCast(@ptrCast(context)); +fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + const c: *Client = @fieldParentPtr("writer", w); + if (true) @panic("update to use the buffer and flush"); const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); @@ -1043,8 +1043,8 @@ pub fn eof(c: Client) bool { return c.received_close_notify; } -fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamError!usize { - const c: *Client = @ptrCast(@alignCast(context)); +fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { + const c: *Client = @fieldParentPtr("reader", r); if (c.eof()) return error.EndOfStream; const input = c.input; // If at least one full encrypted record is not buffered, read once. @@ -1214,7 +1214,7 @@ fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamErr }, .application_data => { if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize); - try bw.writeAll(cleartext); + try w.writeAll(cleartext); return cleartext.len; }, else => return failRead(c, error.TlsUnexpectedMessage), @@ -1226,8 +1226,8 @@ fn failRead(c: *Client, err: ReadError) error{ReadFailed} { return error.ReadFailed; } -fn logSecrets(bw: *Writer, context: anytype, secrets: anytype) void { - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| bw.print("{s}" ++ +fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void { + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random, diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 7f349e633e..1980056334 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -943,7 +943,6 @@ pub const Reader = struct { pub fn initInterface(buffer: []u8) std.io.Reader { return .{ - .context = undefined, .vtable = &.{ .stream = Reader.stream, .discard = Reader.discard, @@ -1291,7 +1290,6 @@ pub const Writer = struct { pub fn initInterface(buffer: []u8) std.io.Writer { return .{ - .context = undefined, .vtable = &.{ .drain = drain, .sendFile = sendFile, diff --git a/lib/std/http.zig b/lib/std/http.zig index fec584a53a..949f8e9619 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -328,6 +328,7 @@ pub const Header = struct { pub const Reader = struct { in: *std.io.Reader, + interface: std.io.Reader, /// Keeps track of whether the stream is ready to accept a new request, /// making invalid API usage cause assertion failures rather than HTTP /// protocol violations. @@ -438,37 +439,41 @@ pub const Reader = struct { buffer: []u8, transfer_encoding: TransferEncoding, content_length: ?u64, - ) std.io.Reader { + ) *std.io.Reader { assert(reader.state == .received_head); - return switch (transfer_encoding) { + switch (transfer_encoding) { .chunked => { reader.state = .{ .body_remaining_chunk_len = .head }; - return .{ + reader.interface = .{ .buffer = buffer, - .context = reader, + .seek = 0, + .end = 0, .vtable = &.{ - .read = chunkedRead, + .stream = chunkedStream, .discard = chunkedDiscard, }, }; + return &reader.interface; }, .none => { if (content_length) |len| { reader.state = .{ .body_remaining_content_length = len }; - return .{ + reader.interface = .{ .buffer = buffer, - .context = reader, + .seek = 0, + .end = 0, .vtable = &.{ - .read = contentLengthRead, + .stream = contentLengthStream, .discard = contentLengthDiscard, }, }; + return &reader.interface; } else { reader.state = .body_none; - return reader.in.reader(); + return reader.in; } }, - }; + } } /// If compressed body has been negotiated this will return decompressed bytes. @@ -511,25 +516,25 @@ pub const Reader = struct { return decompressor.reader(transfer_reader, decompression_buffer, content_encoding); } - fn contentLengthRead( - ctx: ?*anyopaque, - bw: *Writer, + fn contentLengthStream( + io_r: *std.io.Reader, + w: *Writer, limit: std.io.Limit, ) std.io.Reader.StreamError!usize { - const reader: *Reader = @alignCast(@ptrCast(ctx)); + const reader: *Reader = @fieldParentPtr("interface", io_r); const remaining_content_length = &reader.state.body_remaining_content_length; const remaining = remaining_content_length.*; if (remaining == 0) { reader.state = .ready; return error.EndOfStream; } - const n = try reader.in.read(bw, limit.min(.limited(remaining))); + const n = try reader.in.stream(w, limit.min(.limited(remaining))); remaining_content_length.* = remaining - n; return n; } - fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize { - const reader: *Reader = @alignCast(@ptrCast(ctx)); + fn contentLengthDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); const remaining_content_length = &reader.state.body_remaining_content_length; const remaining = remaining_content_length.*; if (remaining == 0) { @@ -541,18 +546,14 @@ pub const Reader = struct { return n; } - fn chunkedRead( - ctx: ?*anyopaque, - bw: *Writer, - limit: std.io.Limit, - ) std.io.Reader.StreamError!usize { - const reader: *Reader = @alignCast(@ptrCast(ctx)); + fn chunkedStream(io_r: *std.io.Reader, w: *Writer, limit: std.io.Limit) std.io.Reader.StreamError!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); const chunk_len_ptr = switch (reader.state) { .ready => return error.EndOfStream, .body_remaining_chunk_len => |*x| x, else => unreachable, }; - return chunkedReadEndless(reader, bw, limit, chunk_len_ptr) catch |err| switch (err) { + return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) { error.ReadFailed => return error.ReadFailed, error.WriteFailed => return error.WriteFailed, error.EndOfStream => { @@ -568,7 +569,7 @@ pub const Reader = struct { fn chunkedReadEndless( reader: *Reader, - bw: *Writer, + w: *Writer, limit: std.io.Limit, chunk_len_ptr: *RemainingChunkLen, ) (BodyError || std.io.Reader.StreamError)!usize { @@ -592,7 +593,7 @@ pub const Reader = struct { } } if (cp.chunk_len == 0) return parseTrailers(reader, 0); - const n = try in.read(bw, limit.min(.limited(cp.chunk_len))); + const n = try in.stream(w, limit.min(.limited(cp.chunk_len))); chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); return n; }, @@ -608,15 +609,15 @@ pub const Reader = struct { continue :len .head; }, else => |remaining_chunk_len| { - const n = try in.read(bw, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); + const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); return n; }, } } - fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize { - const reader: *Reader = @alignCast(@ptrCast(ctx)); + fn chunkedDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); const chunk_len_ptr = switch (reader.state) { .ready => return error.EndOfStream, .body_remaining_chunk_len => |*x| x, @@ -758,7 +759,6 @@ pub const BodyWriter = struct { /// state of this other than via methods of `BodyWriter`. http_protocol_output: *Writer, state: State, - elide: bool, interface: Writer, pub const Error = Writer.Error; @@ -796,6 +796,10 @@ pub const BodyWriter = struct { }; }; + pub fn isEliding(w: *const BodyWriter) bool { + return w.interface.vtable.drain == Writer.discardingDrain; + } + /// Sends all buffered data across `BodyWriter.http_protocol_output`. pub fn flush(w: *BodyWriter) Error!void { const out = w.http_protocol_output; @@ -825,7 +829,7 @@ pub const BodyWriter = struct { /// with empty trailers, then flushes the stream to the system. Asserts any /// started chunk has been completely finished. /// - /// Respects the value of `elide` to omit all data after the headers. + /// Respects the value of `isEliding` to omit all data after the headers. /// /// See also: /// * `endUnflushed` @@ -841,7 +845,7 @@ pub const BodyWriter = struct { /// Otherwise, transfer-encoding: chunked is being used, and it writes the /// end-of-stream message with empty trailers. /// - /// Respects the value of `elide` to omit all data after the headers. + /// Respects the value of `isEliding` to omit all data after the headers. /// /// See also: /// * `end` @@ -867,7 +871,7 @@ pub const BodyWriter = struct { /// /// Asserts that the BodyWriter is using transfer-encoding: chunked. /// - /// Respects the value of `elide` to omit all data after the headers. + /// Respects the value of `isEliding` to omit all data after the headers. /// /// See also: /// * `endChunkedUnflushed` @@ -883,7 +887,7 @@ pub const BodyWriter = struct { /// /// Asserts that the BodyWriter is using transfer-encoding: chunked. /// - /// Respects the value of `elide` to omit all data after the headers. + /// Respects the value of `isEliding` to omit all data after the headers. /// /// See also: /// * `endChunked` @@ -891,7 +895,7 @@ pub const BodyWriter = struct { /// * `end` pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void { const chunked = &w.state.chunked; - if (w.elide) { + if (w.isEliding()) { w.state = .end; return; } @@ -922,7 +926,7 @@ pub const BodyWriter = struct { fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); const out = w.http_protocol_output; const n = try w.drainTo(out, data, splat); w.state.content_length -= n; @@ -931,7 +935,7 @@ pub const BodyWriter = struct { fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); const out = w.http_protocol_output; return try w.drainTo(out, data, splat); } @@ -939,13 +943,13 @@ pub const BodyWriter = struct { /// Returns `null` if size cannot be computed without making any syscalls. fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); return w.sendFileTo(bw.http_protocol_output, file_reader, limit); } fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); const n = try w.sendFileTo(bw.http_protocol_output, file_reader, limit); bw.state.content_length -= n; return n; @@ -953,7 +957,7 @@ pub const BodyWriter = struct { fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); const data_len = w.countSendFileUpperBound(file_reader, limit) orelse { // If the file size is unknown, we cannot lower to a `writeFile` since we would // have to flush the chunk header before knowing the chunk length. @@ -1001,7 +1005,7 @@ pub const BodyWriter = struct { fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { const bw: *BodyWriter = @fieldParentPtr("interface", w); - assert(!bw.elide); + assert(!bw.isEliding()); const out = w.http_protocol_output; const data_len = Writer.countSplat(w.end, data, splat); const chunked = &bw.state.chunked; @@ -1059,27 +1063,6 @@ pub const BodyWriter = struct { a /= base; } } - - pub fn writer(w: *BodyWriter) Writer { - return if (w.elide) .discarding else .{ - .context = w, - .vtable = switch (w.state) { - .none => &.{ - .drain = noneDrain, - .sendFile = noneSendFile, - }, - .content_length => &.{ - .drain = contentLengthDrain, - .sendFile = contentLengthSendFile, - }, - .chunked => &.{ - .drain = chunkedDrain, - .sendFile = chunkedSendFile, - }, - .end => unreachable, - }, - }; - } }; test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index a2bfae3d91..32b5853c68 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -14,6 +14,7 @@ const Uri = std.Uri; const Allocator = mem.Allocator; const assert = std.debug.assert; const Writer = std.io.Writer; +const Reader = std.io.Reader; const Client = @This(); @@ -228,12 +229,6 @@ pub const Connection = struct { client: *Client, stream_writer: net.Stream.Writer, stream_reader: net.Stream.Reader, - /// HTTP protocol from client to server. - /// This either goes directly to `stream_writer`, or to a TLS client. - writer: Writer, - /// HTTP protocol from server to client. - /// This either comes directly from `stream_reader`, or from a TLS client. - reader: std.io.Reader, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, port: u16, @@ -264,10 +259,8 @@ pub const Connection = struct { plain.* = .{ .connection = .{ .client = client, - .stream_writer = stream.writer(), - .stream_reader = stream.reader(), - .writer = plain.connection.stream_writer.interface().buffered(socket_write_buffer), - .reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer), + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(socket_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -297,10 +290,6 @@ pub const Connection = struct { }; const Tls = struct { - /// Data from `client` to `Connection.stream`. - writer: Writer, - /// Data from `Connection.stream` to `client`. - reader: std.io.Reader, client: std.crypto.tls.Client, connection: Connection, @@ -324,10 +313,8 @@ pub const Connection = struct { tls.* = .{ .connection = .{ .client = client, - .stream_writer = stream.writer(), - .stream_reader = stream.reader(), - .writer = tls.client.writer().buffered(socket_write_buffer), - .reader = tls.client.reader().unbuffered(), + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(&.{}), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -335,20 +322,22 @@ pub const Connection = struct { .closing = false, .protocol = .tls, }, - .writer = tls.connection.stream_writer.interface().buffered(tls_write_buffer), - .reader = tls.connection.stream_reader.interface().buffered(tls_read_buffer), - .client = undefined, + // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true + .client = std.crypto.tls.Client.init( + tls.connection.stream_reader.interface(), + &tls.connection.stream_writer.interface, + .{ + .host = .{ .explicit = remote_host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log = client.ssl_key_log, + .read_buffer = tls_read_buffer, + .write_buffer = tls_write_buffer, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + }, + ) catch return error.TlsInitializationFailed, }; - // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true - tls.client.init(&tls.reader, &tls.writer, .{ - .host = .{ .explicit = remote_host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log = client.ssl_key_log, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - tls.client.allow_truncation_attacks = true; - return tls; } @@ -404,26 +393,52 @@ pub const Connection = struct { } } + /// HTTP protocol from client to server. + /// This either goes directly to `stream_writer`, or to a TLS client. + pub fn writer(c: *Connection) *Writer { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.writer; + }, + .plain => &c.stream_writer.interface, + }; + } + + /// HTTP protocol from server to client. + /// This either comes directly from `stream_reader`, or from a TLS client. + pub fn reader(c: *const Connection) *Reader { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.reader; + }, + .plain => c.stream_reader.interface(), + }; + } + pub fn flush(c: *Connection) Writer.Error!void { - try c.writer.flush(); if (c.protocol == .tls) { if (disable_tls) unreachable; const tls: *Tls = @fieldParentPtr("connection", c); - try tls.writer.flush(); + try tls.client.writer.flush(); } + try c.stream_writer.interface.flush(); } /// If the connection is a TLS connection, sends the close_notify alert. /// /// Flushes all buffers. pub fn end(c: *Connection) Writer.Error!void { - try c.writer.flush(); if (c.protocol == .tls) { if (disable_tls) unreachable; const tls: *Tls = @fieldParentPtr("connection", c); try tls.client.end(); - try tls.writer.flush(); + try tls.client.writer.flush(); } + try c.stream_writer.interface.flush(); } }; @@ -660,14 +675,14 @@ pub const Response = struct { /// If compressed body has been negotiated this will return compressed bytes. /// - /// If the returned `std.io.Reader` returns `error.ReadFailed` the error is + /// If the returned `Reader` returns `error.ReadFailed` the error is /// available via `bodyErr`. /// /// Asserts that this function is only called once. /// /// See also: /// * `readerDecompressing` - pub fn reader(response: *Response, buffer: []u8) std.io.Reader { + pub fn reader(response: *Response, buffer: []u8) Reader { const req = response.request; if (!req.method.responseHasBody()) return .ending; const head = &response.head; @@ -676,7 +691,7 @@ pub const Response = struct { /// If compressed body has been negotiated this will return decompressed bytes. /// - /// If the returned `std.io.Reader` returns `error.ReadFailed` the error is + /// If the returned `Reader` returns `error.ReadFailed` the error is /// available via `bodyErr`. /// /// Asserts that this function is only called once. @@ -687,7 +702,7 @@ pub const Response = struct { response: *Response, decompressor: *http.Decompressor, decompression_buffer: []u8, - ) std.io.Reader { + ) Reader { const head = &response.head; return response.request.reader.bodyReaderDecompressing( head.transfer_encoding, @@ -698,7 +713,7 @@ pub const Response = struct { ); } - /// After receiving `error.ReadFailed` from the `std.io.Reader` returned by + /// After receiving `error.ReadFailed` from the `Reader` returned by /// `reader` or `readerDecompressing`, this function accesses the /// more specific error code. pub fn bodyErr(response: *const Response) ?http.Reader.BodyError { @@ -835,8 +850,8 @@ pub const Request = struct { /// /// See also: /// * `sendBodyUnflushed` - pub fn sendBody(r: *Request) Writer.Error!http.BodyWriter { - const result = try sendBodyUnflushed(r); + pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + const result = try sendBodyUnflushed(r, buffer); try r.connection.?.flush(); return result; } @@ -846,17 +861,44 @@ pub const Request = struct { /// /// See also: /// * `sendBody` - pub fn sendBodyUnflushed(r: *Request) Writer.Error!http.BodyWriter { + pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { assert(r.method.requestHasBody()); try sendHead(r); - return .{ - .http_protocol_output = &r.connection.?.writer, - .state = switch (r.transfer_encoding) { - .chunked => .{ .chunked = .init }, - .content_length => |len| .{ .content_length = len }, - .none => .none, + const http_protocol_output = &r.connection.?.writer; + return switch (r.transfer_encoding) { + .chunked => .{ + .http_protocol_output = http_protocol_output, + .state = .{ .chunked = .init }, + .interface = .{ + .buffer = buffer, + .interface = &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + }, + }, + .content_length => |len| .{ + .http_protocol_output = http_protocol_output, + .state = .{ .content_length = len }, + .interface = .{ + .buffer = buffer, + .interface = &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + }, + }, + .none => .{ + .http_protocol_output = http_protocol_output, + .state = .none, + .interface = .{ + .buffer = buffer, + .interface = &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + }, }, - .elide = false, }; } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index f4b29d2d8b..d3077842c8 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -381,6 +381,8 @@ pub const Request = struct { content_length: ?u64 = null, /// Options that are shared with the `respond` method. respond_options: RespondOptions = .{}, + /// Used by `http.BodyWriter`. + buffer: []u8, }; /// The header is not guaranteed to be sent until `BodyWriter.flush` or @@ -436,16 +438,38 @@ pub const Request = struct { try out.writeAll("\r\n"); const elide_body = request.head.method == .HEAD; + const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) { + .chunked => .{ .chunked = .init }, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .{ .chunked = .init }; - return .{ + return if (elide_body) .{ .http_protocol_output = request.server.out, - .state = if (o.transfer_encoding) |te| switch (te) { - .chunked => .{ .chunked = .init }, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .{ .chunked = .init }, - .elide = elide_body, + .state = state, + .interface = .discarding(options.buffer), + } else .{ + .http_protocol_output = request.server.out, + .state = state, + .interface = .{ + .buffer = options.buffer, + .vtable = switch (state) { + .none => &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + .content_length => &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + .chunked => &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + .end => unreachable, + }, + }, }; } diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index 9a518cdc33..fc0e57e5bc 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -13,7 +13,6 @@ const Limit = std.io.Limit; pub const Limited = @import("Reader/Limited.zig"); -context: ?*anyopaque = null, vtable: *const VTable, buffer: []u8, /// Number of bytes which have been consumed from `buffer`. @@ -88,7 +87,6 @@ pub const ShortError = error{ }; pub const failing: Reader = .{ - .context = undefined, .vtable = &.{ .read = failingStream, .discard = failingDiscard, @@ -107,7 +105,6 @@ pub fn limited(r: *Reader, limit: Limit, buffer: []u8) Limited { /// Constructs a `Reader` such that it will read from `buffer` and then end. pub fn fixed(buffer: []const u8) Reader { return .{ - .context = undefined, .vtable = &.{ .stream = endingStream, .discard = endingDiscard, @@ -1402,12 +1399,6 @@ test "readAlloc when the backing reader provides one byte at a time" { self.curr += 1; return 1; } - - fn reader(self: *@This()) std.io.Reader { - return .{ - .context = self, - }; - } }; const str = "This is a test"; diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index fd3670707c..58c497c4da 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -9,12 +9,6 @@ const File = std.fs.File; const testing = std.testing; const Allocator = std.mem.Allocator; -/// There are two strategies for obtaining context; one can use this field, or -/// embed the `Writer` and use `@fieldParentPtr`. This field must be either set -/// to a valid pointer or left as `null` because the interface will sometimes -/// check if this pointer value is a known special value, for example to make -/// `writableVector` work. -context: ?*anyopaque = null, vtable: *const VTable, /// If this has length zero, the writer is unbuffered, and `flush` is a no-op. buffer: []u8,