mirror of
https://github.com/ziglang/zig.git
synced 2026-02-12 20:37:54 +00:00
std: start converting networking stuff to new reader/writer
This commit is contained in:
parent
c872a9fd49
commit
20a784f713
@ -155,38 +155,35 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type {
|
||||
}
|
||||
|
||||
// Read a DER-encoded integer.
|
||||
fn readDerInt(out: []u8, reader: anytype) EncodingError!void {
|
||||
var buf: [2]u8 = undefined;
|
||||
_ = reader.readNoEof(&buf) catch return error.InvalidEncoding;
|
||||
// Asserts `br` has storage capacity >= 2.
|
||||
fn readDerInt(out: []u8, br: *std.io.BufferedReader) EncodingError!void {
|
||||
const buf = br.take(2) catch return error.InvalidEncoding;
|
||||
if (buf[0] != 0x02) return error.InvalidEncoding;
|
||||
var expected_len = @as(usize, buf[1]);
|
||||
var expected_len: usize = buf[1];
|
||||
if (expected_len == 0 or expected_len > 1 + out.len) return error.InvalidEncoding;
|
||||
var has_top_bit = false;
|
||||
if (expected_len == 1 + out.len) {
|
||||
if ((reader.readByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding;
|
||||
if ((br.takeByte() catch return error.InvalidEncoding) != 0) return error.InvalidEncoding;
|
||||
expected_len -= 1;
|
||||
has_top_bit = true;
|
||||
}
|
||||
const out_slice = out[out.len - expected_len ..];
|
||||
reader.readNoEof(out_slice) catch return error.InvalidEncoding;
|
||||
br.read(out_slice) catch return error.InvalidEncoding;
|
||||
if (@intFromBool(has_top_bit) != out[0] >> 7) return error.InvalidEncoding;
|
||||
}
|
||||
|
||||
/// Create a signature from a DER representation.
|
||||
/// Returns InvalidEncoding if the DER encoding is invalid.
|
||||
pub fn fromDer(der: []const u8) EncodingError!Signature {
|
||||
if (der.len < 2) return error.InvalidEncoding;
|
||||
var br: std.io.BufferedReader = undefined;
|
||||
br.initFixed(der);
|
||||
const buf = br.take(2) catch return error.InvalidEncoding;
|
||||
if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) return error.InvalidEncoding;
|
||||
var sig: Signature = mem.zeroInit(Signature, .{});
|
||||
var fb: std.io.FixedBufferStream = .{ .buffer = der };
|
||||
const reader = fb.reader();
|
||||
var buf: [2]u8 = undefined;
|
||||
_ = reader.readNoEof(&buf) catch return error.InvalidEncoding;
|
||||
if (buf[0] != 0x30 or @as(usize, buf[1]) + 2 != der.len) {
|
||||
return error.InvalidEncoding;
|
||||
}
|
||||
try readDerInt(&sig.r, reader);
|
||||
try readDerInt(&sig.s, reader);
|
||||
if (fb.getPos() catch unreachable != der.len) return error.InvalidEncoding;
|
||||
|
||||
try readDerInt(&sig.r, &br);
|
||||
try readDerInt(&sig.s, &br);
|
||||
if (br.seek != der.len) return error.InvalidEncoding;
|
||||
return sig;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
const builtin = @import("builtin");
|
||||
const native_endian = builtin.cpu.arch.endian();
|
||||
|
||||
const std = @import("../../std.zig");
|
||||
const tls = std.crypto.tls;
|
||||
const Client = @This();
|
||||
@ -13,18 +16,44 @@ const hkdfExpandLabel = tls.hkdfExpandLabel;
|
||||
const int = tls.int;
|
||||
const array = tls.array;
|
||||
|
||||
/// The encrypted stream from the server to the client. Bytes are pulled from
|
||||
/// here via `reader`.
|
||||
///
|
||||
/// The buffer is asserted to have capacity at least `min_buffer_len`.
|
||||
///
|
||||
/// The size is enough to contain exactly one TLSCiphertext record.
|
||||
/// This buffer is segmented into four parts:
|
||||
/// 0. unused
|
||||
/// 1. cleartext
|
||||
/// 2. ciphertext
|
||||
/// 3. unused
|
||||
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
|
||||
/// `partial_ciphertext_end` describe the span of the segments.
|
||||
input: *std.io.BufferedReader,
|
||||
/// The encrypted stream from the client to the server. Bytes are pushed here
|
||||
/// via `writer`.
|
||||
///
|
||||
/// The buffer is asserted to have capacity at least `min_buffer_len`.
|
||||
output: *std.io.BufferedWriter,
|
||||
/// Cleartext received from the server here.
|
||||
///
|
||||
/// Its buffer aliases the buffer of `input`.
|
||||
reader: std.io.BufferedReader,
|
||||
/// Populated under various error conditions.
|
||||
diagnostics: Diagnostics,
|
||||
|
||||
tls_version: tls.ProtocolVersion,
|
||||
read_seq: u64,
|
||||
write_seq: u64,
|
||||
/// The starting index of cleartext bytes inside `partially_read_buffer`.
|
||||
/// The starting index of cleartext bytes inside the input buffer.
|
||||
partial_cleartext_idx: u15,
|
||||
/// The ending index of cleartext bytes inside `partially_read_buffer` as well
|
||||
/// The ending index of cleartext bytes inside the input buffer as well
|
||||
/// as the starting index of ciphertext bytes.
|
||||
partial_ciphertext_idx: u15,
|
||||
/// The ending index of ciphertext bytes inside `partially_read_buffer`.
|
||||
/// The ending index of ciphertext bytes inside the input buffer.
|
||||
partial_ciphertext_end: u15,
|
||||
/// When this is true, the stream may still not be at the end because there
|
||||
/// may be data in `partially_read_buffer`.
|
||||
/// may be data in the input buffer.
|
||||
received_close_notify: bool,
|
||||
/// By default, reaching the end-of-stream when reading from the server will
|
||||
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
|
||||
@ -35,24 +64,40 @@ received_close_notify: bool,
|
||||
/// the amount of data expected, such as HTTP with the Content-Length header.
|
||||
allow_truncation_attacks: bool,
|
||||
application_cipher: tls.ApplicationCipher,
|
||||
/// The size is enough to contain exactly one TLSCiphertext record.
|
||||
/// This buffer is segmented into four parts:
|
||||
/// 0. unused
|
||||
/// 1. cleartext
|
||||
/// 2. ciphertext
|
||||
/// 3. unused
|
||||
/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
|
||||
/// `partial_ciphertext_end` describe the span of the segments.
|
||||
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
|
||||
/// Encrypted bytes sent to the server here.
|
||||
output: *std.io.BufferedWriter,
|
||||
/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
|
||||
/// programs with access to that file to decrypt all traffic over this connection.
|
||||
ssl_key_log: ?struct {
|
||||
/// If non-null, ssl secrets are logged to a stream. Creating such a log file
|
||||
/// allows other programs with access to that file to decrypt all traffic over
|
||||
/// this connection.
|
||||
ssl_key_log: ?*SslKeyLog,
|
||||
|
||||
pub const Diagnostics = union {
|
||||
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
|
||||
err: anyerror,
|
||||
/// Populated on `error.TlsAlert`.
|
||||
///
|
||||
/// If this isn't a error alert, then it's a closure alert, which makes
|
||||
/// no sense in a handshake.
|
||||
alert: tls.AlertDescription,
|
||||
|
||||
fn wrapWrite(d: *Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
|
||||
returned catch |err| {
|
||||
d.* = .{ .err = err };
|
||||
return error.WriteFailure;
|
||||
};
|
||||
}
|
||||
|
||||
fn wrapRead(d: *Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
|
||||
returned catch |err| {
|
||||
d.* = .{ .err = err };
|
||||
return error.ReadFailure;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const SslKeyLog = struct {
|
||||
client_key_seq: u64,
|
||||
server_key_seq: u64,
|
||||
client_random: [32]u8,
|
||||
file: std.fs.File,
|
||||
writer: *std.io.BufferedWriter,
|
||||
|
||||
fn clientCounter(key_log: *@This()) u64 {
|
||||
defer key_log.client_key_seq += 1;
|
||||
@ -63,31 +108,12 @@ ssl_key_log: ?struct {
|
||||
defer key_log.server_key_seq += 1;
|
||||
return key_log.server_key_seq;
|
||||
}
|
||||
},
|
||||
|
||||
/// This is an example of the type that is needed by the read and write
|
||||
/// functions. It can have any fields but it must at least have these
|
||||
/// functions.
|
||||
///
|
||||
/// Note that `std.net.Stream` conforms to this interface.
|
||||
///
|
||||
/// This declaration serves as documentation only.
|
||||
pub const StreamInterface = struct {
|
||||
/// Can be any error set.
|
||||
pub const ReadError = error{};
|
||||
|
||||
/// Returns the number of bytes read. The number read may be less than the
|
||||
/// buffer space provided. End-of-stream is indicated by a return value of 0.
|
||||
///
|
||||
/// The `iovecs` parameter is mutable because so that function may to
|
||||
/// mutate the fields in order to handle partial reads from the underlying
|
||||
/// stream layer.
|
||||
pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize {
|
||||
_ = .{ this, iovecs };
|
||||
@panic("unimplemented");
|
||||
}
|
||||
};
|
||||
|
||||
/// The `std.io.BufferedReader` and `std.io.BufferedWriter` supplied to `init`
|
||||
/// each require a buffer capacity at least this amount.
|
||||
pub const min_buffer_len = tls.max_ciphertext_record_len;
|
||||
|
||||
pub const Options = struct {
|
||||
/// How to perform host verification of server certificates.
|
||||
host: union(enum) {
|
||||
@ -109,39 +135,11 @@ pub const Options = struct {
|
||||
/// Verify that the server certificate is authorized by a given ca bundle.
|
||||
bundle: Certificate.Bundle,
|
||||
},
|
||||
/// If non-null, ssl secrets are logged to this file. Creating such a log file allows
|
||||
/// If non-null, ssl secrets are logged to this stream. Creating such a log file allows
|
||||
/// other programs with access to that file to decrypt all traffic over this connection.
|
||||
/// TODO `std.crypto` should have no dependencies on `std.fs`.
|
||||
ssl_key_log_file: ?std.fs.File = null,
|
||||
diagnostics: ?*Diagnostics = null,
|
||||
|
||||
pub const Diagnostics = union {
|
||||
/// Populated on `error.WriteFailure` and `error.ReadFailure`.
|
||||
err: anyerror,
|
||||
/// Populated on `error.TlsAlert`.
|
||||
///
|
||||
/// If this isn't a error alert, then it's a closure alert, which makes
|
||||
/// no sense in a handshake.
|
||||
alert: tls.AlertDescription,
|
||||
};
|
||||
ssl_key_log: ?*std.io.BufferedWriter = null,
|
||||
};
|
||||
|
||||
/// TODO I wish this could be a method of Diagnostics
|
||||
fn wrapWrite(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{WriteFailure}!void {
|
||||
returned catch |err| {
|
||||
if (opt_diags) |diags| diags.* = .{ .err = err };
|
||||
return error.WriteFailure;
|
||||
};
|
||||
}
|
||||
|
||||
/// TODO I wish this could be a method of Diagnostics
|
||||
fn wrapRead(opt_diags: ?*Options.Diagnostics, returned: anyerror!void) error{ReadFailure}!void {
|
||||
returned catch |err| {
|
||||
if (opt_diags) |diags| diags.* = .{ .err = err };
|
||||
return error.ReadFailure;
|
||||
};
|
||||
}
|
||||
|
||||
const InitError = error{
|
||||
//OutOfMemory,
|
||||
WriteFailure,
|
||||
@ -193,12 +191,21 @@ const InitError = error{
|
||||
WeakPublicKey,
|
||||
};
|
||||
|
||||
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `input`, which
|
||||
/// must conform to `StreamInterface`.
|
||||
/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session.
|
||||
///
|
||||
/// `host` is only borrowed during this function call.
|
||||
pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) InitError!Client {
|
||||
const diags = options.diagnostics;
|
||||
///
|
||||
/// Both `input` and `output` are asserted to have buffer capacity at least
|
||||
/// `min_buffer_len`.
|
||||
pub fn init(
|
||||
client: *Client,
|
||||
input: *std.io.BufferedReader,
|
||||
output: *std.io.BufferedWriter,
|
||||
options: Options,
|
||||
) InitError!void {
|
||||
assert(input.storage.buffer.len >= min_buffer_len);
|
||||
assert(output.buffer.len >= min_buffer_len);
|
||||
const diags = &client.diagnostics;
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
@ -291,7 +298,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
|
||||
{
|
||||
var iovecs: [2][]const u8 = .{ cleartext_header, host };
|
||||
try wrapWrite(diags, output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
|
||||
try diags.wrapWrite(output.writevAll(iovecs[0..if (host.len == 0) 1 else 2]));
|
||||
}
|
||||
|
||||
var tls_version: tls.ProtocolVersion = undefined;
|
||||
@ -343,12 +350,12 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
|
||||
var d: tls.Decoder = .{ .buf = &handshake_buffer };
|
||||
fragment: while (true) {
|
||||
try wrapRead(diags, d.readAtLeastOurAmt(input, tls.record_header_len));
|
||||
try diags.wrapRead(d.readAtLeastOurAmt(input, tls.record_header_len));
|
||||
const record_header = d.buf[d.idx..][0..tls.record_header_len];
|
||||
const record_ct = d.decode(tls.ContentType);
|
||||
d.skip(2); // legacy_version
|
||||
const record_len = d.decode(u16);
|
||||
try wrapRead(diags, d.readAtLeast(input, record_len));
|
||||
try diags.wrapRead(d.readAtLeast(input, record_len));
|
||||
var record_decoder = try d.sub(record_len);
|
||||
var ctd, const ct = content: switch (cipher_state) {
|
||||
.cleartext => .{ record_decoder, record_ct },
|
||||
@ -426,7 +433,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
const level = ctd.decode(tls.AlertLevel);
|
||||
const desc = ctd.decode(tls.AlertDescription);
|
||||
_ = level;
|
||||
if (diags) |x| x.* = .{ .alert = desc };
|
||||
diags.* = .{ .alert = desc };
|
||||
return error.TlsAlert;
|
||||
},
|
||||
.change_cipher_spec => {
|
||||
@ -768,7 +775,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
&client_change_cipher_spec_msg,
|
||||
&client_verify_msg,
|
||||
};
|
||||
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
|
||||
try diags.wrapWrite(output.writevAll(&all_msgs_vec));
|
||||
},
|
||||
}
|
||||
write_seq += 1;
|
||||
@ -833,7 +840,7 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
&client_change_cipher_spec_msg,
|
||||
&finished_msg,
|
||||
};
|
||||
try wrapWrite(diags, output.writevAll(&all_msgs_vec));
|
||||
try diags.wrapWrite(output.writevAll(&all_msgs_vec));
|
||||
|
||||
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
@ -865,7 +872,10 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
},
|
||||
};
|
||||
const leftover = d.rest();
|
||||
var client: Client = .{
|
||||
client.* = .{
|
||||
.input = input,
|
||||
.output = output,
|
||||
.reader = undefined,
|
||||
.tls_version = tls_version,
|
||||
.read_seq = switch (tls_version) {
|
||||
.tls_1_3 => 0,
|
||||
@ -883,7 +893,6 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
.received_close_notify = false,
|
||||
.allow_truncation_attacks = false,
|
||||
.application_cipher = app_cipher,
|
||||
.output = output,
|
||||
.partially_read_buffer = undefined,
|
||||
.ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
|
||||
.client_key_seq = key_seq,
|
||||
@ -893,7 +902,14 @@ pub fn init(input: anytype, output: *std.io.BufferedWriter, options: Options) In
|
||||
} else null,
|
||||
};
|
||||
@memcpy(client.partially_read_buffer[0..leftover.len], leftover);
|
||||
return client;
|
||||
client.reader.init(.{
|
||||
.context = client,
|
||||
.vtable = &.{
|
||||
.read = reader_read,
|
||||
.readv = reader_readv,
|
||||
},
|
||||
}, input.storage.buffer[0..0]);
|
||||
return;
|
||||
},
|
||||
else => return error.TlsUnexpectedMessage,
|
||||
}
|
||||
@ -919,81 +935,45 @@ pub fn writer(c: *Client) std.io.Writer {
|
||||
|
||||
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize {
|
||||
const c: *Client = @alignCast(@ptrCast(context));
|
||||
assert(data.len > 1 or splat > 0);
|
||||
return writeEnd(c, data[0], false);
|
||||
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
|
||||
const output = &c.output;
|
||||
const ciphertext_buf = try output.writableSlice(min_buffer_len);
|
||||
var total_clear: usize = 0;
|
||||
var ciphertext_end: usize = 0;
|
||||
for (sliced_data) |buf| {
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
|
||||
total_clear += prepared.cleartext_len;
|
||||
ciphertext_end += prepared.ciphertext_end;
|
||||
if (total_clear < buf.len) break;
|
||||
}
|
||||
output.advance(ciphertext_end);
|
||||
return total_clear;
|
||||
}
|
||||
|
||||
/// If `end` is true, then this function additionally sends a `close_notify`
|
||||
/// alert, which is necessary for the server to distinguish between a properly
|
||||
/// finished TLS session, or a truncation attack.
|
||||
pub fn writeAllEnd(c: *Client, bytes: []const u8, end: bool) anyerror!void {
|
||||
var index: usize = 0;
|
||||
while (index < bytes.len) {
|
||||
index += try c.writeEnd(bytes[index..], end);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
|
||||
/// If `end` is true, then this function additionally sends a `close_notify` alert,
|
||||
/// which is necessary for the server to distinguish between a properly finished
|
||||
/// TLS session, or a truncation attack.
|
||||
pub fn writeEnd(c: *Client, bytes: []const u8, end: bool) anyerror!usize {
|
||||
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
|
||||
var iovecs_buf: [6][]const u8 = undefined;
|
||||
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
|
||||
if (end) {
|
||||
prepared.iovec_end += prepareCiphertextRecord(
|
||||
c,
|
||||
iovecs_buf[prepared.iovec_end..],
|
||||
ciphertext_buf[prepared.ciphertext_end..],
|
||||
&tls.close_notify_alert,
|
||||
.alert,
|
||||
).iovec_end;
|
||||
}
|
||||
|
||||
const iovec_end = prepared.iovec_end;
|
||||
const overhead_len = prepared.overhead_len;
|
||||
|
||||
// Ideally we would call writev exactly once here, however, we must ensure
|
||||
// that we don't return with a record partially written.
|
||||
var i: usize = 0;
|
||||
var total_amt: usize = 0;
|
||||
while (true) {
|
||||
var amt = try c.output.writev(iovecs_buf[i..iovec_end]);
|
||||
while (amt >= iovecs_buf[i].len) {
|
||||
const encrypted_amt = iovecs_buf[i].len;
|
||||
total_amt += encrypted_amt - overhead_len;
|
||||
amt -= encrypted_amt;
|
||||
i += 1;
|
||||
// Rely on the property that iovecs delineate records, meaning that
|
||||
// if amt equals zero here, we have fortunately found ourselves
|
||||
// with a short read that aligns at the record boundary.
|
||||
if (i >= iovec_end) return total_amt;
|
||||
// We also cannot return on a vector boundary if the final close_notify is
|
||||
// not sent; otherwise the caller would not know to retry the call.
|
||||
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
|
||||
}
|
||||
iovecs_buf[i] = iovecs_buf[i][amt..];
|
||||
}
|
||||
/// Sends a `close_notify` alert, which is necessary for the server to
|
||||
/// distinguish between a properly finished TLS session, or a truncation
|
||||
/// attack.
|
||||
pub fn end(c: *Client) anyerror!void {
|
||||
const output = &c.output;
|
||||
const ciphertext_buf = try output.writableSlice(min_buffer_len);
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
|
||||
output.advance(prepared.cleartext_len);
|
||||
return prepared.ciphertext_end;
|
||||
}
|
||||
|
||||
fn prepareCiphertextRecord(
|
||||
c: *Client,
|
||||
iovecs: [][]const u8,
|
||||
ciphertext_buf: []u8,
|
||||
bytes: []const u8,
|
||||
inner_content_type: tls.ContentType,
|
||||
) struct {
|
||||
iovec_end: usize,
|
||||
ciphertext_end: usize,
|
||||
/// How many bytes are taken up by overhead per record.
|
||||
overhead_len: usize,
|
||||
cleartext_len: usize,
|
||||
} {
|
||||
// Due to the trailing inner content type byte in the ciphertext, we need
|
||||
// an additional buffer for storing the cleartext into before encrypting.
|
||||
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
|
||||
var ciphertext_end: usize = 0;
|
||||
var iovec_end: usize = 0;
|
||||
var bytes_i: usize = 0;
|
||||
switch (c.application_cipher) {
|
||||
inline else => |*p| switch (c.tls_version) {
|
||||
@ -1001,18 +981,15 @@ fn prepareCiphertextRecord(
|
||||
const pv = &p.tls_1_3;
|
||||
const P = @TypeOf(p.*);
|
||||
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
|
||||
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
|
||||
while (true) {
|
||||
const encrypted_content_len: u16 = @min(
|
||||
bytes.len - bytes_i,
|
||||
tls.max_ciphertext_inner_record_len,
|
||||
ciphertext_buf.len -|
|
||||
(close_notify_alert_reserved + overhead_len + ciphertext_end),
|
||||
ciphertext_buf.len -| (overhead_len + ciphertext_end),
|
||||
);
|
||||
if (encrypted_content_len == 0) return .{
|
||||
.iovec_end = iovec_end,
|
||||
.ciphertext_end = ciphertext_end,
|
||||
.overhead_len = overhead_len,
|
||||
.cleartext_len = bytes_i,
|
||||
};
|
||||
|
||||
@memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]);
|
||||
@ -1021,7 +998,6 @@ fn prepareCiphertextRecord(
|
||||
const ciphertext_len = encrypted_content_len + 1;
|
||||
const cleartext = cleartext_buf[0..ciphertext_len];
|
||||
|
||||
const record_start = ciphertext_end;
|
||||
const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
|
||||
ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++
|
||||
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
|
||||
@ -1039,35 +1015,27 @@ fn prepareCiphertextRecord(
|
||||
};
|
||||
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
|
||||
c.write_seq += 1; // TODO send key_update on overflow
|
||||
|
||||
const record = ciphertext_buf[record_start..ciphertext_end];
|
||||
iovecs[iovec_end] = record;
|
||||
iovec_end += 1;
|
||||
}
|
||||
},
|
||||
.tls_1_2 => {
|
||||
const pv = &p.tls_1_2;
|
||||
const P = @TypeOf(p.*);
|
||||
const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length;
|
||||
const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
|
||||
while (true) {
|
||||
const message_len: u16 = @min(
|
||||
bytes.len - bytes_i,
|
||||
tls.max_ciphertext_inner_record_len,
|
||||
ciphertext_buf.len -|
|
||||
(close_notify_alert_reserved + overhead_len + ciphertext_end),
|
||||
ciphertext_buf.len -| (overhead_len + ciphertext_end),
|
||||
);
|
||||
if (message_len == 0) return .{
|
||||
.iovec_end = iovec_end,
|
||||
.ciphertext_end = ciphertext_end,
|
||||
.overhead_len = overhead_len,
|
||||
.cleartext_len = bytes_i,
|
||||
};
|
||||
|
||||
@memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]);
|
||||
bytes_i += message_len;
|
||||
const cleartext = cleartext_buf[0..message_len];
|
||||
|
||||
const record_start = ciphertext_end;
|
||||
const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
|
||||
ciphertext_end += tls.record_header_len;
|
||||
record_header.* = .{@intFromEnum(inner_content_type)} ++
|
||||
@ -1089,10 +1057,6 @@ fn prepareCiphertextRecord(
|
||||
ciphertext_end += P.mac_length;
|
||||
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key);
|
||||
c.write_seq += 1; // TODO send key_update on overflow
|
||||
|
||||
const record = ciphertext_buf[record_start..ciphertext_end];
|
||||
iovecs[iovec_end] = record;
|
||||
iovec_end += 1;
|
||||
}
|
||||
},
|
||||
else => unreachable,
|
||||
@ -1106,74 +1070,22 @@ pub fn eof(c: Client) bool {
|
||||
c.partial_ciphertext_idx >= c.partial_ciphertext_end;
|
||||
}
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of bytes read, calling the underlying read function the
|
||||
/// minimal number of times until the buffer has at least `len` bytes filled.
|
||||
/// If the number read is less than `len` it means the stream reached the end.
|
||||
/// Reaching the end of the stream is not an error condition.
|
||||
pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
|
||||
var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }};
|
||||
return readvAtLeast(c, stream, &iovecs, len);
|
||||
fn reader_read(
|
||||
context: ?*anyopaque,
|
||||
bw: *std.io.BufferedWriter,
|
||||
limit: std.io.Reader.Limit,
|
||||
) anyerror!std.io.Reader.Status {
|
||||
const buf = limit.slice(try bw.writableSlice(1));
|
||||
const status = try reader_readv(context, &.{buf});
|
||||
bw.advance(status.len);
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
|
||||
return readAtLeast(c, stream, buffer, 1);
|
||||
}
|
||||
fn reader_readv(context: ?*anyopaque, data: []const []u8) anyerror!std.io.Reader.Status {
|
||||
const c: *Client = @ptrCast(@alignCast(context));
|
||||
if (c.eof()) return .{ .end = true };
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of bytes read. If the number read is smaller than
|
||||
/// `buffer.len`, it means the stream reached the end. Reaching the end of the
|
||||
/// stream is not an error condition.
|
||||
pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
|
||||
return readAtLeast(c, stream, buffer, buffer.len);
|
||||
}
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of bytes read. If the number read is less than the space
|
||||
/// provided it means the stream reached the end. Reaching the end of the
|
||||
/// stream is not an error condition.
|
||||
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
|
||||
/// order to handle partial reads from the underlying stream layer.
|
||||
pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
|
||||
return readvAtLeast(c, stream, iovecs, 1);
|
||||
}
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns the number of bytes read, calling the underlying read function the
|
||||
/// minimal number of times until the iovecs have at least `len` bytes filled.
|
||||
/// If the number read is less than `len` it means the stream reached the end.
|
||||
/// Reaching the end of the stream is not an error condition.
|
||||
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
|
||||
/// order to handle partial reads from the underlying stream layer.
|
||||
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize {
|
||||
if (c.eof()) return 0;
|
||||
|
||||
var off_i: usize = 0;
|
||||
var vec_i: usize = 0;
|
||||
while (true) {
|
||||
var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
|
||||
off_i += amt;
|
||||
if (c.eof() or off_i >= len) return off_i;
|
||||
while (amt >= iovecs[vec_i].len) {
|
||||
amt -= iovecs[vec_i].len;
|
||||
vec_i += 1;
|
||||
}
|
||||
iovecs[vec_i].base += amt;
|
||||
iovecs[vec_i].len -= amt;
|
||||
}
|
||||
}
|
||||
|
||||
/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
|
||||
/// Returns number of bytes that have been read, populated inside `iovecs`. A
|
||||
/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
|
||||
/// for the end of stream. The `eof()` may be true after any call to
|
||||
/// `read`, including when greater than zero bytes are returned, and this
|
||||
/// function asserts that `eof()` is `false`.
|
||||
/// See `readv` for a higher level function that has the same, familiar API as
|
||||
/// other read functions, such as `std.fs.File.read`.
|
||||
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize {
|
||||
var vp: VecPut = .{ .iovecs = iovecs };
|
||||
var vp: VecPut = .{ .iovecs = data };
|
||||
|
||||
// Give away the buffered cleartext we have, if any.
|
||||
const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
|
||||
@ -1193,11 +1105,11 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
|
||||
if (c.received_close_notify) {
|
||||
c.partial_ciphertext_end = 0;
|
||||
assert(vp.total == amt);
|
||||
return amt;
|
||||
return .{ .len = amt, .end = c.eof() };
|
||||
} else if (amt > 0) {
|
||||
// We don't need more data, so don't call read.
|
||||
assert(vp.total == amt);
|
||||
return amt;
|
||||
return .{ .len = amt, .end = c.eof() };
|
||||
}
|
||||
}
|
||||
|
||||
@ -1241,7 +1153,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
|
||||
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
|
||||
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end;
|
||||
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
|
||||
const actual_read_len = try stream.readv(ask_iovecs);
|
||||
const actual_read_len = try c.input.readv(ask_iovecs);
|
||||
if (actual_read_len == 0) {
|
||||
// This is either a truncation attack, a bug in the server, or an
|
||||
// intentional omission of the close_notify message due to truncation
|
||||
@ -1268,7 +1180,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
|
||||
// Perfect split.
|
||||
if (frag.ptr == frag1.ptr) {
|
||||
c.partial_ciphertext_end = c.partial_ciphertext_idx;
|
||||
return vp.total;
|
||||
return .{ .len = vp.total, .end = c.eof() };
|
||||
}
|
||||
frag = frag1;
|
||||
in = 0;
|
||||
@ -1310,8 +1222,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
|
||||
const record_len = mem.readInt(u16, frag[in..][0..2], .big);
|
||||
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
|
||||
in += 2;
|
||||
const end = in + record_len;
|
||||
if (end > frag.len) {
|
||||
const the_end = in + record_len;
|
||||
if (the_end > frag.len) {
|
||||
// We need the record header on the next iteration of the loop.
|
||||
in -= tls.record_header_len;
|
||||
|
||||
@ -1398,17 +1310,23 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
|
||||
.alert => {
|
||||
if (cleartext.len != 2) return error.TlsDecodeError;
|
||||
const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
|
||||
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
|
||||
if (desc == .close_notify) {
|
||||
c.received_close_notify = true;
|
||||
c.partial_ciphertext_end = c.partial_ciphertext_idx;
|
||||
return vp.total;
|
||||
}
|
||||
_ = level;
|
||||
|
||||
try desc.toError();
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
|
||||
switch (desc) {
|
||||
.close_notify => {
|
||||
c.received_close_notify = true;
|
||||
c.partial_ciphertext_end = c.partial_ciphertext_idx;
|
||||
return .{ .len = vp.total, .end = c.eof() };
|
||||
},
|
||||
.user_canceled => {
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
else => {
|
||||
c.diagnostics = .{ .alert = desc };
|
||||
return error.TlsAlert;
|
||||
},
|
||||
}
|
||||
},
|
||||
.handshake => {
|
||||
var ct_i: usize = 0;
|
||||
@ -1524,7 +1442,7 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
|
||||
}) catch {};
|
||||
}
|
||||
|
||||
fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
|
||||
fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) std.io.Reader.Status {
|
||||
const saved_buf = frag[in..];
|
||||
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
|
||||
// There is cleartext at the beginning already which we need to preserve.
|
||||
@ -1536,11 +1454,11 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
|
||||
c.partial_ciphertext_end = @intCast(saved_buf.len);
|
||||
@memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf);
|
||||
}
|
||||
return out;
|
||||
return .{ .len = out, .end = c.eof() };
|
||||
}
|
||||
|
||||
/// Note that `first` usually overlaps with `c.partially_read_buffer`.
|
||||
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
|
||||
fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) std.io.Reader.Status {
|
||||
if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
|
||||
// There is cleartext at the beginning already which we need to preserve.
|
||||
c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len);
|
||||
@ -1555,7 +1473,7 @@ fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usi
|
||||
std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
|
||||
@memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
|
||||
}
|
||||
return out;
|
||||
return .{ .len = out, .end = c.eof() };
|
||||
}
|
||||
|
||||
fn limitedOverlapCopy(frag: []u8, in: usize) void {
|
||||
@ -1577,9 +1495,6 @@ fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
|
||||
}
|
||||
}
|
||||
|
||||
const builtin = @import("builtin");
|
||||
const native_endian = builtin.cpu.arch.endian();
|
||||
|
||||
inline fn big(x: anytype) @TypeOf(x) {
|
||||
return switch (native_endian) {
|
||||
.big => x,
|
||||
@ -1958,7 +1873,3 @@ else
|
||||
.AES_256_GCM_SHA384,
|
||||
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
});
|
||||
|
||||
test {
|
||||
_ = StreamInterface;
|
||||
}
|
||||
|
||||
@ -24,6 +24,12 @@ allocator: Allocator,
|
||||
|
||||
ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{},
|
||||
ca_bundle_mutex: std.Thread.Mutex = .{},
|
||||
/// Used both for the reader and writer buffers.
|
||||
tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
|
||||
/// If non-null, ssl secrets are logged to a stream. Creating such a stream
|
||||
/// allows other processes with access to that stream to decrypt all
|
||||
/// traffic over connections created with this `Client`.
|
||||
ssl_key_logger: ?*std.io.BufferedWriter = null,
|
||||
|
||||
/// When this is `true`, the next time this client performs an HTTPS request,
|
||||
/// it will first rescan the system for root certificates.
|
||||
@ -31,6 +37,10 @@ next_https_rescan_certs: bool = true,
|
||||
|
||||
/// The pool of connections that can be reused (and currently in use).
|
||||
connection_pool: ConnectionPool = .{},
|
||||
/// Each `Connection` allocates this amount for the reader buffer.
|
||||
read_buffer_size: usize,
|
||||
/// Each `Connection` allocates this amount for the writer buffer.
|
||||
write_buffer_size: usize,
|
||||
|
||||
/// If populated, all http traffic travels through this third party.
|
||||
/// This field cannot be modified while the client has active connections.
|
||||
@ -41,7 +51,7 @@ http_proxy: ?*Proxy = null,
|
||||
/// Pointer to externally-owned memory.
|
||||
https_proxy: ?*Proxy = null,
|
||||
|
||||
/// A set of linked lists of connections that can be reused.
|
||||
/// A Least-Recently-Used cache of open connections to be reused.
|
||||
pub const ConnectionPool = struct {
|
||||
mutex: std.Thread.Mutex = .{},
|
||||
/// Open connections that are currently in use.
|
||||
@ -58,8 +68,10 @@ pub const ConnectionPool = struct {
|
||||
protocol: Connection.Protocol,
|
||||
};
|
||||
|
||||
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
|
||||
/// Finds and acquires a connection from the connection pool matching the criteria.
|
||||
/// If no connection is found, null is returned.
|
||||
///
|
||||
/// Threadsafe.
|
||||
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
@ -96,21 +108,21 @@ pub const ConnectionPool = struct {
|
||||
return pool.acquireUnsafe(connection);
|
||||
}
|
||||
|
||||
/// Tries to release a connection back to the connection pool. This function is threadsafe.
|
||||
/// Tries to release a connection back to the connection pool.
|
||||
/// If the connection is marked as closing, it will be closed instead.
|
||||
///
|
||||
/// The allocator must be the owner of all nodes in this pool.
|
||||
/// The allocator must be the owner of all resources associated with the connection.
|
||||
/// `allocator` must be the same one used to create `connection`.
|
||||
///
|
||||
/// Threadsafe.
|
||||
pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
|
||||
if (connection.closing) return connection.destroy(allocator);
|
||||
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
pool.used.remove(&connection.pool_node);
|
||||
|
||||
if (connection.closing or pool.free_size == 0) {
|
||||
connection.close(allocator);
|
||||
return allocator.destroy(connection);
|
||||
}
|
||||
if (pool.free_size == 0) return connection.destroy(allocator);
|
||||
|
||||
if (pool.free_len >= pool.free_size) {
|
||||
const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?);
|
||||
@ -138,9 +150,11 @@ pub const ConnectionPool = struct {
|
||||
pool.used.append(&connection.pool_node);
|
||||
}
|
||||
|
||||
/// Resizes the connection pool. This function is threadsafe.
|
||||
/// Resizes the connection pool.
|
||||
///
|
||||
/// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size.
|
||||
///
|
||||
/// Threadsafe.
|
||||
pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
@ -158,9 +172,11 @@ pub const ConnectionPool = struct {
|
||||
pool.free_size = new_size;
|
||||
}
|
||||
|
||||
/// Frees the connection pool and closes all connections within. This function is threadsafe.
|
||||
/// Frees the connection pool and closes all connections within.
|
||||
///
|
||||
/// All future operations on the connection pool will deadlock.
|
||||
///
|
||||
/// Threadsafe.
|
||||
pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void {
|
||||
pool.mutex.lock();
|
||||
|
||||
@ -184,160 +200,212 @@ pub const ConnectionPool = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// An interface to either a plain or TLS connection.
|
||||
pub const Connection = struct {
|
||||
client: *Client,
|
||||
stream: net.Stream,
|
||||
/// Populated when protocol is TLS; this is the writer given to the TLS
|
||||
/// client, which writes directly to `stream`, unbuffered.
|
||||
stream_writer: std.io.BufferedWriter,
|
||||
/// undefined unless protocol is tls.
|
||||
tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
|
||||
|
||||
/// HTTP protocol from client to server.
|
||||
/// This either goes directly to `stream`, or to a TLS client.
|
||||
writer: std.io.BufferedWriter,
|
||||
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
|
||||
pool_node: std.DoublyLinkedList.Node,
|
||||
|
||||
/// The protocol that this connection is using.
|
||||
protocol: Protocol,
|
||||
|
||||
/// The host that this connection is connected to.
|
||||
host: []u8,
|
||||
|
||||
/// The port that this connection is connected to.
|
||||
port: u16,
|
||||
|
||||
/// Whether this connection is proxied and is not directly connected.
|
||||
proxied: bool = false,
|
||||
|
||||
/// Whether this connection is closing when we're done with it.
|
||||
closing: bool = false,
|
||||
|
||||
read_start: BufferSize = 0,
|
||||
read_end: BufferSize = 0,
|
||||
read_buf: [buffer_size]u8,
|
||||
|
||||
write_buffer: [buffer_size]u8,
|
||||
writer: std.io.BufferedWriter,
|
||||
|
||||
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
|
||||
const BufferSize = std.math.IntFittingRange(0, buffer_size);
|
||||
host_len: u8,
|
||||
proxied: bool,
|
||||
closing: bool,
|
||||
protocol: Protocol,
|
||||
|
||||
pub const Protocol = enum { plain, tls };
|
||||
|
||||
pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
|
||||
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
|
||||
// https://github.com/ziglang/zig/issues/2473
|
||||
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
|
||||
const Plain = struct {
|
||||
/// Data from `Connection.stream`.
|
||||
reader: std.io.BufferedReader,
|
||||
connection: Connection,
|
||||
|
||||
switch (err) {
|
||||
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
|
||||
if (conn.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
|
||||
return conn.readvDirectTls(buffers);
|
||||
fn create(
|
||||
client: *Client,
|
||||
remote_host: []const u8,
|
||||
port: u16,
|
||||
stream: net.Stream,
|
||||
) error{OutOfMemory}!*Connection {
|
||||
const gpa = client.allocator;
|
||||
const alloc_len = allocLen(client, remote_host.len);
|
||||
const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len);
|
||||
errdefer gpa.free(base);
|
||||
const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len];
|
||||
const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size];
|
||||
const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size];
|
||||
assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
|
||||
@memcpy(host_buffer, remote_host);
|
||||
const plain: *Plain = @ptrCast(base);
|
||||
plain.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream = stream,
|
||||
.writer = stream.writer().buffered(socket_write_buffer),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.proxied = false,
|
||||
.closing = false,
|
||||
.protocol = .plain,
|
||||
},
|
||||
.reader = undefined,
|
||||
};
|
||||
plain.reader.init(stream.reader(), socket_read_buffer);
|
||||
}
|
||||
|
||||
return conn.stream.readv(buffers) catch |err| switch (err) {
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
};
|
||||
}
|
||||
|
||||
/// Refills the read buffer with data from the connection.
|
||||
pub fn fill(conn: *Connection) ReadError!void {
|
||||
if (conn.read_end != conn.read_start) return;
|
||||
|
||||
var iovecs = [1]std.posix.iovec{
|
||||
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
|
||||
};
|
||||
const nread = try conn.readvDirect(&iovecs);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
conn.read_start = 0;
|
||||
conn.read_end = @intCast(nread);
|
||||
}
|
||||
|
||||
/// Returns the current slice of buffered data.
|
||||
pub fn peek(conn: *Connection) []const u8 {
|
||||
return conn.read_buf[conn.read_start..conn.read_end];
|
||||
}
|
||||
|
||||
/// Discards the given number of bytes from the read buffer.
|
||||
pub fn drop(conn: *Connection, num: BufferSize) void {
|
||||
conn.read_start += num;
|
||||
}
|
||||
|
||||
/// Reads data from the connection into the given buffer.
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
const available_read = conn.read_end - conn.read_start;
|
||||
const available_buffer = buffer.len;
|
||||
|
||||
if (available_read > available_buffer) { // partially read buffered data
|
||||
@memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
|
||||
conn.read_start += @intCast(available_buffer);
|
||||
|
||||
return available_buffer;
|
||||
} else if (available_read > 0) { // fully read buffered data
|
||||
@memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
|
||||
conn.read_start += available_read;
|
||||
|
||||
return available_read;
|
||||
fn destroy(plain: *Plain) void {
|
||||
const c = &plain.connection;
|
||||
const gpa = c.client.allocator;
|
||||
const base: [*]u8 = @ptrCast(plain);
|
||||
gpa.free(base[0..allocLen(c.client, c.host_len)]);
|
||||
}
|
||||
|
||||
var iovecs = [2]std.posix.iovec{
|
||||
.{ .base = buffer.ptr, .len = buffer.len },
|
||||
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
|
||||
};
|
||||
const nread = try conn.readvDirect(&iovecs);
|
||||
|
||||
if (nread > buffer.len) {
|
||||
conn.read_start = 0;
|
||||
conn.read_end = @intCast(nread - buffer.len);
|
||||
return buffer.len;
|
||||
fn allocLen(client: *Client, host_len: usize) usize {
|
||||
return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size;
|
||||
}
|
||||
|
||||
return nread;
|
||||
}
|
||||
|
||||
pub const ReadError = error{
|
||||
TlsFailure,
|
||||
TlsAlert,
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedReadFailure,
|
||||
EndOfStream,
|
||||
fn host(plain: *Plain) []u8 {
|
||||
const base: [*]u8 = @ptrCast(plain);
|
||||
return base[@sizeOf(Plain)..][0..plain.connection.host_len];
|
||||
}
|
||||
};
|
||||
|
||||
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
||||
const Tls = struct {
|
||||
/// Data from `client` to `Connection.stream`.
|
||||
writer: std.io.BufferedWriter,
|
||||
/// Data from `Connection.stream` to `client`.
|
||||
reader: std.io.BufferedReader,
|
||||
client: std.crypto.tls.Client,
|
||||
connection: Connection,
|
||||
|
||||
pub fn reader(conn: *Connection) Reader {
|
||||
return .{ .context = conn };
|
||||
}
|
||||
fn create(
|
||||
client: *Client,
|
||||
remote_host: []const u8,
|
||||
port: u16,
|
||||
stream: net.Stream,
|
||||
) error{ OutOfMemory, TlsInitializationFailed }!*Tls {
|
||||
const gpa = client.allocator;
|
||||
const alloc_len = allocLen(client, remote_host.len);
|
||||
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
|
||||
errdefer gpa.free(base);
|
||||
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
|
||||
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
|
||||
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
|
||||
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
|
||||
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
|
||||
@memcpy(host_buffer, remote_host);
|
||||
const tls: *Tls = @ptrCast(base);
|
||||
tls.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream = stream,
|
||||
.writer = tls.client.writer().buffered(socket_write_buffer),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.proxied = false,
|
||||
.closing = false,
|
||||
.protocol = .tls,
|
||||
},
|
||||
.writer = stream.writer().buffered(tls_write_buffer),
|
||||
.reader = undefined,
|
||||
.client = undefined,
|
||||
};
|
||||
tls.reader.init(stream.reader(), tls_read_buffer);
|
||||
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
|
||||
tls.client.init(&tls.reader, &tls.writer, .{
|
||||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_logger = client.ssl_key_logger,
|
||||
}) catch return error.TlsInitializationFailed;
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
tls.client.allow_truncation_attacks = true;
|
||||
|
||||
pub const WriteError = error{
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedWriteFailure,
|
||||
};
|
||||
|
||||
pub fn close(conn: *Connection, allocator: Allocator) void {
|
||||
if (conn.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
|
||||
// try to cleanly close the TLS connection, for any server that cares.
|
||||
_ = conn.tls_client.writeEnd("", true) catch {};
|
||||
if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close();
|
||||
allocator.destroy(conn.tls_client);
|
||||
return tls;
|
||||
}
|
||||
|
||||
conn.stream.close();
|
||||
allocator.free(conn.host);
|
||||
fn destroy(tls: *Tls, gpa: Allocator) void {
|
||||
const c = &tls.connection;
|
||||
const base: [*]u8 = @ptrCast(tls);
|
||||
gpa.free(base[0..allocLen(c.client, c.host_len)]);
|
||||
}
|
||||
|
||||
fn allocLen(client: *Client, host_len: usize) usize {
|
||||
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size;
|
||||
}
|
||||
|
||||
fn host(tls: *Tls) []u8 {
|
||||
const base: [*]u8 = @ptrCast(tls);
|
||||
return base[@sizeOf(Tls)..][0..tls.connection.host_len];
|
||||
}
|
||||
};
|
||||
|
||||
fn host(c: *Connection) []u8 {
|
||||
return switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
return tls.host();
|
||||
},
|
||||
.plain => {
|
||||
const plain: *Plain = @fieldParentPtr("connection", c);
|
||||
return plain.host();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/// This is either data from `stream`, or `Tls.client`.
|
||||
fn reader(c: *Connection) *std.io.BufferedReader {
|
||||
return switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
return &tls.client.reader;
|
||||
},
|
||||
.plain => {
|
||||
const plain: *Plain = @fieldParentPtr("connection", c);
|
||||
return &plain.reader;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/// If this is called without calling `flush` or `end`, data will be
|
||||
/// dropped unsent.
|
||||
pub fn destroy(c: *Connection) void {
|
||||
c.stream.close();
|
||||
switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
tls.destroy();
|
||||
},
|
||||
.plain => {
|
||||
const plain: *Plain = @fieldParentPtr("connection", c);
|
||||
plain.destroy();
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(c: *Connection) anyerror!void {
|
||||
try c.writer.flush();
|
||||
if (c.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
try tls.writer.flush();
|
||||
}
|
||||
}
|
||||
|
||||
/// If the connection is a TLS connection, sends the close_notify alert.
|
||||
///
|
||||
/// Flushes all buffers.
|
||||
pub fn end(c: *Connection) anyerror!void {
|
||||
try c.writer.flush();
|
||||
if (c.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
try tls.client.end();
|
||||
try tls.writer.flush();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -350,10 +418,10 @@ pub const RequestTransfer = union(enum) {
|
||||
|
||||
/// The decompressor for response messages.
|
||||
pub const Compression = union(enum) {
|
||||
pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader);
|
||||
pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader);
|
||||
pub const DeflateDecompressor = std.compress.zlib.Decompressor;
|
||||
pub const GzipDecompressor = std.compress.gzip.Decompressor;
|
||||
// https://github.com/ziglang/zig/issues/18937
|
||||
//pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
|
||||
//pub const ZstdDecompressor = std.compress.zstd.DecompressStream(.{});
|
||||
|
||||
deflate: DeflateDecompressor,
|
||||
gzip: GzipDecompressor,
|
||||
@ -617,9 +685,6 @@ pub const Response = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// A HTTP request that has been sent.
|
||||
///
|
||||
/// Order of operations: open -> send[ -> write -> finish] -> wait -> read
|
||||
pub const Request = struct {
|
||||
uri: Uri,
|
||||
client: *Client,
|
||||
@ -1300,24 +1365,34 @@ pub const basic_authorization = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
|
||||
pub const ConnectTcpError = Allocator.Error || error{
|
||||
ConnectionRefused,
|
||||
NetworkUnreachable,
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
TemporaryNameServerFailure,
|
||||
NameServerFailure,
|
||||
UnknownHostName,
|
||||
HostLacksNetworkAddresses,
|
||||
UnexpectedConnectFailure,
|
||||
TlsInitializationFailed,
|
||||
};
|
||||
|
||||
/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
|
||||
/// Reuses a `Connection` if one matching `host` and `port` is already open.
|
||||
///
|
||||
/// This function is threadsafe.
|
||||
pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
|
||||
/// Threadsafe.
|
||||
pub fn connectTcp(
|
||||
client: *Client,
|
||||
host: []const u8,
|
||||
port: u16,
|
||||
protocol: Connection.Protocol,
|
||||
) ConnectTcpError!*Connection {
|
||||
if (client.connection_pool.findConnection(.{
|
||||
.host = host,
|
||||
.port = port,
|
||||
.protocol = protocol,
|
||||
})) |conn| return conn;
|
||||
|
||||
if (disable_tls and protocol == .tls)
|
||||
return error.TlsInitializationFailed;
|
||||
|
||||
const conn = try client.allocator.create(Connection);
|
||||
errdefer client.allocator.destroy(conn);
|
||||
|
||||
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
|
||||
error.ConnectionRefused => return error.ConnectionRefused,
|
||||
error.NetworkUnreachable => return error.NetworkUnreachable,
|
||||
@ -1331,77 +1406,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
|
||||
};
|
||||
errdefer stream.close();
|
||||
|
||||
conn.* = .{
|
||||
.stream = stream,
|
||||
.stream_writer = undefined,
|
||||
.tls_client = undefined,
|
||||
.read_buf = undefined,
|
||||
|
||||
.write_buffer = undefined,
|
||||
.writer = undefined, // populated below
|
||||
|
||||
.protocol = protocol,
|
||||
.host = try client.allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
|
||||
.pool_node = .{},
|
||||
};
|
||||
errdefer client.allocator.free(conn.host);
|
||||
|
||||
switch (protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
|
||||
const tls_client = try client.allocator.create(std.crypto.tls.Client);
|
||||
errdefer client.allocator.destroy(tls_client);
|
||||
|
||||
const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: {
|
||||
const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) {
|
||||
error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null,
|
||||
error.OutOfMemory => return error.OutOfMemory,
|
||||
};
|
||||
defer client.allocator.free(ssl_key_log_path);
|
||||
break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{
|
||||
.truncate = false,
|
||||
.mode = switch (builtin.os.tag) {
|
||||
.windows, .wasi => 0,
|
||||
else => 0o600,
|
||||
},
|
||||
}) catch null;
|
||||
} else null;
|
||||
errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close();
|
||||
|
||||
conn.stream_writer = .{
|
||||
.unbuffered_writer = stream.writer(),
|
||||
.buffer = &.{},
|
||||
};
|
||||
|
||||
tls_client.* = std.crypto.tls.Client.init(stream, &conn.stream_writer, .{
|
||||
.host = .{ .explicit = host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_log_file = ssl_key_log_file,
|
||||
}) catch return error.TlsInitializationFailed;
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
tls_client.allow_truncation_attacks = true;
|
||||
|
||||
conn.writer = .{
|
||||
.unbuffered_writer = tls_client.writer(),
|
||||
.buffer = &conn.write_buffer,
|
||||
};
|
||||
conn.tls_client = tls_client;
|
||||
if (disable_tls) return error.TlsInitializationFailed;
|
||||
const tc = try Connection.Tls.create(client, host, port, stream);
|
||||
client.connection_pool.addUsed(&tc.connection);
|
||||
return &tc.connection;
|
||||
},
|
||||
.plain => {
|
||||
conn.writer = .{
|
||||
.unbuffered_writer = stream.writer(),
|
||||
.buffer = &conn.write_buffer,
|
||||
};
|
||||
const pc = try Connection.Plain.create(client, host, port, stream);
|
||||
client.connection_pool.addUsed(&pc.connection);
|
||||
return &pc.connection;
|
||||
},
|
||||
}
|
||||
|
||||
client.connection_pool.addUsed(conn);
|
||||
|
||||
return conn;
|
||||
}
|
||||
|
||||
pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError;
|
||||
@ -1662,16 +1679,17 @@ pub fn open(
|
||||
var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer);
|
||||
const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
|
||||
|
||||
if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
|
||||
if (protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
|
||||
client.ca_bundle_mutex.lock();
|
||||
defer client.ca_bundle_mutex.unlock();
|
||||
|
||||
client.ca_bundle_mutex.lock();
|
||||
defer client.ca_bundle_mutex.unlock();
|
||||
|
||||
if (client.next_https_rescan_certs) {
|
||||
client.ca_bundle.rescan(client.allocator) catch
|
||||
return error.CertificateBundleLoadFailure;
|
||||
@atomicStore(bool, &client.next_https_rescan_certs, false, .release);
|
||||
if (client.next_https_rescan_certs) {
|
||||
client.ca_bundle.rescan(client.allocator) catch
|
||||
return error.CertificateBundleLoadFailure;
|
||||
@atomicStore(bool, &client.next_https_rescan_certs, false, .release);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,18 +1,25 @@
|
||||
//! Blocking HTTP server implementation.
|
||||
//! Handles a single connection's lifecycle.
|
||||
|
||||
connection: net.Server.Connection,
|
||||
const std = @import("../std.zig");
|
||||
const http = std.http;
|
||||
const mem = std.mem;
|
||||
const net = std.net;
|
||||
const Uri = std.Uri;
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
|
||||
const Server = @This();
|
||||
|
||||
/// The reader's buffer must be large enough to store the client's entire HTTP
|
||||
/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`.
|
||||
in: *std.io.BufferedReader,
|
||||
out: *std.io.BufferedWriter,
|
||||
/// Keeps track of whether the Server is ready to accept a new request on the
|
||||
/// same connection, and makes invalid API usage cause assertion failures
|
||||
/// rather than HTTP protocol violations.
|
||||
state: State,
|
||||
/// User-provided buffer that must outlive this Server.
|
||||
/// Used to store the client's entire HTTP header.
|
||||
read_buffer: []u8,
|
||||
/// Amount of available data inside read_buffer.
|
||||
read_buffer_len: usize,
|
||||
/// Index into `read_buffer` of the first byte of the next HTTP request.
|
||||
next_request_start: usize,
|
||||
in_err: anyerror,
|
||||
|
||||
pub const State = enum {
|
||||
/// The connection is available to be used for the first time, or reused.
|
||||
@ -31,14 +38,13 @@ pub const State = enum {
|
||||
|
||||
/// Initialize an HTTP server that can respond to multiple requests on the same
|
||||
/// connection.
|
||||
///
|
||||
/// The returned `Server` is ready for `receiveHead` to be called.
|
||||
pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server {
|
||||
pub fn init(in: *std.io.BufferedReader, out: *std.io.BufferedWriter) Server {
|
||||
return .{
|
||||
.connection = connection,
|
||||
.in = in,
|
||||
.out = out,
|
||||
.state = .ready,
|
||||
.read_buffer = read_buffer,
|
||||
.read_buffer_len = 0,
|
||||
.next_request_start = 0,
|
||||
};
|
||||
}
|
||||
|
||||
@ -48,78 +54,55 @@ pub const ReceiveHeadError = error{
|
||||
/// before closing the connection.
|
||||
HttpHeadersOversize,
|
||||
/// Client sent headers that did not conform to the HTTP protocol.
|
||||
/// `in_err` is populated with a `Request.Head.ParseError`.
|
||||
HttpHeadersInvalid,
|
||||
/// A low level I/O error occurred trying to read the headers.
|
||||
HttpHeadersUnreadable,
|
||||
/// Partial HTTP request was received but the connection was closed before
|
||||
/// fully receiving the headers.
|
||||
HttpRequestTruncated,
|
||||
/// The client sent 0 bytes of headers before closing the stream.
|
||||
/// In other words, a keep-alive connection was finally closed.
|
||||
HttpConnectionClosing,
|
||||
/// Error occurred reading from `in`; `in_err` is populated.
|
||||
ReadFailure,
|
||||
};
|
||||
|
||||
/// The header bytes reference the read buffer that Server was initialized with
|
||||
/// and remain alive until the next call to receiveHead.
|
||||
/// The header bytes reference the internal storage of `in`, which are
|
||||
/// invalidated with the next call to `receiveHead`.
|
||||
pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
|
||||
assert(s.state == .ready);
|
||||
s.state = .received_head;
|
||||
errdefer s.state = .receiving_head;
|
||||
|
||||
// In case of a reused connection, move the next request's bytes to the
|
||||
// beginning of the buffer.
|
||||
if (s.next_request_start > 0) {
|
||||
if (s.read_buffer_len > s.next_request_start) {
|
||||
rebase(s, 0);
|
||||
} else {
|
||||
s.read_buffer_len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const in = &s.in;
|
||||
var hp: http.HeadParser = .{};
|
||||
|
||||
if (s.read_buffer_len > 0) {
|
||||
const bytes = s.read_buffer[0..s.read_buffer_len];
|
||||
const end = hp.feed(bytes);
|
||||
if (hp.state == .finished)
|
||||
return finishReceivingHead(s, end);
|
||||
}
|
||||
var head_end: usize = 0;
|
||||
|
||||
while (true) {
|
||||
const buf = s.read_buffer[s.read_buffer_len..];
|
||||
if (buf.len == 0)
|
||||
return error.HttpHeadersOversize;
|
||||
const read_n = s.connection.stream.read(buf) catch
|
||||
return error.HttpHeadersUnreadable;
|
||||
if (read_n == 0) {
|
||||
if (s.read_buffer_len > 0) {
|
||||
return error.HttpRequestTruncated;
|
||||
} else {
|
||||
return error.HttpConnectionClosing;
|
||||
}
|
||||
}
|
||||
s.read_buffer_len += read_n;
|
||||
const bytes = buf[0..read_n];
|
||||
const end = hp.feed(bytes);
|
||||
if (hp.state == .finished)
|
||||
return finishReceivingHead(s, s.read_buffer_len - bytes.len + end);
|
||||
if (head_end >= in.bufferContents().len) return error.HttpHeadersOversize;
|
||||
const buf = (in.peekGreedy(head_end + 1) catch |err| {
|
||||
s.in_err = err;
|
||||
return error.ReadFailure;
|
||||
}) orelse switch (head_end) {
|
||||
0 => return error.HttpConnectionClosing,
|
||||
else => return error.HttpRequestTruncated,
|
||||
};
|
||||
head_end += hp.feed(buf[head_end..]);
|
||||
if (hp.state == .finished) return .{
|
||||
.server = s,
|
||||
.head_end = head_end,
|
||||
.head = Request.Head.parse(buf[0..head_end]) catch |err| {
|
||||
s.in_err = err;
|
||||
return error.HttpHeadersInvalid;
|
||||
},
|
||||
.reader_state = undefined,
|
||||
.write_error = undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request {
|
||||
return .{
|
||||
.server = s,
|
||||
.head_end = head_end,
|
||||
.head = Request.Head.parse(s.read_buffer[0..head_end]) catch
|
||||
return error.HttpHeadersInvalid,
|
||||
.reader_state = undefined,
|
||||
.write_error = undefined,
|
||||
};
|
||||
}
|
||||
|
||||
pub const Request = struct {
|
||||
server: *Server,
|
||||
/// Index into Server's read_buffer.
|
||||
/// Index into `Server.in` internal buffer.
|
||||
head_end: usize,
|
||||
head: Head,
|
||||
reader_state: union {
|
||||
@ -299,7 +282,7 @@ pub const Request = struct {
|
||||
};
|
||||
|
||||
pub fn iterateHeaders(r: *Request) http.HeaderIterator {
|
||||
return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]);
|
||||
return http.HeaderIterator.init(r.in.bufferContents()[0..r.head_end]);
|
||||
}
|
||||
|
||||
test iterateHeaders {
|
||||
@ -312,13 +295,14 @@ pub const Request = struct {
|
||||
|
||||
var read_buffer: [500]u8 = undefined;
|
||||
@memcpy(read_buffer[0..request_bytes.len], request_bytes);
|
||||
var br: std.io.BufferedReader = undefined;
|
||||
br.initFixed(&read_buffer);
|
||||
|
||||
var server: Server = .{
|
||||
.connection = undefined,
|
||||
.in = &br,
|
||||
.out = undefined,
|
||||
.state = .ready,
|
||||
.read_buffer = &read_buffer,
|
||||
.read_buffer_len = request_bytes.len,
|
||||
.next_request_start = 0,
|
||||
.in_err = undefined,
|
||||
};
|
||||
|
||||
var request: Request = .{
|
||||
@ -1158,13 +1142,3 @@ fn rebase(s: *Server, index: usize) void {
|
||||
}
|
||||
s.read_buffer_len = index + leftover.len;
|
||||
}
|
||||
|
||||
const std = @import("../std.zig");
|
||||
const http = std.http;
|
||||
const mem = std.mem;
|
||||
const net = std.net;
|
||||
const Uri = std.Uri;
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
|
||||
const Server = @This();
|
||||
|
||||
@ -94,6 +94,11 @@ pub fn storageBuffer(br: *BufferedReader) []u8 {
|
||||
return storage.buffer;
|
||||
}
|
||||
|
||||
pub fn bufferContents(br: *BufferedReader) []u8 {
|
||||
const storage = &br.storage;
|
||||
return storage.buffer[br.seek..storage.end];
|
||||
}
|
||||
|
||||
/// Although `BufferedReader` can easily satisfy the `Reader` interface, it's
|
||||
/// generally more practical to pass a `BufferedReader` instance itself around,
|
||||
/// since it will result in fewer calls across vtable boundaries.
|
||||
@ -159,31 +164,69 @@ pub fn seekForwardBy(br: *BufferedReader, seek_by: u64) anyerror!void {
|
||||
/// is returned instead.
|
||||
///
|
||||
/// See also:
|
||||
/// * `peekAll`
|
||||
/// * `peekGreedy`
|
||||
/// * `toss`
|
||||
pub fn peek(br: *BufferedReader, n: usize) anyerror![]u8 {
|
||||
return (try br.peekAll(n))[0..n];
|
||||
return (try br.peekGreedy(n))[0..n];
|
||||
}
|
||||
|
||||
/// Returns the next buffered bytes from `unbuffered_reader`, after filling the buffer
|
||||
/// with at least `n` bytes.
|
||||
/// Returns the next `n` bytes from `unbuffered_reader`, filling the buffer as
|
||||
/// necessary.
|
||||
///
|
||||
/// Invalidates previously returned values from `peek`.
|
||||
///
|
||||
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
|
||||
/// least as big as `n`.
|
||||
///
|
||||
/// If there are fewer than `n` bytes left in the stream, `null` is returned
|
||||
/// instead.
|
||||
///
|
||||
/// See also:
|
||||
/// * `peekGreedy`
|
||||
/// * `toss`
|
||||
pub fn peek2(br: *BufferedReader, n: usize) anyerror!?[]u8 {
|
||||
if (try br.peekGreedy(n)) |buf| return buf[0..n];
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Returns all the next buffered bytes from `unbuffered_reader`, after filling
|
||||
/// the buffer to ensure it contains at least `n` bytes.
|
||||
///
|
||||
/// Invalidates previously returned values from `peek` and `peekGreedy`.
|
||||
///
|
||||
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
|
||||
/// least as big as `n`.
|
||||
///
|
||||
/// If there are fewer than `n` bytes left in the stream, `error.EndOfStream`
|
||||
/// is returned instead.
|
||||
///
|
||||
/// See also:
|
||||
/// * `peek`
|
||||
/// * `toss`
|
||||
pub fn peekAll(br: *BufferedReader, n: usize) anyerror![]u8 {
|
||||
const storage = &br.storage;
|
||||
assert(n <= storage.buffer.len);
|
||||
try br.fill(n);
|
||||
return storage.buffer[br.seek..storage.end];
|
||||
pub fn peekGreedy(br: *BufferedReader, n: usize) anyerror![]u8 {
|
||||
assert(n <= br.storage.buffer.len);
|
||||
if (try br.fill(n)) return br.bufferContents();
|
||||
return error.EndOfStream;
|
||||
}
|
||||
|
||||
/// Returns all the next buffered bytes from `unbuffered_reader`, after filling
|
||||
/// the buffer to ensure it contains at least `n` bytes.
|
||||
///
|
||||
/// Invalidates previously returned values from `peek` and `peekGreedy`.
|
||||
///
|
||||
/// Asserts that the `BufferedReader` was initialized with a buffer capacity at
|
||||
/// least as big as `n`.
|
||||
///
|
||||
/// If there are fewer than `n` bytes left in the stream, `null` is returned
|
||||
/// instead.
|
||||
///
|
||||
/// See also:
|
||||
/// * `peek`
|
||||
/// * `toss`
|
||||
pub fn peekGreedy2(br: *BufferedReader, n: usize) anyerror!?[]u8 {
|
||||
assert(n <= br.storage.buffer.len);
|
||||
if (try br.fill(n)) return br.bufferContents();
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Skips the next `n` bytes from the stream, advancing the seek position. This
|
||||
@ -505,17 +548,17 @@ pub fn discardDelimiterInclusive(br: *BufferedReader, delimiter: u8) anyerror!vo
|
||||
/// Fills the buffer such that it contains at least `n` bytes, without
|
||||
/// advancing the seek position.
|
||||
///
|
||||
/// Returns `error.EndOfStream` if there are fewer than `n` bytes remaining.
|
||||
/// Returns `false` if and only if there are fewer than `n` bytes remaining.
|
||||
///
|
||||
/// Asserts buffer capacity is at least `n`.
|
||||
pub fn fill(br: *BufferedReader, n: usize) anyerror!void {
|
||||
pub fn fill(br: *BufferedReader, n: usize) anyerror!bool {
|
||||
const storage = &br.storage;
|
||||
assert(n <= storage.buffer.len);
|
||||
const buffer = storage.buffer[0..storage.end];
|
||||
const seek = br.seek;
|
||||
if (seek + n <= buffer.len) {
|
||||
@branchHint(.likely);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
const remainder = buffer[seek..];
|
||||
std.mem.copyForwards(u8, buffer[0..remainder.len], remainder);
|
||||
@ -523,8 +566,8 @@ pub fn fill(br: *BufferedReader, n: usize) anyerror!void {
|
||||
br.seek = 0;
|
||||
while (true) {
|
||||
const status = try br.unbuffered_reader.read(storage, .unlimited);
|
||||
if (n <= storage.end) return;
|
||||
if (status.end) return error.EndOfStream;
|
||||
if (n <= storage.end) return true;
|
||||
if (status.end) return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -535,7 +578,8 @@ pub fn takeByte(br: *BufferedReader) anyerror!u8 {
|
||||
const seek = br.seek;
|
||||
if (seek >= buffer.len) {
|
||||
@branchHint(.unlikely);
|
||||
try br.fill(1);
|
||||
const filled = try fill(br, 1);
|
||||
if (!filled) return error.EndOfStream;
|
||||
}
|
||||
br.seek = seek + 1;
|
||||
return buffer[seek];
|
||||
@ -603,7 +647,7 @@ fn takeMultipleOf7Leb128(br: *BufferedReader, comptime Result: type) anyerror!Re
|
||||
var result: UnsignedResult = 0;
|
||||
var fits = true;
|
||||
while (true) {
|
||||
const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekAll(1));
|
||||
const buffer: []const packed struct(u8) { bits: u7, more: bool } = @ptrCast(try br.peekGreedy(1));
|
||||
for (buffer, 1..) |byte, len| {
|
||||
if (remaining_bits > 0) {
|
||||
result = @shlExact(@as(UnsignedResult, byte.bits), result_info.bits - 7) |
|
||||
@ -639,7 +683,7 @@ test peek {
|
||||
return error.Unimplemented;
|
||||
}
|
||||
|
||||
test peekAll {
|
||||
test peekGreedy {
|
||||
return error.Unimplemented;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user