diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 0dc6946003..65f54ffa68 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -188,6 +188,12 @@ const NamedGroup = enum(u16) { // * fragment: opaque // - the data being transmitted +// Ciphertext +// * ContentType opaque_type = application_data; /* 23 */ +// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ +// * uint16 length; +// * opaque encrypted_record[TLSCiphertext.length]; + // Handshake: // * type: HandshakeType // * length: u24 @@ -331,105 +337,144 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }; try stream.writevAll(&iovecs); - { - var handshake_buf: [4000]u8 = undefined; + var handshake_buf: [4000]u8 = undefined; + var len: usize = 0; + var i: usize = i: { const plaintext = handshake_buf[0..5]; - const amt = try stream.readAtLeast(&handshake_buf, plaintext.len); - if (amt < plaintext.len) return error.EndOfStream; + len = try stream.readAtLeast(&handshake_buf, plaintext.len); + if (len < plaintext.len) return error.EndOfStream; const ct = @intToEnum(ContentType, plaintext[0]); const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); const end = plaintext.len + frag_len; - if (end > handshake_buf.len) return error.TlsServerHelloTooBig; - if (amt < end) { - const amt2 = try stream.readAll(handshake_buf[amt..end]); - if (amt2 < plaintext.len) return error.EndOfStream; + if (end > handshake_buf.len) return error.TlsRecordOverflow; + if (end > len) { + len += try stream.readAtLeast(handshake_buf[len..], end - len); + if (end > len) return error.EndOfStream; } const frag = handshake_buf[plaintext.len..end]; - if (ct == .alert) { - const level = @intToEnum(AlertLevel, frag[0]); - const desc = @intToEnum(AlertDescription, frag[1]); - std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); - std.process.exit(1); - } else if (ct == .handshake) { - if (frag[0] != @enumToInt(HandshakeType.server_hello)) { - return error.TlsUnexpectedMessage; - } - const length = mem.readIntBig(u24, frag[1..4]); - if (4 + length != frag.len) return error.TlsBadLength; - const hello = frag[4..]; - const legacy_version = mem.readIntBig(u16, hello[0..2]); - const random = hello[2..34].*; - _ = random; - const legacy_session_id_echo_len = hello[34]; - if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; - const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch - return error.TlsIllegalParameter; - std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); - const legacy_compression_method = hello[37]; - _ = legacy_compression_method; - const extensions_size = mem.readIntBig(u16, hello[38..40]); - if (40 + extensions_size != hello.len) return error.TlsBadLength; - var i: usize = 40; - var supported_version: u16 = 0; - var have_server_pub_key = false; - while (i < hello.len) { - const et = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - const ext_size = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - const next_i = i + ext_size; - if (next_i > hello.len) return error.TlsBadLength; - switch (et) { - @enumToInt(ExtensionType.supported_versions) => { - if (supported_version != 0) return error.TlsIllegalParameter; - supported_version = mem.readIntBig(u16, hello[i..][0..2]); - }, - @enumToInt(ExtensionType.key_share) => { - if (have_server_pub_key) return error.TlsIllegalParameter; - const named_group = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - switch (named_group) { - @enumToInt(NamedGroup.x25519) => { - const key_size = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - if (key_size != 32) return error.TlsBadLength; - const encrypted_key = hello[i..][0..32].*; - const server_pub_key = try crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - encrypted_key, - ); - tls.x25519_server_pub_key = server_pub_key; - have_server_pub_key = true; - }, - else => { - std.debug.print("named group: {x}\n", .{named_group}); - return error.TlsIllegalParameter; - }, - } - }, - else => { - std.debug.print("unexpected extension: {x}\n", .{et}); - }, + switch (ct) { + .alert => { + const level = @intToEnum(AlertLevel, frag[0]); + const desc = @intToEnum(AlertDescription, frag[1]); + std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + return error.TlsAlert; + }, + .handshake => { + if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + return error.TlsUnexpectedMessage; } - i = next_i; - } - if (!have_server_pub_key) return error.TlsIllegalParameter; - const tls_version = if (supported_version == 0) legacy_version else supported_version; - switch (tls_version) { - @enumToInt(ProtocolVersion.tls_1_2) => { - std.debug.print("server wants TLS v1.2\n", .{}); - }, - @enumToInt(ProtocolVersion.tls_1_3) => { - std.debug.print("server wants TLS v1.3\n", .{}); - }, - else => return error.TlsIllegalParameter, - } - } else { - std.debug.print("content_type: {s}\n", .{@tagName(ct)}); - std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) }); + const length = mem.readIntBig(u24, frag[1..4]); + if (4 + length != frag.len) return error.TlsBadLength; + const hello = frag[4..]; + const legacy_version = mem.readIntBig(u16, hello[0..2]); + const random = hello[2..34].*; + _ = random; + const legacy_session_id_echo_len = hello[34]; + if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; + const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); + const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + return error.TlsIllegalParameter; + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + const legacy_compression_method = hello[37]; + _ = legacy_compression_method; + const extensions_size = mem.readIntBig(u16, hello[38..40]); + if (40 + extensions_size != hello.len) return error.TlsBadLength; + var i: usize = 40; + var supported_version: u16 = 0; + var have_server_pub_key = false; + while (i < hello.len) { + const et = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const ext_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const next_i = i + ext_size; + if (next_i > hello.len) return error.TlsBadLength; + switch (et) { + @enumToInt(ExtensionType.supported_versions) => { + if (supported_version != 0) return error.TlsIllegalParameter; + supported_version = mem.readIntBig(u16, hello[i..][0..2]); + }, + @enumToInt(ExtensionType.key_share) => { + if (have_server_pub_key) return error.TlsIllegalParameter; + const named_group = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + switch (named_group) { + @enumToInt(NamedGroup.x25519) => { + const key_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + if (key_size != 32) return error.TlsBadLength; + const encrypted_key = hello[i..][0..32].*; + const server_pub_key = try crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + encrypted_key, + ); + tls.x25519_server_pub_key = server_pub_key; + have_server_pub_key = true; + }, + else => { + std.debug.print("named group: {x}\n", .{named_group}); + return error.TlsIllegalParameter; + }, + } + }, + else => { + std.debug.print("unexpected extension: {x}\n", .{et}); + }, + } + i = next_i; + } + if (!have_server_pub_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; + switch (tls_version) { + @enumToInt(ProtocolVersion.tls_1_2) => { + std.debug.print("server wants TLS v1.2\n", .{}); + }, + @enumToInt(ProtocolVersion.tls_1_3) => { + std.debug.print("server wants TLS v1.3\n", .{}); + }, + else => return error.TlsIllegalParameter, + } + }, + else => return error.TlsUnexpectedMessage, } + break :i end; + }; + + while (true) { + const end_hdr = i + 5; + if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; + if (end_hdr > len) { + len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); + if (end_hdr > len) return error.EndOfStream; + } + const ct = @intToEnum(ContentType, handshake_buf[i]); + i += 1; + const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]); + i += 2; + _ = legacy_version; + const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); + i += 2; + const end = i + record_size; + if (end > handshake_buf.len) return error.TlsRecordOverflow; + if (end > len) { + len += try stream.readAtLeast(handshake_buf[len..], end - len); + if (end > len) return error.EndOfStream; + } + switch (ct) { + .change_cipher_spec => { + if (record_size != 1) return error.TlsUnexpectedMessage; + if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; + }, + .application_data => { + std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size}); + }, + else => { + std.debug.print("content type: {s}\n", .{@tagName(ct)}); + return error.TlsUnexpectedMessage; + }, + } + i = end; } tls.state = .sent_hello; diff --git a/lib/std/net.zig b/lib/std/net.zig index 8c8ab51a4a..a265fa69a9 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1680,9 +1680,9 @@ pub const Stream = struct { } /// Returns the number of bytes read, calling the underlying read function - /// multiple times until at least the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error + /// the minimal number of times until at least the buffer has at least + /// `len` bytes filled. If the number read is less than `len` it means the + /// stream reached the end. Reaching the end of the stream is not an error /// condition. pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { var index: usize = 0;