diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index 6485e2edff..6f6e14792b 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -1306,31 +1306,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void { r.end = data.len; } -/// Advances the stream and decreases the size of the storage buffer by `n`, -/// returning the range of bytes no longer accessible by `r`. -/// -/// This action can be undone by `restitute`. -/// -/// Asserts there are at least `n` buffered bytes already. -/// -/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state. -pub fn steal(r: *Reader, n: usize) []u8 { - assert(r.seek == 0); - assert(n <= r.end); - const stolen = r.buffer[0..n]; - r.buffer = r.buffer[n..]; - r.end -= n; - return stolen; -} - -/// Expands the storage buffer, undoing the effects of `steal` -/// Assumes that `n` does not exceed the total number of stolen bytes. -pub fn restitute(r: *Reader, n: usize) void { - r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n]; - r.end += n; - r.seek += n; -} - test fixed { var r: Reader = .fixed("a\x02"); try testing.expect((try r.takeByte()) == 'a'); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index da6a431840..e647a7710e 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{ }; pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.Description.close_notify), }; pub const ProtocolVersion = enum(u16) { @@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) { _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; +pub const Alert = struct { + level: Level, + description: Description, -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + pub const Description = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; - pub fn toError(alert: AlertDescription) Error!void { - switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => return error.TlsAlertUnexpectedMessage, - .bad_record_mac => return error.TlsAlertBadRecordMac, - .record_overflow => return error.TlsAlertRecordOverflow, - .handshake_failure => return error.TlsAlertHandshakeFailure, - .bad_certificate => return error.TlsAlertBadCertificate, - .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, - .certificate_revoked => return error.TlsAlertCertificateRevoked, - .certificate_expired => return error.TlsAlertCertificateExpired, - .certificate_unknown => return error.TlsAlertCertificateUnknown, - .illegal_parameter => return error.TlsAlertIllegalParameter, - .unknown_ca => return error.TlsAlertUnknownCa, - .access_denied => return error.TlsAlertAccessDenied, - .decode_error => return error.TlsAlertDecodeError, - .decrypt_error => return error.TlsAlertDecryptError, - .protocol_version => return error.TlsAlertProtocolVersion, - .insufficient_security => return error.TlsAlertInsufficientSecurity, - .internal_error => return error.TlsAlertInternalError, - .inappropriate_fallback => return error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => return error.TlsAlertMissingExtension, - .unsupported_extension => return error.TlsAlertUnsupportedExtension, - .unrecognized_name => return error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, - .certificate_required => return error.TlsAlertCertificateRequired, - .no_application_protocol => return error.TlsAlertNoApplicationProtocol, - _ => return error.TlsAlertUnknown, + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(description: Description) Error!void { + switch (description) { + .close_notify => {}, // not an error + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } - } + }; }; pub const SignatureScheme = enum(u16) { @@ -650,7 +655,7 @@ pub const Decoder = struct { } /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void { assert(!d.disable_reads); const existing_amt = d.cap - d.idx; d.their_end = d.idx + their_amt; @@ -658,14 +663,16 @@ pub const Decoder = struct { const request_amt = their_amt - existing_amt; const dest = d.buf[d.cap..]; if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; + stream.readSlice(dest[0..request_amt]) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + d.cap += request_amt; } /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void { assert(!d.disable_reads); try readAtLeast(d, stream, our_amt); d.our_end = d.idx + our_amt; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3fa7b73d06..082fc9da70 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,11 +1,15 @@ +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; +const Reader = std.io.Reader; +const Writer = std.io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -13,44 +17,58 @@ const hkdfExpandLabel = tls.hkdfExpandLabel; const int = tls.int; const array = tls.array; +/// The encrypted stream from the server to the client. Bytes are pulled from +/// here via `reader`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +input: *Reader, +/// Decrypted stream from the server to the client. +reader: Reader, + +/// The encrypted stream from the client to the server. Bytes are pushed here +/// via `writer`. +output: *Writer, +/// The plaintext stream from the client to the server. +writer: Writer, + +/// Populated when `error.TlsAlert` is returned. +alert: ?tls.Alert = null, +read_err: ?ReadError = null, tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. +/// may be data in the input buffer. received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, -/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other -/// programs with access to that file to decrypt all traffic over this connection. -ssl_key_log: ?struct { + +/// If non-null, ssl secrets are logged to a stream. Creating such a log file +/// allows other programs with access to that file to decrypt all traffic over +/// this connection. +ssl_key_log: ?*SslKeyLog, + +pub const ReadError = error{ + /// The alert description will be stored in `alert`. + TlsAlert, + TlsBadLength, + TlsBadRecordMac, + TlsConnectionTruncated, + TlsDecodeError, + TlsRecordOverflow, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsSequenceOverflow, + /// The buffer provided to the read function was not at least + /// `min_buffer_len`. + OutputBufferUndersize, +}; + +pub const SslKeyLog = struct { client_key_seq: u64, server_key_seq: u64, client_random: [32]u8, - file: std.fs.File, + writer: *Writer, fn clientCounter(key_log: *@This()) u64 { defer key_log.client_key_seq += 1; @@ -61,51 +79,12 @@ ssl_key_log: ?struct { defer key_log.server_key_seq += 1; return key_log.server_key_seq; } -}, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } }; +/// The `Reader` supplied to `init` requires a buffer capacity +/// at least this amount. +pub const min_buffer_len = tls.max_ciphertext_record_len; + pub const Options = struct { /// How to perform host verification of server certificates. host: union(enum) { @@ -127,64 +106,85 @@ pub const Options = struct { /// Verify that the server certificate is authorized by a given ca bundle. bundle: Certificate.Bundle, }, - /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows /// other programs with access to that file to decrypt all traffic over this connection. - ssl_key_log_file: ?std.fs.File = null, + /// + /// Only the `writer` field is observed during the handshake (`init`). + /// After that, the other fields are populated. + ssl_key_log: ?*SslKeyLog = null, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + write_buffer: []u8, + /// Asserted to have capacity at least `min_buffer_len`. + read_buffer: []u8, + /// Populated when `error.TlsAlert` is returned from `init`. + alert: ?*tls.Alert = null, }; -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; -} +const InitError = error{ + WriteFailed, + ReadFailed, + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + /// The alert description will be stored in `alert`. + TlsAlert, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + WeakPublicKey, +}; -/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { +/// +/// `input` is asserted to have buffer capacity at least `min_buffer_len`. +pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { + assert(input.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -276,11 +276,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; { - var iovecs = [_]std.posix.iovec_const{ - .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, - .{ .base = host.ptr, .len = host.len }, - }; - try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); + var iovecs: [2][]const u8 = .{ cleartext_header, host }; + try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); } var tls_version: tls.ProtocolVersion = undefined; @@ -329,20 +326,26 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var cleartext_fragment_start: usize = 0; var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; - var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..tls.record_header_len]; - const record_ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); + // Ensure the input buffer pointer is stable in this scope. + input.rebaseCapacity(tls.max_ciphertext_record_len); + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked + input.toss(2); // legacy_version + const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked + if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow; + const record_buffer = input.take(record_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, .handshake => { - std.debug.assert(tls_version == .tls_1_3); + assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -374,7 +377,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { - std.debug.assert(tls_version == .tls_1_2); + assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -412,14 +415,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client switch (ct) { .alert => { ctd.ensure(2) catch continue :fragment; - const level = ctd.decode(tls.AlertLevel); - const desc = ctd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + if (options.alert) |a| a.* = .{ + .level = ctd.decode(tls.Alert.Level), + .description = ctd.decode(tls.Alert.Description), + }; + return error.TlsAlert; }, .change_cipher_spec => { ctd.ensure(1) catch continue :fragment; @@ -533,7 +533,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, @@ -707,7 +707,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client &client_hello_rand, &server_hello_rand, }, 48); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .CLIENT_RANDOM = &master_secret, @@ -755,11 +755,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client nonce, pv.app_cipher.client_write_key, ); - const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [3][]const u8 = .{ + &client_key_exchange_msg, + &client_change_cipher_spec_msg, + &client_verify_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); }, } write_seq += 1; @@ -820,15 +821,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const nonce = pv.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [2][]const u8 = .{ + &client_change_cipher_spec_msg, + &finished_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .counter = key_seq, .client_random = &client_hello_rand, }, .{ @@ -855,8 +856,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client else => unreachable, }, }; - const leftover = d.rest(); - var client: Client = .{ + if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .writer = ssl_key_log.writer, + }; + return .{ + .input = input, + .reader = .{ + .buffer = options.read_buffer, + .vtable = &.{ .stream = stream }, + .seek = 0, + .end = 0, + }, + .output = output, + .writer = .{ + .buffer = options.write_buffer, + .vtable = &.{ + .drain = drain, + .sendFile = Writer.unimplementedSendFile, + }, + }, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -868,22 +889,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client .tls_1_2 => write_seq, else => unreachable, }, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), .received_close_notify = false, - .allow_truncation_attacks = false, + .allow_truncation_attacks = options.allow_truncation_attacks, .application_cipher = app_cipher, - .partially_read_buffer = undefined, - .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ - .client_key_seq = key_seq, - .server_key_seq = key_seq, - .client_random = client_hello_rand, - .file = key_log_file, - } else null, + .ssl_key_log = options.ssl_key_log, }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; }, else => return error.TlsUnexpectedMessage, } @@ -897,94 +907,48 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); +fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + const c: *Client = @fieldParentPtr("writer", w); + if (true) @panic("update to use the buffer and flush"); + const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + var total_clear: usize = 0; + var ciphertext_end: usize = 0; + for (sliced_data) |buf| { + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (total_clear < buf.len) break; + } + output.advance(ciphertext_end); + return total_clear; } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.posix.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].len) { - const encrypted_amt = iovecs_buf[i].len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i].base += amt; - iovecs_buf[i].len -= amt; - } +/// Sends a `close_notify` alert, which is necessary for the server to +/// distinguish between a properly finished TLS session, or a truncation +/// attack. +pub fn end(c: *Client) Writer.Error!void { + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); + output.advance(prepared.cleartext_len); + return prepared.ciphertext_end; } fn prepareCiphertextRecord( c: *Client, - iovecs: []std.posix.iovec_const, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, ) struct { - iovec_end: usize, ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, + cleartext_len: usize, } { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; var ciphertext_end: usize = 0; - var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { @@ -992,18 +956,15 @@ fn prepareCiphertextRecord( const pv = &p.tls_1_3; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); @@ -1012,7 +973,6 @@ fn prepareCiphertextRecord( const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -1030,38 +990,27 @@ fn prepareCiphertextRecord( }; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, .tls_1_2 => { const pv = &p.tls_1_2; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const message_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (message_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); bytes_i += message_len; const cleartext = cleartext_buf[0..message_len]; - const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; record_header.* = .{@intFromEnum(inner_content_type)} ++ @@ -1083,13 +1032,6 @@ fn prepareCiphertextRecord( ciphertext_end += P.mac_length; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, else => unreachable, @@ -1098,421 +1040,194 @@ fn prepareCiphertextRecord( } pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; + return c.received_close_notify; } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].len) { - amt -= iovecs[vec_i].len; - vec_i += 1; - } - iovecs[vec_i].base += amt; - iovecs[vec_i].len -= amt; - } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } - } - - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer` space. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, +fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { + const c: *Client = @fieldParentPtr("reader", r); + if (c.eof()) return error.EndOfStream; + const input = c.input; + // If at least one full encrypted record is not buffered, read once. + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => { + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + return error.EndOfStream; + } else { + return failRead(c, error.TlsConnectionTruncated); + } }, - .{ - .base = &in_stack_buffer, - .len = in_stack_buffer.len, + error.ReadFailed => return error.ReadFailed, + }; + const ct: tls.ContentType = @enumFromInt(record_header[0]); + const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big); + _ = legacy_version; + const record_len = mem.readInt(u16, record_header[3..][0..2], .big); + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); + const record_end = 5 + record_len; + if (record_end > input.buffered().len) { + input.fillMore() catch |err| switch (err) { + error.EndOfStream => return failRead(c, error.TlsConnectionTruncated), + error.ReadFailed => return error.ReadFailed, + }; + if (record_end > input.buffered().len) return 0; + } + + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = input.take(tls.record_header_len) catch unreachable; // already peeked + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked + const nonce = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return failRead(c, error.TlsBadRecordMac); + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked + const ad = std.mem.toBytes(big(c.read_seq)) ++ + ad_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = input.take(message_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return failRead(c, error.TlsBadRecordMac); + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, }, }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end; - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { - inline else => |*p| switch (c.tls_version) { - .tls_1_3 => { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); - break :nonce @as(V, pv.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch - return error.TlsBadRecordMac; - const msg = mem.trimEnd(u8, cleartext, "\x00"); - break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; - }, - .tls_1_2 => { - const pv = &p.tls_1_2; - const P = @TypeOf(p.*); - const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - const ad = std.mem.toBytes(big(c.read_seq)) ++ - frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); - const record_iv = frag[in..][0..P.record_iv_length].*; - in += P.record_iv_length; - const masked_read_seq = c.read_seq & - comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); - const nonce: [P.AEAD.nonce_length]u8 = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); - break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; - }; - const ciphertext = frag[in..][0..message_len]; - in += message_len; - const auth_tag = frag[in..][0..P.mac_length].*; - in += P.mac_length; - const out_buf = vp.peek(); - const cleartext_buf = if (message_len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch - return error.TlsBadRecordMac; - break :cleartext .{ cleartext, ct }; - }, - else => unreachable, - }, - }; - c.read_seq = try std.math.add(u64, c.read_seq, 1); - switch (inner_ct) { - .alert => { - if (cleartext.len != 2) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { + c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); + switch (inner_ct) { + .alert => { + if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); + const alert: tls.Alert = .{ + .level = @enumFromInt(cleartext[0]), + .description = @enumFromInt(cleartext[1]), + }; + switch (alert.description) { + .close_notify => { c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.serverCounter(), - .client_random = &key_log.client_random, - }, .{ - .SERVER_TRAFFIC_SECRET = &server_secret, - }); - pv.server_secret = server_secret; - pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.clientCounter(), - .client_random = &key_log.client_random, - }, .{ - .CLIENT_TRAFFIC_SECRET = &client_secret, - }); - pv.client_secret = client_secret; - pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], - cleartext, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); - } else { - const amt = vp.put(cleartext); - if (amt < cleartext.len) { - const rest = cleartext[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); + return 0; + }, + .user_canceled => { + // TODO: handle server-side closures + return failRead(c, error.TlsUnexpectedMessage); + }, + else => { + c.alert = alert; + return failRead(c, error.TlsAlert); + }, + } + }, + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len); + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return failRead(c, error.TlsIllegalParameter), + } + }, + else => return failRead(c, error.TlsUnexpectedMessage), } - }, - else => return error.TlsUnexpectedMessage, - } - in = end; + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; + } + return 0; + }, + .application_data => { + if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize); + try w.writeAll(cleartext); + return cleartext.len; + }, + else => return failRead(c, error.TlsUnexpectedMessage), } } -fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { - const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; - defer if (locked) key_log_file.unlock(); - key_log_file.seekFromEnd(0) catch {}; - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.deprecatedWriter().print("{s}" ++ +fn failRead(c: *Client, err: ReadError) error{ReadFailed} { + c.read_err = err; + return error.ReadFailed; +} + +fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void { + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random, @@ -1520,62 +1235,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi }) catch {}; } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} - -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; -} - -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} - -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; - } -} - -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); - fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, @@ -1836,81 +1495,6 @@ const CertificatePublicKey = struct { } }; -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.base[vp.off..v.len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].len) { - vp.off = 0; - vp.idx += 1; - } - } - - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len; - return total; - } -}; - -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.len) { - iovec.len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.len; - } - return iovecs; -} - /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. @@ -1954,7 +1538,3 @@ else .AES_256_GCM_SHA384, .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); - -test { - _ = StreamInterface; -} diff --git a/lib/std/http.zig b/lib/std/http.zig index 6075a2fe6d..c64a946a25 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -343,10 +343,9 @@ pub const Reader = struct { /// read from `in`. trailers: []const u8 = &.{}, body_err: ?BodyError = null, - /// Stolen from `in`. - head_buffer: []u8 = &.{}, - - pub const max_chunk_header_len = 22; + /// Determines at which point `error.HttpHeadersOversize` occurs, as well + /// as the minimum buffer capacity of `in`. + max_head_len: usize, pub const RemainingChunkLen = enum(u64) { head = 0, @@ -398,19 +397,11 @@ pub const Reader = struct { ReadFailed, }; - pub fn restituteHeadBuffer(reader: *Reader) void { - reader.in.restitute(reader.head_buffer.len); - reader.head_buffer.len = 0; - } - - /// Buffers the entire head into `head_buffer`, invalidating the previous - /// `head_buffer`, if any. + /// Buffers the entire head. pub fn receiveHead(reader: *Reader) HeadError!void { reader.trailers = &.{}; const in = reader.in; - in.restitute(reader.head_buffer.len); - reader.head_buffer.len = 0; - in.rebase(); + try in.rebase(reader.max_head_len); var hp: HeadParser = .{}; var head_end: usize = 0; while (true) {