mirror of
https://github.com/ziglang/zig.git
synced 2026-02-14 21:38:33 +00:00
TLS, HTTP, and package fetching fixes
* TLS: add missing assert for output buffer length requirement * TLS: add missing flushes * TLS: add flush implementation * TLS: finish drain implementation * HTTP: correct buffer sizes for TLS * HTTP: expose a getReadError method on Connection * HTTP: add missing flush on sendBodyComplete * Fetch: remove unwanted deinit * Fetch: improve error reporting
This commit is contained in:
parent
ac18b98aa3
commit
aac26f3b31
@ -8,8 +8,8 @@ const mem = std.mem;
|
||||
const crypto = std.crypto;
|
||||
const assert = std.debug.assert;
|
||||
const Certificate = std.crypto.Certificate;
|
||||
const Reader = std.io.Reader;
|
||||
const Writer = std.io.Writer;
|
||||
const Reader = std.Io.Reader;
|
||||
const Writer = std.Io.Writer;
|
||||
|
||||
const max_ciphertext_len = tls.max_ciphertext_len;
|
||||
const hmacExpandLabel = tls.hmacExpandLabel;
|
||||
@ -27,6 +27,8 @@ reader: Reader,
|
||||
|
||||
/// The encrypted stream from the client to the server. Bytes are pushed here
|
||||
/// via `writer`.
|
||||
///
|
||||
/// The buffer is asserted to have capacity at least `min_buffer_len`.
|
||||
output: *Writer,
|
||||
/// The plaintext stream from the client to the server.
|
||||
writer: Writer,
|
||||
@ -122,7 +124,6 @@ pub const Options = struct {
|
||||
/// the amount of data expected, such as HTTP with the Content-Length header.
|
||||
allow_truncation_attacks: bool = false,
|
||||
write_buffer: []u8,
|
||||
/// Asserted to have capacity at least `min_buffer_len`.
|
||||
read_buffer: []u8,
|
||||
/// Populated when `error.TlsAlert` is returned from `init`.
|
||||
alert: ?*tls.Alert = null,
|
||||
@ -185,6 +186,7 @@ const InitError = error{
|
||||
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
|
||||
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
|
||||
assert(input.buffer.len >= min_buffer_len);
|
||||
assert(output.buffer.len >= min_buffer_len);
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
@ -278,6 +280,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
|
||||
{
|
||||
var iovecs: [2][]const u8 = .{ cleartext_header, host };
|
||||
try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]);
|
||||
try output.flush();
|
||||
}
|
||||
|
||||
var tls_version: tls.ProtocolVersion = undefined;
|
||||
@ -763,6 +766,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
|
||||
&client_verify_msg,
|
||||
};
|
||||
try output.writeVecAll(&all_msgs_vec);
|
||||
try output.flush();
|
||||
},
|
||||
}
|
||||
write_seq += 1;
|
||||
@ -828,6 +832,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
|
||||
&finished_msg,
|
||||
};
|
||||
try output.writeVecAll(&all_msgs_vec);
|
||||
try output.flush();
|
||||
|
||||
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
|
||||
@ -877,7 +882,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
|
||||
.buffer = options.write_buffer,
|
||||
.vtable = &.{
|
||||
.drain = drain,
|
||||
.sendFile = Writer.unimplementedSendFile,
|
||||
.flush = flush,
|
||||
},
|
||||
},
|
||||
.tls_version = tls_version,
|
||||
@ -911,31 +916,56 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
|
||||
|
||||
fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
|
||||
const c: *Client = @alignCast(@fieldParentPtr("writer", w));
|
||||
if (true) @panic("update to use the buffer and flush");
|
||||
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
var total_clear: usize = 0;
|
||||
var ciphertext_end: usize = 0;
|
||||
for (sliced_data) |buf| {
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
|
||||
total_clear += prepared.cleartext_len;
|
||||
ciphertext_end += prepared.ciphertext_end;
|
||||
if (total_clear < buf.len) break;
|
||||
var total_clear: usize = 0;
|
||||
done: {
|
||||
{
|
||||
const buf = w.buffered();
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
|
||||
total_clear += prepared.cleartext_len;
|
||||
ciphertext_end += prepared.ciphertext_end;
|
||||
if (prepared.cleartext_len < buf.len) break :done;
|
||||
}
|
||||
for (data[0 .. data.len - 1]) |buf| {
|
||||
if (buf.len < min_buffer_len) break :done;
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
|
||||
total_clear += prepared.cleartext_len;
|
||||
ciphertext_end += prepared.ciphertext_end;
|
||||
if (prepared.cleartext_len < buf.len) break :done;
|
||||
}
|
||||
const buf = data[data.len - 1];
|
||||
for (0..splat) |_| {
|
||||
if (buf.len < min_buffer_len) break :done;
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
|
||||
total_clear += prepared.cleartext_len;
|
||||
ciphertext_end += prepared.ciphertext_end;
|
||||
if (prepared.cleartext_len < buf.len) break :done;
|
||||
}
|
||||
}
|
||||
output.advance(ciphertext_end);
|
||||
return total_clear;
|
||||
return w.consume(total_clear);
|
||||
}
|
||||
|
||||
fn flush(w: *Writer) Writer.Error!void {
|
||||
const c: *Client = @alignCast(@fieldParentPtr("writer", w));
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data);
|
||||
output.advance(prepared.ciphertext_end);
|
||||
w.end = 0;
|
||||
}
|
||||
|
||||
/// Sends a `close_notify` alert, which is necessary for the server to
|
||||
/// distinguish between a properly finished TLS session, or a truncation
|
||||
/// attack.
|
||||
pub fn end(c: *Client) Writer.Error!void {
|
||||
try flush(&c.writer);
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
|
||||
output.advance(prepared.cleartext_len);
|
||||
return prepared.ciphertext_end;
|
||||
output.advance(prepared.ciphertext_end);
|
||||
}
|
||||
|
||||
fn prepareCiphertextRecord(
|
||||
@ -1045,7 +1075,7 @@ pub fn eof(c: Client) bool {
|
||||
return c.received_close_notify;
|
||||
}
|
||||
|
||||
fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
|
||||
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
|
||||
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
|
||||
if (c.eof()) return error.EndOfStream;
|
||||
const input = c.input;
|
||||
|
||||
@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
|
||||
///
|
||||
/// If the entire HTTP header cannot fit in this amount of bytes,
|
||||
/// `error.HttpHeadersOversize` will be returned from `Request.wait`.
|
||||
read_buffer_size: usize = 4096,
|
||||
read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
|
||||
/// Each `Connection` allocates this amount for the writer buffer.
|
||||
write_buffer_size: usize = 1024,
|
||||
|
||||
@ -304,15 +304,16 @@ pub const Connection = struct {
|
||||
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
|
||||
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
|
||||
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
|
||||
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
|
||||
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
|
||||
const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
|
||||
const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
|
||||
assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
|
||||
@memcpy(host_buffer, remote_host);
|
||||
const tls: *Tls = @ptrCast(base);
|
||||
tls.* = .{
|
||||
.connection = .{
|
||||
.client = client,
|
||||
.stream_writer = stream.writer(socket_write_buffer),
|
||||
.stream_reader = stream.reader(&.{}),
|
||||
.stream_writer = stream.writer(tls_write_buffer),
|
||||
.stream_reader = stream.reader(tls_read_buffer),
|
||||
.pool_node = .{},
|
||||
.port = port,
|
||||
.host_len = @intCast(remote_host.len),
|
||||
@ -328,8 +329,8 @@ pub const Connection = struct {
|
||||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_log = client.ssl_key_log,
|
||||
.read_buffer = tls_read_buffer,
|
||||
.write_buffer = tls_write_buffer,
|
||||
.read_buffer = read_buffer,
|
||||
.write_buffer = write_buffer,
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
.allow_truncation_attacks = true,
|
||||
@ -347,7 +348,8 @@ pub const Connection = struct {
|
||||
}
|
||||
|
||||
fn allocLen(client: *Client, host_len: usize) usize {
|
||||
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size;
|
||||
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
|
||||
client.write_buffer_size + client.read_buffer_size;
|
||||
}
|
||||
|
||||
fn host(tls: *Tls) []u8 {
|
||||
@ -356,6 +358,21 @@ pub const Connection = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError;
|
||||
|
||||
pub fn getReadError(c: *const Connection) ?ReadError {
|
||||
return switch (c.protocol) {
|
||||
.tls => {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c));
|
||||
return tls.client.read_err orelse c.stream_reader.getError();
|
||||
},
|
||||
.plain => {
|
||||
return c.stream_reader.getError();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
fn getStream(c: *Connection) net.Stream {
|
||||
return c.stream_reader.getStream();
|
||||
}
|
||||
@ -434,7 +451,6 @@ pub const Connection = struct {
|
||||
if (disable_tls) unreachable;
|
||||
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
|
||||
try tls.client.end();
|
||||
try tls.client.writer.flush();
|
||||
}
|
||||
try c.stream_writer.interface.flush();
|
||||
}
|
||||
@ -874,6 +890,7 @@ pub const Request = struct {
|
||||
var bw = try sendBodyUnflushed(r, body);
|
||||
bw.writer.end = body.len;
|
||||
try bw.end();
|
||||
try r.connection.?.flush();
|
||||
}
|
||||
|
||||
/// Transfers the HTTP head over the connection, which is not flushed until
|
||||
@ -1063,6 +1080,9 @@ pub const Request = struct {
|
||||
/// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize`
|
||||
/// is returned instead. This buffer may be empty if no redirects are to be
|
||||
/// handled.
|
||||
///
|
||||
/// If this fails with `error.ReadFailed` then the `Connection.getReadError`
|
||||
/// method of `r.connection` can be used to get more detailed information.
|
||||
pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response {
|
||||
var aux_buf = redirect_buffer;
|
||||
while (true) {
|
||||
|
||||
@ -998,15 +998,21 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
|
||||
.buffer = reader_buffer,
|
||||
} };
|
||||
const request = &resource.http_request.request;
|
||||
defer request.deinit();
|
||||
errdefer request.deinit();
|
||||
|
||||
request.sendBodiless() catch |err|
|
||||
return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err}));
|
||||
|
||||
var redirect_buffer: [1024]u8 = undefined;
|
||||
const response = &resource.http_request.response;
|
||||
response.* = request.receiveHead(&redirect_buffer) catch |err|
|
||||
return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err}));
|
||||
response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) {
|
||||
error.ReadFailed => {
|
||||
return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{
|
||||
request.connection.?.getReadError().?,
|
||||
}));
|
||||
},
|
||||
else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})),
|
||||
};
|
||||
|
||||
if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString(
|
||||
"bad HTTP response code: '{d} {s}'",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user