From 6c7b103122f7394f2596830e77c8194a0d33770e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 18 Feb 2025 23:33:27 -0800 Subject: [PATCH] std.crypto.tls.Client: upgrade to `std.io.BufferedWriter` This is pretty clearly a better API. --- lib/std/crypto/tls/Client.zig | 243 +++++++++++++++++----------------- 1 file changed, 124 insertions(+), 119 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 868096f477..fd0ab12930 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -44,6 +44,8 @@ application_cipher: tls.ApplicationCipher, /// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and /// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, +/// Encrypted bytes sent to the server here. +output: *std.io.BufferedWriter, /// If non-null, ssl secrets are logged to a file. Creating such a log file allows other /// programs with access to that file to decrypt all traffic over this connection. ssl_key_log: ?struct { @@ -84,24 +86,6 @@ pub const StreamInterface = struct { _ = .{ this, iovecs }; @panic("unimplemented"); } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!void { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } }; pub const Options = struct { @@ -129,61 +113,92 @@ pub const Options = struct { /// other programs with access to that file to decrypt all traffic over this connection. /// TODO `std.crypto` should have no dependencies on `std.fs`. ssl_key_log_file: ?std.fs.File = null, + diagnostics: ?*Diagnostics = null, + + pub const Diagnostics = union { + /// Populated on `error.WriteFailure` and `error.ReadFailure`. + err: anyerror, + /// Populated on `error.TlsAlert`. + /// + /// If this isn't a error alert, then it's a closure alert, which makes + /// no sense in a handshake. + alert: tls.AlertDescription, + }; }; -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, +/// TODO I wish this could be a method of Diagnostics +fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void { + returned catch |err| { + if (opt_diags) |diags| diags.* = .{ .err = err }; + return error.WriteFailure; }; } -/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which +/// TODO I wish this could be a method of Diagnostics +fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void { + returned catch |err| { + if (opt_diags) |diags| diags.* = .{ .err = err }; + return error.ReadFailure; + }; +} + +const InitError = error{ + //OutOfMemory, + WriteFailure, + ReadFailure, + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + /// The alert description will be stored in `Options.Diagnostics.alert`. + TlsAlert, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + WeakPublicKey, +}; + +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which /// must conform to `StreamInterface`. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { +pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client { + const diags = options.diagnostics; const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -275,11 +290,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; { - var iovecs = [_]std.posix.iovec_const{ - .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, - .{ .base = host.ptr, .len = host.len }, - }; - try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); + var iovecs: [2][]const u8 = .{ cleartext_header, host }; + try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2])); } var tls_version: tls.ProtocolVersion = undefined; @@ -331,12 +343,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); + try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len)); const record_header = d.buf[d.idx..][0..tls.record_header_len]; const record_ct = d.decode(tls.ContentType); d.skip(2); // legacy_version const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); + try wrapRead(diags, d.readAtLeast(input, record_len)); var record_decoder = try d.sub(record_len); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, @@ -414,11 +426,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const level = ctd.decode(tls.AlertLevel); const desc = ctd.decode(tls.AlertDescription); _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + if (diags) |x| x.* = .{ .alert = desc }; + return error.TlsAlert; }, .change_cipher_spec => { ctd.ensure(1) catch continue :fragment; @@ -754,11 +763,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client nonce, pv.app_cipher.client_write_key, ); - const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [3][]const u8 = .{ + &client_key_exchange_msg, + &client_change_cipher_spec_msg, + &client_verify_msg, }; - try stream.writevAll(&all_msgs_vec); + try wrapWrite(diags, output.writevAll(&all_msgs_vec)); }, } write_seq += 1; @@ -819,11 +829,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const nonce = pv.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [2][]const u8 = .{ + &client_change_cipher_spec_msg, + &finished_msg, }; - try stream.writevAll(&all_msgs_vec); + try wrapWrite(diags, output.writevAll(&all_msgs_vec)); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); @@ -873,6 +883,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client .received_close_notify = false, .allow_truncation_attacks = false, .application_cipher = app_cipher, + .output = output, .partially_read_buffer = undefined, .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ .client_key_seq = key_seq, @@ -896,39 +907,39 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); +pub fn writer(c: *Client) std.io.Writer { + return .{ + .context = c, + .vtable = &.{ + .writeSplat = writeSplat, + .writeFile = std.io.Writer.unimplemented_writeFile, + }, + }; } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { +fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + const c: *Client = @alignCast(@ptrCast(context)); + assert(data.len > 1 or splat > 0); + return writeEnd(c, data[0], false); +} + +/// If `end` is true, then this function additionally sends a `close_notify` +/// alert, which is necessary for the server to distinguish between a properly +/// finished TLS session, or a truncation attack. +pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void { var index: usize = 0; while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); + index += try c.writeEnd(bytes[index..], end); } } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. /// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. /// If `end` is true, then this function additionally sends a `close_notify` alert, /// which is necessary for the server to distinguish between a properly finished /// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { +pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize { var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.posix.iovec_const = undefined; + var iovecs_buf: [6][]const u8 = undefined; var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); if (end) { prepared.iovec_end += prepareCiphertextRecord( @@ -948,7 +959,7 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz var i: usize = 0; var total_amt: usize = 0; while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); + var amt = try c.output.writev(iovecs_buf[i..iovec_end]); while (amt >= iovecs_buf[i].len) { const encrypted_amt = iovecs_buf[i].len; total_amt += encrypted_amt - overhead_len; @@ -962,14 +973,13 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz // not sent; otherwise the caller would not know to retry the call. if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; } - iovecs_buf[i].base += amt; - iovecs_buf[i].len -= amt; + iovecs_buf[i] = iovecs_buf[i][amt..]; } } fn prepareCiphertextRecord( c: *Client, - iovecs: []std.posix.iovec_const, + iovecs: [][]const u8, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, @@ -1031,10 +1041,7 @@ fn prepareCiphertextRecord( c.write_seq += 1; // TODO send key_update on overflow const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; + iovecs[iovec_end] = record; iovec_end += 1; } }, @@ -1084,10 +1091,7 @@ fn prepareCiphertextRecord( c.write_seq += 1; // TODO send key_update on overflow const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; + iovecs[iovec_end] = record; iovec_end += 1; } }, @@ -1511,7 +1515,8 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; defer if (locked) key_log_file.unlock(); key_log_file.seekFromEnd(0) catch {}; - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++ + var w = key_log_file.writer().unbuffered(); + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random,