From 0eef21d8ec290564ab503e5ad25f4c0c86f04d45 Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 5 Oct 2023 12:19:06 -0500 Subject: [PATCH] std.http.Client: add option to disable https std_options.http_connection_pool_size removed in favor of ``` client.connection_pool.resize(client.allocator, size); ``` std_options.http_disable_tls will remove all https capability from std.http when true. Any https request will error with `error.TlsInitializationFailed`. Solves #17051. --- lib/std/http/Client.zig | 102 ++++++++++++++++++++++++++++----------- lib/std/std.zig | 11 +++-- test/standalone/http.zig | 4 ++ 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 55ae62a183..4107cfdcc8 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -12,8 +12,7 @@ const assert = std.debug.assert; const Client = @This(); const proto = @import("protocol.zig"); -pub const default_connection_pool_size = 32; -pub const connection_pool_size = std.options.http_connection_pool_size; +pub const disable_tls = std.options.http_disable_tls; allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, @@ -50,7 +49,7 @@ pub const ConnectionPool = struct { /// Open connections that are not currently in use. free: Queue = .{}, free_len: usize = 0, - free_size: usize = connection_pool_size, + free_size: usize = 32, /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. @@ -127,23 +126,43 @@ pub const ConnectionPool = struct { pool.used.append(node); } - pub fn deinit(pool: *ConnectionPool, client: *Client) void { + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.close(client.allocator); + node.data.close(allocator); } next = pool.used.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.close(client.allocator); + node.data.close(allocator); } pool.* = undefined; @@ -159,7 +178,7 @@ pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, + tls_client: if (!disable_tls) *std.crypto.tls.Client else void, protocol: Protocol, host: []u8, @@ -174,11 +193,8 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readv(buffers), - .tls => conn.tls_client.readv(conn.stream, buffers), - } catch |err| { + pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + return conn.tls_client.readv(conn.stream, buffers) catch |err| { // TODO: https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; @@ -191,6 +207,20 @@ pub const Connection = struct { }; } + pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; @@ -257,11 +287,21 @@ pub const Connection = struct { return Reader{ .context = conn }; } + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + } + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - .tls => conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -303,6 +343,8 @@ pub const Connection = struct { pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { + if (disable_tls) unreachable; + // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; allocator.destroy(conn.tls_client); @@ -932,7 +974,7 @@ pub const ProxyInformation = struct { /// Release all associated resources with the client. /// TODO: currently leaks all request allocated data pub fn deinit(client: *Client) void { - client.connection_pool.deinit(client); + client.connection_pool.deinit(client.allocator); if (client.http_proxy) |*proxy| { proxy.allocator.free(proxy.host); @@ -1046,6 +1088,9 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec })) |node| return node; + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; @@ -1073,17 +1118,16 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer client.allocator.free(conn.data.host); - switch (protocol) { - .plain => {}, - .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.data.tls_client); + if (protocol == .tls) { + if (disable_tls) unreachable; - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; - }, + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(conn.data.tls_client); + + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + conn.data.tls_client.allow_truncation_attacks = true; } client.connection_pool.addUsed(conn); diff --git a/lib/std/std.zig b/lib/std/std.zig index 16222e52da..5829a241c4 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -283,10 +283,15 @@ pub const options = struct { else false; - pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size")) - options_override.http_connection_pool_size + /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to + /// disable TLS support. + /// + /// This will likely reduce the size of the binary, but it will also make it impossible to + /// make a HTTPS connection. + pub const http_disable_tls = if (@hasDecl(options_override, "http_disable_tls")) + options_override.http_disable_tls else - http.Client.default_connection_pool_size; + false; pub const side_channels_mitigations: crypto.SideChannelsMitigations = if (@hasDecl(options_override, "side_channels_mitigations")) options_override.side_channels_mitigations diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 71f3481767..55a8456fde 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -7,6 +7,10 @@ const Client = http.Client; const mem = std.mem; const testing = std.testing; +pub const std_options = struct { + pub const http_disable_tls = true; +}; + const max_header_size = 8192; var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){};