From 71c228fe6572b9f3b30e82035bf8fd7e2b1dd29d Mon Sep 17 00:00:00 2001 From: Nameless Date: Tue, 25 Apr 2023 09:22:20 -0500 Subject: [PATCH 1/7] std.http: add simple standalone http tests, add state check for http server --- lib/std/Uri.zig | 9 +- lib/std/http.zig | 1 - lib/std/http/Client.zig | 16 +- lib/std/http/Headers.zig | 11 - lib/std/http/Server.zig | 135 ++++++++---- lib/std/http/test.zig | 72 ------- test/standalone.zig | 4 + test/standalone/http.zig | 432 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 546 insertions(+), 134 deletions(-) delete mode 100644 lib/std/http/test.zig create mode 100644 test/standalone/http.zig diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index b0bb3047cb..2cf0f7a46b 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -216,6 +216,7 @@ pub fn format( const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null; const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0; + const needs_fragment = comptime std.mem.indexOf(u8, fmt, "#") != null; if (needs_absolute) { try writer.writeAll(uri.scheme); @@ -253,9 +254,11 @@ pub fn format( try Uri.writeEscapedQuery(writer, q); } - if (uri.fragment) |f| { - try writer.writeAll("#"); - try Uri.writeEscapedQuery(writer, f); + if (needs_fragment) { + if (uri.fragment) |f| { + try writer.writeAll("#"); + try Uri.writeEscapedQuery(writer, f); + } } } } diff --git a/lib/std/http.zig b/lib/std/http.zig index 744615d7d7..e9c62705b5 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -275,5 +275,4 @@ test { _ = Client; _ = Method; _ = Status; - _ = @import("http/test.zig"); } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 42ef766bd3..74321492c8 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -264,7 +264,7 @@ pub const BufferedConnection = struct { const nread = try bconn.conn.read(bconn.buf[0..]); if (nread == 0) return error.EndOfStream; bconn.start = 0; - bconn.end = @truncate(u16, nread); + bconn.end = @intCast(u16, nread); } pub fn peek(bconn: *BufferedConnection) []const u8 { @@ -282,7 +282,7 @@ pub const BufferedConnection = struct { const left = buffer.len - out_index; if (available > 0) { - const can_read = @truncate(u16, @min(available, left)); + const can_read = @intCast(u16, @min(available, left)); @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); out_index += can_read; @@ -355,8 +355,6 @@ pub const Compression = union(enum) { /// A HTTP response originating from a server. pub const Response = struct { pub const ParseError = Allocator.Error || error{ - ShortHttpStatusLine, - BadHttpVersion, HttpHeadersInvalid, HttpHeaderContinuationsUnsupported, HttpTransferEncodingUnsupported, @@ -370,12 +368,12 @@ pub const Response = struct { const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 12) - return error.ShortHttpStatusLine; + return error.HttpHeadersInvalid; const version: http.Version = switch (int64(first_line[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, + else => return error.HttpHeadersInvalid, }; if (first_line[8] != ' ') return error.HttpHeadersInvalid; const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); @@ -695,7 +693,6 @@ pub const Request = struct { if (req.method == .CONNECT and req.response.status == .ok) { req.connection.data.closing = false; - req.connection.data.proxied = true; req.response.parser.done = true; } @@ -725,6 +722,11 @@ pub const Request = struct { req.response.parser.done = true; } + // HEAD requests have no body + if (req.method == .HEAD) { + req.response.parser.done = true; + } + if (req.transfer_encoding == .none and req.response.status.class() == .redirect and req.handle_redirects) { req.response.skip = true; diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig index 376fd60b61..429df9368a 100644 --- a/lib/std/http/Headers.zig +++ b/lib/std/http/Headers.zig @@ -36,17 +36,6 @@ pub const Field = struct { name: []const u8, value: []const u8, - pub fn modify(entry: *Field, allocator: Allocator, new_value: []const u8) !void { - if (entry.value.len <= new_value.len) { - // TODO: eliminate this use of `@constCast`. - @memcpy(@constCast(entry.value)[0..new_value.len], new_value); - } else { - allocator.free(entry.value); - - entry.value = try allocator.dupe(u8, new_value); - } - } - fn lessThan(ctx: void, a: Field, b: Field) bool { _ = ctx; if (a.name.ptr == b.name.ptr) return false; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index c7f2a86c27..1b5fd045fa 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -108,7 +108,7 @@ pub const BufferedConnection = struct { const nread = try bconn.conn.read(bconn.buf[0..]); if (nread == 0) return error.EndOfStream; bconn.start = 0; - bconn.end = @truncate(u16, nread); + bconn.end = @intCast(u16, nread); } pub fn peek(bconn: *BufferedConnection) []const u8 { @@ -126,7 +126,7 @@ pub const BufferedConnection = struct { const left = buffer.len - out_index; if (available > 0) { - const can_read = @truncate(u16, @min(available, left)); + const can_read = @intCast(u16, @min(available, left)); @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); out_index += can_read; @@ -199,8 +199,6 @@ pub const Compression = union(enum) { /// A HTTP request originating from a client. pub const Request = struct { pub const ParseError = Allocator.Error || error{ - ShortHttpStatusLine, - BadHttpVersion, UnknownHttpMethod, HttpHeadersInvalid, HttpHeaderContinuationsUnsupported, @@ -215,7 +213,7 @@ pub const Request = struct { const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 10) - return error.ShortHttpStatusLine; + return error.HttpHeadersInvalid; const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; const method_str = first_line[0..method_end]; @@ -229,7 +227,7 @@ pub const Request = struct { const version: http.Version = switch (int64(version_str[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, + else => return error.HttpHeadersInvalid, }; const target = first_line[method_end + 1 .. version_start]; @@ -312,7 +310,7 @@ pub const Request = struct { transfer_encoding: ?http.TransferEncoding = null, transfer_compression: ?http.ContentEncoding = null, - headers: http.Headers = undefined, + headers: http.Headers, parser: proto.HeadersParser, compression: Compression = .none, }; @@ -336,14 +334,52 @@ pub const Response = struct { headers: http.Headers, request: Request, + state: State = .first, + + const State = enum { + first, + start, + waited, + responded, + finished, + }; + pub fn deinit(res: *Response) void { - res.server.allocator.destroy(res); + res.connection.close(); + + res.headers.deinit(); + res.request.headers.deinit(); + + if (res.request.parser.header_bytes_owned) { + res.request.parser.header_bytes.deinit(res.server.allocator); + } } /// Reset this response to its initial state. This must be called before handling a second request on the same connection. - pub fn reset(res: *Response) void { - res.request.headers.deinit(); - res.headers.deinit(); + pub fn reset(res: *Response) bool { + if (res.state == .first) { + res.state = .start; + return true; + } + + if (!res.request.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + res.connection.conn.closing = true; + return false; + } + + // A connection is only keep-alive if the Connection header is present and it's value is not "close". + // The server and client must both agree + const res_connection = res.headers.getFirstValue("connection"); + const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); + + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + if (res_keepalive and req_keepalive) { + res.connection.conn.closing = false; + } else { + res.connection.conn.closing = true; + } switch (res.request.compression) { .none => {}, @@ -352,26 +388,38 @@ pub const Response = struct { .zstd => |*zstd| zstd.deinit(), } - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection.conn.closing = true; - } + res.state = .start; + res.version = .@"HTTP/1.1"; + res.status = .ok; + res.reason = null; - if (res.connection.conn.closing) { - res.connection.close(); + res.transfer_encoding = .none; - if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.server.allocator); - } - } else { - res.request.parser.reset(); - } + res.headers.clearRetainingCapacity(); + + res.request.headers.clearRetainingCapacity(); + res.request.parser.reset(); + + res.request = Request{ + .version = undefined, + .method = undefined, + .target = undefined, + .headers = res.request.headers, + .parser = res.request.parser, + }; + + return !res.connection.conn.closing; } pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; /// Send the response headers. pub fn do(res: *Response) !void { + switch (res.state) { + .waited => res.state = .responded, + .first, .start, .responded, .finished => unreachable, + } + var buffered = std.io.bufferedWriter(res.connection.writer()); const w = buffered.writer(); @@ -452,6 +500,11 @@ pub const Response = struct { /// Wait for the client to send a complete request head. pub fn wait(res: *Response) WaitError!void { + switch (res.state) { + .first, .start => res.state = .waited, + .waited, .responded, .finished => unreachable, + } + while (true) { try res.connection.fill(); @@ -464,17 +517,6 @@ pub const Response = struct { res.request.headers = .{ .allocator = res.server.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); - const res_connection = res.headers.getFirstValue("connection"); - const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (res_keepalive and req_keepalive) { - res.connection.conn.closing = false; - } else { - res.connection.conn.closing = true; - } - if (res.request.transfer_encoding) |te| { switch (te) { .chunked => { @@ -515,6 +557,11 @@ pub const Response = struct { } pub fn read(res: *Response, buffer: []u8) ReadError!usize { + switch (res.state) { + .waited, .responded, .finished => {}, + .first, .start => unreachable, + } + const out_index = switch (res.request.compression) { .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, @@ -564,6 +611,11 @@ pub const Response = struct { /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. pub fn write(res: *Response, bytes: []const u8) WriteError!usize { + switch (res.state) { + .responded => {}, + .first, .waited, .start, .finished => unreachable, + } + switch (res.transfer_encoding) { .chunked => { try res.connection.writer().print("{x}\r\n", .{bytes.len}); @@ -583,7 +635,7 @@ pub const Response = struct { } } - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { index += try write(req, bytes[index..]); @@ -594,6 +646,11 @@ pub const Response = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(res: *Response) FinishError!void { + switch (res.state) { + .responded => res.state = .finished, + .first, .waited, .start, .finished => unreachable, + } + switch (res.transfer_encoding) { .chunked => try res.connection.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, @@ -636,11 +693,10 @@ pub const HeaderStrategy = union(enum) { }; /// Accept a new connection and allocate a Response for it. -pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { +pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!Response { const in = try server.socket.accept(); - const res = try server.allocator.create(Response); - res.* = .{ + return Response{ .server = server, .address = in.address, .connection = .{ .conn = .{ @@ -652,14 +708,13 @@ pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { .version = undefined, .method = undefined, .target = undefined, + .headers = .{ .allocator = server.allocator, .owned = false }, .parser = switch (options) { .dynamic => |max| proto.HeadersParser.initDynamic(max), .static => |buf| proto.HeadersParser.initStatic(buf), }, }, }; - - return res; } test "HTTP server handles a chunked transfer coding request" { diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig deleted file mode 100644 index ee164c297a..0000000000 --- a/lib/std/http/test.zig +++ /dev/null @@ -1,72 +0,0 @@ -const std = @import("std"); -const expect = std.testing.expect; - -test "client requests server" { - const builtin = @import("builtin"); - - // This test requires spawning threads. - if (builtin.single_threaded) { - return error.SkipZigTest; - } - - const native_endian = comptime builtin.cpu.arch.endian(); - if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) { - // https://github.com/ziglang/zig/issues/13782 - return error.SkipZigTest; - } - - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const allocator = std.testing.allocator; - - const max_header_size = 8192; - var server = std.http.Server.init(allocator, .{ .reuse_address = true }); - defer server.deinit(); - - const address = try std.net.Address.parseIp("127.0.0.1", 0); - try server.listen(address); - const server_port = server.socket.listen_address.in.getPort(); - - const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.http.Server) !void { - const res = try s.accept(.{ .dynamic = max_header_size }); - defer res.deinit(); - defer res.reset(); - try res.wait(); - - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - try res.headers.append("content-type", "text/plain"); - try res.headers.append("connection", "close"); - try res.do(); - - var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); - try expect(std.mem.eql(u8, buf[0..n], "Hello, World!\n")); - _ = try res.writer().writeAll(server_body); - try res.finish(); - } - }).apply, .{&server}); - - var uri_buf: [22]u8 = undefined; - const uri = try std.Uri.parse(try std.fmt.bufPrint(&uri_buf, "http://127.0.0.1:{d}", .{server_port})); - var client = std.http.Client{ .allocator = allocator }; - defer client.deinit(); - var client_headers = std.http.Headers{ .allocator = allocator }; - defer client_headers.deinit(); - var client_req = try client.request(.POST, uri, client_headers, .{}); - defer client_req.deinit(); - - client_req.transfer_encoding = .{ .content_length = 14 }; // this will be checked to ensure you sent exactly 14 bytes - try client_req.start(); // this sends the request - try client_req.writeAll("Hello, "); - try client_req.writeAll("World!\n"); - try client_req.finish(); - try client_req.wait(); // this waits for a response - - const body = try client_req.reader().readAllAlloc(allocator, 8192 * 1024); - defer allocator.free(body); - try expect(std.mem.eql(u8, body, "message from server!\n")); - - server_thread.join(); -} diff --git a/test/standalone.zig b/test/standalone.zig index 1f4d7cfded..a055da9761 100644 --- a/test/standalone.zig +++ b/test/standalone.zig @@ -55,6 +55,10 @@ pub const simple_cases = [_]SimpleCase{ .os_filter = .windows, .link_libc = true, }, + .{ + .src_path = "test/standalone/http.zig", + .all_modes = true, + }, // Ensure the development tools are buildable. Alphabetically sorted. // No need to build `tools/spirv/grammar.zig`. diff --git a/test/standalone/http.zig b/test/standalone/http.zig new file mode 100644 index 0000000000..1ed4b1b279 --- /dev/null +++ b/test/standalone/http.zig @@ -0,0 +1,432 @@ +const std = @import("std"); + +const http = std.http; +const Server = http.Server; +const Client = http.Client; + +const mem = std.mem; +const testing = std.testing; + +const max_header_size = 8192; + +var gpa_server = std.heap.GeneralPurposeAllocator(.{}){}; +var gpa_client = std.heap.GeneralPurposeAllocator(.{}){}; + +const salloc = gpa_server.allocator(); +const calloc = gpa_client.allocator(); + +fn handleRequest(res: *Server.Response) !void { + const log = std.log.scoped(.server); + + log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target }); + + const body = try res.reader().readAllAlloc(salloc, 8192); + defer salloc.free(body); + + if (res.request.headers.contains("connection")) { + try res.headers.append("connection", "keep-alive"); + } + + if (mem.startsWith(u8, res.request.target, "/get")) { + if (std.mem.indexOf(u8, res.request.target, "?chunked") != null) { + res.transfer_encoding = .chunked; + } else { + res.transfer_encoding = .{ .content_length = 14 }; + } + + try res.headers.append("content-type", "text/plain"); + + try res.do(); + if (res.request.method != .HEAD) { + try res.writeAll("Hello, "); + try res.writeAll("World!\n"); + try res.finish(); + } + } else if (mem.eql(u8, res.request.target, "/echo-content")) { + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); + + if (res.request.headers.contains("transfer-encoding")) { + try testing.expectEqualStrings("chunked", res.request.headers.getFirstValue("transfer-encoding").?); + res.transfer_encoding = .chunked; + } else { + res.transfer_encoding = .{ .content_length = 14 }; + try testing.expectEqualStrings("14", res.request.headers.getFirstValue("content-length").?); + } + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("World!\n"); + try res.finish(); + } else if (mem.eql(u8, res.request.target, "/trailer")) { + res.transfer_encoding = .chunked; + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("World!\n"); + // try res.finish(); + try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); + } else if (mem.eql(u8, res.request.target, "/redirect/1")) { + res.transfer_encoding = .chunked; + + res.status = .found; + try res.headers.append("location", "../../get"); + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("Redirected!\n"); + try res.finish(); + } else if (mem.eql(u8, res.request.target, "/redirect/2")) { + res.transfer_encoding = .chunked; + + res.status = .found; + try res.headers.append("location", "/redirect/1"); + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("Redirected!\n"); + try res.finish(); + } else if (mem.eql(u8, res.request.target, "/redirect/3")) { + res.transfer_encoding = .chunked; + + const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{res.server.socket.listen_address.getPort()}); + defer salloc.free(location); + + res.status = .found; + try res.headers.append("location", location); + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("Redirected!\n"); + try res.finish(); + } else if (mem.eql(u8, res.request.target, "/redirect/4")) { + res.transfer_encoding = .chunked; + + res.status = .found; + try res.headers.append("location", "/redirect/3"); + + try res.do(); + try res.writeAll("Hello, "); + try res.writeAll("Redirected!\n"); + try res.finish(); + } else { + res.status = .not_found; + try res.do(); + } +} + +var handle_new_requests = true; + +fn runServer(srv: *Server) !void { + outer: while (handle_new_requests) { + var res = try srv.accept(.{ .dynamic = max_header_size }); + defer res.deinit(); + + while (res.reset()) { + res.wait() catch |err| switch (err) { + error.HttpHeadersInvalid => continue :outer, + error.EndOfStream => continue, + else => return err, + }; + + try handleRequest(&res); + } + } +} + +fn serverThread(srv: *Server) void { + defer srv.deinit(); + defer _ = gpa_server.deinit(); + + runServer(srv) catch |err| { + std.debug.print("server error: {}\n", .{err}); + + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + + _ = gpa_server.deinit(); + std.os.exit(1); + }; +} + +fn killServer(addr: std.net.Address) void { + handle_new_requests = false; + + const conn = std.net.tcpConnectToAddress(addr) catch return; + conn.close(); +} + +pub fn main() !void { + const log = std.log.scoped(.client); + + defer _ = gpa_client.deinit(); + + var server = Server.init(salloc, .{ .reuse_address = true }); + + const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; + try server.listen(addr); + + const port = server.socket.listen_address.getPort(); + + const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); + + var client = Client{ .allocator = calloc }; + + defer client.deinit(); + + { // read content-length response + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + } + + { // send head request and not read chunked + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.HEAD, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + try testing.expectEqualStrings("14", req.response.headers.getFirstValue("content-length").?); + } + + { // read chunked response + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get?chunked", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + } + + { // send head request and not read chunked + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get?chunked", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.HEAD, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + try testing.expectEqualStrings("chunked", req.response.headers.getFirstValue("transfer-encoding").?); + } + + { // check trailing headers + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/trailer", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?); + } + + { // send content-length request + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("content-type", "text/plain"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .{ .content_length = 14 }; + + try req.start(); + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // send chunked request + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("content-type", "text/plain"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.start(); + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // relative redirect + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/1", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // redirect from root + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/2", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // absolute redirect + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/3", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // too many redirects + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/4", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + req.wait() catch |err| switch (err) { + error.TooManyHttpRedirects => {}, + else => return err, + }; + } + + killServer(server.socket.listen_address); + server_thread.join(); +} From 6513eb4696551a91e322fe2d9879335cd73c92db Mon Sep 17 00:00:00 2001 From: Nameless Date: Thu, 27 Apr 2023 20:25:46 -0500 Subject: [PATCH 2/7] std.http.Server: use client recommendation for keepalive --- lib/std/http/Server.zig | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 1b5fd045fa..71a7a351ad 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -439,7 +439,14 @@ pub const Response = struct { } if (!res.headers.contains("connection")) { - try w.writeAll("Connection: keep-alive\r\n"); + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + + if (req_keepalive) { + try w.writeAll("Connection: keep-alive\r\n"); + } else { + try w.writeAll("Connection: close\r\n"); + } } const has_transfer_encoding = res.headers.contains("transfer-encoding"); From 533049fdd80a4c7bc3098512b4033a60daea745e Mon Sep 17 00:00:00 2001 From: Nameless Date: Fri, 28 Apr 2023 09:55:23 -0500 Subject: [PATCH 3/7] std.http.Server: use enum for reset state instead of bool --- lib/std/http/Server.zig | 16 +++++++++++----- test/standalone/http.zig | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 71a7a351ad..2e7d985e4c 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -355,17 +355,19 @@ pub const Response = struct { } } + pub const ResetState = enum { reset, closing }; + /// Reset this response to its initial state. This must be called before handling a second request on the same connection. - pub fn reset(res: *Response) bool { + pub fn reset(res: *Response) ResetState { if (res.state == .first) { res.state = .start; - return true; + return .reset; } if (!res.request.parser.done) { // If the response wasn't fully read, then we need to close the connection. res.connection.conn.closing = true; - return false; + return .closing; } // A connection is only keep-alive if the Connection header is present and it's value is not "close". @@ -408,7 +410,11 @@ pub const Response = struct { .parser = res.request.parser, }; - return !res.connection.conn.closing; + if (res.connection.conn.closing) { + return .closing; + } else { + return .reset; + } } pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; @@ -699,7 +705,7 @@ pub const HeaderStrategy = union(enum) { static: []u8, }; -/// Accept a new connection and allocate a Response for it. +/// Accept a new connection. pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!Response { const in = try server.socket.accept(); diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 1ed4b1b279..493931a2ea 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -122,7 +122,7 @@ fn runServer(srv: *Server) !void { var res = try srv.accept(.{ .dynamic = max_header_size }); defer res.deinit(); - while (res.reset()) { + while (res.reset() != .closing) { res.wait() catch |err| switch (err) { error.HttpHeadersInvalid => continue :outer, error.EndOfStream => continue, From 7b0962938859a955fa8e057ef34f4abd925bb1ca Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 1 May 2023 12:13:33 -0500 Subject: [PATCH 4/7] std.http: buffer writes --- lib/std/http/Client.zig | 62 +++++++++++++++++++++++++++++------------ lib/std/http/Server.zig | 62 +++++++++++++++++++++++++++++------------ 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 74321492c8..72f16b0e76 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -254,44 +254,47 @@ pub const BufferedConnection = struct { pub const buffer_size = 0x2000; conn: Connection, - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, + read_buf: [buffer_size]u8 = undefined, + read_start: u16 = 0, + read_end: u16 = 0, + + write_buf: [buffer_size]u8 = undefined, + write_end: u16 = 0, pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.end != bconn.start) return; + if (bconn.read_end != bconn.read_start) return; - const nread = try bconn.conn.read(bconn.buf[0..]); + const nread = try bconn.conn.read(bconn.read_buf[0..]); if (nread == 0) return error.EndOfStream; - bconn.start = 0; - bconn.end = @intCast(u16, nread); + bconn.read_start = 0; + bconn.read_end = @intCast(u16, nread); } pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.buf[bconn.start..bconn.end]; + return bconn.read_buf[bconn.read_start..bconn.read_end]; } pub fn clear(bconn: *BufferedConnection, num: u16) void { - bconn.start += num; + bconn.read_start += num; } pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { var out_index: u16 = 0; while (out_index < len) { - const available = bconn.end - bconn.start; + const available = bconn.read_end - bconn.read_start; const left = buffer.len - out_index; if (available > 0) { const can_read = @intCast(u16, @min(available, left)); - @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); + @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); out_index += can_read; - bconn.start += can_read; + bconn.read_start += can_read; continue; } - if (left > bconn.buf.len) { + if (left > bconn.read_buf.len) { // skip the buffer if the output is large enough return bconn.conn.read(buffer[out_index..]); } @@ -314,11 +317,33 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - return bconn.conn.writeAll(buffer); + if (bconn.write_buf.len - bconn.write_end <= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..], buffer); + bconn.write_end += @intCast(u16, buffer.len); + } else { + try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); + bconn.write_end = 0; + + try bconn.conn.writeAll(buffer); + } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - return bconn.conn.write(buffer); + if (bconn.write_buf.len - bconn.write_end <= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..], buffer); + bconn.write_end += @intCast(u16, buffer.len); + + return buffer.len; + } else { + try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); + bconn.write_end = 0; + + return try bconn.conn.write(buffer); + } + } + + pub fn flush(bconn: *BufferedConnection) WriteError!void { + return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } pub const WriteError = Connection.WriteError; @@ -567,8 +592,7 @@ pub const Request = struct { /// Send the request to the server. pub fn start(req: *Request) StartError!void { - var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer()); - const w = buffered.writer(); + const w = req.connection.data.buffered.writer(); try w.writeAll(@tagName(req.method)); try w.writeByte(' '); @@ -642,7 +666,7 @@ pub const Request = struct { try w.writeAll("\r\n"); - try buffered.flush(); + try req.connection.data.buffered.flush(); } pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; @@ -868,6 +892,8 @@ pub const Request = struct { .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try req.connection.data.buffered.flush(); } }; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 2e7d985e4c..46c689ba6d 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -98,44 +98,47 @@ pub const BufferedConnection = struct { pub const buffer_size = 0x2000; conn: Connection, - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, + read_buf: [buffer_size]u8 = undefined, + read_start: u16 = 0, + read_end: u16 = 0, + + write_buf: [buffer_size]u8 = undefined, + write_end: u16 = 0, pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.end != bconn.start) return; + if (bconn.read_end != bconn.read_start) return; - const nread = try bconn.conn.read(bconn.buf[0..]); + const nread = try bconn.conn.read(bconn.read_buf[0..]); if (nread == 0) return error.EndOfStream; - bconn.start = 0; - bconn.end = @intCast(u16, nread); + bconn.read_start = 0; + bconn.read_end = @intCast(u16, nread); } pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.buf[bconn.start..bconn.end]; + return bconn.read_buf[bconn.read_start..bconn.read_end]; } pub fn clear(bconn: *BufferedConnection, num: u16) void { - bconn.start += num; + bconn.read_start += num; } pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { var out_index: u16 = 0; while (out_index < len) { - const available = bconn.end - bconn.start; + const available = bconn.read_end - bconn.read_start; const left = buffer.len - out_index; if (available > 0) { const can_read = @intCast(u16, @min(available, left)); - @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); + @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); out_index += can_read; - bconn.start += can_read; + bconn.read_start += can_read; continue; } - if (left > bconn.buf.len) { + if (left > bconn.read_buf.len) { // skip the buffer if the output is large enough return bconn.conn.read(buffer[out_index..]); } @@ -158,11 +161,33 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - return bconn.conn.writeAll(buffer); + if (bconn.write_buf.len - bconn.write_end <= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..], buffer); + bconn.write_end += @intCast(u16, buffer.len); + } else { + try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); + bconn.write_end = 0; + + try bconn.conn.writeAll(buffer); + } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - return bconn.conn.write(buffer); + if (bconn.write_buf.len - bconn.write_end <= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..], buffer); + bconn.write_end += @intCast(u16, buffer.len); + + return buffer.len; + } else { + try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); + bconn.write_end = 0; + + return try bconn.conn.write(buffer); + } + } + + pub fn flush(bconn: *BufferedConnection) WriteError!void { + return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } pub const WriteError = Connection.WriteError; @@ -426,8 +451,7 @@ pub const Response = struct { .first, .start, .responded, .finished => unreachable, } - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); + const w = res.connection.writer(); try w.writeAll(@tagName(res.version)); try w.writeByte(' '); @@ -485,7 +509,7 @@ pub const Response = struct { try w.writeAll("\r\n"); - try buffered.flush(); + try res.connection.flush(); } pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; @@ -669,6 +693,8 @@ pub const Response = struct { .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try res.connection.flush(); } }; From 5f219a2d118cac1410888fb2c0abc0cc91d092de Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 1 May 2023 17:49:34 -0500 Subject: [PATCH 5/7] std.http.Server: give Response access to their own allocator * This makes it easier for threaded servers to use a different allocator for each request. --- lib/std/http/Server.zig | 33 +++++++++++++++++++-------------- test/standalone/http.zig | 11 ++++++++--- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 46c689ba6d..d473366092 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -352,7 +352,7 @@ pub const Response = struct { transfer_encoding: ResponseTransfer = .none, - server: *Server, + allocator: Allocator, address: net.Address, connection: BufferedConnection, @@ -376,7 +376,7 @@ pub const Response = struct { res.request.headers.deinit(); if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.server.allocator); + res.request.parser.header_bytes.deinit(res.allocator); } } @@ -545,13 +545,13 @@ pub const Response = struct { while (true) { try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); if (res.request.parser.state.isContent()) break; } - res.request.headers = .{ .allocator = res.server.allocator, .owned = true }; + res.request.headers = .{ .allocator = res.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); if (res.request.transfer_encoding) |te| { @@ -573,13 +573,13 @@ pub const Response = struct { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .deflate = std.compress.zlib.zlibStream(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .gzip => res.request.compression = .{ - .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .gzip = std.compress.gzip.decompress(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ - .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), + .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), }, }; } @@ -612,12 +612,12 @@ pub const Response = struct { while (!res.request.parser.state.isContent()) { // read trailing headers try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); } if (has_trail) { - res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false }; + res.request.headers = http.Headers{ .allocator = res.allocator, .owned = false }; // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. // This will *only* fail for a malformed trailer. @@ -731,24 +731,29 @@ pub const HeaderStrategy = union(enum) { static: []u8, }; +pub const AcceptOptions = struct { + allocator: Allocator, + header_strategy: HeaderStrategy = .{ .dynamic = 8192 }, +}; + /// Accept a new connection. -pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!Response { +pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { const in = try server.socket.accept(); return Response{ - .server = server, + .allocator = options.allocator, .address = in.address, .connection = .{ .conn = .{ .stream = in.stream, .protocol = .plain, } }, - .headers = .{ .allocator = server.allocator }, + .headers = .{ .allocator = options.allocator }, .request = .{ .version = undefined, .method = undefined, .target = undefined, - .headers = .{ .allocator = server.allocator, .owned = false }, - .parser = switch (options) { + .headers = .{ .allocator = options.allocator, .owned = false }, + .parser = switch (options.header_strategy) { .dynamic => |max| proto.HeadersParser.initDynamic(max), .static => |buf| proto.HeadersParser.initStatic(buf), }, diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 493931a2ea..180e863bba 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -15,6 +15,8 @@ var gpa_client = std.heap.GeneralPurposeAllocator(.{}){}; const salloc = gpa_server.allocator(); const calloc = gpa_client.allocator(); +var server: Server = undefined; + fn handleRequest(res: *Server.Response) !void { const log = std.log.scoped(.server); @@ -89,7 +91,7 @@ fn handleRequest(res: *Server.Response) !void { } else if (mem.eql(u8, res.request.target, "/redirect/3")) { res.transfer_encoding = .chunked; - const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{res.server.socket.listen_address.getPort()}); + const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()}); defer salloc.free(location); res.status = .found; @@ -119,7 +121,10 @@ var handle_new_requests = true; fn runServer(srv: *Server) !void { outer: while (handle_new_requests) { - var res = try srv.accept(.{ .dynamic = max_header_size }); + var res = try srv.accept(.{ + .allocator = salloc, + .header_strategy = .{ .dynamic = max_header_size }, + }); defer res.deinit(); while (res.reset() != .closing) { @@ -162,7 +167,7 @@ pub fn main() !void { defer _ = gpa_client.deinit(); - var server = Server.init(salloc, .{ .reuse_address = true }); + server = Server.init(salloc, .{ .reuse_address = true }); const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; try server.listen(addr); From 1b3ebfefd8d8fd05152d55c431f741880d1ce2a7 Mon Sep 17 00:00:00 2001 From: Nameless Date: Wed, 3 May 2023 14:34:10 -0500 Subject: [PATCH 6/7] fix keepalive and large buffered writes --- lib/std/http/Client.zig | 22 ++++---- lib/std/http/Server.zig | 23 ++++---- test/standalone/http.zig | 112 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 129 insertions(+), 28 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 72f16b0e76..77f4c6c6c3 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -71,7 +71,7 @@ pub const ConnectionPool = struct { while (next) |node| : (next = node.prev) { if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue; if (node.data.port != criteria.port) continue; - if (mem.eql(u8, node.data.host, criteria.host)) continue; + if (!mem.eql(u8, node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); return node; @@ -317,32 +317,29 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); try bconn.conn.writeAll(buffer); } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); return buffer.len; } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); return try bconn.conn.write(buffer); } } pub fn flush(bconn: *BufferedConnection) WriteError!void { + defer bconn.write_end = 0; return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } @@ -720,12 +717,13 @@ pub const Request = struct { req.response.parser.done = true; } + // we default to using keep-alive if not provided const req_connection = req.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - if (req_keepalive and res_keepalive) { + if (res_keepalive and (req_keepalive or req_connection == null)) { req.connection.data.closing = false; } else { req.connection.data.closing = true; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index d473366092..2d3032bcdb 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -161,32 +161,29 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); try bconn.conn.writeAll(buffer); } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end <= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..], buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); bconn.write_end += @intCast(u16, buffer.len); return buffer.len; } else { - try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - bconn.write_end = 0; - + try bconn.flush(); return try bconn.conn.write(buffer); } } pub fn flush(bconn: *BufferedConnection) WriteError!void { + defer bconn.write_end = 0; return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } @@ -397,12 +394,14 @@ pub const Response = struct { // A connection is only keep-alive if the Connection header is present and it's value is not "close". // The server and client must both agree + // + // do() defaults to using keep-alive if the client requests it. const res_connection = res.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); const req_connection = res.request.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (res_keepalive and req_keepalive) { + if (req_keepalive and (res_keepalive or res_connection == null)) { res.connection.conn.closing = false; } else { res.connection.conn.closing = true; @@ -424,7 +423,7 @@ pub const Response = struct { res.headers.clearRetainingCapacity(); - res.request.headers.clearRetainingCapacity(); + res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here res.request.parser.reset(); res.request = Request{ diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 180e863bba..13dc278b6d 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -9,8 +9,8 @@ const testing = std.testing; const max_header_size = 8192; -var gpa_server = std.heap.GeneralPurposeAllocator(.{}){}; -var gpa_client = std.heap.GeneralPurposeAllocator(.{}){}; +var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; +var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){}; const salloc = gpa_server.allocator(); const calloc = gpa_client.allocator(); @@ -44,6 +44,24 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); try res.finish(); } + } else if (mem.startsWith(u8, res.request.target, "/large")) { + res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 }; + + try res.do(); + + var i: u32 = 0; + while (i < 5) : (i += 1) { + try res.writeAll("Hello, World!\n"); + } + + try res.writeAll("Hello, World!\n" ** 1024); + + i = 0; + while (i < 5) : (i += 1) { + try res.writeAll("Hello, World!\n"); + } + + try res.finish(); } else if (mem.eql(u8, res.request.target, "/echo-content")) { try testing.expectEqualStrings("Hello, World!\n", body); try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); @@ -68,6 +86,7 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); // try res.finish(); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); + try res.connection.flush(); } else if (mem.eql(u8, res.request.target, "/redirect/1")) { res.transfer_encoding = .chunked; @@ -177,8 +196,7 @@ pub fn main() !void { const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); var client = Client{ .allocator = calloc }; - - defer client.deinit(); + // defer client.deinit(); handled below { // read content-length response var h = http.Headers{ .allocator = calloc }; @@ -202,6 +220,33 @@ pub fn main() !void { try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + { // read large content-length response + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/large", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192 * 1024); + defer calloc.free(body); + + try testing.expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); + } + + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -225,6 +270,9 @@ pub fn main() !void { try testing.expectEqualStrings("14", req.response.headers.getFirstValue("content-length").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // read chunked response var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -247,6 +295,9 @@ pub fn main() !void { try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send head request and not read chunked var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -270,6 +321,9 @@ pub fn main() !void { try testing.expectEqualStrings("chunked", req.response.headers.getFirstValue("transfer-encoding").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // check trailing headers var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -292,6 +346,9 @@ pub fn main() !void { try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // send content-length request var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -321,6 +378,36 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + { // read content-length response with connection close + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("connection", "close"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.GET, uri, h, .{}); + defer req.deinit(); + + try req.start(); + try req.wait(); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?); + } + + // connection has been closed + try testing.expect(client.connection_pool.free_len == 0); + { // send chunked request var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -350,6 +437,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // relative redirect var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -371,6 +461,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // redirect from root var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -392,6 +485,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // absolute redirect var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -413,6 +509,9 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", body); } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + { // too many redirects var h = http.Headers{ .allocator = calloc }; defer h.deinit(); @@ -432,6 +531,11 @@ pub fn main() !void { }; } + // connection has been kept alive + try testing.expect(client.connection_pool.free_len == 1); + + client.deinit(); + killServer(server.socket.listen_address); server_thread.join(); } From 9017d758b96c1f296249670dffb774a598bc8598 Mon Sep 17 00:00:00 2001 From: Nameless Date: Sat, 6 May 2023 21:35:04 -0500 Subject: [PATCH 7/7] std.http: use larger read buffer to hit faster tls code --- lib/std/http/Client.zig | 2 +- lib/std/http/Server.zig | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 77f4c6c6c3..023bdd28bc 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -251,7 +251,7 @@ pub const Connection = struct { /// A buffered (and peekable) Connection. pub const BufferedConnection = struct { - pub const buffer_size = 0x2000; + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; conn: Connection, read_buf: [buffer_size]u8 = undefined, diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 2d3032bcdb..6b5db6725f 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -95,7 +95,7 @@ pub const Connection = struct { /// A buffered (and peekable) Connection. pub const BufferedConnection = struct { - pub const buffer_size = 0x2000; + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; conn: Connection, read_buf: [buffer_size]u8 = undefined,