std.crypto.Tls: client is working against some servers

This commit is contained in:
Andrew Kelley 2022-12-16 02:14:35 -07:00
parent 40a85506b2
commit b97fc43baa
2 changed files with 121 additions and 53 deletions

View File

@ -9,12 +9,14 @@ application_cipher: ApplicationCipher,
read_seq: u64,
write_seq: u64,
/// The size is enough to contain exactly one TLSCiphertext record.
partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8,
partially_read_buffer: [max_ciphertext_record_len]u8,
/// The number of partially read bytes inside `partiall_read_buffer`.
partially_read_len: u15,
eof: bool,
pub const ciphertext_record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256;
pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
pub const ProtocolVersion = enum(u16) {
tls_1_2 = 0x0303,
@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
var cipher_params: CipherParams = undefined;
var handshake_buf: [4000]u8 = undefined;
var handshake_buf: [8000]u8 = undefined;
var len: usize = 0;
var i: usize = i: {
const plaintext = handshake_buf[0..5];
@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
// std.fmt.fmtSliceHexLower(&hello_hash),
// std.fmt.fmtSliceHexLower(&early_secret),
// std.fmt.fmtSliceHexLower(&empty_hash),
// std.fmt.fmtSliceHexLower(&derived_secret),
// std.fmt.fmtSliceHexLower(&handshake_secret),
// std.fmt.fmtSliceHexLower(&hs_derived_secret),
// std.fmt.fmtSliceHexLower(&p.handshake_secret),
// std.fmt.fmtSliceHexLower(&client_secret),
// std.fmt.fmtSliceHexLower(&server_secret),
//});
@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const end_hdr = i + 5;
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
if (end_hdr > len) {
std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len });
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
std.debug.print("new len: {d} bytes\n", .{len});
if (end_hdr > len) return error.EndOfStream;
}
const ct = @intToEnum(ContentType, handshake_buf[i]);
@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
const end = i + record_size;
std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end });
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
std.debug.print("read len={d} atleast={d}\n", .{ len, end - len });
len += try stream.readAtLeast(handshake_buf[len..], end - len);
std.debug.print("new len: {d} bytes\n", .{len});
if (end > len) return error.EndOfStream;
}
switch (ct) {
@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
},
.application_data => {
var cleartext_buf: [1000]u8 = undefined;
var cleartext_buf: [8000]u8 = undefined;
const cleartext = switch (cipher_params) {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
const P = @TypeOf(p.*);
@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
};
const inner_ct = cleartext[cleartext.len - 1];
std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)});
switch (inner_ct) {
@enumToInt(ContentType.handshake) => {
const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength;
std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len });
switch (cleartext[0]) {
@enumToInt(HandshakeType.encrypted_extensions) => {
const ext_size = mem.readIntBig(u16, cleartext[4..6]);
if (ext_size != 0) {
@panic("TODO handle encrypted extensions");
}
std.debug.print("empty encrypted extensions\n", .{});
std.debug.print("{d} bytes of encrypted extensions\n", .{
ext_size,
});
},
@enumToInt(HandshakeType.certificate) => {
std.debug.print("cool certificate bro\n", .{});
@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const nonce = p.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
{
var iovecs = [_]std.os.iovec_const{
.{
.iov_base = &client_change_cipher_spec_msg,
.iov_len = client_change_cipher_spec_msg.len,
},
.{
.iov_base = &finished_msg,
.iov_len = finished_msg.len,
},
};
try stream.writevAll(&iovecs);
}
//const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
_ = client_change_cipher_spec_msg;
const both_msgs = finished_msg;
try stream.writeAll(&both_msgs);
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
//std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
// std.fmt.fmtSliceHexLower(&p.master_secret),
// std.fmt.fmtSliceHexLower(&client_secret),
// std.fmt.fmtSliceHexLower(&server_secret),
//});
break :c @unionInit(ApplicationCipher, @tagName(tag), .{
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
@panic("TODO");
},
};
std.debug.print("remaining bytes: {d}\n", .{len - end});
return .{
.application_cipher = app_cipher,
.read_seq = read_seq,
.write_seq = 1,
.read_seq = 0,
.write_seq = 0,
.partially_read_buffer = undefined,
.partially_read_len = 0,
.eof = false,
};
},
else => {
@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
}
pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined;
var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined;
// Due to the trailing inner content type byte in the ciphertext, we need
// an additional buffer for storing the cleartext into before encrypting.
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var iovecs_buf: [5]std.os.iovec_const = undefined;
var ciphertext_end: usize = 0;
var iovec_end: usize = 0;
var bytes_i: usize = 0;
switch (tls.application_cipher) {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| {
// How many bytes are taken up by overhead per record.
const overhead_len: usize = switch (tls.application_cipher) {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1;
while (true) {
const ciphertext_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len),
ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end,
const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len -
ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
));
if (ciphertext_len == 0) return bytes_i;
if (encrypted_content_len == 0) break :l overhead_len;
const wrapped_len = ciphertext_len + P.AEAD.tag_length;
const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len];
mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
bytes_i += encrypted_content_len;
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
const ad = record[0..5];
ciphertext_end += 5;
const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..5];
ad.* =
[_]u8{@enumToInt(ContentType.application_data)} ++
int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
int2(ciphertext_len + P.AEAD.tag_length);
ciphertext_end += ad.len;
const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
ciphertext_end += ciphertext_len;
const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
ciphertext_end += P.AEAD.tag_length;
ciphertext_end += auth_tag.len;
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq));
tls.write_seq += 1;
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
ad.* =
[_]u8{@enumToInt(ContentType.application_data)} ++
int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
int2(wrapped_len);
const cleartext = bytes[bytes_i..ciphertext.len];
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
//std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{
// tls.write_seq - 1,
// std.fmt.fmtSliceHexLower(&nonce),
// std.fmt.fmtSliceHexLower(&p.client_key),
// std.fmt.fmtSliceHexLower(&p.client_iv),
// std.fmt.fmtSliceHexLower(ad),
// std.fmt.fmtSliceHexLower(auth_tag),
// std.fmt.fmtSliceHexLower(&p.server_key),
// std.fmt.fmtSliceHexLower(&p.server_iv),
//});
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs_buf[iovec_end] = .{
.iov_base = record.ptr,
.iov_len = record.len,
};
iovec_end += 1;
bytes_i += ciphertext_len;
}
},
.TLS_CHACHA20_POLY1305_SHA256 => {
@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
.TLS_AES_128_CCM_8_SHA256 => {
@panic("TODO");
},
}
};
// Ideally we would call writev exactly once here, however, we must ensure
// that we don't return with a record partially written.
@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
var total_amt: usize = 0;
while (true) {
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
total_amt += amt;
while (amt >= iovecs_buf[i].iov_len) {
amt -= iovecs_buf[i].iov_len;
const encrypted_amt = iovecs_buf[i].iov_len;
total_amt += encrypted_amt - overhead_len;
amt -= encrypted_amt;
i += 1;
// Rely on the property that iovecs delineate records, meaning that
// if amt equals zero here, we have fortunately found ourselves
@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len);
const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]);
const frag = in_buf[0 .. prev_len + actual_read_len];
if (frag.len == 0) {
tls.eof = true;
return 0;
}
std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len });
var in: usize = 0;
var out: usize = 0;
while (true) {
if (in + ciphertext_record_header_len > frag.len) {
std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len });
return finishRead(tls, frag, in, out);
}
const ct = @intToEnum(ContentType, frag[in]);
@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const end = in + record_size;
if (end > frag.len) {
if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len });
return finishRead(tls, frag, in, out);
}
switch (ct) {
@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const ad = frag[in - 5 ..][0..5];
const ciphertext_len = record_size - P.AEAD.tag_length;
const ciphertext = frag[in..][0..ciphertext_len];
in += ciphertext_len;
@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq));
tls.read_seq += 1;
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
const ad = frag[0..ciphertext_record_header_len];
//std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
// tls.read_seq - 1,
// std.fmt.fmtSliceHexLower(&nonce),
// std.fmt.fmtSliceHexLower(&p.server_key),
// std.fmt.fmtSliceHexLower(&p.server_iv),
//});
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
return error.TlsBadRecordMac;
break :c cleartext.len;
@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
},
};
const inner_ct = buffer[out + cleartext_len - 1];
const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
switch (inner_ct) {
@enumToInt(ContentType.handshake) => {
.alert => {
const level = @intToEnum(AlertLevel, buffer[out]);
const desc = @intToEnum(AlertDescription, buffer[out + 1]);
if (desc == .close_notify) {
tls.eof = true;
return out;
}
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
return error.TlsAlert;
},
.handshake => {
std.debug.print("the server wants to keep shaking hands\n", .{});
},
@enumToInt(ContentType.application_data) => {
.application_data => {
out += cleartext_len - 1;
},
else => {
std.debug.print("inner content type: {d}\n", .{inner_ct});
return error.TlsUnexpectedMessage;
},
}

View File

@ -62,6 +62,25 @@ pub const Request = struct {
.https => return req.tls.read(req.stream, buffer),
}
}
pub fn readAll(req: *Request, buffer: []u8) !usize {
return readAtLeast(req, buffer, buffer.len);
}
pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
var index: usize = 0;
while (index < len) {
const amt = try req.read(buffer[index..]);
if (amt == 0) {
switch (req.protocol) {
.http => break,
.https => if (req.tls.eof) break,
}
}
index += amt;
}
return index;
}
};
pub fn deinit(client: *Client) void {
@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
@tagName(options.method).len +
1 +
options.path.len +
" HTTP/2\r\nHost: ".len +
" HTTP/1.1\r\nHost: ".len +
options.host.len +
"\r\nUpgrade-Insecure-Requests: 1\r\n".len +
client.headers.items.len +
@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
req.headers.appendSliceAssumeCapacity(@tagName(options.method));
req.headers.appendSliceAssumeCapacity(" ");
req.headers.appendSliceAssumeCapacity(options.path);
req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: ");
req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
req.headers.appendSliceAssumeCapacity(options.host);
switch (options.protocol) {
.https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),