std: start converting networking stuff to new reader/writer

This commit is contained in:
Andrew Kelley 2025-04-15 23:09:01 -07:00
parent c872a9fd49
commit 20a784f713
5 changed files with 545 additions and 601 deletions

View File

@ -155,38 +155,35 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type {
}
// Read a DER-encoded integer.
fn readDerInt(out: []u8, reader: anytype) EncodingError!void {
var buf: [2]u8 = undefined;
_ = reader.readNoEof(&buf) catch return error.InvalidEncoding;
// Asserts `br` has storage capacity >= 2.
fn readDerInt(out: []u8, br: *std.io.BufferedReader) EncodingError!void {
const buf = br.take(2) catch return error.InvalidEncoding;
if (buf[0] != 0x02) return error.InvalidEncoding;
var expected_len = @as(usize, buf[1]);
var expected_len: usize = buf[1];
if (expected_len == 0 or expected_len > 1 + out.len) return error.InvalidEncoding;
var has_top_bit = false;
if (expected_len == 1 + out.len) {
if ((reader.readByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding;
if ((br.takeByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding;
expected_len -= 1;
has_top_bit = true;
}
const out_slice = out[out.len - expected_len ..];
reader.readNoEof(out_slice) catch return error.InvalidEncoding;
br.read(out_slice) catch return error.InvalidEncoding;
if (@intFromBool(has_top_bit) != out[0] >> 7) return error.InvalidEncoding;
}
/// Create a signature from a DER representation.
/// Returns InvalidEncoding if the DER encoding is invalid.
pub fn fromDer(der: []const u8) EncodingError!Signature {
if (der.len < 2) return error.InvalidEncoding;
var br: std.io.BufferedReader = undefined;
br.initFixed(der);
const buf = br.take(2) catch return error.InvalidEncoding;
if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) return error.InvalidEncoding;
var sig: Signature = mem.zeroInit(Signature, .{});
var fb: std.io.FixedBufferStream = .{ .buffer = der };
const reader = fb.reader();
var buf: [2]u8 = undefined;
_ = reader.readNoEof(&buf) catch return error.InvalidEncoding;
if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) {
return error.InvalidEncoding;
}
try readDerInt(&sig.r, reader);
try readDerInt(&sig.s, reader);
if (fb.getPos() catch unreachable != der.len) return error.InvalidEncoding;
try readDerInt(&sig.r, &br);
try readDerInt(&sig.s, &br);
if (br.seek != der.len) return error.InvalidEncoding;
return sig;
}
};

View File

@ -1,3 +1,6 @@
const builtin = @import("builtin");
const native_endian = builtin.cpu.arch.endian();
const std = @import("../../std.zig");
const tls = std.crypto.tls;
const Client = @This();
@ -13,18 +16,44 @@ const hkdfExpandLabel = tls.hkdfExpandLabel;
const int = tls.int;
const array = tls.array;
/// The encrypted stream from the server to the client. Bytes are pulled from
/// here via `reader`.
///
/// The buffer is asserted to have capacity at least `min_buffer_len`.
///
/// The size is enough to contain exactly one TLSCiphertext record.
/// This buffer is segmented into four parts:
/// 0. unused
/// 1. cleartext
/// 2. ciphertext
/// 3. unused
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
/// `partial_ciphertext_end` describe the span of the segments.
input: *std.io.BufferedReader,
/// The encrypted stream from the client to the server. Bytes are pushed here
/// via `writer`.
///
/// The buffer is asserted to have capacity at least `min_buffer_len`.
output: *std.io.BufferedWriter,
/// Cleartext received from the server here.
///
/// Its buffer aliases the buffer of `input`.
reader: std.io.BufferedReader,
/// Populated under various error conditions.
diagnostics: Diagnostics,
tls_version: tls.ProtocolVersion,
read_seq: u64,
write_seq: u64,
/// The starting index of cleartext bytes inside `partially_read_buffer`.
/// The starting index of cleartext bytes inside the input buffer.
partial_cleartext_idx: u15,
/// The ending index of cleartext bytes inside `partially_read_buffer` as well
/// The ending index of cleartext bytes inside the input buffer as well
/// as the starting index of ciphertext bytes.
partial_ciphertext_idx: u15,
/// The ending index of ciphertext bytes inside `partially_read_buffer`.
/// The ending index of ciphertext bytes inside the input buffer.
partial_ciphertext_end: u15,
/// When this is true, the stream may still not be at the end because there
/// may be data in `partially_read_buffer`.
/// may be data in the input buffer.
received_close_notify: bool,
/// By default, reaching the end-of-stream when reading from the server will
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
@ -35,24 +64,40 @@ received_close_notify: bool,
/// the amount of data expected, such as HTTP with the Content-Length header.
allow_truncation_attacks: bool,
application_cipher: tls.ApplicationCipher,
/// The size is enough to contain exactly one TLSCiphertext record.
/// This buffer is segmented into four parts:
/// 0. unused
/// 1. cleartext
/// 2. ciphertext
/// 3. unused
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
/// `partial_ciphertext_end` describe the span of the segments.
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
/// Encrypted bytes sent to the server here.
output: *std.io.BufferedWriter,
/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
/// programs with access to that file to decrypt all traffic over this connection.
ssl_key_log: ?struct {
/// If non-null, ssl secrets are logged to a stream. Creating such a log file
/// allows other programs with access to that file to decrypt all traffic over
/// this connection.
ssl_key_log: ?*SslKeyLog,
pub const Diagnostics = union {
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
err: anyerror,
/// Populated on `error.TlsAlert`.
///
/// If this isn't a error alert, then it's a closure alert, which makes
/// no sense in a handshake.
alert: tls.AlertDescription,
fn wrapWrite(d: *Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
returned catch |err| {
d.* = .{ .err = err };
return error.WriteFailure;
};
}
fn wrapRead(d: *Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
returned catch |err| {
d.* = .{ .err = err };
return error.ReadFailure;
};
}
};
pub const SslKeyLog = struct {
client_key_seq: u64,
server_key_seq: u64,
client_random: [32]u8,
file: std.fs.File,
writer: *std.io.BufferedWriter,
fn clientCounter(key_log: *@This()) u64 {
defer key_log.client_key_seq += 1;
@ -63,31 +108,12 @@ ssl_key_log: ?struct {
defer key_log.server_key_seq += 1;
return key_log.server_key_seq;
}
},
/// This is an example of the type that is needed by the read and write
/// functions. It can have any fields but it must at least have these
/// functions.
///
/// Note that `std.net.Stream` conforms to this interface.
///
/// This declaration serves as documentation only.
pub const StreamInterface = struct {
/// Can be any error set.
pub const ReadError = error{};
/// Returns the number of bytes read. The number read may be less than the
/// buffer space provided. End-of-stream is indicated by a return value of 0.
///
/// The `iovecs` parameter is mutable because so that function may to
/// mutate the fields in order to handle partial reads from the underlying
/// stream layer.
pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
};
/// The `std.io.BufferedReader` and `std.io.BufferedWriter` supplied to `init`
/// each require a buffer capacity at least this amount.
pub const min_buffer_len = tls.max_ciphertext_record_len;
pub const Options = struct {
/// How to perform host verification of server certificates.
host: union(enum) {
@ -109,39 +135,11 @@ pub const Options = struct {
/// Verify that the server certificate is authorized by a given ca bundle.
bundle: Certificate.Bundle,
},
/// If non-null, ssl secrets are logged to this file. Creating such a log file allows
/// If non-null, ssl secrets are logged to this stream. Creating such a log file allows
/// other programs with access to that file to decrypt all traffic over this connection.
/// TODO `std.crypto` should have no dependencies on `std.fs`.
ssl_key_log_file: ?std.fs.File = null,
diagnostics: ?*Diagnostics = null,
pub const Diagnostics = union {
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
err: anyerror,
/// Populated on `error.TlsAlert`.
///
/// If this isn't a error alert, then it's a closure alert, which makes
/// no sense in a handshake.
alert: tls.AlertDescription,
};
ssl_key_log: ?*std.io.BufferedWriter = null,
};
/// TODO I wish this could be a method of Diagnostics
fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
returned catch |err| {
if (opt_diags) |diags| diags.* = .{ .err = err };
return error.WriteFailure;
};
}
/// TODO I wish this could be a method of Diagnostics
fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
returned catch |err| {
if (opt_diags) |diags| diags.* = .{ .err = err };
return error.ReadFailure;
};
}
const InitError = error{
//OutOfMemory,
WriteFailure,
@ -193,12 +191,21 @@ const InitError = error{
WeakPublicKey,
};
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which
/// must conform to `StreamInterface`.
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session.
///
/// `host` is only borrowed during this function call.
pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client {
const diags = options.diagnostics;
///
/// Both `input` and `output` are asserted to have buffer capacity at least
/// `min_buffer_len`.
pub fn init(
client: *Client,
input: *std.io.BufferedReader,
output: *std.io.BufferedWriter,
options: Options,
) InitError!void {
assert(input.storage.buffer.len >= min_buffer_len);
assert(output.buffer.len >= min_buffer_len);
const diags = &client.diagnostics;
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@ -291,7 +298,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
{
var iovecs: [2][]const u8 = .{ cleartext_header, host };
try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
try diags.wrapWrite(output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
}
var tls_version: tls.ProtocolVersion = undefined;
@ -343,12 +350,12 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
fragment: while (true) {
try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len));
try diags.wrapRead(d.readAtLeastOurAmt(input, tls.record_header_len));
const record_header = d.buf[d.idx..][0..tls.record_header_len];
const record_ct = d.decode(tls.ContentType);
d.skip(2); // legacy_version
const record_len = d.decode(u16);
try wrapRead(diags, d.readAtLeast(input, record_len));
try diags.wrapRead(d.readAtLeast(input, record_len));
var record_decoder = try d.sub(record_len);
var ctd, const ct = content: switch (cipher_state) {
.cleartext => .{ record_decoder, record_ct },
@ -426,7 +433,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
const level = ctd.decode(tls.AlertLevel);
const desc = ctd.decode(tls.AlertDescription);
_ = level;
if (diags) |x| x.* = .{ .alert = desc };
diags.* = .{ .alert = desc };
return error.TlsAlert;
},
.change_cipher_spec => {
@ -768,7 +775,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
&client_change_cipher_spec_msg,
&client_verify_msg,
};
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
try diags.wrapWrite(output.writevAll(&all_msgs_vec));
},
}
write_seq += 1;
@ -833,7 +840,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
&client_change_cipher_spec_msg,
&finished_msg,
};
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
try diags.wrapWrite(output.writevAll(&all_msgs_vec));
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
@ -865,7 +872,10 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
},
};
const leftover = d.rest();
var client: Client = .{
client.* = .{
.input = input,
.output = output,
.reader = undefined,
.tls_version = tls_version,
.read_seq = switch (tls_version) {
.tls_1_3 => 0,
@ -883,7 +893,6 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
.received_close_notify = false,
.allow_truncation_attacks = false,
.application_cipher = app_cipher,
.output = output,
.partially_read_buffer = undefined,
.ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
.client_key_seq = key_seq,
@ -893,7 +902,14 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
} else null,
};
@memcpy(client.partially_read_buffer[0..leftover.len], leftover);
return client;
client.reader.init(.{
.context = client,
.vtable = &.{
.read = reader_read,
.readv = reader_readv,
},
}, input.storage.buffer[0..0]);
return;
},
else => return error.TlsUnexpectedMessage,
}
@ -919,81 +935,45 @@ pub fn writer(c: *Client) std.io.Writer {
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
const c: *Client = @alignCast(@ptrCast(context));
assert(data.len > 1 or splat > 0);
return writeEnd(c, data[0], false);
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
const output = &c.output;
const ciphertext_buf = try output.writableSlice(min_buffer_len);
var total_clear: usize = 0;
var ciphertext_end: usize = 0;
for (sliced_data) |buf| {
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
total_clear += prepared.cleartext_len;
ciphertext_end += prepared.ciphertext_end;
if (total_clear < buf.len) break;
}
output.advance(ciphertext_end);
return total_clear;
}
/// If `end` is true, then this function additionally sends a `close_notify`
/// alert, which is necessary for the server to distinguish between a properly
/// finished TLS session, or a truncation attack.
pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void {
var index: usize = 0;
while (index < bytes.len) {
index += try c.writeEnd(bytes[index..], end);
}
}
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
/// If `end` is true, then this function additionally sends a `close_notify` alert,
/// which is necessary for the server to distinguish between a properly finished
/// TLS session, or a truncation attack.
pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize {
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
var iovecs_buf: [6][]const u8 = undefined;
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
if (end) {
prepared.iovec_end += prepareCiphertextRecord(
c,
iovecs_buf[prepared.iovec_end..],
ciphertext_buf[prepared.ciphertext_end..],
&tls.close_notify_alert,
.alert,
).iovec_end;
}
const iovec_end = prepared.iovec_end;
const overhead_len = prepared.overhead_len;
// Ideally we would call writev exactly once here, however, we must ensure
// that we don't return with a record partially written.
var i: usize = 0;
var total_amt: usize = 0;
while (true) {
var amt = try c.output.writev(iovecs_buf[i..iovec_end]);
while (amt >= iovecs_buf[i].len) {
const encrypted_amt = iovecs_buf[i].len;
total_amt += encrypted_amt - overhead_len;
amt -= encrypted_amt;
i += 1;
// Rely on the property that iovecs delineate records, meaning that
// if amt equals zero here, we have fortunately found ourselves
// with a short read that aligns at the record boundary.
if (i >= iovec_end) return total_amt;
// We also cannot return on a vector boundary if the final close_notify is
// not sent; otherwise the caller would not know to retry the call.
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
}
iovecs_buf[i] = iovecs_buf[i][amt..];
}
/// Sends a `close_notify` alert, which is necessary for the server to
/// distinguish between a properly finished TLS session, or a truncation
/// attack.
pub fn end(c: *Client) anyerror!void {
const output = &c.output;
const ciphertext_buf = try output.writableSlice(min_buffer_len);
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
output.advance(prepared.cleartext_len);
return prepared.ciphertext_end;
}
fn prepareCiphertextRecord(
c: *Client,
iovecs: [][]const u8,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
) struct {
iovec_end: usize,
ciphertext_end: usize,
/// How many bytes are taken up by overhead per record.
overhead_len: usize,
cleartext_len: usize,
} {
// Due to the trailing inner content type byte in the ciphertext, we need
// an additional buffer for storing the cleartext into before encrypting.
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var ciphertext_end: usize = 0;
var iovec_end: usize = 0;
var bytes_i: usize = 0;
switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
@ -1001,18 +981,15 @@ fn prepareCiphertextRecord(
const pv = &p.tls_1_3;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
ciphertext_buf.len -|
(close_notify_alert_reserved + overhead_len + ciphertext_end),
ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (encrypted_content_len == 0) return .{
.iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
.overhead_len = overhead_len,
.cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]);
@ -1021,7 +998,6 @@ fn prepareCiphertextRecord(
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
@ -1039,35 +1015,27 @@ fn prepareCiphertextRecord(
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
c.write_seq += 1; // TODO send key_update on overflow
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = record;
iovec_end += 1;
}
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length;
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const message_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
ciphertext_buf.len -|
(close_notify_alert_reserved + overhead_len + ciphertext_end),
ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (message_len == 0) return .{
.iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
.overhead_len = overhead_len,
.cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]);
bytes_i += message_len;
const cleartext = cleartext_buf[0..message_len];
const record_start = ciphertext_end;
const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ciphertext_end += tls.record_header_len;
record_header.* = .{@intFromEnum(inner_content_type)} ++
@ -1089,10 +1057,6 @@ fn prepareCiphertextRecord(
ciphertext_end += P.mac_length;
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key);
c.write_seq += 1; // TODO send key_update on overflow
const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = record;
iovec_end += 1;
}
},
else => unreachable,
@ -1106,74 +1070,22 @@ pub fn eof(c: Client) bool {
c.partial_ciphertext_idx >= c.partial_ciphertext_end;
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read, calling the underlying read function the
/// minimal number of times until the buffer has at least `len` bytes filled.
/// If the number read is less than `len` it means the stream reached the end.
/// Reaching the end of the stream is not an error condition.
pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }};
return readvAtLeast(c, stream, &iovecs, len);
fn reader_read(
context: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
) anyerror!std.io.Reader.Status {
const buf = limit.slice(try bw.writableSlice(1));
const status = try reader_readv(context, &.{buf});
bw.advance(status.len);
return status;
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, 1);
}
fn reader_readv(context: ?*anyopaque, data: []const []u8) anyerror!std.io.Reader.Status {
const c: *Client = @ptrCast(@alignCast(context));
if (c.eof()) return .{ .end = true };
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read. If the number read is smaller than
/// `buffer.len`, it means the stream reached the end. Reaching the end of the
/// stream is not an error condition.
pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, buffer.len);
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read. If the number read is less than the space
/// provided it means the stream reached the end. Reaching the end of the
/// stream is not an error condition.
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial reads from the underlying stream layer.
pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
return readvAtLeast(c, stream, iovecs, 1);
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read, calling the underlying read function the
/// minimal number of times until the iovecs have at least `len` bytes filled.
/// If the number read is less than `len` it means the stream reached the end.
/// Reaching the end of the stream is not an error condition.
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial reads from the underlying stream layer.
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize {
if (c.eof()) return 0;
var off_i: usize = 0;
var vec_i: usize = 0;
while (true) {
var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
off_i += amt;
if (c.eof() or off_i >= len) return off_i;
while (amt >= iovecs[vec_i].len) {
amt -= iovecs[vec_i].len;
vec_i += 1;
}
iovecs[vec_i].base += amt;
iovecs[vec_i].len -= amt;
}
}
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns number of bytes that have been read, populated inside `iovecs`. A
/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
/// for the end of stream. The `eof()` may be true after any call to
/// `read`, including when greater than zero bytes are returned, and this
/// function asserts that `eof()` is `false`.
/// See `readv` for a higher level function that has the same, familiar API as
/// other read functions, such as `std.fs.File.read`.
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize {
var vp: VecPut = .{ .iovecs = iovecs };
var vp: VecPut = .{ .iovecs = data };
// Give away the buffered cleartext we have, if any.
const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
@ -1193,11 +1105,11 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
if (c.received_close_notify) {
c.partial_ciphertext_end = 0;
assert(vp.total == amt);
return amt;
return .{ .len = amt, .end = c.eof() };
} else if (amt > 0) {
// We don't need more data, so don't call read.
assert(vp.total == amt);
return amt;
return .{ .len = amt, .end = c.eof() };
}
}
@ -1241,7 +1153,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end;
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
const actual_read_len = try stream.readv(ask_iovecs);
const actual_read_len = try c.input.readv(ask_iovecs);
if (actual_read_len == 0) {
// This is either a truncation attack, a bug in the server, or an
// intentional omission of the close_notify message due to truncation
@ -1268,7 +1180,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
// Perfect split.
if (frag.ptr == frag1.ptr) {
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return vp.total;
return .{ .len = vp.total, .end = c.eof() };
}
frag = frag1;
in = 0;
@ -1310,8 +1222,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
const record_len = mem.readInt(u16, frag[in..][0..2], .big);
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
in += 2;
const end = in + record_len;
if (end > frag.len) {
const the_end = in + record_len;
if (the_end > frag.len) {
// We need the record header on the next iteration of the loop.
in -= tls.record_header_len;
@ -1398,17 +1310,23 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
.alert => {
if (cleartext.len != 2) return error.TlsDecodeError;
const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
if (desc == .close_notify) {
c.received_close_notify = true;
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return vp.total;
}
_ = level;
try desc.toError();
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
switch (desc) {
.close_notify => {
c.received_close_notify = true;
c.partial_ciphertext_end = c.partial_ciphertext_idx;
return .{ .len = vp.total, .end = c.eof() };
},
.user_canceled => {
// TODO: handle server-side closures
return error.TlsUnexpectedMessage;
},
else => {
c.diagnostics = .{ .alert = desc };
return error.TlsAlert;
},
}
},
.handshake => {
var ct_i: usize = 0;
@ -1524,7 +1442,7 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
}) catch {};
}
fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) std.io.Reader.Status {
const saved_buf = frag[in..];
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// There is cleartext at the beginning already which we need to preserve.
@ -1536,11 +1454,11 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
c.partial_ciphertext_end = @intCast(saved_buf.len);
@memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf);
}
return out;
return .{ .len = out, .end = c.eof() };
}
/// Note that `first` usually overlaps with `c.partially_read_buffer`.
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) std.io.Reader.Status {
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
// There is cleartext at the beginning already which we need to preserve.
c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len);
@ -1555,7 +1473,7 @@ fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usi
std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
@memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
}
return out;
return .{ .len = out, .end = c.eof() };
}
fn limitedOverlapCopy(frag: []u8, in: usize) void {
@ -1577,9 +1495,6 @@ fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
}
}
const builtin = @import("builtin");
const native_endian = builtin.cpu.arch.endian();
inline fn big(x: anytype) @TypeOf(x) {
return switch (native_endian) {
.big => x,
@ -1958,7 +1873,3 @@ else
.AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
});
test {
_ = StreamInterface;
}

