From aac26f3b31ddb43e863ac7186fd19a7e251c1b8a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 22:39:26 -0700 Subject: [PATCH] TLS, HTTP, and package fetching fixes * TLS: add missing assert for output buffer length requirement * TLS: add missing flushes * TLS: add flush implementation * TLS: finish drain implementation * HTTP: correct buffer sizes for TLS * HTTP: expose a getReadError method on Connection * HTTP: add missing flush on sendBodyComplete * Fetch: remove unwanted deinit * Fetch: improve error reporting --- lib/std/crypto/tls/Client.zig | 62 ++++++++++++++++++++++++++--------- lib/std/http/Client.zig | 38 ++++++++++++++++----- src/Package/Fetch.zig | 12 +++++-- 3 files changed, 84 insertions(+), 28 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3fe51e7b3b..5e89c071c6 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -8,8 +8,8 @@ const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; -const Reader = std.io.Reader; -const Writer = std.io.Writer; +const Reader = std.Io.Reader; +const Writer = std.Io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -27,6 +27,8 @@ reader: Reader, /// The encrypted stream from the client to the server. Bytes are pushed here /// via `writer`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. output: *Writer, /// The plaintext stream from the client to the server. writer: Writer, @@ -122,7 +124,6 @@ pub const Options = struct { /// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool = false, write_buffer: []u8, - /// Asserted to have capacity at least `min_buffer_len`. read_buffer: []u8, /// Populated when `error.TlsAlert` is returned from `init`. alert: ?*tls.Alert = null, @@ -185,6 +186,7 @@ const InitError = error{ /// `input` is asserted to have buffer capacity at least `min_buffer_len`. pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { assert(input.buffer.len >= min_buffer_len); + assert(output.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -278,6 +280,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { var iovecs: [2][]const u8 = .{ cleartext_header, host }; try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); + try output.flush(); } var tls_version: tls.ProtocolVersion = undefined; @@ -763,6 +766,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client &client_verify_msg, }; try output.writeVecAll(&all_msgs_vec); + try output.flush(); }, } write_seq += 1; @@ -828,6 +832,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client &finished_msg, }; try output.writeVecAll(&all_msgs_vec); + try output.flush(); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); @@ -877,7 +882,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client .buffer = options.write_buffer, .vtable = &.{ .drain = drain, - .sendFile = Writer.unimplementedSendFile, + .flush = flush, }, }, .tls_version = tls_version, @@ -911,31 +916,56 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { const c: *Client = @alignCast(@fieldParentPtr("writer", w)); - if (true) @panic("update to use the buffer and flush"); - const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); - var total_clear: usize = 0; var ciphertext_end: usize = 0; - for (sliced_data) |buf| { - const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); - total_clear += prepared.cleartext_len; - ciphertext_end += prepared.ciphertext_end; - if (total_clear < buf.len) break; + var total_clear: usize = 0; + done: { + { + const buf = w.buffered(); + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + for (data[0 .. data.len - 1]) |buf| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + const buf = data[data.len - 1]; + for (0..splat) |_| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } } output.advance(ciphertext_end); - return total_clear; + return w.consume(total_clear); +} + +fn flush(w: *Writer) Writer.Error!void { + const c: *Client = @alignCast(@fieldParentPtr("writer", w)); + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data); + output.advance(prepared.ciphertext_end); + w.end = 0; } /// Sends a `close_notify` alert, which is necessary for the server to /// distinguish between a properly finished TLS session, or a truncation /// attack. pub fn end(c: *Client) Writer.Error!void { + try flush(&c.writer); const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); - output.advance(prepared.cleartext_len); - return prepared.ciphertext_end; + output.advance(prepared.ciphertext_end); } fn prepareCiphertextRecord( @@ -1045,7 +1075,7 @@ pub fn eof(c: Client) bool { return c.received_close_notify; } -fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { +fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { const c: *Client = @alignCast(@fieldParentPtr("reader", r)); if (c.eof()) return error.EndOfStream; const input = c.input; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 95397b6f07..37022b4d0b 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{}, /// /// If the entire HTTP header cannot fit in this amount of bytes, /// `error.HttpHeadersOversize` will be returned from `Request.wait`. -read_buffer_size: usize = 4096, +read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, /// Each `Connection` allocates this amount for the writer buffer. write_buffer_size: usize = 1024, @@ -304,15 +304,16 @@ pub const Connection = struct { const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; - const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; - assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size]; + assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len); @memcpy(host_buffer, remote_host); const tls: *Tls = @ptrCast(base); tls.* = .{ .connection = .{ .client = client, - .stream_writer = stream.writer(socket_write_buffer), - .stream_reader = stream.reader(&.{}), + .stream_writer = stream.writer(tls_write_buffer), + .stream_reader = stream.reader(tls_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -328,8 +329,8 @@ pub const Connection = struct { .host = .{ .explicit = remote_host }, .ca = .{ .bundle = client.ca_bundle }, .ssl_key_log = client.ssl_key_log, - .read_buffer = tls_read_buffer, - .write_buffer = tls_write_buffer, + .read_buffer = read_buffer, + .write_buffer = write_buffer, // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, @@ -347,7 +348,8 @@ pub const Connection = struct { } fn allocLen(client: *Client, host_len: usize) usize { - return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size; + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + + client.write_buffer_size + client.read_buffer_size; } fn host(tls: *Tls) []u8 { @@ -356,6 +358,21 @@ pub const Connection = struct { } }; + pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError; + + pub fn getReadError(c: *const Connection) ?ReadError { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c)); + return tls.client.read_err orelse c.stream_reader.getError(); + }, + .plain => { + return c.stream_reader.getError(); + }, + }; + } + fn getStream(c: *Connection) net.Stream { return c.stream_reader.getStream(); } @@ -434,7 +451,6 @@ pub const Connection = struct { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.end(); - try tls.client.writer.flush(); } try c.stream_writer.interface.flush(); } @@ -874,6 +890,7 @@ pub const Request = struct { var bw = try sendBodyUnflushed(r, body); bw.writer.end = body.len; try bw.end(); + try r.connection.?.flush(); } /// Transfers the HTTP head over the connection, which is not flushed until @@ -1063,6 +1080,9 @@ pub const Request = struct { /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize` /// is returned instead. This buffer may be empty if no redirects are to be /// handled. + /// + /// If this fails with `error.ReadFailed` then the `Connection.getReadError` + /// method of `r.connection` can be used to get more detailed information. pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { var aux_buf = redirect_buffer; while (true) { diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 6fd5a5989c..fd8c26f1e6 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -998,15 +998,21 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u .buffer = reader_buffer, } }; const request = &resource.http_request.request; - defer request.deinit(); + errdefer request.deinit(); request.sendBodiless() catch |err| return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err})); var redirect_buffer: [1024]u8 = undefined; const response = &resource.http_request.response; - response.* = request.receiveHead(&redirect_buffer) catch |err| - return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err})); + response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) { + error.ReadFailed => { + return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{ + request.connection.?.getReadError().?, + })); + }, + else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})), + }; if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString( "bad HTTP response code: '{d} {s}'",