diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 8ef4d9bfad..7d89da8929 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -47,6 +47,11 @@ pub const hello_retry_request_sequence = [32]u8{ 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, }; +pub const close_notify_alert = [_]u8{ + @enumToInt(AlertLevel.warning), + @enumToInt(AlertDescription.close_notify), +}; + pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, tls_1_3 = 0x0304, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index bca05a3ffd..df59932d4a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -37,8 +37,54 @@ application_cipher: tls.ApplicationCipher, /// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, +/// This is an example of the type that is needed by the read and write +/// functions. It can have any fields but it must at least have these +/// functions. +/// +/// Note that `std.net.Stream` conforms to this interface. +/// +/// This declaration serves as documentation only. +pub const StreamInterface = struct { + /// Can be any error set. + pub const ReadError = error{}; + + /// Returns the number of bytes read. The number read may be less than the + /// buffer space provided. End-of-stream is indicated by a return value of 0. + /// + /// The `iovecs` parameter is mutable because so that function may to + /// mutate the fields in order to handle partial reads from the underlying + /// stream layer. + pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Can be any error set. + pub const WriteError = error{}; + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided. A short read does not indicate end-of-stream. + pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided, indicating end-of-stream. + /// The `iovecs` parameter is mutable in case this function needs to mutate + /// the fields in order to handle partial writes from the underlying layer. + pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { + // This can be implemented in terms of writev, or specialized if desired. + _ = .{ this, iovecs }; + @panic("unimplemented"); + } +}; + +/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which +/// must conform to `StreamInterface`. +/// /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { +pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; @@ -579,31 +625,115 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) } } -pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { + return writeEnd(c, stream, bytes, false); +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(stream, bytes[index..]); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.writeEnd(stream, bytes[index..], end); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; + var iovecs_buf: [6]std.os.iovec_const = undefined; + var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); + if (end) { + prepared.iovec_end += prepareCiphertextRecord( + c, + iovecs_buf[prepared.iovec_end..], + ciphertext_buf[prepared.ciphertext_end..], + &tls.close_notify_alert, + .alert, + ).iovec_end; + } + + const iovec_end = prepared.iovec_end; + const overhead_len = prepared.overhead_len; + + // Ideally we would call writev exactly once here, however, we must ensure + // that we don't return with a record partially written. + var i: usize = 0; + var total_amt: usize = 0; + while (true) { + var amt = try stream.writev(iovecs_buf[i..iovec_end]); + while (amt >= iovecs_buf[i].iov_len) { + const encrypted_amt = iovecs_buf[i].iov_len; + total_amt += encrypted_amt - overhead_len; + amt -= encrypted_amt; + i += 1; + // Rely on the property that iovecs delineate records, meaning that + // if amt equals zero here, we have fortunately found ourselves + // with a short read that aligns at the record boundary. + if (i >= iovec_end) return total_amt; + // We also cannot return on a vector boundary if the final close_notify is + // not sent; otherwise the caller would not know to retry the call. + if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; + } + iovecs_buf[i].iov_base += amt; + iovecs_buf[i].iov_len -= amt; + } +} + +fn prepareCiphertextRecord( + c: *Client, + iovecs: []std.os.iovec_const, + ciphertext_buf: []u8, + bytes: []const u8, + inner_content_type: tls.ContentType, +) struct { + iovec_end: usize, + ciphertext_end: usize, + /// How many bytes are taken up by overhead per record. + overhead_len: usize, +} { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var iovecs_buf: [5]std.os.iovec_const = undefined; var ciphertext_end: usize = 0; var iovec_end: usize = 0; var bytes_i: usize = 0; - // How many bytes are taken up by overhead per record. - const overhead_len: usize = switch (c.application_cipher) { - inline else => |*p| l: { + switch (c.application_cipher) { + inline else => |*p| { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len = @intCast(u16, @min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), - ciphertext_buf.len - - tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1, + ciphertext_buf.len - close_notify_alert_reserved - + overhead_len - ciphertext_end, )); - if (encrypted_content_len == 0) break :l overhead_len; + if (encrypted_content_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data); + cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type); bytes_i += encrypted_content_len; const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; @@ -626,40 +756,13 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs_buf[iovec_end] = .{ + iovecs[iovec_end] = .{ .iov_base = record.ptr, .iov_len = record.len, }; iovec_end += 1; } }, - }; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].iov_len) { - const encrypted_amt = iovecs_buf[i].iov_len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end or amt == 0) return total_amt; - } - iovecs_buf[i].iov_base += amt; - iovecs_buf[i].iov_len -= amt; - } -} - -pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); } } @@ -669,6 +772,7 @@ pub fn eof(c: Client) bool { c.partial_ciphertext_idx >= c.partial_ciphertext_end; } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the buffer has at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. @@ -678,10 +782,12 @@ pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize return readvAtLeast(c, stream, &iovecs, len); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, 1); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read. If the number read is smaller than /// `buffer.len`, it means the stream reached the end. Reaching the end of the /// stream is not an error condition. @@ -689,6 +795,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, buffer.len); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read. If the number read is less than the space /// provided it means the stream reached the end. Reaching the end of the /// stream is not an error condition. @@ -698,6 +805,7 @@ pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { return readvAtLeast(c, stream, iovecs); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the iovecs have at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. @@ -722,6 +830,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us } } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns number of bytes that have been read, populated inside `iovecs`. A /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` /// for the end of stream. The `eof()` may be true after any call to @@ -729,7 +838,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us /// function asserts that `eof()` is `false`. /// See `readv` for a higher level function that has the same, familiar API as /// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize { +pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { var vp: VecPut = .{ .iovecs = iovecs }; // Give away the buffered cleartext we have, if any. @@ -905,7 +1014,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove break :c cleartext; }, }; - c.read_seq += 1; + + c.read_seq = try std.math.add(u64, c.read_seq, 1); const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { @@ -1196,3 +1306,7 @@ const cipher_suites = enum_array(tls.CipherSuite, &.{ .AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, }); + +test { + _ = StreamInterface; +} diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index efae62680d..33df40866a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -47,7 +47,7 @@ pub const Request = struct { try req.stream.writeAll(req.headers.items); }, .https => { - try req.tls_client.writeAll(req.stream, req.headers.items); + try req.tls_client.writeAllEnd(req.stream, req.headers.items, true); }, } }