View File

@ -24,6 +24,12 @@ allocator: Allocator,
ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{},
ca_bundle_mutex: std.Thread.Mutex = .{},
/// Used both for the reader and writer buffers.
tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
/// If non-null, ssl secrets are logged to a stream. Creating such a stream
/// allows other processes with access to that stream to decrypt all
/// traffic over connections created with this `Client`.
ssl_key_logger: ?*std.io.BufferedWriter = null,
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
@ -31,6 +37,10 @@ next_https_rescan_certs: bool = true,
/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
/// Each `Connection` allocates this amount for the reader buffer.
read_buffer_size: usize,
/// Each `Connection` allocates this amount for the writer buffer.
write_buffer_size: usize,
/// If populated, all http traffic travels through this third party.
/// This field cannot be modified while the client has active connections.
@ -41,7 +51,7 @@ http_proxy: ?*Proxy = null,
/// Pointer to externally-owned memory.
https_proxy: ?*Proxy = null,
/// A set of linked lists of connections that can be reused.
/// A Least-Recently-Used cache of open connections to be reused.
pub const ConnectionPool = struct {
mutex: std.Thread.Mutex = .{},
/// Open connections that are currently in use.
@ -58,8 +68,10 @@ pub const ConnectionPool = struct {
protocol: Connection.Protocol,
};
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// Finds and acquires a connection from the connection pool matching the criteria.
/// If no connection is found, null is returned.
///
/// Threadsafe.
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
pool.mutex.lock();
defer pool.mutex.unlock();
@ -96,21 +108,21 @@ pub const ConnectionPool = struct {
return pool.acquireUnsafe(connection);
}
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// Tries to release a connection back to the connection pool.
/// If the connection is marked as closing, it will be closed instead.
///
/// The allocator must be the owner of all nodes in this pool.
/// The allocator must be the owner of all resources associated with the connection.
/// `allocator` must be the same one used to create `connection`.
///
/// Threadsafe.
pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
if (connection.closing) return connection.destroy(allocator);
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.remove(&connection.pool_node);
if (connection.closing or pool.free_size == 0) {
connection.close(allocator);
return allocator.destroy(connection);
}
if (pool.free_size == 0) return connection.destroy(allocator);
if (pool.free_len >= pool.free_size) {
const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?);
@ -138,9 +150,11 @@ pub const ConnectionPool = struct {
pool.used.append(&connection.pool_node);
}
/// Resizes the connection pool. This function is threadsafe.
/// Resizes the connection pool.
///
/// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size.
///
/// Threadsafe.
pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void {
pool.mutex.lock();
defer pool.mutex.unlock();
@ -158,9 +172,11 @@ pub const ConnectionPool = struct {
pool.free_size = new_size;
}
/// Frees the connection pool and closes all connections within. This function is threadsafe.
/// Frees the connection pool and closes all connections within.
///
/// All future operations on the connection pool will deadlock.
///
/// Threadsafe.
pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void {
pool.mutex.lock();
@ -184,160 +200,212 @@ pub const ConnectionPool = struct {
}
};
/// An interface to either a plain or TLS connection.
pub const Connection = struct {
client: *Client,
stream: net.Stream,
/// Populated when protocol is TLS; this is the writer given to the TLS
/// client, which writes directly to `stream`, unbuffered.
stream_writer: std.io.BufferedWriter,
/// undefined unless protocol is tls.
tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
/// HTTP protocol from client to server.
/// This either goes directly to `stream`, or to a TLS client.
writer: std.io.BufferedWriter,
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
pool_node: std.DoublyLinkedList.Node,
/// The protocol that this connection is using.
protocol: Protocol,
/// The host that this connection is connected to.
host: []u8,
/// The port that this connection is connected to.
port: u16,
/// Whether this connection is proxied and is not directly connected.
proxied: bool = false,
/// Whether this connection is closing when we're done with it.
closing: bool = false,
read_start: BufferSize = 0,
read_end: BufferSize = 0,
read_buf: [buffer_size]u8,
write_buffer: [buffer_size]u8,
writer: std.io.BufferedWriter,
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
const BufferSize = std.math.IntFittingRange(0, buffer_size);
host_len: u8,
proxied: bool,
closing: bool,
protocol: Protocol,
pub const Protocol = enum { plain, tls };
pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
// https://github.com/ziglang/zig/issues/2473
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
const Plain = struct {
/// Data from `Connection.stream`.
reader: std.io.BufferedReader,
connection: Connection,
switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
};
}
pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
return conn.readvDirectTls(buffers);
fn create(
client: *Client,
remote_host: []const u8,
port: u16,
stream: net.Stream,
) error{OutOfMemory}!*Connection {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len];
const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size];
const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size];
assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
@memcpy(host_buffer, remote_host);
const plain: *Plain = @ptrCast(base);
plain.* = .{
.connection = .{
.client = client,
.stream = stream,
.writer = stream.writer().buffered(socket_write_buffer),
.pool_node = .{},
.port = port,
.proxied = false,
.closing = false,
.protocol = .plain,
},
.reader = undefined,
};
plain.reader.init(stream.reader(), socket_read_buffer);
}
return conn.stream.readv(buffers) catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
/// Refills the read buffer with data from the connection.
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;
var iovecs = [1]std.posix.iovec{
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
conn.read_end = @intCast(nread);
}
/// Returns the current slice of buffered data.
pub fn peek(conn: *Connection) []const u8 {
return conn.read_buf[conn.read_start..conn.read_end];
}
/// Discards the given number of bytes from the read buffer.
pub fn drop(conn: *Connection, num: BufferSize) void {
conn.read_start += num;
}
/// Reads data from the connection into the given buffer.
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len;
if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
conn.read_start += @intCast(available_buffer);
return available_buffer;
} else if (available_read > 0) { // fully read buffered data
@memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
conn.read_start += available_read;
return available_read;
fn destroy(plain: *Plain) void {
const c = &plain.connection;
const gpa = c.client.allocator;
const base: [*]u8 = @ptrCast(plain);
gpa.free(base[0..allocLen(c.client, c.host_len)]);
}
var iovecs = [2]std.posix.iovec{
.{ .base = buffer.ptr, .len = buffer.len },
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread > buffer.len) {
conn.read_start = 0;
conn.read_end = @intCast(nread - buffer.len);
return buffer.len;
fn allocLen(client: *Client, host_len: usize) usize {
return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size;
}
return nread;
}
pub const ReadError = error{
TlsFailure,
TlsAlert,
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
EndOfStream,
fn host(plain: *Plain) []u8 {
const base: [*]u8 = @ptrCast(plain);
return base[@sizeOf(Plain)..][0..plain.connection.host_len];
}
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
const Tls = struct {
/// Data from `client` to `Connection.stream`.
writer: std.io.BufferedWriter,
/// Data from `Connection.stream` to `client`.
reader: std.io.BufferedReader,
client: std.crypto.tls.Client,
connection: Connection,
pub fn reader(conn: *Connection) Reader {
return .{ .context = conn };
}
fn create(
client: *Client,
remote_host: []const u8,
port: u16,
stream: net.Stream,
) error{ OutOfMemory, TlsInitializationFailed }!*Tls {
const gpa = client.allocator;
const alloc_len = allocLen(client, remote_host.len);
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
@memcpy(host_buffer, remote_host);
const tls: *Tls = @ptrCast(base);
tls.* = .{
.connection = .{
.client = client,
.stream = stream,
.writer = tls.client.writer().buffered(socket_write_buffer),
.pool_node = .{},
.port = port,
.proxied = false,
.closing = false,
.protocol = .tls,
},
.writer = stream.writer().buffered(tls_write_buffer),
.reader = undefined,
.client = undefined,
};
tls.reader.init(stream.reader(), tls_read_buffer);
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
tls.client.init(&tls.reader, &tls.writer, .{
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_logger = client.ssl_key_logger,
}) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
tls.client.allow_truncation_attacks = true;
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub fn close(conn: *Connection, allocator: Allocator) void {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
// try to cleanly close the TLS connection, for any server that cares.
_ = conn.tls_client.writeEnd("", true) catch {};
if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close();
allocator.destroy(conn.tls_client);
return tls;
}
conn.stream.close();
allocator.free(conn.host);
fn destroy(tls: *Tls, gpa: Allocator) void {
const c = &tls.connection;
const base: [*]u8 = @ptrCast(tls);
gpa.free(base[0..allocLen(c.client, c.host_len)]);
}
fn allocLen(client: *Client, host_len: usize) usize {
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size;
}
fn host(tls: *Tls) []u8 {
const base: [*]u8 = @ptrCast(tls);
return base[@sizeOf(Tls)..][0..tls.connection.host_len];
}
};
fn host(c: *Connection) []u8 {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
return tls.host();
},
.plain => {
const plain: *Plain = @fieldParentPtr("connection", c);
return plain.host();
},
};
}
/// This is either data from `stream`, or `Tls.client`.
fn reader(c: *Connection) *std.io.BufferedReader {
return switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
return &tls.client.reader;
},
.plain => {
const plain: *Plain = @fieldParentPtr("connection", c);
return &plain.reader;
},
};
}
/// If this is called without calling `flush` or `end`, data will be
/// dropped unsent.
pub fn destroy(c: *Connection) void {
c.stream.close();
switch (c.protocol) {
.tls => {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
tls.destroy();
},
.plain => {
const plain: *Plain = @fieldParentPtr("connection", c);
plain.destroy();
},
}
}
pub fn flush(c: *Connection) anyerror!void {
try c.writer.flush();
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
try tls.writer.flush();
}
}
/// If the connection is a TLS connection, sends the close_notify alert.
///
/// Flushes all buffers.
pub fn end(c: *Connection) anyerror!void {
try c.writer.flush();
if (c.protocol == .tls) {
if (disable_tls) unreachable;
const tls: *Tls = @fieldParentPtr("connection", c);
try tls.client.end();
try tls.writer.flush();
}
}
};
@ -350,10 +418,10 @@ pub const RequestTransfer = union(enum) {
/// The decompressor for response messages.
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader);
pub const DeflateDecompressor = std.compress.zlib.Decompressor;
pub const GzipDecompressor = std.compress.gzip.Decompressor;
// https://github.com/ziglang/zig/issues/18937
//pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
//pub const ZstdDecompressor = std.compress.zstd.DecompressStream(.{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
@ -617,9 +685,6 @@ pub const Response = struct {
}
};
/// A HTTP request that has been sent.
///
/// Order of operations: open -> send[ -> write -> finish] -> wait -> read
pub const Request = struct {
uri: Uri,
client: *Client,
@ -1300,24 +1365,34 @@ pub const basic_authorization = struct {
}
};
pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
pub const ConnectTcpError = Allocator.Error || error{
ConnectionRefused,
NetworkUnreachable,
ConnectionTimedOut,
ConnectionResetByPeer,
TemporaryNameServerFailure,
NameServerFailure,
UnknownHostName,
HostLacksNetworkAddresses,
UnexpectedConnectFailure,
TlsInitializationFailed,
};
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
/// Reuses a `Connection` if one matching `host` and `port` is already open.
///
/// This function is threadsafe.
pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
/// Threadsafe.
pub fn connectTcp(
client: *Client,
host: []const u8,
port: u16,
protocol: Connection.Protocol,
) ConnectTcpError!*Connection {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
.protocol = protocol,
})) |conn| return conn;
if (disable_tls and protocol == .tls)
return error.TlsInitializationFailed;
const conn = try client.allocator.create(Connection);
errdefer client.allocator.destroy(conn);
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
error.NetworkUnreachable => return error.NetworkUnreachable,
@ -1331,77 +1406,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
};
errdefer stream.close();
conn.* = .{
.stream = stream,
.stream_writer = undefined,
.tls_client = undefined,
.read_buf = undefined,
.write_buffer = undefined,
.writer = undefined, // populated below
.protocol = protocol,
.host = try client.allocator.dupe(u8, host),
.port = port,
.pool_node = .{},
};
errdefer client.allocator.free(conn.host);
switch (protocol) {
.tls => {
if (disable_tls) unreachable;
const tls_client = try client.allocator.create(std.crypto.tls.Client);
errdefer client.allocator.destroy(tls_client);
const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: {
const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) {
error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null,
error.OutOfMemory => return error.OutOfMemory,
};
defer client.allocator.free(ssl_key_log_path);
break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{
.truncate = false,
.mode = switch (builtin.os.tag) {
.windows, .wasi => 0,
else => 0o600,
},
}) catch null;
} else null;
errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close();
conn.stream_writer = .{
.unbuffered_writer = stream.writer(),
.buffer = &.{},
};
tls_client.* = std.crypto.tls.Client.init(stream, &conn.stream_writer, .{
.host = .{ .explicit = host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log_file = ssl_key_log_file,
}) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
tls_client.allow_truncation_attacks = true;
conn.writer = .{
.unbuffered_writer = tls_client.writer(),
.buffer = &conn.write_buffer,
};
conn.tls_client = tls_client;
if (disable_tls) return error.TlsInitializationFailed;
const tc = try Connection.Tls.create(client, host, port, stream);
client.connection_pool.addUsed(&tc.connection);
return &tc.connection;
},
.plain => {
conn.writer = .{
.unbuffered_writer = stream.writer(),
.buffer = &conn.write_buffer,
};
const pc = try Connection.Plain.create(client, host, port, stream);
client.connection_pool.addUsed(&pc.connection);
return &pc.connection;
},
}
client.connection_pool.addUsed(conn);
return conn;
}
pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError;
@ -1662,16 +1679,17 @@ pub fn open(
var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer);
const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
if (protocol == .tls) {
if (disable_tls) unreachable;
if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
client.ca_bundle_mutex.lock();
defer client.ca_bundle_mutex.unlock();
client.ca_bundle_mutex.lock();
defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
client.ca_bundle.rescan(client.allocator) catch
return error.CertificateBundleLoadFailure;
@atomicStore(bool, &client.next_https_rescan_certs, false, .release);
if (client.next_https_rescan_certs) {
client.ca_bundle.rescan(client.allocator) catch
return error.CertificateBundleLoadFailure;
@atomicStore(bool, &client.next_https_rescan_certs, false, .release);
}
}
}

