mirror of
https://github.com/ziglang/zig.git
synced 2026-02-12 20:37:54 +00:00
std.crypto.tls.Client: upgrade to std.io.BufferedWriter
This is pretty clearly a better API.
This commit is contained in:
parent
d87b59f5a5
commit
6c7b103122
@ -44,6 +44,8 @@ application_cipher: tls.ApplicationCipher,
|
||||
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
|
||||
/// `partial_ciphertext_end` describe the span of the segments.
|
||||
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
|
||||
/// Encrypted bytes sent to the server here.
|
||||
output: *std.io.BufferedWriter,
|
||||
/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
|
||||
/// programs with access to that file to decrypt all traffic over this connection.
|
||||
ssl_key_log: ?struct {
|
||||
@ -84,24 +86,6 @@ pub const StreamInterface = struct {
|
||||
_ = .{ this, iovecs };
|
||||
@panic("unimplemented");
|
||||
}
|
||||
|
||||
/// Can be any error set.
|
||||
pub const WriteError = error{};
|
||||
|
||||
/// Returns the number of bytes read, which may be less than the buffer
|
||||
/// space provided. A short read does not indicate end-of-stream.
|
||||
pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize {
|
||||
_ = .{ this, iovecs };
|
||||
@panic("unimplemented");
|
||||
}
|
||||
|
||||
/// The `iovecs` parameter is mutable in case this function needs to mutate
|
||||
/// the fields in order to handle partial writes from the underlying layer.
|
||||
pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!void {
|
||||
// This can be implemented in terms of writev, or specialized if desired.
|
||||
_ = .{ this, iovecs };
|
||||
@panic("unimplemented");
|
||||
}
|
||||
};
|
||||
|
||||
pub const Options = struct {
|
||||
@ -129,61 +113,92 @@ pub const Options = struct {
|
||||
/// other programs with access to that file to decrypt all traffic over this connection.
|
||||
/// TODO `std.crypto` should have no dependencies on `std.fs`.
|
||||
ssl_key_log_file: ?std.fs.File = null,
|
||||
diagnostics: ?*Diagnostics = null,
|
||||
|
||||
pub const Diagnostics = union {
|
||||
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
|
||||
err: anyerror,
|
||||
/// Populated on `error.TlsAlert`.
|
||||
///
|
||||
/// If this isn't a error alert, then it's a closure alert, which makes
|
||||
/// no sense in a handshake.
|
||||
alert: tls.AlertDescription,
|
||||
};
|
||||
};
|
||||
|
||||
pub fn InitError(comptime Stream: type) type {
|
||||
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
|
||||
InsufficientEntropy,
|
||||
DiskQuota,
|
||||
LockViolation,
|
||||
NotOpenForWriting,
|
||||
TlsUnexpectedMessage,
|
||||
TlsIllegalParameter,
|
||||
TlsDecryptFailure,
|
||||
TlsRecordOverflow,
|
||||
TlsBadRecordMac,
|
||||
CertificateFieldHasInvalidLength,
|
||||
CertificateHostMismatch,
|
||||
CertificatePublicKeyInvalid,
|
||||
CertificateExpired,
|
||||
CertificateFieldHasWrongDataType,
|
||||
CertificateIssuerMismatch,
|
||||
CertificateNotYetValid,
|
||||
CertificateSignatureAlgorithmMismatch,
|
||||
CertificateSignatureAlgorithmUnsupported,
|
||||
CertificateSignatureInvalid,
|
||||
CertificateSignatureInvalidLength,
|
||||
CertificateSignatureNamedCurveUnsupported,
|
||||
CertificateSignatureUnsupportedBitCount,
|
||||
TlsCertificateNotVerified,
|
||||
TlsBadSignatureScheme,
|
||||
TlsBadRsaSignatureBitCount,
|
||||
InvalidEncoding,
|
||||
IdentityElement,
|
||||
SignatureVerificationFailed,
|
||||
TlsDecryptError,
|
||||
TlsConnectionTruncated,
|
||||
TlsDecodeError,
|
||||
UnsupportedCertificateVersion,
|
||||
CertificateTimeInvalid,
|
||||
CertificateHasUnrecognizedObjectId,
|
||||
CertificateHasInvalidBitString,
|
||||
MessageTooLong,
|
||||
NegativeIntoUnsigned,
|
||||
TargetTooSmall,
|
||||
BufferTooSmall,
|
||||
InvalidSignature,
|
||||
NotSquare,
|
||||
NonCanonical,
|
||||
WeakPublicKey,
|
||||
/// TODO I wish this could be a method of Diagnostics
|
||||
fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
|
||||
returned catch |err| {
|
||||
if (opt_diags) |diags| diags.* = .{ .err = err };
|
||||
return error.WriteFailure;
|
||||
};
|
||||
}
|
||||
|
||||
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which
|
||||
/// TODO I wish this could be a method of Diagnostics
|
||||
fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
|
||||
returned catch |err| {
|
||||
if (opt_diags) |diags| diags.* = .{ .err = err };
|
||||
return error.ReadFailure;
|
||||
};
|
||||
}
|
||||
|
||||
const InitError = error{
|
||||
//OutOfMemory,
|
||||
WriteFailure,
|
||||
ReadFailure,
|
||||
InsufficientEntropy,
|
||||
DiskQuota,
|
||||
LockViolation,
|
||||
NotOpenForWriting,
|
||||
/// The alert description will be stored in `Options.Diagnostics.alert`.
|
||||
TlsAlert,
|
||||
TlsUnexpectedMessage,
|
||||
TlsIllegalParameter,
|
||||
TlsDecryptFailure,
|
||||
TlsRecordOverflow,
|
||||
TlsBadRecordMac,
|
||||
CertificateFieldHasInvalidLength,
|
||||
CertificateHostMismatch,
|
||||
CertificatePublicKeyInvalid,
|
||||
CertificateExpired,
|
||||
CertificateFieldHasWrongDataType,
|
||||
CertificateIssuerMismatch,
|
||||
CertificateNotYetValid,
|
||||
CertificateSignatureAlgorithmMismatch,
|
||||
CertificateSignatureAlgorithmUnsupported,
|
||||
CertificateSignatureInvalid,
|
||||
CertificateSignatureInvalidLength,
|
||||
CertificateSignatureNamedCurveUnsupported,
|
||||
CertificateSignatureUnsupportedBitCount,
|
||||
TlsCertificateNotVerified,
|
||||
TlsBadSignatureScheme,
|
||||
TlsBadRsaSignatureBitCount,
|
||||
InvalidEncoding,
|
||||
IdentityElement,
|
||||
SignatureVerificationFailed,
|
||||
TlsDecryptError,
|
||||
TlsConnectionTruncated,
|
||||
TlsDecodeError,
|
||||
UnsupportedCertificateVersion,
|
||||
CertificateTimeInvalid,
|
||||
CertificateHasUnrecognizedObjectId,
|
||||
CertificateHasInvalidBitString,
|
||||
MessageTooLong,
|
||||
NegativeIntoUnsigned,
|
||||
TargetTooSmall,
|
||||
BufferTooSmall,
|
||||
InvalidSignature,
|
||||
NotSquare,
|
||||
NonCanonical,
|
||||
WeakPublicKey,
|
||||
};
|
||||
|
||||
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which
|
||||
/// must conform to `StreamInterface`.
|
||||
///
|
||||
/// `host` is only borrowed during this function call.
|
||||
pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client {
|
||||
pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client {
|
||||
const diags = options.diagnostics;
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
@ -275,11 +290,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
};
|
||||
|
||||
{
|
||||
var iovecs = [_]std.posix.iovec_const{
|
||||
.{ .base = cleartext_header.ptr, .len = cleartext_header.len },
|
||||
.{ .base = host.ptr, .len = host.len },
|
||||
};
|
||||
try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]);
|
||||
var iovecs: [2][]const u8 = .{ cleartext_header, host };
|
||||
try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
|
||||
}
|
||||
|
||||
var tls_version: tls.ProtocolVersion = undefined;
|
||||
@ -331,12 +343,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
|
||||
var d: tls.Decoder = .{ .buf = &handshake_buffer };
|
||||
fragment: while (true) {
|
||||
try d.readAtLeastOurAmt(stream, tls.record_header_len);
|
||||
try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len));
|
||||
const record_header = d.buf[d.idx..][0..tls.record_header_len];
|
||||
const record_ct = d.decode(tls.ContentType);
|
||||
d.skip(2); // legacy_version
|
||||
const record_len = d.decode(u16);
|
||||
try d.readAtLeast(stream, record_len);
|
||||
try wrapRead(diags, d.readAtLeast(input, record_len));
|
||||
var record_decoder = try d.sub(record_len);
|
||||
var ctd, const ct = content: switch (cipher_state) {
|
||||
.cleartext => .{ record_decoder, record_ct },
|
||||
@ -414,11 +426,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
const level = ctd.decode(tls.AlertLevel);
|
||||
const desc = ctd.decode(tls.AlertDescription);
|
||||
_ = level;
|
||||
|
||||
// if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
|
||||
try desc.toError();
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
if (diags) |x| x.* = .{ .alert = desc };
|
||||
return error.TlsAlert;
|
||||
},
|
||||
.change_cipher_spec => {
|
||||
ctd.ensure(1) catch continue :fragment;
|
||||
@ -754,11 +763,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
nonce,
|
||||
pv.app_cipher.client_write_key,
|
||||
);
|
||||
const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg;
|
||||
var all_msgs_vec = [_]std.posix.iovec_const{
|
||||
.{ .base = &all_msgs, .len = all_msgs.len },
|
||||
var all_msgs_vec: [3][]const u8 = .{
|
||||
&client_key_exchange_msg,
|
||||
&client_change_cipher_spec_msg,
|
||||
&client_verify_msg,
|
||||
};
|
||||
try stream.writevAll(&all_msgs_vec);
|
||||
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
|
||||
},
|
||||
}
|
||||
write_seq += 1;
|
||||
@ -819,11 +829,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
const nonce = pv.client_handshake_iv;
|
||||
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key);
|
||||
|
||||
const all_msgs = client_change_cipher_spec_msg ++ finished_msg;
|
||||
var all_msgs_vec = [_]std.posix.iovec_const{
|
||||
.{ .base = &all_msgs, .len = all_msgs.len },
|
||||
var all_msgs_vec: [2][]const u8 = .{
|
||||
&client_change_cipher_spec_msg,
|
||||
&finished_msg,
|
||||
};
|
||||
try stream.writevAll(&all_msgs_vec);
|
||||
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
|
||||
|
||||
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
@ -873,6 +883,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
.received_close_notify = false,
|
||||
.allow_truncation_attacks = false,
|
||||
.application_cipher = app_cipher,
|
||||
.output = output,
|
||||
.partially_read_buffer = undefined,
|
||||
.ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
|
||||
.client_key_seq = key_seq,
|
||||
@ -896,39 +907,39 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
|
||||
pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
|
||||
return writeEnd(c, stream, bytes, false);
|
||||
pub fn writer(c: *Client) std.io.Writer {
|
||||
return .{
|
||||
.context = c,
|
||||
.vtable = &.{
|
||||
.writeSplat = writeSplat,
|
||||
.writeFile = std.io.Writer.unimplemented_writeFile,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
|
||||
pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
|
||||
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
|
||||
const c: *Client = @alignCast(@ptrCast(context));
|
||||
assert(data.len > 1 or splat > 0);
|
||||
return writeEnd(c, data[0], false);
|
||||
}
|
||||
|
||||
/// If `end` is true, then this function additionally sends a `close_notify`
|
||||
/// alert, which is necessary for the server to distinguish between a properly
|
||||
/// finished TLS session, or a truncation attack.
|
||||
pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void {
|
||||
var index: usize = 0;
|
||||
while (index < bytes.len) {
|
||||
index += try c.write(stream, bytes[index..]);
|
||||
index += try c.writeEnd(bytes[index..], end);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
|
||||
/// If `end` is true, then this function additionally sends a `close_notify` alert,
|
||||
/// which is necessary for the server to distinguish between a properly finished
|
||||
/// TLS session, or a truncation attack.
|
||||
pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
|
||||
var index: usize = 0;
|
||||
while (index < bytes.len) {
|
||||
index += try c.writeEnd(stream, bytes[index..], end);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
|
||||
/// If `end` is true, then this function additionally sends a `close_notify` alert,
|
||||
/// which is necessary for the server to distinguish between a properly finished
|
||||
/// TLS session, or a truncation attack.
|
||||
pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
|
||||
pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize {
|
||||
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
|
||||
var iovecs_buf: [6]std.posix.iovec_const = undefined;
|
||||
var iovecs_buf: [6][]const u8 = undefined;
|
||||
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
|
||||
if (end) {
|
||||
prepared.iovec_end += prepareCiphertextRecord(
|
||||
@ -948,7 +959,7 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
|
||||
var i: usize = 0;
|
||||
var total_amt: usize = 0;
|
||||
while (true) {
|
||||
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
|
||||
var amt = try c.output.writev(iovecs_buf[i..iovec_end]);
|
||||
while (amt >= iovecs_buf[i].len) {
|
||||
const encrypted_amt = iovecs_buf[i].len;
|
||||
total_amt += encrypted_amt - overhead_len;
|
||||
@ -962,14 +973,13 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
|
||||
// not sent; otherwise the caller would not know to retry the call.
|
||||
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
|
||||
}
|
||||
iovecs_buf[i].base += amt;
|
||||
iovecs_buf[i].len -= amt;
|
||||
iovecs_buf[i] = iovecs_buf[i][amt..];
|
||||
}
|
||||
}
|
||||
|
||||
fn prepareCiphertextRecord(
|
||||
c: *Client,
|
||||
iovecs: []std.posix.iovec_const,
|
||||
iovecs: [][]const u8,
|
||||
ciphertext_buf: []u8,
|
||||
bytes: []const u8,
|
||||
inner_content_type: tls.ContentType,
|
||||
@ -1031,10 +1041,7 @@ fn prepareCiphertextRecord(
|
||||
c.write_seq += 1; // TODO send key_update on overflow
|
||||
|
||||
const record = ciphertext_buf[record_start..ciphertext_end];
|
||||
iovecs[iovec_end] = .{
|
||||
.base = record.ptr,
|
||||
.len = record.len,
|
||||
};
|
||||
iovecs[iovec_end] = record;
|
||||
iovec_end += 1;
|
||||
}
|
||||
},
|
||||
@ -1084,10 +1091,7 @@ fn prepareCiphertextRecord(
|
||||
c.write_seq += 1; // TODO send key_update on overflow
|
||||
|
||||
const record = ciphertext_buf[record_start..ciphertext_end];
|
||||
iovecs[iovec_end] = .{
|
||||
.base = record.ptr,
|
||||
.len = record.len,
|
||||
};
|
||||
iovecs[iovec_end] = record;
|
||||
iovec_end += 1;
|
||||
}
|
||||
},
|
||||
@ -1511,7 +1515,8 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
|
||||
const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
|
||||
defer if (locked) key_log_file.unlock();
|
||||
key_log_file.seekFromEnd(0) catch {};
|
||||
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++
|
||||
var w = key_log_file.writer().unbuffered();
|
||||
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
|
||||
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
|
||||
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
|
||||
context.client_random,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user