diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index c7015b7fd3..76a1761ec5 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -155,38 +155,35 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { } // Read a DER-encoded integer. - fn readDerInt(out: []u8, reader: anytype) EncodingError!void { - var buf: [2]u8 = undefined; - _ = reader.readNoEof(&buf) catch return error.InvalidEncoding; + // Asserts `br` has storage capacity >= 2. + fn readDerInt(out: []u8, br: *std.io.BufferedReader) EncodingError!void { + const buf = br.take(2) catch return error.InvalidEncoding; if (buf[0] != 0x02) return error.InvalidEncoding; - var expected_len = @as(usize, buf[1]); + var expected_len: usize = buf[1]; if (expected_len == 0 or expected_len > 1 + out.len) return error.InvalidEncoding; var has_top_bit = false; if (expected_len == 1 + out.len) { - if ((reader.readByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding; + if ((br.takeByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding; expected_len -= 1; has_top_bit = true; } const out_slice = out[out.len - expected_len ..]; - reader.readNoEof(out_slice) catch return error.InvalidEncoding; + br.read(out_slice) catch return error.InvalidEncoding; if (@intFromBool(has_top_bit) != out[0] >> 7) return error.InvalidEncoding; } /// Create a signature from a DER representation. /// Returns InvalidEncoding if the DER encoding is invalid. pub fn fromDer(der: []const u8) EncodingError!Signature { + if (der.len < 2) return error.InvalidEncoding; + var br: std.io.BufferedReader = undefined; + br.initFixed(der); + const buf = br.take(2) catch return error.InvalidEncoding; + if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) return error.InvalidEncoding; var sig: Signature = mem.zeroInit(Signature, .{}); - var fb: std.io.FixedBufferStream = .{ .buffer = der }; - const reader = fb.reader(); - var buf: [2]u8 = undefined; - _ = reader.readNoEof(&buf) catch return error.InvalidEncoding; - if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) { - return error.InvalidEncoding; - } - try readDerInt(&sig.r, reader); - try readDerInt(&sig.s, reader); - if (fb.getPos() catch unreachable != der.len) return error.InvalidEncoding; - + try readDerInt(&sig.r, &br); + try readDerInt(&sig.s, &br); + if (br.seek != der.len) return error.InvalidEncoding; return sig; } }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index fd0ab12930..9cfce837b7 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,3 +1,6 @@ +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); @@ -13,18 +16,44 @@ const hkdfExpandLabel = tls.hkdfExpandLabel; const int = tls.int; const array = tls.array; +/// The encrypted stream from the server to the client. Bytes are pulled from +/// here via `reader`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +/// +/// The size is enough to contain exactly one TLSCiphertext record. +/// This buffer is segmented into four parts: +/// 0. unused +/// 1. cleartext +/// 2. ciphertext +/// 3. unused +/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and +/// `partial_ciphertext_end` describe the span of the segments. +input: *std.io.BufferedReader, +/// The encrypted stream from the client to the server. Bytes are pushed here +/// via `writer`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +output: *std.io.BufferedWriter, +/// Cleartext received from the server here. +/// +/// Its buffer aliases the buffer of `input`. +reader: std.io.BufferedReader, +/// Populated under various error conditions. +diagnostics: Diagnostics, + tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. +/// The starting index of cleartext bytes inside the input buffer. partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well +/// The ending index of cleartext bytes inside the input buffer as well /// as the starting index of ciphertext bytes. partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. +/// The ending index of ciphertext bytes inside the input buffer. partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. +/// may be data in the input buffer. received_close_notify: bool, /// By default, reaching the end-of-stream when reading from the server will /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify @@ -35,24 +64,40 @@ received_close_notify: bool, /// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, -/// Encrypted bytes sent to the server here. -output: *std.io.BufferedWriter, -/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other -/// programs with access to that file to decrypt all traffic over this connection. -ssl_key_log: ?struct { +/// If non-null, ssl secrets are logged to a stream. Creating such a log file +/// allows other programs with access to that file to decrypt all traffic over +/// this connection. +ssl_key_log: ?*SslKeyLog, + +pub const Diagnostics = union { + /// Populated on `error.WriteFailure` and `error.ReadFailure`. + err: anyerror, + /// Populated on `error.TlsAlert`. + /// + /// If this isn't a error alert, then it's a closure alert, which makes + /// no sense in a handshake. + alert: tls.AlertDescription, + + fn wrapWrite(d: *Diagnostics, returned: anyerror!void) error{WriteFailure}!void { + returned catch |err| { + d.* = .{ .err = err }; + return error.WriteFailure; + }; + } + + fn wrapRead(d: *Diagnostics, returned: anyerror!void) error{ReadFailure}!void { + returned catch |err| { + d.* = .{ .err = err }; + return error.ReadFailure; + }; + } +}; + +pub const SslKeyLog = struct { client_key_seq: u64, server_key_seq: u64, client_random: [32]u8, - file: std.fs.File, + writer: *std.io.BufferedWriter, fn clientCounter(key_log: *@This()) u64 { defer key_log.client_key_seq += 1; @@ -63,31 +108,12 @@ ssl_key_log: ?struct { defer key_log.server_key_seq += 1; return key_log.server_key_seq; } -}, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } }; +/// The `std.io.BufferedReader` and `std.io.BufferedWriter` supplied to `init` +/// each require a buffer capacity at least this amount. +pub const min_buffer_len = tls.max_ciphertext_record_len; + pub const Options = struct { /// How to perform host verification of server certificates. host: union(enum) { @@ -109,39 +135,11 @@ pub const Options = struct { /// Verify that the server certificate is authorized by a given ca bundle. bundle: Certificate.Bundle, }, - /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows /// other programs with access to that file to decrypt all traffic over this connection. - /// TODO `std.crypto` should have no dependencies on `std.fs`. - ssl_key_log_file: ?std.fs.File = null, - diagnostics: ?*Diagnostics = null, - - pub const Diagnostics = union { - /// Populated on `error.WriteFailure` and `error.ReadFailure`. - err: anyerror, - /// Populated on `error.TlsAlert`. - /// - /// If this isn't a error alert, then it's a closure alert, which makes - /// no sense in a handshake. - alert: tls.AlertDescription, - }; + ssl_key_log: ?*std.io.BufferedWriter = null, }; -/// TODO I wish this could be a method of Diagnostics -fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void { - returned catch |err| { - if (opt_diags) |diags| diags.* = .{ .err = err }; - return error.WriteFailure; - }; -} - -/// TODO I wish this could be a method of Diagnostics -fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void { - returned catch |err| { - if (opt_diags) |diags| diags.* = .{ .err = err }; - return error.ReadFailure; - }; -} - const InitError = error{ //OutOfMemory, WriteFailure, @@ -193,12 +191,21 @@ const InitError = error{ WeakPublicKey, }; -/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which -/// must conform to `StreamInterface`. +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session. /// /// `host` is only borrowed during this function call. -pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client { - const diags = options.diagnostics; +/// +/// Both `input` and `output` are asserted to have buffer capacity at least +/// `min_buffer_len`. +pub fn init( + client: *Client, + input: *std.io.BufferedReader, + output: *std.io.BufferedWriter, + options: Options, +) InitError!void { + assert(input.storage.buffer.len >= min_buffer_len); + assert(output.buffer.len >= min_buffer_len); + const diags = &client.diagnostics; const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -291,7 +298,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In { var iovecs: [2][]const u8 = .{ cleartext_header, host }; - try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2])); + try diags.wrapWrite(output.writevAll(iovecs[0..if (host.len == 0) 1 else 2])); } var tls_version: tls.ProtocolVersion = undefined; @@ -343,12 +350,12 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len)); + try diags.wrapRead(d.readAtLeastOurAmt(input, tls.record_header_len)); const record_header = d.buf[d.idx..][0..tls.record_header_len]; const record_ct = d.decode(tls.ContentType); d.skip(2); // legacy_version const record_len = d.decode(u16); - try wrapRead(diags, d.readAtLeast(input, record_len)); + try diags.wrapRead(d.readAtLeast(input, record_len)); var record_decoder = try d.sub(record_len); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, @@ -426,7 +433,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In const level = ctd.decode(tls.AlertLevel); const desc = ctd.decode(tls.AlertDescription); _ = level; - if (diags) |x| x.* = .{ .alert = desc }; + diags.* = .{ .alert = desc }; return error.TlsAlert; }, .change_cipher_spec => { @@ -768,7 +775,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In &client_change_cipher_spec_msg, &client_verify_msg, }; - try wrapWrite(diags, output.writevAll(&all_msgs_vec)); + try diags.wrapWrite(output.writevAll(&all_msgs_vec)); }, } write_seq += 1; @@ -833,7 +840,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In &client_change_cipher_spec_msg, &finished_msg, }; - try wrapWrite(diags, output.writevAll(&all_msgs_vec)); + try diags.wrapWrite(output.writevAll(&all_msgs_vec)); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); @@ -865,7 +872,10 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In }, }; const leftover = d.rest(); - var client: Client = .{ + client.* = .{ + .input = input, + .output = output, + .reader = undefined, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -883,7 +893,6 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In .received_close_notify = false, .allow_truncation_attacks = false, .application_cipher = app_cipher, - .output = output, .partially_read_buffer = undefined, .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ .client_key_seq = key_seq, @@ -893,7 +902,14 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In } else null, }; @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; + client.reader.init(.{ + .context = client, + .vtable = &.{ + .read = reader_read, + .readv = reader_readv, + }, + }, input.storage.buffer[0..0]); + return; }, else => return error.TlsUnexpectedMessage, } @@ -919,81 +935,45 @@ pub fn writer(c: *Client) std.io.Writer { fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { const c: *Client = @alignCast(@ptrCast(context)); - assert(data.len > 1 or splat > 0); - return writeEnd(c, data[0], false); + const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; + const output = &c.output; + const ciphertext_buf = try output.writableSlice(min_buffer_len); + var total_clear: usize = 0; + var ciphertext_end: usize = 0; + for (sliced_data) |buf| { + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (total_clear < buf.len) break; + } + output.advance(ciphertext_end); + return total_clear; } -/// If `end` is true, then this function additionally sends a `close_notify` -/// alert, which is necessary for the server to distinguish between a properly -/// finished TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(bytes[index..], end); - } -} - -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6][]const u8 = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try c.output.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].len) { - const encrypted_amt = iovecs_buf[i].len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i] = iovecs_buf[i][amt..]; - } +/// Sends a `close_notify` alert, which is necessary for the server to +/// distinguish between a properly finished TLS session, or a truncation +/// attack. +pub fn end(c: *Client) anyerror!void { + const output = &c.output; + const ciphertext_buf = try output.writableSlice(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); + output.advance(prepared.cleartext_len); + return prepared.ciphertext_end; } fn prepareCiphertextRecord( c: *Client, - iovecs: [][]const u8, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, ) struct { - iovec_end: usize, ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, + cleartext_len: usize, } { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; var ciphertext_end: usize = 0; - var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { @@ -1001,18 +981,15 @@ fn prepareCiphertextRecord( const pv = &p.tls_1_3; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); @@ -1021,7 +998,6 @@ fn prepareCiphertextRecord( const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -1039,35 +1015,27 @@ fn prepareCiphertextRecord( }; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = record; - iovec_end += 1; } }, .tls_1_2 => { const pv = &p.tls_1_2; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const message_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (message_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); bytes_i += message_len; const cleartext = cleartext_buf[0..message_len]; - const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; record_header.* = .{@intFromEnum(inner_content_type)} ++ @@ -1089,10 +1057,6 @@ fn prepareCiphertextRecord( ciphertext_end += P.mac_length; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = record; - iovec_end += 1; } }, else => unreachable, @@ -1106,74 +1070,22 @@ pub fn eof(c: Client) bool { c.partial_ciphertext_idx >= c.partial_ciphertext_end; } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); +fn reader_read( + context: ?*anyopaque, + bw: *std.io.BufferedWriter, + limit: std.io.Reader.Limit, +) anyerror!std.io.Reader.Status { + const buf = limit.slice(try bw.writableSlice(1)); + const status = try reader_readv(context, &.{buf}); + bw.advance(status.len); + return status; } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} +fn reader_readv(context: ?*anyopaque, data: []const []u8) anyerror!std.io.Reader.Status { + const c: *Client = @ptrCast(@alignCast(context)); + if (c.eof()) return .{ .end = true }; -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].len) { - amt -= iovecs[vec_i].len; - vec_i += 1; - } - iovecs[vec_i].base += amt; - iovecs[vec_i].len -= amt; - } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; + var vp: VecPut = .{ .iovecs = data }; // Give away the buffered cleartext we have, if any. const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; @@ -1193,11 +1105,11 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove if (c.received_close_notify) { c.partial_ciphertext_end = 0; assert(vp.total == amt); - return amt; + return .{ .len = amt, .end = c.eof() }; } else if (amt > 0) { // We don't need more data, so don't call read. assert(vp.total == amt); - return amt; + return .{ .len = amt, .end = c.eof() }; } } @@ -1241,7 +1153,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end; const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); + const actual_read_len = try c.input.readv(ask_iovecs); if (actual_read_len == 0) { // This is either a truncation attack, a bug in the server, or an // intentional omission of the close_notify message due to truncation @@ -1268,7 +1180,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove // Perfect split. if (frag.ptr == frag1.ptr) { c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; + return .{ .len = vp.total, .end = c.eof() }; } frag = frag1; in = 0; @@ -1310,8 +1222,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove const record_len = mem.readInt(u16, frag[in..][0..2], .big); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; - const end = in + record_len; - if (end > frag.len) { + const the_end = in + record_len; + if (the_end > frag.len) { // We need the record header on the next iteration of the loop. in -= tls.record_header_len; @@ -1398,17 +1310,23 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove .alert => { if (cleartext.len != 2) return error.TlsDecodeError; const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); + switch (desc) { + .close_notify => { + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return .{ .len = vp.total, .end = c.eof() }; + }, + .user_canceled => { + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; + }, + else => { + c.diagnostics = .{ .alert = desc }; + return error.TlsAlert; + }, + } }, .handshake => { var ct_i: usize = 0; @@ -1524,7 +1442,7 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi }) catch {}; } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { +fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) std.io.Reader.Status { const saved_buf = frag[in..]; if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { // There is cleartext at the beginning already which we need to preserve. @@ -1536,11 +1454,11 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { c.partial_ciphertext_end = @intCast(saved_buf.len); @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); } - return out; + return .{ .len = out, .end = c.eof() }; } /// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { +fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) std.io.Reader.Status { if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { // There is cleartext at the beginning already which we need to preserve. c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); @@ -1555,7 +1473,7 @@ fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usi std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); } - return out; + return .{ .len = out, .end = c.eof() }; } fn limitedOverlapCopy(frag: []u8, in: usize) void { @@ -1577,9 +1495,6 @@ fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { } } -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); - inline fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, @@ -1958,7 +1873,3 @@ else .AES_256_GCM_SHA384, .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); - -test { - _ = StreamInterface; -} diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 815cad6d27..6f7388a9b4 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -24,6 +24,12 @@ allocator: Allocator, ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, +/// Used both for the reader and writer buffers. +tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +/// If non-null, ssl secrets are logged to a stream. Creating such a stream +/// allows other processes with access to that stream to decrypt all +/// traffic over connections created with this `Client`. +ssl_key_logger: ?*std.io.BufferedWriter = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -31,6 +37,10 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, +/// Each `Connection` allocates this amount for the reader buffer. +read_buffer_size: usize, +/// Each `Connection` allocates this amount for the writer buffer. +write_buffer_size: usize, /// If populated, all http traffic travels through this third party. /// This field cannot be modified while the client has active connections. @@ -41,7 +51,7 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, -/// A set of linked lists of connections that can be reused. +/// A Least-Recently-Used cache of open connections to be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, /// Open connections that are currently in use. @@ -58,8 +68,10 @@ pub const ConnectionPool = struct { protocol: Connection.Protocol, }; - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// Finds and acquires a connection from the connection pool matching the criteria. /// If no connection is found, null is returned. + /// + /// Threadsafe. pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -96,21 +108,21 @@ pub const ConnectionPool = struct { return pool.acquireUnsafe(connection); } - /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// Tries to release a connection back to the connection pool. /// If the connection is marked as closing, it will be closed instead. /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. + /// `allocator` must be the same one used to create `connection`. + /// + /// Threadsafe. pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + if (connection.closing) return connection.destroy(allocator); + pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(&connection.pool_node); - if (connection.closing or pool.free_size == 0) { - connection.close(allocator); - return allocator.destroy(connection); - } + if (pool.free_size == 0) return connection.destroy(allocator); if (pool.free_len >= pool.free_size) { const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); @@ -138,9 +150,11 @@ pub const ConnectionPool = struct { pool.used.append(&connection.pool_node); } - /// Resizes the connection pool. This function is threadsafe. + /// Resizes the connection pool. /// /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + /// + /// Threadsafe. pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -158,9 +172,11 @@ pub const ConnectionPool = struct { pool.free_size = new_size; } - /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// Frees the connection pool and closes all connections within. /// /// All future operations on the connection pool will deadlock. + /// + /// Threadsafe. pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { pool.mutex.lock(); @@ -184,160 +200,212 @@ pub const ConnectionPool = struct { } }; -/// An interface to either a plain or TLS connection. pub const Connection = struct { + client: *Client, stream: net.Stream, - /// Populated when protocol is TLS; this is the writer given to the TLS - /// client, which writes directly to `stream`, unbuffered. - stream_writer: std.io.BufferedWriter, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, - + /// HTTP protocol from client to server. + /// This either goes directly to `stream`, or to a TLS client. + writer: std.io.BufferedWriter, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, - - /// The protocol that this connection is using. - protocol: Protocol, - - /// The host that this connection is connected to. - host: []u8, - - /// The port that this connection is connected to. port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - read_buf: [buffer_size]u8, - - write_buffer: [buffer_size]u8, - writer: std.io.BufferedWriter, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); + host_len: u8, + proxied: bool, + closing: bool, + protocol: Protocol, pub const Protocol = enum { plain, tls }; - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + const Plain = struct { + /// Data from `Connection.stream`. + reader: std.io.BufferedReader, + connection: Connection, - switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{OutOfMemory}!*Connection { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len]; + const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size]; + const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len); + @memcpy(host_buffer, remote_host); + const plain: *Plain = @ptrCast(base); + plain.* = .{ + .connection = .{ + .client = client, + .stream = stream, + .writer = stream.writer().buffered(socket_write_buffer), + .pool_node = .{}, + .port = port, + .proxied = false, + .closing = false, + .protocol = .plain, + }, + .reader = undefined, + }; + plain.reader.init(stream.reader(), socket_read_buffer); } - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; + fn destroy(plain: *Plain) void { + const c = &plain.connection; + const gpa = c.client.allocator; + const base: [*]u8 = @ptrCast(plain); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size; } - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, + fn host(plain: *Plain) []u8 { + const base: [*]u8 = @ptrCast(plain); + return base[@sizeOf(Plain)..][0..plain.connection.host_len]; + } }; - pub const Reader = std.io.Reader(*Connection, ReadError, read); + const Tls = struct { + /// Data from `client` to `Connection.stream`. + writer: std.io.BufferedWriter, + /// Data from `Connection.stream` to `client`. + reader: std.io.BufferedReader, + client: std.crypto.tls.Client, + connection: Connection, - pub fn reader(conn: *Connection) Reader { - return .{ .context = conn }; - } + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{ OutOfMemory, TlsInitializationFailed }!*Tls { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; + const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; + const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; + const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const tls: *Tls = @ptrCast(base); + tls.* = .{ + .connection = .{ + .client = client, + .stream = stream, + .writer = tls.client.writer().buffered(socket_write_buffer), + .pool_node = .{}, + .port = port, + .proxied = false, + .closing = false, + .protocol = .tls, + }, + .writer = stream.writer().buffered(tls_write_buffer), + .reader = undefined, + .client = undefined, + }; + tls.reader.init(stream.reader(), tls_read_buffer); + // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true + tls.client.init(&tls.reader, &tls.writer, .{ + .host = .{ .explicit = remote_host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_logger = client.ssl_key_logger, + }) catch return error.TlsInitializationFailed; + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + tls.client.allow_truncation_attacks = true; - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd("", true) catch {}; - if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); - allocator.destroy(conn.tls_client); + return tls; } - conn.stream.close(); - allocator.free(conn.host); + fn destroy(tls: *Tls, gpa: Allocator) void { + const c = &tls.connection; + const base: [*]u8 = @ptrCast(tls); + gpa.free(base[0..allocLen(c.client, c.host_len)]); + } + + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size; + } + + fn host(tls: *Tls) []u8 { + const base: [*]u8 = @ptrCast(tls); + return base[@sizeOf(Tls)..][0..tls.connection.host_len]; + } + }; + + fn host(c: *Connection) []u8 { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return tls.host(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + return plain.host(); + }, + }; + } + + /// This is either data from `stream`, or `Tls.client`. + fn reader(c: *Connection) *std.io.BufferedReader { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.reader; + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + return &plain.reader; + }, + }; + } + + /// If this is called without calling `flush` or `end`, data will be + /// dropped unsent. + pub fn destroy(c: *Connection) void { + c.stream.close(); + switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + tls.destroy(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + plain.destroy(); + }, + } + } + + pub fn flush(c: *Connection) anyerror!void { + try c.writer.flush(); + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.writer.flush(); + } + } + + /// If the connection is a TLS connection, sends the close_notify alert. + /// + /// Flushes all buffers. + pub fn end(c: *Connection) anyerror!void { + try c.writer.flush(); + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.client.end(); + try tls.writer.flush(); + } } }; @@ -350,10 +418,10 @@ pub const RequestTransfer = union(enum) { /// The decompressor for response messages. pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader); + pub const DeflateDecompressor = std.compress.zlib.Decompressor; + pub const GzipDecompressor = std.compress.gzip.Decompressor; // https://github.com/ziglang/zig/issues/18937 - //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(.{}); deflate: DeflateDecompressor, gzip: GzipDecompressor, @@ -617,9 +685,6 @@ pub const Response = struct { } }; -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read pub const Request = struct { uri: Uri, client: *Client, @@ -1300,24 +1365,34 @@ pub const basic_authorization = struct { } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, + TlsInitializationFailed, +}; -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// Reuses a `Connection` if one matching `host` and `port` is already open. /// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { +/// Threadsafe. +pub fn connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Connection.Protocol, +) ConnectTcpError!*Connection { if (client.connection_pool.findConnection(.{ .host = host, .port = port, .protocol = protocol, })) |conn| return conn; - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(Connection); - errdefer client.allocator.destroy(conn); - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, error.NetworkUnreachable => return error.NetworkUnreachable, @@ -1331,77 +1406,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer stream.close(); - conn.* = .{ - .stream = stream, - .stream_writer = undefined, - .tls_client = undefined, - .read_buf = undefined, - - .write_buffer = undefined, - .writer = undefined, // populated below - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - - .pool_node = .{}, - }; - errdefer client.allocator.free(conn.host); - switch (protocol) { .tls => { - if (disable_tls) unreachable; - - const tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(tls_client); - - const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { - const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { - error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, - error.OutOfMemory => return error.OutOfMemory, - }; - defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ - .truncate = false, - .mode = switch (builtin.os.tag) { - .windows, .wasi => 0, - else => 0o600, - }, - }) catch null; - } else null; - errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); - - conn.stream_writer = .{ - .unbuffered_writer = stream.writer(), - .buffer = &.{}, - }; - - tls_client.* = std.crypto.tls.Client.init(stream, &conn.stream_writer, .{ - .host = .{ .explicit = host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log_file = ssl_key_log_file, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - tls_client.allow_truncation_attacks = true; - - conn.writer = .{ - .unbuffered_writer = tls_client.writer(), - .buffer = &conn.write_buffer, - }; - conn.tls_client = tls_client; + if (disable_tls) return error.TlsInitializationFailed; + const tc = try Connection.Tls.create(client, host, port, stream); + client.connection_pool.addUsed(&tc.connection); + return &tc.connection; }, .plain => { - conn.writer = .{ - .unbuffered_writer = stream.writer(), - .buffer = &conn.write_buffer, - }; + const pc = try Connection.Plain.create(client, host, port, stream); + client.connection_pool.addUsed(&pc.connection); + return &pc.connection; }, } - - client.connection_pool.addUsed(conn); - - return conn; } pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; @@ -1662,16 +1679,17 @@ pub fn open( var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (protocol == .tls) { if (disable_tls) unreachable; + if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } } } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 91208f56d9..b01a883e48 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,18 +1,25 @@ //! Blocking HTTP server implementation. //! Handles a single connection's lifecycle. -connection: net.Server.Connection, +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; + +const Server = @This(); + +/// The reader's buffer must be large enough to store the client's entire HTTP +/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. +in: *std.io.BufferedReader, +out: *std.io.BufferedWriter, /// Keeps track of whether the Server is ready to accept a new request on the /// same connection, and makes invalid API usage cause assertion failures /// rather than HTTP protocol violations. state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, +in_err: anyerror, pub const State = enum { /// The connection is available to be used for the first time, or reused. @@ -31,14 +38,13 @@ pub const State = enum { /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. +/// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { +pub fn init(in: *std.io.BufferedReader, out: *std.io.BufferedWriter) Server { return .{ - .connection = connection, + .in = in, + .out = out, .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, }; } @@ -48,78 +54,55 @@ pub const ReceiveHeadError = error{ /// before closing the connection. HttpHeadersOversize, /// Client sent headers that did not conform to the HTTP protocol. + /// `in_err` is populated with a `Request.Head.ParseError`. HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, /// Partial HTTP request was received but the connection was closed before /// fully receiving the headers. HttpRequestTruncated, /// The client sent 0 bytes of headers before closing the stream. /// In other words, a keep-alive connection was finally closed. HttpConnectionClosing, + /// Error occurred reading from `in`; `in_err` is populated. + ReadFailure, }; -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. +/// The header bytes reference the internal storage of `in`, which are +/// invalidated with the next call to `receiveHead`. pub fn receiveHead(s: *Server) ReceiveHeadError!Request { assert(s.state == .ready); s.state = .received_head; errdefer s.state = .receiving_head; - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - + const in = &s.in; var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } + var head_end: usize = 0; while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); + if (head_end >= in.bufferContents().len) return error.HttpHeadersOversize; + const buf = (in.peekGreedy(head_end + 1) catch |err| { + s.in_err = err; + return error.ReadFailure; + }) orelse switch (head_end) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }; + head_end += hp.feed(buf[head_end..]); + if (hp.state == .finished) return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(buf[0..head_end]) catch |err| { + s.in_err = err; + return error.HttpHeadersInvalid; + }, + .reader_state = undefined, + .write_error = undefined, + }; } } -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { - return .{ - .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, - .write_error = undefined, - }; -} - pub const Request = struct { server: *Server, - /// Index into Server's read_buffer. + /// Index into `Server.in` internal buffer. head_end: usize, head: Head, reader_state: union { @@ -299,7 +282,7 @@ pub const Request = struct { }; pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + return http.HeaderIterator.init(r.in.bufferContents()[0..r.head_end]); } test iterateHeaders { @@ -312,13 +295,14 @@ pub const Request = struct { var read_buffer: [500]u8 = undefined; @memcpy(read_buffer[0..request_bytes.len], request_bytes); + var br: std.io.BufferedReader = undefined; + br.initFixed(&read_buffer); var server: Server = .{ - .connection = undefined, + .in = &br, + .out = undefined, .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, + .in_err = undefined, }; var request: Request = .{ @@ -1158,13 +1142,3 @@ fn rebase(s: *Server, index: usize) void { } s.read_buffer_len = index + leftover.len; } - -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/io/BufferedReader.zig b/lib/std/io/BufferedReader.zig index b119002392..ff14af18a5 100644 --- a/lib/std/io/BufferedReader.zig +++ b/lib/std/io/BufferedReader.zig @@ -94,6 +94,11 @@ pub fn storageBuffer(br: *BufferedReader) []u8 { return storage.buffer; } +pub fn bufferContents(br: *BufferedReader) []u8 { + const storage = &br.storage; + return storage.buffer[br.seek..storage.end]; +} + /// Although `BufferedReader` can easily satisfy the `Reader` interface, it's /// generally more practical to pass a `BufferedReader` instance itself around, /// since it will result in fewer calls across vtable boundaries. @@ -159,31 +164,69 @@ pub fn seekForwardBy(br: *BufferedReader, seek_by: u64) anyerror!void { /// is returned instead. /// /// See also: -/// * `peekAll` +/// * `peekGreedy` /// * `toss` pub fn peek(br: *BufferedReader, n: usize) anyerror![]u8 { - return (try br.peekAll(n))[0..n]; + return (try br.peekGreedy(n))[0..n]; } -/// Returns the next buffered bytes from `unbuffered_reader`, after filling the buffer -/// with at least `n` bytes. +/// Returns the next `n` bytes from `unbuffered_reader`, filling the buffer as +/// necessary. /// /// Invalidates previously returned values from `peek`. /// /// Asserts that the `BufferedReader` was initialized with a buffer capacity at /// least as big as `n`. /// +/// If there are fewer than `n` bytes left in the stream, `null` is returned +/// instead. +/// +/// See also: +/// * `peekGreedy` +/// * `toss` +pub fn peek2(br: *BufferedReader, n: usize) anyerror!?[]u8 { + if (try br.peekGreedy(n)) |buf| return buf[0..n]; + return null; +} + +/// Returns all the next buffered bytes from `unbuffered_reader`, after filling +/// the buffer to ensure it contains at least `n` bytes. +/// +/// Invalidates previously returned values from `peek` and `peekGreedy`. +/// +/// Asserts that the `BufferedReader` was initialized with a buffer capacity at +/// least as big as `n`. +/// /// If there are fewer than `n` bytes left in the stream, `error.EndOfStream` /// is returned instead. /// /// See also: /// * `peek` /// * `toss` -pub fn peekAll(br: *BufferedReader, n: usize) anyerror![]u8 { - const storage = &br.storage; - assert(n <= storage.buffer.len); - try br.fill(n); - return storage.buffer[br.seek..storage.end]; +pub fn peekGreedy(br: *BufferedReader, n: usize) anyerror![]u8 { + assert(n <= br.storage.buffer.len); + if (try br.fill(n)) return br.bufferContents(); + return error.EndOfStream; +} + +/// Returns all the next buffered bytes from `unbuffered_reader`, after filling +/// the buffer to ensure it contains at least `n` bytes. +/// +/// Invalidates previously returned values from `peek` and `peekGreedy`. +/// +/// Asserts that the `BufferedReader` was initialized with a buffer capacity at +/// least as big as `n`. +/// +/// If there are fewer than `n` bytes left in the stream, `null` is returned +/// instead. +/// +/// See also: +/// * `peek` +/// * `toss` +pub fn peekGreedy2(br: *BufferedReader, n: usize) anyerror!?[]u8 { + assert(n <= br.storage.buffer.len); + if (try br.fill(n)) return br.bufferContents(); + return null; } /// Skips the next `n` bytes from the stream, advancing the seek position. This @@ -505,17 +548,17 @@ pub fn discardDelimiterInclusive(br: *BufferedReader, delimiter: u8) anyerror!vo /// Fills the buffer such that it contains at least `n` bytes, without /// advancing the seek position. /// -/// Returns `error.EndOfStream` if there are fewer than `n` bytes remaining. +/// Returns `false` if and only if there are fewer than `n` bytes remaining. /// /// Asserts buffer capacity is at least `n`. -pub fn fill(br: *BufferedReader, n: usize) anyerror!void { +pub fn fill(br: *BufferedReader, n: usize) anyerror!bool { const storage = &br.storage; assert(n <= storage.buffer.len); const buffer = storage.buffer[0..storage.end]; const seek = br.seek; if (seek + n <= buffer.len) { @branchHint(.likely); - return; + return true; } const remainder = buffer[seek..]; std.mem.copyForwards(u8, buffer[0..remainder.len], remainder); @@ -523,8 +566,8 @@ pub fn fill(br: *BufferedReader, n: usize) anyerror!void { br.seek = 0; while (true) { const status = try br.unbuffered_reader.read(storage, .unlimited); - if (n <= storage.end) return; - if (status.end) return error.EndOfStream; + if (n <= storage.end) return true; + if (status.end) return false; } } @@ -535,7 +578,8 @@ pub fn takeByte(br: *BufferedReader) anyerror!u8 { const seek = br.seek; if (seek >= buffer.len) { @branchHint(.unlikely); - try br.fill(1); + const filled = try fill(br, 1); + if (!filled) return error.EndOfStream; } br.seek = seek + 1; return buffer[seek]; @@ -603,7 +647,7 @@ fn takeMultipleOf7Leb128(br: *BufferedReader, comptime Result: type) anyerror!Re var result: UnsignedResult = 0; var fits = true; while (true) { - const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekAll(1)); + const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekGreedy(1)); for (buffer, 1..) |byte, len| { if (remaining_bits > 0) { result = @shlExact(@as(UnsignedResult, byte.bits), result_info.bits - 7) | @@ -639,7 +683,7 @@ test peek { return error.Unimplemented; } -test peekAll { +test peekGreedy { return error.Unimplemented; }