std.crypto.tls.Client: upgrade to std.io.BufferedWriter

This is pretty clearly a better API.
This commit is contained in:
Andrew Kelley 2025-02-18 23:33:27 -08:00
parent d87b59f5a5
commit 6c7b103122

View File

@ -44,6 +44,8 @@ application_cipher: tls.ApplicationCipher,
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
/// `partial_ciphertext_end` describe the span of the segments.
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
/// Encrypted bytes sent to the server here.
output: *std.io.BufferedWriter,
/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
/// programs with access to that file to decrypt all traffic over this connection.
ssl_key_log: ?struct {
@ -84,24 +86,6 @@ pub const StreamInterface = struct {
_ = .{ this, iovecs };
@panic("unimplemented");
}
/// Can be any error set.
pub const WriteError = error{};
/// Returns the number of bytes read, which may be less than the buffer
/// space provided. A short read does not indicate end-of-stream.
pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
/// The `iovecs` parameter is mutable in case this function needs to mutate
/// the fields in order to handle partial writes from the underlying layer.
pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!void {
// This can be implemented in terms of writev, or specialized if desired.
_ = .{ this, iovecs };
@panic("unimplemented");
}
};
pub const Options = struct {
@ -129,61 +113,92 @@ pub const Options = struct {
/// other programs with access to that file to decrypt all traffic over this connection.
/// TODO `std.crypto` should have no dependencies on `std.fs`.
ssl_key_log_file: ?std.fs.File = null,
diagnostics: ?*Diagnostics = null,
pub const Diagnostics = union {
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
err: anyerror,
/// Populated on `error.TlsAlert`.
///
/// If this isn't a error alert, then it's a closure alert, which makes
/// no sense in a handshake.
alert: tls.AlertDescription,
};
};
pub fn InitError(comptime Stream: type) type {
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
InsufficientEntropy,
DiskQuota,
LockViolation,
NotOpenForWriting,
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsDecryptFailure,
TlsRecordOverflow,
TlsBadRecordMac,
CertificateFieldHasInvalidLength,
CertificateHostMismatch,
CertificatePublicKeyInvalid,
CertificateExpired,
CertificateFieldHasWrongDataType,
CertificateIssuerMismatch,
CertificateNotYetValid,
CertificateSignatureAlgorithmMismatch,
CertificateSignatureAlgorithmUnsupported,
CertificateSignatureInvalid,
CertificateSignatureInvalidLength,
CertificateSignatureNamedCurveUnsupported,
CertificateSignatureUnsupportedBitCount,
TlsCertificateNotVerified,
TlsBadSignatureScheme,
TlsBadRsaSignatureBitCount,
InvalidEncoding,
IdentityElement,
SignatureVerificationFailed,
TlsDecryptError,
TlsConnectionTruncated,
TlsDecodeError,
UnsupportedCertificateVersion,
CertificateTimeInvalid,
CertificateHasUnrecognizedObjectId,
CertificateHasInvalidBitString,
MessageTooLong,
NegativeIntoUnsigned,
TargetTooSmall,
BufferTooSmall,
InvalidSignature,
NotSquare,
NonCanonical,
WeakPublicKey,
/// TODO I wish this could be a method of Diagnostics
fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
returned catch |err| {
if (opt_diags) |diags| diags.* = .{ .err = err };
return error.WriteFailure;
};
}
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which
/// TODO I wish this could be a method of Diagnostics
fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
returned catch |err| {
if (opt_diags) |diags| diags.* = .{ .err = err };
return error.ReadFailure;
};
}
const InitError = error{
//OutOfMemory,
WriteFailure,
ReadFailure,
InsufficientEntropy,
DiskQuota,
LockViolation,
NotOpenForWriting,
/// The alert description will be stored in `Options.Diagnostics.alert`.
TlsAlert,
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsDecryptFailure,
TlsRecordOverflow,
TlsBadRecordMac,
CertificateFieldHasInvalidLength,
CertificateHostMismatch,
CertificatePublicKeyInvalid,
CertificateExpired,
CertificateFieldHasWrongDataType,
CertificateIssuerMismatch,
CertificateNotYetValid,
CertificateSignatureAlgorithmMismatch,
CertificateSignatureAlgorithmUnsupported,
CertificateSignatureInvalid,
CertificateSignatureInvalidLength,
CertificateSignatureNamedCurveUnsupported,
CertificateSignatureUnsupportedBitCount,
TlsCertificateNotVerified,
TlsBadSignatureScheme,
TlsBadRsaSignatureBitCount,
InvalidEncoding,
IdentityElement,
SignatureVerificationFailed,
TlsDecryptError,
TlsConnectionTruncated,
TlsDecodeError,
UnsupportedCertificateVersion,
CertificateTimeInvalid,
CertificateHasUnrecognizedObjectId,
CertificateHasInvalidBitString,
MessageTooLong,
NegativeIntoUnsigned,
TargetTooSmall,
BufferTooSmall,
InvalidSignature,
NotSquare,
NonCanonical,
WeakPublicKey,
};
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which
/// must conform to `StreamInterface`.
///
/// `host` is only borrowed during this function call.
pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client {
pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client {
const diags = options.diagnostics;
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@ -275,11 +290,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
};
{
var iovecs = [_]std.posix.iovec_const{
.{ .base = cleartext_header.ptr, .len = cleartext_header.len },
.{ .base = host.ptr, .len = host.len },
};
try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]);
var iovecs: [2][]const u8 = .{ cleartext_header, host };
try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
}
var tls_version: tls.ProtocolVersion = undefined;
@ -331,12 +343,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
fragment: while (true) {
try d.readAtLeastOurAmt(stream, tls.record_header_len);
try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len));
const record_header = d.buf[d.idx..][0..tls.record_header_len];
const record_ct = d.decode(tls.ContentType);
d.skip(2); // legacy_version
const record_len = d.decode(u16);
try d.readAtLeast(stream, record_len);
try wrapRead(diags, d.readAtLeast(input, record_len));
var record_decoder = try d.sub(record_len);
var ctd, const ct = content: switch (cipher_state) {
.cleartext => .{ record_decoder, record_ct },
@ -414,11 +426,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const level = ctd.decode(tls.AlertLevel);
const desc = ctd.decode(tls.AlertDescription);
_ = level;
// if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
if (diags) |x| x.* = .{ .alert = desc };
return error.TlsAlert;
},
.change_cipher_spec => {
ctd.ensure(1) catch continue :fragment;
@ -754,11 +763,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
nonce,
pv.app_cipher.client_write_key,
);
const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg;
var all_msgs_vec = [_]std.posix.iovec_const{
.{ .base = &all_msgs, .len = all_msgs.len },
var all_msgs_vec: [3][]const u8 = .{
&client_key_exchange_msg,
&client_change_cipher_spec_msg,
&client_verify_msg,
};
try stream.writevAll(&all_msgs_vec);
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
},
}
write_seq += 1;
@ -819,11 +829,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const nonce = pv.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key);
const all_msgs = client_change_cipher_spec_msg ++ finished_msg;
var all_msgs_vec = [_]std.posix.iovec_const{
.{ .base = &all_msgs, .len = all_msgs.len },
var all_msgs_vec: [2][]const u8 = .{
&client_change_cipher_spec_msg,
&finished_msg,
};
try stream.writevAll(&all_msgs_vec);
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
@ -873,6 +883,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
.received_close_notify = false,
.allow_truncation_attacks = false,
.application_cipher = app_cipher,
.output = output,
.partially_read_buffer = undefined,
.ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
.client_key_seq = key_seq,
@ -896,39 +907,39 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
}
}
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
return writeEnd(c, stream, bytes, false);
pub fn writer(c: *Client) std.io.Writer {
return .{
.context = c,
.vtable = &.{
.writeSplat = writeSplat,
.writeFile = std.io.Writer.unimplemented_writeFile,
},
};
}
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const c: *Client = @alignCast(@ptrCast(context));
assert(data.len > 1 or splat > 0);
return writeEnd(c, data[0], false);
}
/// If `end` is true, then this function additionally sends a `close_notify`
/// alert, which is necessary for the server to distinguish between a properly
/// finished TLS session, or a truncation attack.
pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.write(stream, bytes[index..]);
index += try c.writeEnd(bytes[index..], end);
}
}
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
/// If `end` is true, then this function additionally sends a `close_notify` alert,
/// which is necessary for the server to distinguish between a properly finished
/// TLS session, or a truncation attack.
pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.writeEnd(stream, bytes[index..], end);
}
}
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
/// If `end` is true, then this function additionally sends a `close_notify` alert,
/// which is necessary for the server to distinguish between a properly finished
/// TLS session, or a truncation attack.
pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize {
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
var iovecs_buf: [6]std.posix.iovec_const = undefined;
var iovecs_buf: [6][]const u8 = undefined;
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
if (end) {
prepared.iovec_end += prepareCiphertextRecord(
@ -948,7 +959,7 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
var i: usize = 0;
var total_amt: usize = 0;
while (true) {
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
var amt = try c.output.writev(iovecs_buf[i..iovec_end]);
while (amt >= iovecs_buf[i].len) {
const encrypted_amt = iovecs_buf[i].len;
total_amt += encrypted_amt - overhead_len;
@ -962,14 +973,13 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
// not sent; otherwise the caller would not know to retry the call.
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
}
iovecs_buf[i].base += amt;
iovecs_buf[i].len -= amt;
iovecs_buf[i] = iovecs_buf[i][amt..];
}
}
fn prepareCiphertextRecord(
c: *Client,
iovecs: []std.posix.iovec_const,
iovecs: [][]const u8,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
@ -1031,10 +1041,7 @@ fn prepareCiphertextRecord(
c.write_seq += 1; // TODO send key_update on overflow
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = .{
.base = record.ptr,
.len = record.len,
};
iovecs[iovec_end] = record;
iovec_end += 1;
}
},
@ -1084,10 +1091,7 @@ fn prepareCiphertextRecord(
c.write_seq += 1; // TODO send key_update on overflow
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = .{
.base = record.ptr,
.len = record.len,
};
iovecs[iovec_end] = record;
iovec_end += 1;
}
},
@ -1511,7 +1515,8 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
defer if (locked) key_log_file.unlock();
key_log_file.seekFromEnd(0) catch {};
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++
var w = key_log_file.writer().unbuffered();
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
context.client_random,