View File

@ -1,18 +1,25 @@
//! Blocking HTTP server implementation.
//! Handles a single connection's lifecycle.
connection: net.Server.Connection,
const std = @import("../std.zig");
const http = std.http;
const mem = std.mem;
const net = std.net;
const Uri = std.Uri;
const assert = std.debug.assert;
const testing = std.testing;
const Server = @This();
/// The reader's buffer must be large enough to store the client's entire HTTP
/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`.
in: *std.io.BufferedReader,
out: *std.io.BufferedWriter,
/// Keeps track of whether the Server is ready to accept a new request on the
/// same connection, and makes invalid API usage cause assertion failures
/// rather than HTTP protocol violations.
state: State,
/// User-provided buffer that must outlive this Server.
/// Used to store the client's entire HTTP header.
read_buffer: []u8,
/// Amount of available data inside read_buffer.
read_buffer_len: usize,
/// Index into `read_buffer` of the first byte of the next HTTP request.
next_request_start: usize,
in_err: anyerror,
pub const State = enum {
/// The connection is available to be used for the first time, or reused.
@ -31,14 +38,13 @@ pub const State = enum {
/// Initialize an HTTP server that can respond to multiple requests on the same
/// connection.
///
/// The returned `Server` is ready for `receiveHead` to be called.
pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server {
pub fn init(in: *std.io.BufferedReader, out: *std.io.BufferedWriter) Server {
return .{
.connection = connection,
.in = in,
.out = out,
.state = .ready,
.read_buffer = read_buffer,
.read_buffer_len = 0,
.next_request_start = 0,
};
}
@ -48,78 +54,55 @@ pub const ReceiveHeadError = error{
/// before closing the connection.
HttpHeadersOversize,
/// Client sent headers that did not conform to the HTTP protocol.
/// `in_err` is populated with a `Request.Head.ParseError`.
HttpHeadersInvalid,
/// A low level I/O error occurred trying to read the headers.
HttpHeadersUnreadable,
/// Partial HTTP request was received but the connection was closed before
/// fully receiving the headers.
HttpRequestTruncated,
/// The client sent 0 bytes of headers before closing the stream.
/// In other words, a keep-alive connection was finally closed.
HttpConnectionClosing,
/// Error occurred reading from `in`; `in_err` is populated.
ReadFailure,
};
/// The header bytes reference the read buffer that Server was initialized with
/// and remain alive until the next call to receiveHead.
/// The header bytes reference the internal storage of `in`, which are
/// invalidated with the next call to `receiveHead`.
pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
assert(s.state == .ready);
s.state = .received_head;
errdefer s.state = .receiving_head;
// In case of a reused connection, move the next request's bytes to the
// beginning of the buffer.
if (s.next_request_start > 0) {
if (s.read_buffer_len > s.next_request_start) {
rebase(s, 0);
} else {
s.read_buffer_len = 0;
}
}
const in = &s.in;
var hp: http.HeadParser = .{};
if (s.read_buffer_len > 0) {
const bytes = s.read_buffer[0..s.read_buffer_len];
const end = hp.feed(bytes);
if (hp.state == .finished)
return finishReceivingHead(s, end);
}
var head_end: usize = 0;
while (true) {
const buf = s.read_buffer[s.read_buffer_len..];
if (buf.len == 0)
return error.HttpHeadersOversize;
const read_n = s.connection.stream.read(buf) catch
return error.HttpHeadersUnreadable;
if (read_n == 0) {
if (s.read_buffer_len > 0) {
return error.HttpRequestTruncated;
} else {
return error.HttpConnectionClosing;
}
}
s.read_buffer_len += read_n;
const bytes = buf[0..read_n];
const end = hp.feed(bytes);
if (hp.state == .finished)
return finishReceivingHead(s, s.read_buffer_len - bytes.len + end);
if (head_end >= in.bufferContents().len) return error.HttpHeadersOversize;
const buf = (in.peekGreedy(head_end + 1) catch |err| {
s.in_err = err;
return error.ReadFailure;
}) orelse switch (head_end) {
0 => return error.HttpConnectionClosing,
else => return error.HttpRequestTruncated,
};
head_end += hp.feed(buf[head_end..]);
if (hp.state == .finished) return .{
.server = s,
.head_end = head_end,
.head = Request.Head.parse(buf[0..head_end]) catch |err| {
s.in_err = err;
return error.HttpHeadersInvalid;
},
.reader_state = undefined,
.write_error = undefined,
};
}
}
fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request {
return .{
.server = s,
.head_end = head_end,
.head = Request.Head.parse(s.read_buffer[0..head_end]) catch
return error.HttpHeadersInvalid,
.reader_state = undefined,
.write_error = undefined,
};
}
pub const Request = struct {
server: *Server,
/// Index into Server's read_buffer.
/// Index into `Server.in` internal buffer.
head_end: usize,
head: Head,
reader_state: union {
@ -299,7 +282,7 @@ pub const Request = struct {
};
pub fn iterateHeaders(r: *Request) http.HeaderIterator {
return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]);
return http.HeaderIterator.init(r.in.bufferContents()[0..r.head_end]);
}
test iterateHeaders {
@ -312,13 +295,14 @@ pub const Request = struct {
var read_buffer: [500]u8 = undefined;
@memcpy(read_buffer[0..request_bytes.len], request_bytes);
var br: std.io.BufferedReader = undefined;
br.initFixed(&read_buffer);
var server: Server = .{
.connection = undefined,
.in = &br,
.out = undefined,
.state = .ready,
.read_buffer = &read_buffer,
.read_buffer_len = request_bytes.len,
.next_request_start = 0,
.in_err = undefined,
};
var request: Request = .{
@ -1158,13 +1142,3 @@ fn rebase(s: *Server, index: usize) void {
}
s.read_buffer_len = index + leftover.len;
}
const std = @import("../std.zig");
const http = std.http;
const mem = std.mem;
const net = std.net;
const Uri = std.Uri;
const assert = std.debug.assert;
const testing = std.testing;
const Server = @This();

View File

@ -94,6 +94,11 @@ pub fn storageBuffer(br: *BufferedReader) []u8 {
return storage.buffer;
}
pub fn bufferContents(br: *BufferedReader) []u8 {
const storage = &br.storage;
return storage.buffer[br.seek..storage.end];
}
/// Although `BufferedReader` can easily satisfy the `Reader` interface, it's
/// generally more practical to pass a `BufferedReader` instance itself around,
/// since it will result in fewer calls across vtable boundaries.
@ -159,31 +164,69 @@ pub fn seekForwardBy(br: *BufferedReader, seek_by: u64) anyerror!void {
/// is returned instead.
///
/// See also:
/// * `peekAll`
/// * `peekGreedy`
/// * `toss`
pub fn peek(br: *BufferedReader, n: usize) anyerror![]u8 {
return (try br.peekAll(n))[0..n];
return (try br.peekGreedy(n))[0..n];
}
/// Returns the next buffered bytes from `unbuffered_reader`, after filling the buffer
/// with at least `n` bytes.
/// Returns the next `n` bytes from `unbuffered_reader`, filling the buffer as
/// necessary.
///
/// Invalidates previously returned values from `peek`.
///
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
/// least as big as `n`.
///
/// If there are fewer than `n` bytes left in the stream, `null` is returned
/// instead.
///
/// See also:
/// * `peekGreedy`
/// * `toss`
pub fn peek2(br: *BufferedReader, n: usize) anyerror!?[]u8 {
if (try br.peekGreedy(n)) |buf| return buf[0..n];
return null;
}
/// Returns all the next buffered bytes from `unbuffered_reader`, after filling
/// the buffer to ensure it contains at least `n` bytes.
///
/// Invalidates previously returned values from `peek` and `peekGreedy`.
///
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
/// least as big as `n`.
///
/// If there are fewer than `n` bytes left in the stream, `error.EndOfStream`
/// is returned instead.
///
/// See also:
/// * `peek`
/// * `toss`
pub fn peekAll(br: *BufferedReader, n: usize) anyerror![]u8 {
const storage = &br.storage;
assert(n <= storage.buffer.len);
try br.fill(n);
return storage.buffer[br.seek..storage.end];
pub fn peekGreedy(br: *BufferedReader, n: usize) anyerror![]u8 {
assert(n <= br.storage.buffer.len);
if (try br.fill(n)) return br.bufferContents();
return error.EndOfStream;
}
/// Returns all the next buffered bytes from `unbuffered_reader`, after filling
/// the buffer to ensure it contains at least `n` bytes.
///
/// Invalidates previously returned values from `peek` and `peekGreedy`.
///
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
/// least as big as `n`.
///
/// If there are fewer than `n` bytes left in the stream, `null` is returned
/// instead.
///
/// See also:
/// * `peek`
/// * `toss`
pub fn peekGreedy2(br: *BufferedReader, n: usize) anyerror!?[]u8 {
assert(n <= br.storage.buffer.len);
if (try br.fill(n)) return br.bufferContents();
return null;
}
/// Skips the next `n` bytes from the stream, advancing the seek position. This
@ -505,17 +548,17 @@ pub fn discardDelimiterInclusive(br: *BufferedReader, delimiter: u8) anyerror!vo
/// Fills the buffer such that it contains at least `n` bytes, without
/// advancing the seek position.
///
/// Returns `error.EndOfStream` if there are fewer than `n` bytes remaining.
/// Returns `false` if and only if there are fewer than `n` bytes remaining.
///
/// Asserts buffer capacity is at least `n`.
pub fn fill(br: *BufferedReader, n: usize) anyerror!void {
pub fn fill(br: *BufferedReader, n: usize) anyerror!bool {
const storage = &br.storage;
assert(n <= storage.buffer.len);
const buffer = storage.buffer[0..storage.end];
const seek = br.seek;
if (seek + n <= buffer.len) {
@branchHint(.likely);
return;
return true;
}
const remainder = buffer[seek..];
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
@ -523,8 +566,8 @@ pub fn fill(br: *BufferedReader, n: usize) anyerror!void {
br.seek = 0;
while (true) {
const status = try br.unbuffered_reader.read(storage, .unlimited);
if (n <= storage.end) return;
if (status.end) return error.EndOfStream;
if (n <= storage.end) return true;
if (status.end) return false;
}
}
@ -535,7 +578,8 @@ pub fn takeByte(br: *BufferedReader) anyerror!u8 {
const seek = br.seek;
if (seek >= buffer.len) {
@branchHint(.unlikely);
try br.fill(1);
const filled = try fill(br, 1);
if (!filled) return error.EndOfStream;
}
br.seek = seek + 1;
return buffer[seek];
@ -603,7 +647,7 @@ fn takeMultipleOf7Leb128(br: *BufferedReader, comptime Result: type) anyerror!Re
var result: UnsignedResult = 0;
var fits = true;
while (true) {
const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekAll(1));
const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekGreedy(1));
for (buffer, 1..) |byte, len| {
if (remaining_bits > 0) {
result = @shlExact(@as(UnsignedResult, byte.bits), result_info.bits - 7) |
@ -639,7 +683,7 @@ test peek {
return error.Unimplemented;
}
test peekAll {
test peekGreedy {
return error.Unimplemented;
}