mirror of
https://github.com/ziglang/zig.git
synced 2026-01-20 14:25:16 +00:00
std.crypto.Tls: discard ChangeCipherSpec messages
The next step here is to decrypt encrypted records
This commit is contained in:
parent
d2f5d0b199
commit
920e5bc4ff
@ -188,6 +188,12 @@ const NamedGroup = enum(u16) {
|
||||
// * fragment: opaque
|
||||
// - the data being transmitted
|
||||
|
||||
// Ciphertext
|
||||
// * ContentType opaque_type = application_data; /* 23 */
|
||||
// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
|
||||
// * uint16 length;
|
||||
// * opaque encrypted_record[TLSCiphertext.length];
|
||||
|
||||
// Handshake:
|
||||
// * type: HandshakeType
|
||||
// * length: u24
|
||||
@ -331,105 +337,144 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
|
||||
};
|
||||
try stream.writevAll(&iovecs);
|
||||
|
||||
{
|
||||
var handshake_buf: [4000]u8 = undefined;
|
||||
var handshake_buf: [4000]u8 = undefined;
|
||||
var len: usize = 0;
|
||||
var i: usize = i: {
|
||||
const plaintext = handshake_buf[0..5];
|
||||
const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
|
||||
if (amt < plaintext.len) return error.EndOfStream;
|
||||
len = try stream.readAtLeast(&handshake_buf, plaintext.len);
|
||||
if (len < plaintext.len) return error.EndOfStream;
|
||||
const ct = @intToEnum(ContentType, plaintext[0]);
|
||||
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
|
||||
const end = plaintext.len + frag_len;
|
||||
if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
|
||||
if (amt < end) {
|
||||
const amt2 = try stream.readAll(handshake_buf[amt..end]);
|
||||
if (amt2 < plaintext.len) return error.EndOfStream;
|
||||
if (end > handshake_buf.len) return error.TlsRecordOverflow;
|
||||
if (end > len) {
|
||||
len += try stream.readAtLeast(handshake_buf[len..], end - len);
|
||||
if (end > len) return error.EndOfStream;
|
||||
}
|
||||
const frag = handshake_buf[plaintext.len..end];
|
||||
|
||||
if (ct == .alert) {
|
||||
const level = @intToEnum(AlertLevel, frag[0]);
|
||||
const desc = @intToEnum(AlertDescription, frag[1]);
|
||||
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
|
||||
std.process.exit(1);
|
||||
} else if (ct == .handshake) {
|
||||
if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
|
||||
return error.TlsUnexpectedMessage;
|
||||
}
|
||||
const length = mem.readIntBig(u24, frag[1..4]);
|
||||
if (4 + length != frag.len) return error.TlsBadLength;
|
||||
const hello = frag[4..];
|
||||
const legacy_version = mem.readIntBig(u16, hello[0..2]);
|
||||
const random = hello[2..34].*;
|
||||
_ = random;
|
||||
const legacy_session_id_echo_len = hello[34];
|
||||
if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
|
||||
const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
|
||||
const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
|
||||
return error.TlsIllegalParameter;
|
||||
std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
|
||||
const legacy_compression_method = hello[37];
|
||||
_ = legacy_compression_method;
|
||||
const extensions_size = mem.readIntBig(u16, hello[38..40]);
|
||||
if (40 + extensions_size != hello.len) return error.TlsBadLength;
|
||||
var i: usize = 40;
|
||||
var supported_version: u16 = 0;
|
||||
var have_server_pub_key = false;
|
||||
while (i < hello.len) {
|
||||
const et = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
const next_i = i + ext_size;
|
||||
if (next_i > hello.len) return error.TlsBadLength;
|
||||
switch (et) {
|
||||
@enumToInt(ExtensionType.supported_versions) => {
|
||||
if (supported_version != 0) return error.TlsIllegalParameter;
|
||||
supported_version = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
},
|
||||
@enumToInt(ExtensionType.key_share) => {
|
||||
if (have_server_pub_key) return error.TlsIllegalParameter;
|
||||
const named_group = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
switch (named_group) {
|
||||
@enumToInt(NamedGroup.x25519) => {
|
||||
const key_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
if (key_size != 32) return error.TlsBadLength;
|
||||
const encrypted_key = hello[i..][0..32].*;
|
||||
const server_pub_key = try crypto.dh.X25519.scalarmult(
|
||||
tls.x25519_priv_key,
|
||||
encrypted_key,
|
||||
);
|
||||
tls.x25519_server_pub_key = server_pub_key;
|
||||
have_server_pub_key = true;
|
||||
},
|
||||
else => {
|
||||
std.debug.print("named group: {x}\n", .{named_group});
|
||||
return error.TlsIllegalParameter;
|
||||
},
|
||||
}
|
||||
},
|
||||
else => {
|
||||
std.debug.print("unexpected extension: {x}\n", .{et});
|
||||
},
|
||||
switch (ct) {
|
||||
.alert => {
|
||||
const level = @intToEnum(AlertLevel, frag[0]);
|
||||
const desc = @intToEnum(AlertDescription, frag[1]);
|
||||
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
|
||||
return error.TlsAlert;
|
||||
},
|
||||
.handshake => {
|
||||
if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
|
||||
return error.TlsUnexpectedMessage;
|
||||
}
|
||||
i = next_i;
|
||||
}
|
||||
if (!have_server_pub_key) return error.TlsIllegalParameter;
|
||||
const tls_version = if (supported_version == 0) legacy_version else supported_version;
|
||||
switch (tls_version) {
|
||||
@enumToInt(ProtocolVersion.tls_1_2) => {
|
||||
std.debug.print("server wants TLS v1.2\n", .{});
|
||||
},
|
||||
@enumToInt(ProtocolVersion.tls_1_3) => {
|
||||
std.debug.print("server wants TLS v1.3\n", .{});
|
||||
},
|
||||
else => return error.TlsIllegalParameter,
|
||||
}
|
||||
} else {
|
||||
std.debug.print("content_type: {s}\n", .{@tagName(ct)});
|
||||
std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
|
||||
const length = mem.readIntBig(u24, frag[1..4]);
|
||||
if (4 + length != frag.len) return error.TlsBadLength;
|
||||
const hello = frag[4..];
|
||||
const legacy_version = mem.readIntBig(u16, hello[0..2]);
|
||||
const random = hello[2..34].*;
|
||||
_ = random;
|
||||
const legacy_session_id_echo_len = hello[34];
|
||||
if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
|
||||
const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
|
||||
const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
|
||||
return error.TlsIllegalParameter;
|
||||
std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
|
||||
const legacy_compression_method = hello[37];
|
||||
_ = legacy_compression_method;
|
||||
const extensions_size = mem.readIntBig(u16, hello[38..40]);
|
||||
if (40 + extensions_size != hello.len) return error.TlsBadLength;
|
||||
var i: usize = 40;
|
||||
var supported_version: u16 = 0;
|
||||
var have_server_pub_key = false;
|
||||
while (i < hello.len) {
|
||||
const et = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
const next_i = i + ext_size;
|
||||
if (next_i > hello.len) return error.TlsBadLength;
|
||||
switch (et) {
|
||||
@enumToInt(ExtensionType.supported_versions) => {
|
||||
if (supported_version != 0) return error.TlsIllegalParameter;
|
||||
supported_version = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
},
|
||||
@enumToInt(ExtensionType.key_share) => {
|
||||
if (have_server_pub_key) return error.TlsIllegalParameter;
|
||||
const named_group = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
switch (named_group) {
|
||||
@enumToInt(NamedGroup.x25519) => {
|
||||
const key_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||
i += 2;
|
||||
if (key_size != 32) return error.TlsBadLength;
|
||||
const encrypted_key = hello[i..][0..32].*;
|
||||
const server_pub_key = try crypto.dh.X25519.scalarmult(
|
||||
tls.x25519_priv_key,
|
||||
encrypted_key,
|
||||
);
|
||||
tls.x25519_server_pub_key = server_pub_key;
|
||||
have_server_pub_key = true;
|
||||
},
|
||||
else => {
|
||||
std.debug.print("named group: {x}\n", .{named_group});
|
||||
return error.TlsIllegalParameter;
|
||||
},
|
||||
}
|
||||
},
|
||||
else => {
|
||||
std.debug.print("unexpected extension: {x}\n", .{et});
|
||||
},
|
||||
}
|
||||
i = next_i;
|
||||
}
|
||||
if (!have_server_pub_key) return error.TlsIllegalParameter;
|
||||
const tls_version = if (supported_version == 0) legacy_version else supported_version;
|
||||
switch (tls_version) {
|
||||
@enumToInt(ProtocolVersion.tls_1_2) => {
|
||||
std.debug.print("server wants TLS v1.2\n", .{});
|
||||
},
|
||||
@enumToInt(ProtocolVersion.tls_1_3) => {
|
||||
std.debug.print("server wants TLS v1.3\n", .{});
|
||||
},
|
||||
else => return error.TlsIllegalParameter,
|
||||
}
|
||||
},
|
||||
else => return error.TlsUnexpectedMessage,
|
||||
}
|
||||
break :i end;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const end_hdr = i + 5;
|
||||
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
|
||||
if (end_hdr > len) {
|
||||
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
|
||||
if (end_hdr > len) return error.EndOfStream;
|
||||
}
|
||||
const ct = @intToEnum(ContentType, handshake_buf[i]);
|
||||
i += 1;
|
||||
const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
|
||||
i += 2;
|
||||
_ = legacy_version;
|
||||
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
|
||||
i += 2;
|
||||
const end = i + record_size;
|
||||
if (end > handshake_buf.len) return error.TlsRecordOverflow;
|
||||
if (end > len) {
|
||||
len += try stream.readAtLeast(handshake_buf[len..], end - len);
|
||||
if (end > len) return error.EndOfStream;
|
||||
}
|
||||
switch (ct) {
|
||||
.change_cipher_spec => {
|
||||
if (record_size != 1) return error.TlsUnexpectedMessage;
|
||||
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
|
||||
},
|
||||
.application_data => {
|
||||
std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
|
||||
},
|
||||
else => {
|
||||
std.debug.print("content type: {s}\n", .{@tagName(ct)});
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
}
|
||||
i = end;
|
||||
}
|
||||
|
||||
tls.state = .sent_hello;
|
||||
|
||||
@ -1680,9 +1680,9 @@ pub const Stream = struct {
|
||||
}
|
||||
|
||||
/// Returns the number of bytes read, calling the underlying read function
|
||||
/// multiple times until at least the buffer has at least `len` bytes
|
||||
/// filled. If the number read is less than `len` it means the stream
|
||||
/// reached the end. Reaching the end of the stream is not an error
|
||||
/// the minimal number of times until at least the buffer has at least
|
||||
/// `len` bytes filled. If the number read is less than `len` it means the
|
||||
/// stream reached the end. Reaching the end of the stream is not an error
|
||||
/// condition.
|
||||
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
|
||||
var index: usize = 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user