mirror of
https://github.com/ziglang/zig.git
synced 2026-01-20 14:25:16 +00:00
std.crypto.Tls: client is working against some servers
This commit is contained in:
parent
40a85506b2
commit
b97fc43baa
@ -9,12 +9,14 @@ application_cipher: ApplicationCipher,
|
||||
read_seq: u64,
|
||||
write_seq: u64,
|
||||
/// The size is enough to contain exactly one TLSCiphertext record.
|
||||
partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8,
|
||||
partially_read_buffer: [max_ciphertext_record_len]u8,
|
||||
/// The number of partially read bytes inside `partiall_read_buffer`.
|
||||
partially_read_len: u15,
|
||||
eof: bool,
|
||||
|
||||
pub const ciphertext_record_header_len = 5;
|
||||
pub const max_ciphertext_len = (1 << 14) + 256;
|
||||
pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
|
||||
|
||||
pub const ProtocolVersion = enum(u16) {
|
||||
tls_1_2 = 0x0303,
|
||||
@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
|
||||
var cipher_params: CipherParams = undefined;
|
||||
|
||||
var handshake_buf: [4000]u8 = undefined;
|
||||
var handshake_buf: [8000]u8 = undefined;
|
||||
var len: usize = 0;
|
||||
var i: usize = i: {
|
||||
const plaintext = handshake_buf[0..5];
|
||||
@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
// std.fmt.fmtSliceHexLower(&hello_hash),
|
||||
// std.fmt.fmtSliceHexLower(&early_secret),
|
||||
// std.fmt.fmtSliceHexLower(&empty_hash),
|
||||
// std.fmt.fmtSliceHexLower(&derived_secret),
|
||||
// std.fmt.fmtSliceHexLower(&handshake_secret),
|
||||
// std.fmt.fmtSliceHexLower(&hs_derived_secret),
|
||||
// std.fmt.fmtSliceHexLower(&p.handshake_secret),
|
||||
// std.fmt.fmtSliceHexLower(&client_secret),
|
||||
// std.fmt.fmtSliceHexLower(&server_secret),
|
||||
//});
|
||||
@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
const end_hdr = i + 5;
|
||||
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
|
||||
if (end_hdr > len) {
|
||||
std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len });
|
||||
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
|
||||
std.debug.print("new len: {d} bytes\n", .{len});
|
||||
if (end_hdr > len) return error.EndOfStream;
|
||||
}
|
||||
const ct = @intToEnum(ContentType, handshake_buf[i]);
|
||||
@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
|
||||
i += 2;
|
||||
const end = i + record_size;
|
||||
std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end });
|
||||
if (end > handshake_buf.len) return error.TlsRecordOverflow;
|
||||
if (end > len) {
|
||||
std.debug.print("read len={d} atleast={d}\n", .{ len, end - len });
|
||||
len += try stream.readAtLeast(handshake_buf[len..], end - len);
|
||||
std.debug.print("new len: {d} bytes\n", .{len});
|
||||
if (end > len) return error.EndOfStream;
|
||||
}
|
||||
switch (ct) {
|
||||
@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
|
||||
},
|
||||
.application_data => {
|
||||
var cleartext_buf: [1000]u8 = undefined;
|
||||
var cleartext_buf: [8000]u8 = undefined;
|
||||
const cleartext = switch (cipher_params) {
|
||||
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
|
||||
const P = @TypeOf(p.*);
|
||||
@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
};
|
||||
|
||||
const inner_ct = cleartext[cleartext.len - 1];
|
||||
std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)});
|
||||
switch (inner_ct) {
|
||||
@enumToInt(ContentType.handshake) => {
|
||||
const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
|
||||
if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
|
||||
if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength;
|
||||
std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len });
|
||||
switch (cleartext[0]) {
|
||||
@enumToInt(HandshakeType.encrypted_extensions) => {
|
||||
const ext_size = mem.readIntBig(u16, cleartext[4..6]);
|
||||
if (ext_size != 0) {
|
||||
@panic("TODO handle encrypted extensions");
|
||||
}
|
||||
std.debug.print("empty encrypted extensions\n", .{});
|
||||
std.debug.print("{d} bytes of encrypted extensions\n", .{
|
||||
ext_size,
|
||||
});
|
||||
},
|
||||
@enumToInt(HandshakeType.certificate) => {
|
||||
std.debug.print("cool certificate bro\n", .{});
|
||||
@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
const nonce = p.client_handshake_iv;
|
||||
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
|
||||
|
||||
{
|
||||
var iovecs = [_]std.os.iovec_const{
|
||||
.{
|
||||
.iov_base = &client_change_cipher_spec_msg,
|
||||
.iov_len = client_change_cipher_spec_msg.len,
|
||||
},
|
||||
.{
|
||||
.iov_base = &finished_msg,
|
||||
.iov_len = finished_msg.len,
|
||||
},
|
||||
};
|
||||
try stream.writevAll(&iovecs);
|
||||
}
|
||||
//const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
|
||||
_ = client_change_cipher_spec_msg;
|
||||
const both_msgs = finished_msg;
|
||||
try stream.writeAll(&both_msgs);
|
||||
|
||||
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
//std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
|
||||
// std.fmt.fmtSliceHexLower(&p.master_secret),
|
||||
// std.fmt.fmtSliceHexLower(&client_secret),
|
||||
// std.fmt.fmtSliceHexLower(&server_secret),
|
||||
//});
|
||||
break :c @unionInit(ApplicationCipher, @tagName(tag), .{
|
||||
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
|
||||
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
|
||||
@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
@panic("TODO");
|
||||
},
|
||||
};
|
||||
std.debug.print("remaining bytes: {d}\n", .{len - end});
|
||||
return .{
|
||||
.application_cipher = app_cipher,
|
||||
.read_seq = read_seq,
|
||||
.write_seq = 1,
|
||||
.read_seq = 0,
|
||||
.write_seq = 0,
|
||||
.partially_read_buffer = undefined,
|
||||
.partially_read_len = 0,
|
||||
.eof = false,
|
||||
};
|
||||
},
|
||||
else => {
|
||||
@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
|
||||
}
|
||||
|
||||
pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
|
||||
var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined;
|
||||
var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined;
|
||||
// Due to the trailing inner content type byte in the ciphertext, we need
|
||||
// an additional buffer for storing the cleartext into before encrypting.
|
||||
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
|
||||
var iovecs_buf: [5]std.os.iovec_const = undefined;
|
||||
var ciphertext_end: usize = 0;
|
||||
var iovec_end: usize = 0;
|
||||
var bytes_i: usize = 0;
|
||||
switch (tls.application_cipher) {
|
||||
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| {
|
||||
// How many bytes are taken up by overhead per record.
|
||||
const overhead_len: usize = switch (tls.application_cipher) {
|
||||
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
|
||||
const P = @TypeOf(p.*);
|
||||
const V = @Vector(P.AEAD.nonce_length, u8);
|
||||
const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1;
|
||||
while (true) {
|
||||
const ciphertext_len = @intCast(u16, @min(
|
||||
@min(bytes.len - bytes_i, max_ciphertext_len),
|
||||
ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end,
|
||||
const encrypted_content_len = @intCast(u16, @min(
|
||||
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
|
||||
ciphertext_buf.len -
|
||||
ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
|
||||
));
|
||||
if (ciphertext_len == 0) return bytes_i;
|
||||
if (encrypted_content_len == 0) break :l overhead_len;
|
||||
|
||||
const wrapped_len = ciphertext_len + P.AEAD.tag_length;
|
||||
const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len];
|
||||
mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
|
||||
cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
|
||||
bytes_i += encrypted_content_len;
|
||||
const ciphertext_len = encrypted_content_len + 1;
|
||||
const cleartext = cleartext_buf[0..ciphertext_len];
|
||||
|
||||
const ad = record[0..5];
|
||||
ciphertext_end += 5;
|
||||
const record_start = ciphertext_end;
|
||||
const ad = ciphertext_buf[ciphertext_end..][0..5];
|
||||
ad.* =
|
||||
[_]u8{@enumToInt(ContentType.application_data)} ++
|
||||
int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
|
||||
int2(ciphertext_len + P.AEAD.tag_length);
|
||||
ciphertext_end += ad.len;
|
||||
const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
|
||||
ciphertext_end += ciphertext_len;
|
||||
const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
|
||||
ciphertext_end += P.AEAD.tag_length;
|
||||
ciphertext_end += auth_tag.len;
|
||||
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
|
||||
const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq));
|
||||
tls.write_seq += 1;
|
||||
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
|
||||
ad.* =
|
||||
[_]u8{@enumToInt(ContentType.application_data)} ++
|
||||
int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
|
||||
int2(wrapped_len);
|
||||
const cleartext = bytes[bytes_i..ciphertext.len];
|
||||
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
|
||||
//std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{
|
||||
// tls.write_seq - 1,
|
||||
// std.fmt.fmtSliceHexLower(&nonce),
|
||||
// std.fmt.fmtSliceHexLower(&p.client_key),
|
||||
// std.fmt.fmtSliceHexLower(&p.client_iv),
|
||||
// std.fmt.fmtSliceHexLower(ad),
|
||||
// std.fmt.fmtSliceHexLower(auth_tag),
|
||||
// std.fmt.fmtSliceHexLower(&p.server_key),
|
||||
// std.fmt.fmtSliceHexLower(&p.server_iv),
|
||||
//});
|
||||
|
||||
const record = ciphertext_buf[record_start..ciphertext_end];
|
||||
iovecs_buf[iovec_end] = .{
|
||||
.iov_base = record.ptr,
|
||||
.iov_len = record.len,
|
||||
};
|
||||
iovec_end += 1;
|
||||
|
||||
bytes_i += ciphertext_len;
|
||||
}
|
||||
},
|
||||
.TLS_CHACHA20_POLY1305_SHA256 => {
|
||||
@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
|
||||
.TLS_AES_128_CCM_8_SHA256 => {
|
||||
@panic("TODO");
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
// Ideally we would call writev exactly once here, however, we must ensure
|
||||
// that we don't return with a record partially written.
|
||||
@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
|
||||
var total_amt: usize = 0;
|
||||
while (true) {
|
||||
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
|
||||
total_amt += amt;
|
||||
while (amt >= iovecs_buf[i].iov_len) {
|
||||
amt -= iovecs_buf[i].iov_len;
|
||||
const encrypted_amt = iovecs_buf[i].iov_len;
|
||||
total_amt += encrypted_amt - overhead_len;
|
||||
amt -= encrypted_amt;
|
||||
i += 1;
|
||||
// Rely on the property that iovecs delineate records, meaning that
|
||||
// if amt equals zero here, we have fortunately found ourselves
|
||||
@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
|
||||
const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len);
|
||||
const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]);
|
||||
const frag = in_buf[0 .. prev_len + actual_read_len];
|
||||
if (frag.len == 0) {
|
||||
tls.eof = true;
|
||||
return 0;
|
||||
}
|
||||
std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len });
|
||||
var in: usize = 0;
|
||||
var out: usize = 0;
|
||||
|
||||
while (true) {
|
||||
if (in + ciphertext_record_header_len > frag.len) {
|
||||
std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len });
|
||||
return finishRead(tls, frag, in, out);
|
||||
}
|
||||
const ct = @intToEnum(ContentType, frag[in]);
|
||||
@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
|
||||
const end = in + record_size;
|
||||
if (end > frag.len) {
|
||||
if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
|
||||
std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len });
|
||||
return finishRead(tls, frag, in, out);
|
||||
}
|
||||
switch (ct) {
|
||||
@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
|
||||
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
|
||||
const P = @TypeOf(p.*);
|
||||
const V = @Vector(P.AEAD.nonce_length, u8);
|
||||
const ad = frag[in - 5 ..][0..5];
|
||||
const ciphertext_len = record_size - P.AEAD.tag_length;
|
||||
const ciphertext = frag[in..][0..ciphertext_len];
|
||||
in += ciphertext_len;
|
||||
@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
|
||||
const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq));
|
||||
tls.read_seq += 1;
|
||||
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
|
||||
const ad = frag[0..ciphertext_record_header_len];
|
||||
//std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
|
||||
// tls.read_seq - 1,
|
||||
// std.fmt.fmtSliceHexLower(&nonce),
|
||||
// std.fmt.fmtSliceHexLower(&p.server_key),
|
||||
// std.fmt.fmtSliceHexLower(&p.server_iv),
|
||||
//});
|
||||
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
|
||||
return error.TlsBadRecordMac;
|
||||
break :c cleartext.len;
|
||||
@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
|
||||
},
|
||||
};
|
||||
|
||||
const inner_ct = buffer[out + cleartext_len - 1];
|
||||
const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
|
||||
switch (inner_ct) {
|
||||
@enumToInt(ContentType.handshake) => {
|
||||
.alert => {
|
||||
const level = @intToEnum(AlertLevel, buffer[out]);
|
||||
const desc = @intToEnum(AlertDescription, buffer[out + 1]);
|
||||
if (desc == .close_notify) {
|
||||
tls.eof = true;
|
||||
return out;
|
||||
}
|
||||
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
|
||||
return error.TlsAlert;
|
||||
},
|
||||
.handshake => {
|
||||
std.debug.print("the server wants to keep shaking hands\n", .{});
|
||||
},
|
||||
@enumToInt(ContentType.application_data) => {
|
||||
.application_data => {
|
||||
out += cleartext_len - 1;
|
||||
},
|
||||
else => {
|
||||
std.debug.print("inner content type: {d}\n", .{inner_ct});
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
}
|
||||
|
||||
@ -62,6 +62,25 @@ pub const Request = struct {
|
||||
.https => return req.tls.read(req.stream, buffer),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn readAll(req: *Request, buffer: []u8) !usize {
|
||||
return readAtLeast(req, buffer, buffer.len);
|
||||
}
|
||||
|
||||
pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
|
||||
var index: usize = 0;
|
||||
while (index < len) {
|
||||
const amt = try req.read(buffer[index..]);
|
||||
if (amt == 0) {
|
||||
switch (req.protocol) {
|
||||
.http => break,
|
||||
.https => if (req.tls.eof) break,
|
||||
}
|
||||
}
|
||||
index += amt;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn deinit(client: *Client) void {
|
||||
@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
|
||||
@tagName(options.method).len +
|
||||
1 +
|
||||
options.path.len +
|
||||
" HTTP/2\r\nHost: ".len +
|
||||
" HTTP/1.1\r\nHost: ".len +
|
||||
options.host.len +
|
||||
"\r\nUpgrade-Insecure-Requests: 1\r\n".len +
|
||||
client.headers.items.len +
|
||||
@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
|
||||
req.headers.appendSliceAssumeCapacity(@tagName(options.method));
|
||||
req.headers.appendSliceAssumeCapacity(" ");
|
||||
req.headers.appendSliceAssumeCapacity(options.path);
|
||||
req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: ");
|
||||
req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
|
||||
req.headers.appendSliceAssumeCapacity(options.host);
|
||||
switch (options.protocol) {
|
||||
.https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user