mirror of
https://github.com/ziglang/zig.git
synced 2026-02-20 16:24:51 +00:00
std.crypto.Tls: parse the ServerHello handshake
This commit is contained in:
parent
ba44513c2f
commit
d2f5d0b199
@ -8,6 +8,13 @@ const assert = std.debug.assert;
|
|||||||
state: State = .start,
|
state: State = .start,
|
||||||
x25519_priv_key: [32]u8 = undefined,
|
x25519_priv_key: [32]u8 = undefined,
|
||||||
x25519_pub_key: [32]u8 = undefined,
|
x25519_pub_key: [32]u8 = undefined,
|
||||||
|
x25519_server_pub_key: [32]u8 = undefined,
|
||||||
|
|
||||||
|
const ProtocolVersion = enum(u16) {
|
||||||
|
tls_1_2 = 0x0303,
|
||||||
|
tls_1_3 = 0x0304,
|
||||||
|
_,
|
||||||
|
};
|
||||||
|
|
||||||
const State = enum {
|
const State = enum {
|
||||||
/// In this state, all fields are undefined except state.
|
/// In this state, all fields are undefined except state.
|
||||||
@ -186,6 +193,18 @@ const NamedGroup = enum(u16) {
|
|||||||
// * length: u24
|
// * length: u24
|
||||||
// * data: opaque
|
// * data: opaque
|
||||||
|
|
||||||
|
// ServerHello:
|
||||||
|
// * ProtocolVersion legacy_version = 0x0303;
|
||||||
|
// * Random random;
|
||||||
|
// * opaque legacy_session_id_echo<0..32>;
|
||||||
|
// * CipherSuite cipher_suite;
|
||||||
|
// * uint8 legacy_compression_method = 0;
|
||||||
|
// * Extension extensions<6..2^16-1>;
|
||||||
|
|
||||||
|
// Extension:
|
||||||
|
// * ExtensionType extension_type;
|
||||||
|
// * opaque extension_data<0..2^16-1>;
|
||||||
|
|
||||||
const CipherSuite = enum(u16) {
|
const CipherSuite = enum(u16) {
|
||||||
TLS_AES_128_GCM_SHA256 = 0x1301,
|
TLS_AES_128_GCM_SHA256 = 0x1301,
|
||||||
TLS_AES_256_GCM_SHA384 = 0x1302,
|
TLS_AES_256_GCM_SHA384 = 0x1302,
|
||||||
@ -259,10 +278,10 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
|
|||||||
|
|
||||||
// Extension: key_share
|
// Extension: key_share
|
||||||
0, 51, // ExtensionType.key_share
|
0, 51, // ExtensionType.key_share
|
||||||
0x00, 38, // byte length of this extension payload
|
0, 38, // byte length of this extension payload
|
||||||
0x00, 36, // byte length of client_shares
|
0, 36, // byte length of client_shares
|
||||||
0x00, 0x1D, // NamedGroup.x25519
|
0x00, 0x1D, // NamedGroup.x25519
|
||||||
0x00, 32, // byte length of key_exchange
|
0, 32, // byte length of key_exchange
|
||||||
} ++ tls.x25519_pub_key ++ [_]u8{
|
} ++ tls.x25519_pub_key ++ [_]u8{
|
||||||
|
|
||||||
// Extension: server_name
|
// Extension: server_name
|
||||||
@ -313,21 +332,103 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
|
|||||||
try stream.writevAll(&iovecs);
|
try stream.writevAll(&iovecs);
|
||||||
|
|
||||||
{
|
{
|
||||||
var buf: [1000]u8 = undefined;
|
var handshake_buf: [4000]u8 = undefined;
|
||||||
const amt = try stream.read(&buf);
|
const plaintext = handshake_buf[0..5];
|
||||||
const resp = buf[0..amt];
|
const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
|
||||||
const ct = @intToEnum(ContentType, resp[0]);
|
if (amt < plaintext.len) return error.EndOfStream;
|
||||||
|
const ct = @intToEnum(ContentType, plaintext[0]);
|
||||||
|
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
|
||||||
|
const end = plaintext.len + frag_len;
|
||||||
|
if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
|
||||||
|
if (amt < end) {
|
||||||
|
const amt2 = try stream.readAll(handshake_buf[amt..end]);
|
||||||
|
if (amt2 < plaintext.len) return error.EndOfStream;
|
||||||
|
}
|
||||||
|
const frag = handshake_buf[plaintext.len..end];
|
||||||
|
|
||||||
if (ct == .alert) {
|
if (ct == .alert) {
|
||||||
//const prot_ver = @bitCast(u16, resp[1..][0..2].*);
|
const level = @intToEnum(AlertLevel, frag[0]);
|
||||||
const len = std.mem.readIntBig(u16, resp[3..][0..2]);
|
const desc = @intToEnum(AlertDescription, frag[1]);
|
||||||
const alert = resp[5..][0..len];
|
|
||||||
const level = @intToEnum(AlertLevel, alert[0]);
|
|
||||||
const desc = @intToEnum(AlertDescription, alert[1]);
|
|
||||||
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
|
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
|
||||||
std.process.exit(1);
|
std.process.exit(1);
|
||||||
|
} else if (ct == .handshake) {
|
||||||
|
if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
|
||||||
|
return error.TlsUnexpectedMessage;
|
||||||
|
}
|
||||||
|
const length = mem.readIntBig(u24, frag[1..4]);
|
||||||
|
if (4 + length != frag.len) return error.TlsBadLength;
|
||||||
|
const hello = frag[4..];
|
||||||
|
const legacy_version = mem.readIntBig(u16, hello[0..2]);
|
||||||
|
const random = hello[2..34].*;
|
||||||
|
_ = random;
|
||||||
|
const legacy_session_id_echo_len = hello[34];
|
||||||
|
if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
|
||||||
|
const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
|
||||||
|
const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
|
||||||
|
return error.TlsIllegalParameter;
|
||||||
|
std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
|
||||||
|
const legacy_compression_method = hello[37];
|
||||||
|
_ = legacy_compression_method;
|
||||||
|
const extensions_size = mem.readIntBig(u16, hello[38..40]);
|
||||||
|
if (40 + extensions_size != hello.len) return error.TlsBadLength;
|
||||||
|
var i: usize = 40;
|
||||||
|
var supported_version: u16 = 0;
|
||||||
|
var have_server_pub_key = false;
|
||||||
|
while (i < hello.len) {
|
||||||
|
const et = mem.readIntBig(u16, hello[i..][0..2]);
|
||||||
|
i += 2;
|
||||||
|
const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||||
|
i += 2;
|
||||||
|
const next_i = i + ext_size;
|
||||||
|
if (next_i > hello.len) return error.TlsBadLength;
|
||||||
|
switch (et) {
|
||||||
|
@enumToInt(ExtensionType.supported_versions) => {
|
||||||
|
if (supported_version != 0) return error.TlsIllegalParameter;
|
||||||
|
supported_version = mem.readIntBig(u16, hello[i..][0..2]);
|
||||||
|
},
|
||||||
|
@enumToInt(ExtensionType.key_share) => {
|
||||||
|
if (have_server_pub_key) return error.TlsIllegalParameter;
|
||||||
|
const named_group = mem.readIntBig(u16, hello[i..][0..2]);
|
||||||
|
i += 2;
|
||||||
|
switch (named_group) {
|
||||||
|
@enumToInt(NamedGroup.x25519) => {
|
||||||
|
const key_size = mem.readIntBig(u16, hello[i..][0..2]);
|
||||||
|
i += 2;
|
||||||
|
if (key_size != 32) return error.TlsBadLength;
|
||||||
|
const encrypted_key = hello[i..][0..32].*;
|
||||||
|
const server_pub_key = try crypto.dh.X25519.scalarmult(
|
||||||
|
tls.x25519_priv_key,
|
||||||
|
encrypted_key,
|
||||||
|
);
|
||||||
|
tls.x25519_server_pub_key = server_pub_key;
|
||||||
|
have_server_pub_key = true;
|
||||||
|
},
|
||||||
|
else => {
|
||||||
|
std.debug.print("named group: {x}\n", .{named_group});
|
||||||
|
return error.TlsIllegalParameter;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
else => {
|
||||||
|
std.debug.print("unexpected extension: {x}\n", .{et});
|
||||||
|
},
|
||||||
|
}
|
||||||
|
i = next_i;
|
||||||
|
}
|
||||||
|
if (!have_server_pub_key) return error.TlsIllegalParameter;
|
||||||
|
const tls_version = if (supported_version == 0) legacy_version else supported_version;
|
||||||
|
switch (tls_version) {
|
||||||
|
@enumToInt(ProtocolVersion.tls_1_2) => {
|
||||||
|
std.debug.print("server wants TLS v1.2\n", .{});
|
||||||
|
},
|
||||||
|
@enumToInt(ProtocolVersion.tls_1_3) => {
|
||||||
|
std.debug.print("server wants TLS v1.3\n", .{});
|
||||||
|
},
|
||||||
|
else => return error.TlsIllegalParameter,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std.debug.print("content_type: {s}\n", .{@tagName(ct)});
|
std.debug.print("content_type: {s}\n", .{@tagName(ct)});
|
||||||
std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) });
|
std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,7 @@ pub const Request = struct {
|
|||||||
|
|
||||||
pub fn deinit(client: *Client) void {
|
pub fn deinit(client: *Client) void {
|
||||||
assert(client.active_requests == 0);
|
assert(client.active_requests == 0);
|
||||||
client.headers.denit(client.allocator);
|
client.headers.deinit(client.allocator);
|
||||||
client.* = undefined;
|
client.* = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,6 +69,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
|
|||||||
.stream = try net.tcpConnectToHost(client.allocator, options.host, options.port),
|
.stream = try net.tcpConnectToHost(client.allocator, options.host, options.port),
|
||||||
.protocol = options.protocol,
|
.protocol = options.protocol,
|
||||||
};
|
};
|
||||||
|
client.active_requests += 1;
|
||||||
errdefer req.deinit();
|
errdefer req.deinit();
|
||||||
|
|
||||||
switch (options.protocol) {
|
switch (options.protocol) {
|
||||||
@ -100,7 +101,6 @@ pub fn request(client: *Client, options: Request.Options) !Request {
|
|||||||
}
|
}
|
||||||
req.headers.appendSliceAssumeCapacity(client.headers.items);
|
req.headers.appendSliceAssumeCapacity(client.headers.items);
|
||||||
|
|
||||||
client.active_requests += 1;
|
|
||||||
return req;
|
return req;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1672,6 +1672,28 @@ pub const Stream = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the number of bytes read. If the number read is smaller than
|
||||||
|
/// `buffer.len`, it means the stream reached the end. Reaching the end of
|
||||||
|
/// a stream is not an error condition.
|
||||||
|
pub fn readAll(s: Stream, buffer: []u8) ReadError!usize {
|
||||||
|
return readAtLeast(s, buffer, buffer.len);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the number of bytes read, calling the underlying read function
|
||||||
|
/// multiple times until at least the buffer has at least `len` bytes
|
||||||
|
/// filled. If the number read is less than `len` it means the stream
|
||||||
|
/// reached the end. Reaching the end of the stream is not an error
|
||||||
|
/// condition.
|
||||||
|
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
|
||||||
|
var index: usize = 0;
|
||||||
|
while (index < len) {
|
||||||
|
const amt = try s.read(buffer[index..]);
|
||||||
|
if (amt == 0) break;
|
||||||
|
index += amt;
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
/// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
|
/// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
|
||||||
/// file system thread instead of non-blocking. It needs to be reworked to properly
|
/// file system thread instead of non-blocking. It needs to be reworked to properly
|
||||||
/// use non-blocking I/O.
|
/// use non-blocking I/O.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user