From 5f219a2d118cac1410888fb2c0abc0cc91d092de Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 1 May 2023 17:49:34 -0500 Subject: [PATCH] std.http.Server: give Response access to their own allocator * This makes it easier for threaded servers to use a different allocator for each request. --- lib/std/http/Server.zig | 33 +++++++++++++++++++-------------- test/standalone/http.zig | 11 ++++++++--- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 46c689ba6d..d473366092 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -352,7 +352,7 @@ pub const Response = struct { transfer_encoding: ResponseTransfer = .none, - server: *Server, + allocator: Allocator, address: net.Address, connection: BufferedConnection, @@ -376,7 +376,7 @@ pub const Response = struct { res.request.headers.deinit(); if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.server.allocator); + res.request.parser.header_bytes.deinit(res.allocator); } } @@ -545,13 +545,13 @@ pub const Response = struct { while (true) { try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); if (res.request.parser.state.isContent()) break; } - res.request.headers = .{ .allocator = res.server.allocator, .owned = true }; + res.request.headers = .{ .allocator = res.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); if (res.request.transfer_encoding) |te| { @@ -573,13 +573,13 @@ pub const Response = struct { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .deflate = std.compress.zlib.zlibStream(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .gzip => res.request.compression = .{ - .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .gzip = std.compress.gzip.decompress(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ - .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), + .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), }, }; } @@ -612,12 +612,12 @@ pub const Response = struct { while (!res.request.parser.state.isContent()) { // read trailing headers try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); } if (has_trail) { - res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false }; + res.request.headers = http.Headers{ .allocator = res.allocator, .owned = false }; // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. // This will *only* fail for a malformed trailer. @@ -731,24 +731,29 @@ pub const HeaderStrategy = union(enum) { static: []u8, }; +pub const AcceptOptions = struct { + allocator: Allocator, + header_strategy: HeaderStrategy = .{ .dynamic = 8192 }, +}; + /// Accept a new connection. -pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!Response { +pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { const in = try server.socket.accept(); return Response{ - .server = server, + .allocator = options.allocator, .address = in.address, .connection = .{ .conn = .{ .stream = in.stream, .protocol = .plain, } }, - .headers = .{ .allocator = server.allocator }, + .headers = .{ .allocator = options.allocator }, .request = .{ .version = undefined, .method = undefined, .target = undefined, - .headers = .{ .allocator = server.allocator, .owned = false }, - .parser = switch (options) { + .headers = .{ .allocator = options.allocator, .owned = false }, + .parser = switch (options.header_strategy) { .dynamic => |max| proto.HeadersParser.initDynamic(max), .static => |buf| proto.HeadersParser.initStatic(buf), }, diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 493931a2ea..180e863bba 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -15,6 +15,8 @@ var gpa_client = std.heap.GeneralPurposeAllocator(.{}){}; const salloc = gpa_server.allocator(); const calloc = gpa_client.allocator(); +var server: Server = undefined; + fn handleRequest(res: *Server.Response) !void { const log = std.log.scoped(.server); @@ -89,7 +91,7 @@ fn handleRequest(res: *Server.Response) !void { } else if (mem.eql(u8, res.request.target, "/redirect/3")) { res.transfer_encoding = .chunked; - const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{res.server.socket.listen_address.getPort()}); + const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()}); defer salloc.free(location); res.status = .found; @@ -119,7 +121,10 @@ var handle_new_requests = true; fn runServer(srv: *Server) !void { outer: while (handle_new_requests) { - var res = try srv.accept(.{ .dynamic = max_header_size }); + var res = try srv.accept(.{ + .allocator = salloc, + .header_strategy = .{ .dynamic = max_header_size }, + }); defer res.deinit(); while (res.reset() != .closing) { @@ -162,7 +167,7 @@ pub fn main() !void { defer _ = gpa_client.deinit(); - var server = Server.init(salloc, .{ .reuse_address = true }); + server = Server.init(salloc, .{ .reuse_address = true }); const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; try server.listen(addr);