std.crypto.tls.Client: refactor to reduce namespace bloat

This commit is contained in:
Andrew Kelley 2022-12-28 19:54:17 -07:00
parent 16af6286c8
commit 1d20ada366

View File

@ -5,18 +5,14 @@ const net = std.net;
const mem = std.mem;
const crypto = std.crypto;
const assert = std.debug.assert;
const Certificate = std.crypto.Certificate;
const ApplicationCipher = tls.ApplicationCipher;
const CipherSuite = tls.CipherSuite;
const ContentType = tls.ContentType;
const HandshakeCipher = tls.HandshakeCipher;
const max_ciphertext_len = tls.max_ciphertext_len;
const hkdfExpandLabel = tls.hkdfExpandLabel;
const int2 = tls.int2;
const int3 = tls.int3;
const array = tls.array;
const enum_array = tls.enum_array;
const Certificate = crypto.Certificate;
read_seq: u64,
write_seq: u64,
@ -27,7 +23,7 @@ partially_read_len: u15,
/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by
/// the read() API user is not large enough.
partial_cleartext_index: u15,
application_cipher: ApplicationCipher,
application_cipher: tls.ApplicationCipher,
eof: bool,
/// The size is enough to contain exactly one TLSCiphertext record.
/// Contains encrypted bytes.
@ -101,7 +97,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
client_hello;
const plaintext_header = [_]u8{
@enumToInt(ContentType.handshake),
@enumToInt(tls.ContentType.handshake),
0x03, 0x01, // legacy_record_version
} ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
@ -121,7 +117,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const client_hello_bytes1 = plaintext_header[5..];
var handshake_cipher: HandshakeCipher = undefined;
var handshake_cipher: tls.HandshakeCipher = undefined;
var handshake_buf: [8000]u8 = undefined;
var len: usize = 0;
@ -129,7 +125,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const plaintext = handshake_buf[0..5];
len = try stream.readAtLeast(&handshake_buf, plaintext.len);
if (len < plaintext.len) return error.EndOfStream;
const ct = @intToEnum(ContentType, plaintext[0]);
const ct = @intToEnum(tls.ContentType, plaintext[0]);
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
const end = plaintext.len + frag_len;
if (end > handshake_buf.len) return error.TlsRecordOverflow;
@ -169,7 +165,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
i += 32;
const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
i += 2;
const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int);
const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int);
const legacy_compression_method = frag[i];
i += 1;
_ = legacy_compression_method;
@ -247,8 +243,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
.AEGIS_256_SHA384,
.AEGIS_128L_SHA256,
=> |tag| {
const P = std.meta.TagPayloadByName(HandshakeCipher, @tagName(tag));
handshake_cipher = @unionInit(HandshakeCipher, @tagName(tag), .{
const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag));
handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{
.handshake_secret = undefined,
.master_secret = undefined,
.client_handshake_key = undefined,
@ -338,7 +334,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
if (end_hdr > len) return error.EndOfStream;
}
const ct = @intToEnum(ContentType, handshake_buf[i]);
const ct = @intToEnum(tls.ContentType, handshake_buf[i]);
i += 1;
const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
@ -380,7 +376,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
};
const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
.handshake => {
var ct_i: usize = 0;
@ -546,7 +542,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
if (handshake_state != .finished) return error.TlsUnexpectedMessage;
// This message is to trick buggy proxies into behaving correctly.
const client_change_cipher_spec_msg = [_]u8{
@enumToInt(ContentType.change_cipher_spec),
@enumToInt(tls.ContentType.change_cipher_spec),
0x03, 0x03, // legacy protocol version
0x00, 0x01, // length
0x01,
@ -564,12 +560,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const out_cleartext = [_]u8{
@enumToInt(tls.HandshakeType.finished),
0, 0, verify_data.len, // length
} ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
} ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
var finished_msg = [_]u8{
@enumToInt(ContentType.application_data),
@enumToInt(tls.ContentType.application_data),
0x03, 0x03, // legacy protocol version
0, wrapped_len, // byte length of encrypted record
} ++ @as([wrapped_len]u8, undefined);
@ -590,7 +586,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
// std.fmt.fmtSliceHexLower(&client_secret),
// std.fmt.fmtSliceHexLower(&server_secret),
//});
break :c @unionInit(ApplicationCipher, @tagName(tag), .{
break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
.client_secret = client_secret,
.server_secret = server_secret,
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
@ -661,7 +657,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
if (encrypted_content_len == 0) break :l overhead_len;
mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data);
bytes_i += encrypted_content_len;
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
@ -669,7 +665,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..5];
ad.* =
[_]u8{@enumToInt(ContentType.application_data)} ++
[_]u8{@enumToInt(tls.ContentType.application_data)} ++
int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
int2(ciphertext_len + P.AEAD.tag_length);
ciphertext_end += ad.len;
@ -818,7 +814,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize {
return finishRead(c, frag, in, out);
}
const record_start = in;
const ct = @intToEnum(ContentType, frag[in]);
const ct = @intToEnum(tls.ContentType, frag[in]);
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
in += 2;
@ -861,7 +857,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize {
},
};
const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
.alert => {
c.read_seq += 1;