mirror of
https://github.com/ziglang/zig.git
synced 2026-01-25 16:55:22 +00:00
std.io: start removing context from Reader/Writer
rely on the field parent pointer pattern
This commit is contained in:
parent
7d6b5ed510
commit
d4545f216a
@ -21,11 +21,15 @@ const array = tls.array;
|
||||
/// here via `reader`.
|
||||
///
|
||||
/// The buffer is asserted to have capacity at least `min_buffer_len`.
|
||||
input: *std.io.Reader,
|
||||
input: *Reader,
|
||||
/// Decrypted stream from the server to the client.
|
||||
reader: Reader,
|
||||
|
||||
/// The encrypted stream from the client to the server. Bytes are pushed here
|
||||
/// via `writer`.
|
||||
output: *Writer,
|
||||
/// The plaintext stream from the client to the server.
|
||||
writer: Writer,
|
||||
|
||||
/// Populated when `error.TlsAlert` is returned.
|
||||
alert: ?tls.Alert = null,
|
||||
@ -36,14 +40,6 @@ write_seq: u64,
|
||||
/// When this is true, the stream may still not be at the end because there
|
||||
/// may be data in the input buffer.
|
||||
received_close_notify: bool,
|
||||
/// By default, reaching the end-of-stream when reading from the server will
|
||||
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
|
||||
/// message has been received. By setting this flag to `true`, instead, the
|
||||
/// end-of-stream will be forwarded to the application layer above TLS.
|
||||
///
|
||||
/// This makes the application vulnerable to truncation attacks unless the
|
||||
/// application layer itself verifies that the amount of data received equals
|
||||
/// the amount of data expected, such as HTTP with the Content-Length header.
|
||||
allow_truncation_attacks: bool,
|
||||
application_cipher: tls.ApplicationCipher,
|
||||
|
||||
@ -85,7 +81,7 @@ pub const SslKeyLog = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// The `std.io.Reader` supplied to `init` requires a buffer capacity
|
||||
/// The `Reader` supplied to `init` requires a buffer capacity
|
||||
/// at least this amount.
|
||||
pub const min_buffer_len = tls.max_ciphertext_record_len;
|
||||
|
||||
@ -116,6 +112,20 @@ pub const Options = struct {
|
||||
/// Only the `writer` field is observed during the handshake (`init`).
|
||||
/// After that, the other fields are populated.
|
||||
ssl_key_log: ?*SslKeyLog = null,
|
||||
/// By default, reaching the end-of-stream when reading from the server will
|
||||
/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
|
||||
/// message has been received. By setting this flag to `true`, instead, the
|
||||
/// end-of-stream will be forwarded to the application layer above TLS.
|
||||
///
|
||||
/// This makes the application vulnerable to truncation attacks unless the
|
||||
/// application layer itself verifies that the amount of data received equals
|
||||
/// the amount of data expected, such as HTTP with the Content-Length header.
|
||||
allow_truncation_attacks: bool = false,
|
||||
write_buffer: []u8,
|
||||
/// Asserted to have capacity at least `min_buffer_len`.
|
||||
read_buffer: []u8,
|
||||
/// Populated when `error.TlsAlert` is returned from `init`.
|
||||
alert: ?*tls.Alert = null,
|
||||
};
|
||||
|
||||
const InitError = error{
|
||||
@ -173,14 +183,8 @@ const InitError = error{
|
||||
/// `host` is only borrowed during this function call.
|
||||
///
|
||||
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
|
||||
pub fn init(
|
||||
client: *Client,
|
||||
input: *std.io.Reader,
|
||||
output: *Writer,
|
||||
options: Options,
|
||||
) InitError!void {
|
||||
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
|
||||
assert(input.buffer.len >= min_buffer_len);
|
||||
client.alert = null;
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
@ -411,7 +415,7 @@ pub fn init(
|
||||
switch (ct) {
|
||||
.alert => {
|
||||
ctd.ensure(2) catch continue :fragment;
|
||||
client.alert = .{
|
||||
if (options.alert) |a| a.* = .{
|
||||
.level = ctd.decode(tls.Alert.Level),
|
||||
.description = ctd.decode(tls.Alert.Description),
|
||||
};
|
||||
@ -852,9 +856,28 @@ pub fn init(
|
||||
else => unreachable,
|
||||
},
|
||||
};
|
||||
client.* = .{
|
||||
if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
|
||||
.client_key_seq = key_seq,
|
||||
.server_key_seq = key_seq,
|
||||
.client_random = client_hello_rand,
|
||||
.writer = ssl_key_log.writer,
|
||||
};
|
||||
return .{
|
||||
.input = input,
|
||||
.reader = .{
|
||||
.buffer = options.read_buffer,
|
||||
.vtable = &.{ .stream = stream },
|
||||
.seek = 0,
|
||||
.end = 0,
|
||||
},
|
||||
.output = output,
|
||||
.writer = .{
|
||||
.buffer = options.write_buffer,
|
||||
.vtable = &.{
|
||||
.drain = drain,
|
||||
.sendFile = Writer.unimplementedSendFile,
|
||||
},
|
||||
},
|
||||
.tls_version = tls_version,
|
||||
.read_seq = switch (tls_version) {
|
||||
.tls_1_3 => 0,
|
||||
@ -867,17 +890,10 @@ pub fn init(
|
||||
else => unreachable,
|
||||
},
|
||||
.received_close_notify = false,
|
||||
.allow_truncation_attacks = false,
|
||||
.allow_truncation_attacks = options.allow_truncation_attacks,
|
||||
.application_cipher = app_cipher,
|
||||
.ssl_key_log = options.ssl_key_log,
|
||||
};
|
||||
if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
|
||||
.client_key_seq = key_seq,
|
||||
.server_key_seq = key_seq,
|
||||
.client_random = client_hello_rand,
|
||||
.writer = ssl_key_log.writer,
|
||||
};
|
||||
return;
|
||||
},
|
||||
else => return error.TlsUnexpectedMessage,
|
||||
}
|
||||
@ -891,25 +907,9 @@ pub fn init(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reader(c: *Client) Reader {
|
||||
return .{
|
||||
.context = c,
|
||||
.vtable = &.{ .read = read },
|
||||
};
|
||||
}
|
||||
|
||||
pub fn writer(c: *Client) Writer {
|
||||
return .{
|
||||
.context = c,
|
||||
.vtable = &.{
|
||||
.writeSplat = writeSplat,
|
||||
.writeFile = Writer.unimplementedWriteFile,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
fn writeSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Error!usize {
|
||||
const c: *Client = @alignCast(@ptrCast(context));
|
||||
fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
|
||||
const c: *Client = @fieldParentPtr("writer", w);
|
||||
if (true) @panic("update to use the buffer and flush");
|
||||
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
@ -1043,8 +1043,8 @@ pub fn eof(c: Client) bool {
|
||||
return c.received_close_notify;
|
||||
}
|
||||
|
||||
fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
|
||||
const c: *Client = @ptrCast(@alignCast(context));
|
||||
fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
|
||||
const c: *Client = @fieldParentPtr("reader", r);
|
||||
if (c.eof()) return error.EndOfStream;
|
||||
const input = c.input;
|
||||
// If at least one full encrypted record is not buffered, read once.
|
||||
@ -1214,7 +1214,7 @@ fn read(context: ?*anyopaque, bw: *Writer, limit: std.io.Limit) Reader.StreamErr
|
||||
},
|
||||
.application_data => {
|
||||
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
|
||||
try bw.writeAll(cleartext);
|
||||
try w.writeAll(cleartext);
|
||||
return cleartext.len;
|
||||
},
|
||||
else => return failRead(c, error.TlsUnexpectedMessage),
|
||||
@ -1226,8 +1226,8 @@ fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
|
||||
return error.ReadFailed;
|
||||
}
|
||||
|
||||
fn logSecrets(bw: *Writer, context: anytype, secrets: anytype) void {
|
||||
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| bw.print("{s}" ++
|
||||
fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void {
|
||||
inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
|
||||
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
|
||||
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
|
||||
context.client_random,
|
||||
|
||||
@ -943,7 +943,6 @@ pub const Reader = struct {
|
||||
|
||||
pub fn initInterface(buffer: []u8) std.io.Reader {
|
||||
return .{
|
||||
.context = undefined,
|
||||
.vtable = &.{
|
||||
.stream = Reader.stream,
|
||||
.discard = Reader.discard,
|
||||
@ -1291,7 +1290,6 @@ pub const Writer = struct {
|
||||
|
||||
pub fn initInterface(buffer: []u8) std.io.Writer {
|
||||
return .{
|
||||
.context = undefined,
|
||||
.vtable = &.{
|
||||
.drain = drain,
|
||||
.sendFile = sendFile,
|
||||
|
||||
107
lib/std/http.zig
107
lib/std/http.zig
@ -328,6 +328,7 @@ pub const Header = struct {
|
||||
|
||||
pub const Reader = struct {
|
||||
in: *std.io.Reader,
|
||||
interface: std.io.Reader,
|
||||
/// Keeps track of whether the stream is ready to accept a new request,
|
||||
/// making invalid API usage cause assertion failures rather than HTTP
|
||||
/// protocol violations.
|
||||
@ -438,37 +439,41 @@ pub const Reader = struct {
|
||||
buffer: []u8,
|
||||
transfer_encoding: TransferEncoding,
|
||||
content_length: ?u64,
|
||||
) std.io.Reader {
|
||||
) *std.io.Reader {
|
||||
assert(reader.state == .received_head);
|
||||
return switch (transfer_encoding) {
|
||||
switch (transfer_encoding) {
|
||||
.chunked => {
|
||||
reader.state = .{ .body_remaining_chunk_len = .head };
|
||||
return .{
|
||||
reader.interface = .{
|
||||
.buffer = buffer,
|
||||
.context = reader,
|
||||
.seek = 0,
|
||||
.end = 0,
|
||||
.vtable = &.{
|
||||
.read = chunkedRead,
|
||||
.stream = chunkedStream,
|
||||
.discard = chunkedDiscard,
|
||||
},
|
||||
};
|
||||
return &reader.interface;
|
||||
},
|
||||
.none => {
|
||||
if (content_length) |len| {
|
||||
reader.state = .{ .body_remaining_content_length = len };
|
||||
return .{
|
||||
reader.interface = .{
|
||||
.buffer = buffer,
|
||||
.context = reader,
|
||||
.seek = 0,
|
||||
.end = 0,
|
||||
.vtable = &.{
|
||||
.read = contentLengthRead,
|
||||
.stream = contentLengthStream,
|
||||
.discard = contentLengthDiscard,
|
||||
},
|
||||
};
|
||||
return &reader.interface;
|
||||
} else {
|
||||
reader.state = .body_none;
|
||||
return reader.in.reader();
|
||||
return reader.in;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// If compressed body has been negotiated this will return decompressed bytes.
|
||||
@ -511,25 +516,25 @@ pub const Reader = struct {
|
||||
return decompressor.reader(transfer_reader, decompression_buffer, content_encoding);
|
||||
}
|
||||
|
||||
fn contentLengthRead(
|
||||
ctx: ?*anyopaque,
|
||||
bw: *Writer,
|
||||
fn contentLengthStream(
|
||||
io_r: *std.io.Reader,
|
||||
w: *Writer,
|
||||
limit: std.io.Limit,
|
||||
) std.io.Reader.StreamError!usize {
|
||||
const reader: *Reader = @alignCast(@ptrCast(ctx));
|
||||
const reader: *Reader = @fieldParentPtr("interface", io_r);
|
||||
const remaining_content_length = &reader.state.body_remaining_content_length;
|
||||
const remaining = remaining_content_length.*;
|
||||
if (remaining == 0) {
|
||||
reader.state = .ready;
|
||||
return error.EndOfStream;
|
||||
}
|
||||
const n = try reader.in.read(bw, limit.min(.limited(remaining)));
|
||||
const n = try reader.in.stream(w, limit.min(.limited(remaining)));
|
||||
remaining_content_length.* = remaining - n;
|
||||
return n;
|
||||
}
|
||||
|
||||
fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
|
||||
const reader: *Reader = @alignCast(@ptrCast(ctx));
|
||||
fn contentLengthDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize {
|
||||
const reader: *Reader = @fieldParentPtr("interface", io_r);
|
||||
const remaining_content_length = &reader.state.body_remaining_content_length;
|
||||
const remaining = remaining_content_length.*;
|
||||
if (remaining == 0) {
|
||||
@ -541,18 +546,14 @@ pub const Reader = struct {
|
||||
return n;
|
||||
}
|
||||
|
||||
fn chunkedRead(
|
||||
ctx: ?*anyopaque,
|
||||
bw: *Writer,
|
||||
limit: std.io.Limit,
|
||||
) std.io.Reader.StreamError!usize {
|
||||
const reader: *Reader = @alignCast(@ptrCast(ctx));
|
||||
fn chunkedStream(io_r: *std.io.Reader, w: *Writer, limit: std.io.Limit) std.io.Reader.StreamError!usize {
|
||||
const reader: *Reader = @fieldParentPtr("interface", io_r);
|
||||
const chunk_len_ptr = switch (reader.state) {
|
||||
.ready => return error.EndOfStream,
|
||||
.body_remaining_chunk_len => |*x| x,
|
||||
else => unreachable,
|
||||
};
|
||||
return chunkedReadEndless(reader, bw, limit, chunk_len_ptr) catch |err| switch (err) {
|
||||
return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) {
|
||||
error.ReadFailed => return error.ReadFailed,
|
||||
error.WriteFailed => return error.WriteFailed,
|
||||
error.EndOfStream => {
|
||||
@ -568,7 +569,7 @@ pub const Reader = struct {
|
||||
|
||||
fn chunkedReadEndless(
|
||||
reader: *Reader,
|
||||
bw: *Writer,
|
||||
w: *Writer,
|
||||
limit: std.io.Limit,
|
||||
chunk_len_ptr: *RemainingChunkLen,
|
||||
) (BodyError || std.io.Reader.StreamError)!usize {
|
||||
@ -592,7 +593,7 @@ pub const Reader = struct {
|
||||
}
|
||||
}
|
||||
if (cp.chunk_len == 0) return parseTrailers(reader, 0);
|
||||
const n = try in.read(bw, limit.min(.limited(cp.chunk_len)));
|
||||
const n = try in.stream(w, limit.min(.limited(cp.chunk_len)));
|
||||
chunk_len_ptr.* = .init(cp.chunk_len + 2 - n);
|
||||
return n;
|
||||
},
|
||||
@ -608,15 +609,15 @@ pub const Reader = struct {
|
||||
continue :len .head;
|
||||
},
|
||||
else => |remaining_chunk_len| {
|
||||
const n = try in.read(bw, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2)));
|
||||
const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2)));
|
||||
chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n);
|
||||
return n;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
|
||||
const reader: *Reader = @alignCast(@ptrCast(ctx));
|
||||
fn chunkedDiscard(io_r: *std.io.Reader, limit: std.io.Limit) std.io.Reader.Error!usize {
|
||||
const reader: *Reader = @fieldParentPtr("interface", io_r);
|
||||
const chunk_len_ptr = switch (reader.state) {
|
||||
.ready => return error.EndOfStream,
|
||||
.body_remaining_chunk_len => |*x| x,
|
||||
@ -758,7 +759,6 @@ pub const BodyWriter = struct {
|
||||
/// state of this other than via methods of `BodyWriter`.
|
||||
http_protocol_output: *Writer,
|
||||
state: State,
|
||||
elide: bool,
|
||||
interface: Writer,
|
||||
|
||||
pub const Error = Writer.Error;
|
||||
@ -796,6 +796,10 @@ pub const BodyWriter = struct {
|
||||
};
|
||||
};
|
||||
|
||||
pub fn isEliding(w: *const BodyWriter) bool {
|
||||
return w.interface.vtable.drain == Writer.discardingDrain;
|
||||
}
|
||||
|
||||
/// Sends all buffered data across `BodyWriter.http_protocol_output`.
|
||||
pub fn flush(w: *BodyWriter) Error!void {
|
||||
const out = w.http_protocol_output;
|
||||
@ -825,7 +829,7 @@ pub const BodyWriter = struct {
|
||||
/// with empty trailers, then flushes the stream to the system. Asserts any
|
||||
/// started chunk has been completely finished.
|
||||
///
|
||||
/// Respects the value of `elide` to omit all data after the headers.
|
||||
/// Respects the value of `isEliding` to omit all data after the headers.
|
||||
///
|
||||
/// See also:
|
||||
/// * `endUnflushed`
|
||||
@ -841,7 +845,7 @@ pub const BodyWriter = struct {
|
||||
/// Otherwise, transfer-encoding: chunked is being used, and it writes the
|
||||
/// end-of-stream message with empty trailers.
|
||||
///
|
||||
/// Respects the value of `elide` to omit all data after the headers.
|
||||
/// Respects the value of `isEliding` to omit all data after the headers.
|
||||
///
|
||||
/// See also:
|
||||
/// * `end`
|
||||
@ -867,7 +871,7 @@ pub const BodyWriter = struct {
|
||||
///
|
||||
/// Asserts that the BodyWriter is using transfer-encoding: chunked.
|
||||
///
|
||||
/// Respects the value of `elide` to omit all data after the headers.
|
||||
/// Respects the value of `isEliding` to omit all data after the headers.
|
||||
///
|
||||
/// See also:
|
||||
/// * `endChunkedUnflushed`
|
||||
@ -883,7 +887,7 @@ pub const BodyWriter = struct {
|
||||
///
|
||||
/// Asserts that the BodyWriter is using transfer-encoding: chunked.
|
||||
///
|
||||
/// Respects the value of `elide` to omit all data after the headers.
|
||||
/// Respects the value of `isEliding` to omit all data after the headers.
|
||||
///
|
||||
/// See also:
|
||||
/// * `endChunked`
|
||||
@ -891,7 +895,7 @@ pub const BodyWriter = struct {
|
||||
/// * `end`
|
||||
pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void {
|
||||
const chunked = &w.state.chunked;
|
||||
if (w.elide) {
|
||||
if (w.isEliding()) {
|
||||
w.state = .end;
|
||||
return;
|
||||
}
|
||||
@ -922,7 +926,7 @@ pub const BodyWriter = struct {
|
||||
|
||||
fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
const out = w.http_protocol_output;
|
||||
const n = try w.drainTo(out, data, splat);
|
||||
w.state.content_length -= n;
|
||||
@ -931,7 +935,7 @@ pub const BodyWriter = struct {
|
||||
|
||||
fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
const out = w.http_protocol_output;
|
||||
return try w.drainTo(out, data, splat);
|
||||
}
|
||||
@ -939,13 +943,13 @@ pub const BodyWriter = struct {
|
||||
/// Returns `null` if size cannot be computed without making any syscalls.
|
||||
fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
return w.sendFileTo(bw.http_protocol_output, file_reader, limit);
|
||||
}
|
||||
|
||||
fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
const n = try w.sendFileTo(bw.http_protocol_output, file_reader, limit);
|
||||
bw.state.content_length -= n;
|
||||
return n;
|
||||
@ -953,7 +957,7 @@ pub const BodyWriter = struct {
|
||||
|
||||
fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.io.Limit) Writer.FileError!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
const data_len = w.countSendFileUpperBound(file_reader, limit) orelse {
|
||||
// If the file size is unknown, we cannot lower to a `writeFile` since we would
|
||||
// have to flush the chunk header before knowing the chunk length.
|
||||
@ -1001,7 +1005,7 @@ pub const BodyWriter = struct {
|
||||
|
||||
fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
|
||||
const bw: *BodyWriter = @fieldParentPtr("interface", w);
|
||||
assert(!bw.elide);
|
||||
assert(!bw.isEliding());
|
||||
const out = w.http_protocol_output;
|
||||
const data_len = Writer.countSplat(w.end, data, splat);
|
||||
const chunked = &bw.state.chunked;
|
||||
@ -1059,27 +1063,6 @@ pub const BodyWriter = struct {
|
||||
a /= base;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writer(w: *BodyWriter) Writer {
|
||||
return if (w.elide) .discarding else .{
|
||||
.context = w,
|
||||
.vtable = switch (w.state) {
|
||||
.none => &.{
|
||||
.drain = noneDrain,
|
||||
.sendFile = noneSendFile,
|
||||
},
|
||||
.content_length => &.{
|
||||
.drain = contentLengthDrain,
|
||||
.sendFile = contentLengthSendFile,
|
||||
},
|
||||
.chunked => &.{
|
||||
.drain = chunkedDrain,
|
||||
.sendFile = chunkedSendFile,
|
||||
},
|
||||
.end => unreachable,
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
test {
|
||||
|
||||
@ -14,6 +14,7 @@ const Uri = std.Uri;
|
||||
const Allocator = mem.Allocator;
|
||||
const assert = std.debug.assert;
|
||||
const Writer = std.io.Writer;
|
||||
const Reader = std.io.Reader;
|
||||
|
||||
const Client = @This();
|
||||
|
||||
@ -228,12 +229,6 @@ pub const Connection = struct {
|
||||
client: *Client,
|
||||
stream_writer: net.Stream.Writer,
|
||||
stream_reader: net.Stream.Reader,
|
||||
/// HTTP protocol from client to server.
|
||||
/// This either goes directly to `stream_writer`, or to a TLS client.
|
||||
writer: Writer,
|
||||
/// HTTP protocol from server to client.
|
||||
/// This either comes directly from `stream_reader`, or from a TLS client.
|
||||
reader: std.io.Reader,
|
||||
/// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
|
||||
pool_node: std.DoublyLinkedList.Node,
|
||||
port: u16,
|
||||
@ -264,10 +259,8 @@ pub const Connection = struct {
|
||||
plain.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream_writer = stream.writer(),
|
||||
.stream_reader = stream.reader(),
|
||||
.writer = plain.connection.stream_writer.interface().buffered(socket_write_buffer),
|
||||
.reader = plain.connection.stream_reader.interface().buffered(socket_read_buffer),
|
||||
.stream_writer = stream.writer(socket_write_buffer),
|
||||
.stream_reader = stream.reader(socket_read_buffer),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.host_len = @intCast(remote_host.len),
|
||||
@ -297,10 +290,6 @@ pub const Connection = struct {
|
||||
};
|
||||
|
||||
const Tls = struct {
|
||||
/// Data from `client` to `Connection.stream`.
|
||||
writer: Writer,
|
||||
/// Data from `Connection.stream` to `client`.
|
||||
reader: std.io.Reader,
|
||||
client: std.crypto.tls.Client,
|
||||
connection: Connection,
|
||||
|
||||
@ -324,10 +313,8 @@ pub const Connection = struct {
|
||||
tls.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream_writer = stream.writer(),
|
||||
.stream_reader = stream.reader(),
|
||||
.writer = tls.client.writer().buffered(socket_write_buffer),
|
||||
.reader = tls.client.reader().unbuffered(),
|
||||
.stream_writer = stream.writer(socket_write_buffer),
|
||||
.stream_reader = stream.reader(&.{}),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.host_len = @intCast(remote_host.len),
|
||||
@ -335,20 +322,22 @@ pub const Connection = struct {
|
||||
.closing = false,
|
||||
.protocol = .tls,
|
||||
},
|
||||
.writer = tls.connection.stream_writer.interface().buffered(tls_write_buffer),
|
||||
.reader = tls.connection.stream_reader.interface().buffered(tls_read_buffer),
|
||||
.client = undefined,
|
||||
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
|
||||
.client = std.crypto.tls.Client.init(
|
||||
tls.connection.stream_reader.interface(),
|
||||
&tls.connection.stream_writer.interface,
|
||||
.{
|
||||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_log = client.ssl_key_log,
|
||||
.read_buffer = tls_read_buffer,
|
||||
.write_buffer = tls_write_buffer,
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
.allow_truncation_attacks = true,
|
||||
},
|
||||
) catch return error.TlsInitializationFailed,
|
||||
};
|
||||
// TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
|
||||
tls.client.init(&tls.reader, &tls.writer, .{
|
||||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_log = client.ssl_key_log,
|
||||
}) catch return error.TlsInitializationFailed;
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
tls.client.allow_truncation_attacks = true;
|
||||
|
||||
return tls;
|
||||
}
|
||||
|
||||
@ -404,26 +393,52 @@ pub const Connection = struct {
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP protocol from client to server.
|
||||
/// This either goes directly to `stream_writer`, or to a TLS client.
|
||||
pub fn writer(c: *Connection) *Writer {
|
||||
return switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
return &tls.client.writer;
|
||||
},
|
||||
.plain => &c.stream_writer.interface,
|
||||
};
|
||||
}
|
||||
|
||||
/// HTTP protocol from server to client.
|
||||
/// This either comes directly from `stream_reader`, or from a TLS client.
|
||||
pub fn reader(c: *const Connection) *Reader {
|
||||
return switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
return &tls.client.reader;
|
||||
},
|
||||
.plain => c.stream_reader.interface(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn flush(c: *Connection) Writer.Error!void {
|
||||
try c.writer.flush();
|
||||
if (c.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
try tls.writer.flush();
|
||||
try tls.client.writer.flush();
|
||||
}
|
||||
try c.stream_writer.interface.flush();
|
||||
}
|
||||
|
||||
/// If the connection is a TLS connection, sends the close_notify alert.
|
||||
///
|
||||
/// Flushes all buffers.
|
||||
pub fn end(c: *Connection) Writer.Error!void {
|
||||
try c.writer.flush();
|
||||
if (c.protocol == .tls) {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @fieldParentPtr("connection", c);
|
||||
try tls.client.end();
|
||||
try tls.writer.flush();
|
||||
try tls.client.writer.flush();
|
||||
}
|
||||
try c.stream_writer.interface.flush();
|
||||
}
|
||||
};
|
||||
|
||||
@ -660,14 +675,14 @@ pub const Response = struct {
|
||||
|
||||
/// If compressed body has been negotiated this will return compressed bytes.
|
||||
///
|
||||
/// If the returned `std.io.Reader` returns `error.ReadFailed` the error is
|
||||
/// If the returned `Reader` returns `error.ReadFailed` the error is
|
||||
/// available via `bodyErr`.
|
||||
///
|
||||
/// Asserts that this function is only called once.
|
||||
///
|
||||
/// See also:
|
||||
/// * `readerDecompressing`
|
||||
pub fn reader(response: *Response, buffer: []u8) std.io.Reader {
|
||||
pub fn reader(response: *Response, buffer: []u8) Reader {
|
||||
const req = response.request;
|
||||
if (!req.method.responseHasBody()) return .ending;
|
||||
const head = &response.head;
|
||||
@ -676,7 +691,7 @@ pub const Response = struct {
|
||||
|
||||
/// If compressed body has been negotiated this will return decompressed bytes.
|
||||
///
|
||||
/// If the returned `std.io.Reader` returns `error.ReadFailed` the error is
|
||||
/// If the returned `Reader` returns `error.ReadFailed` the error is
|
||||
/// available via `bodyErr`.
|
||||
///
|
||||
/// Asserts that this function is only called once.
|
||||
@ -687,7 +702,7 @@ pub const Response = struct {
|
||||
response: *Response,
|
||||
decompressor: *http.Decompressor,
|
||||
decompression_buffer: []u8,
|
||||
) std.io.Reader {
|
||||
) Reader {
|
||||
const head = &response.head;
|
||||
return response.request.reader.bodyReaderDecompressing(
|
||||
head.transfer_encoding,
|
||||
@ -698,7 +713,7 @@ pub const Response = struct {
|
||||
);
|
||||
}
|
||||
|
||||
/// After receiving `error.ReadFailed` from the `std.io.Reader` returned by
|
||||
/// After receiving `error.ReadFailed` from the `Reader` returned by
|
||||
/// `reader` or `readerDecompressing`, this function accesses the
|
||||
/// more specific error code.
|
||||
pub fn bodyErr(response: *const Response) ?http.Reader.BodyError {
|
||||
@ -835,8 +850,8 @@ pub const Request = struct {
|
||||
///
|
||||
/// See also:
|
||||
/// * `sendBodyUnflushed`
|
||||
pub fn sendBody(r: *Request) Writer.Error!http.BodyWriter {
|
||||
const result = try sendBodyUnflushed(r);
|
||||
pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
|
||||
const result = try sendBodyUnflushed(r, buffer);
|
||||
try r.connection.?.flush();
|
||||
return result;
|
||||
}
|
||||
@ -846,17 +861,44 @@ pub const Request = struct {
|
||||
///
|
||||
/// See also:
|
||||
/// * `sendBody`
|
||||
pub fn sendBodyUnflushed(r: *Request) Writer.Error!http.BodyWriter {
|
||||
pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
|
||||
assert(r.method.requestHasBody());
|
||||
try sendHead(r);
|
||||
return .{
|
||||
.http_protocol_output = &r.connection.?.writer,
|
||||
.state = switch (r.transfer_encoding) {
|
||||
.chunked => .{ .chunked = .init },
|
||||
.content_length => |len| .{ .content_length = len },
|
||||
.none => .none,
|
||||
const http_protocol_output = &r.connection.?.writer;
|
||||
return switch (r.transfer_encoding) {
|
||||
.chunked => .{
|
||||
.http_protocol_output = http_protocol_output,
|
||||
.state = .{ .chunked = .init },
|
||||
.interface = .{
|
||||
.buffer = buffer,
|
||||
.interface = &.{
|
||||
.drain = http.BodyWriter.chunkedDrain,
|
||||
.sendFile = http.BodyWriter.chunkedSendFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
.content_length => |len| .{
|
||||
.http_protocol_output = http_protocol_output,
|
||||
.state = .{ .content_length = len },
|
||||
.interface = .{
|
||||
.buffer = buffer,
|
||||
.interface = &.{
|
||||
.drain = http.BodyWriter.contentLengthDrain,
|
||||
.sendFile = http.BodyWriter.contentLengthSendFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
.none => .{
|
||||
.http_protocol_output = http_protocol_output,
|
||||
.state = .none,
|
||||
.interface = .{
|
||||
.buffer = buffer,
|
||||
.interface = &.{
|
||||
.drain = http.BodyWriter.noneDrain,
|
||||
.sendFile = http.BodyWriter.noneSendFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
.elide = false,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -381,6 +381,8 @@ pub const Request = struct {
|
||||
content_length: ?u64 = null,
|
||||
/// Options that are shared with the `respond` method.
|
||||
respond_options: RespondOptions = .{},
|
||||
/// Used by `http.BodyWriter`.
|
||||
buffer: []u8,
|
||||
};
|
||||
|
||||
/// The header is not guaranteed to be sent until `BodyWriter.flush` or
|
||||
@ -436,16 +438,38 @@ pub const Request = struct {
|
||||
|
||||
try out.writeAll("\r\n");
|
||||
const elide_body = request.head.method == .HEAD;
|
||||
const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) {
|
||||
.chunked => .{ .chunked = .init },
|
||||
.none => .none,
|
||||
} else if (options.content_length) |len| .{
|
||||
.content_length = len,
|
||||
} else .{ .chunked = .init };
|
||||
|
||||
return .{
|
||||
return if (elide_body) .{
|
||||
.http_protocol_output = request.server.out,
|
||||
.state = if (o.transfer_encoding) |te| switch (te) {
|
||||
.chunked => .{ .chunked = .init },
|
||||
.none => .none,
|
||||
} else if (options.content_length) |len| .{
|
||||
.content_length = len,
|
||||
} else .{ .chunked = .init },
|
||||
.elide = elide_body,
|
||||
.state = state,
|
||||
.interface = .discarding(options.buffer),
|
||||
} else .{
|
||||
.http_protocol_output = request.server.out,
|
||||
.state = state,
|
||||
.interface = .{
|
||||
.buffer = options.buffer,
|
||||
.vtable = switch (state) {
|
||||
.none => &.{
|
||||
.drain = http.BodyWriter.noneDrain,
|
||||
.sendFile = http.BodyWriter.noneSendFile,
|
||||
},
|
||||
.content_length => &.{
|
||||
.drain = http.BodyWriter.contentLengthDrain,
|
||||
.sendFile = http.BodyWriter.contentLengthSendFile,
|
||||
},
|
||||
.chunked => &.{
|
||||
.drain = http.BodyWriter.chunkedDrain,
|
||||
.sendFile = http.BodyWriter.chunkedSendFile,
|
||||
},
|
||||
.end => unreachable,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ const Limit = std.io.Limit;
|
||||
|
||||
pub const Limited = @import("Reader/Limited.zig");
|
||||
|
||||
context: ?*anyopaque = null,
|
||||
vtable: *const VTable,
|
||||
buffer: []u8,
|
||||
/// Number of bytes which have been consumed from `buffer`.
|
||||
@ -88,7 +87,6 @@ pub const ShortError = error{
|
||||
};
|
||||
|
||||
pub const failing: Reader = .{
|
||||
.context = undefined,
|
||||
.vtable = &.{
|
||||
.read = failingStream,
|
||||
.discard = failingDiscard,
|
||||
@ -107,7 +105,6 @@ pub fn limited(r: *Reader, limit: Limit, buffer: []u8) Limited {
|
||||
/// Constructs a `Reader` such that it will read from `buffer` and then end.
|
||||
pub fn fixed(buffer: []const u8) Reader {
|
||||
return .{
|
||||
.context = undefined,
|
||||
.vtable = &.{
|
||||
.stream = endingStream,
|
||||
.discard = endingDiscard,
|
||||
@ -1402,12 +1399,6 @@ test "readAlloc when the backing reader provides one byte at a time" {
|
||||
self.curr += 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
fn reader(self: *@This()) std.io.Reader {
|
||||
return .{
|
||||
.context = self,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const str = "This is a test";
|
||||
|
||||
@ -9,12 +9,6 @@ const File = std.fs.File;
|
||||
const testing = std.testing;
|
||||
const Allocator = std.mem.Allocator;
|
||||
|
||||
/// There are two strategies for obtaining context; one can use this field, or
|
||||
/// embed the `Writer` and use `@fieldParentPtr`. This field must be either set
|
||||
/// to a valid pointer or left as `null` because the interface will sometimes
|
||||
/// check if this pointer value is a known special value, for example to make
|
||||
/// `writableVector` work.
|
||||
context: ?*anyopaque = null,
|
||||
vtable: *const VTable,
|
||||
/// If this has length zero, the writer is unbuffered, and `flush` is a no-op.
|
||||
buffer: []u8,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user