diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index bc284f5517..9c865bf604 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -474,7 +474,10 @@ pub const Request = struct { }) catch unreachable; if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); - if (options.content_length) |len| { + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; } else { h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); @@ -496,7 +499,12 @@ pub const Request = struct { .send_buffer = options.send_buffer, .send_buffer_start = 0, .send_buffer_end = h.items.len, - .content_length = options.content_length, + .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { + .chunked => .chunked, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .chunked, .elide_body = elide_body, .chunk_len = 0, }; @@ -709,12 +717,21 @@ pub const Response = struct { send_buffer_end: usize, /// `null` means transfer-encoding: chunked. /// As a debugging utility, counts down to zero as bytes are written. - content_length: ?u64, + transfer_encoding: TransferEncoding, elide_body: bool, /// Indicates how much of the end of the `send_buffer` corresponds to a /// chunk. This amount of data will be wrapped by an HTTP chunk header. chunk_len: usize, + pub const TransferEncoding = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked, + }; + pub const WriteError = net.Stream.WriteError; /// When using content-length, asserts that the amount of data sent matches @@ -723,11 +740,17 @@ pub const Response = struct { /// end-of-stream message, then flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. pub fn end(r: *Response) WriteError!void { - if (r.content_length) |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - } else { - try flush_chunked(r, &.{}); + switch (r.transfer_encoding) { + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + try flush_cl(r); + }, + .none => { + try flush_cl(r); + }, + .chunked => { + try flush_chunked(r, &.{}); + }, } r.* = undefined; } @@ -752,16 +775,21 @@ pub const Response = struct { /// May return 0, which does not indicate end of stream. The caller decides /// when the end of stream occurs by calling `end`. pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - if (r.content_length != null) { - return write_cl(r, bytes); - } else { - return write_chunked(r, bytes); + switch (r.transfer_encoding) { + .content_length, .none => return write_cl(r, bytes), + .chunked => return write_chunked(r, bytes), } } fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); - const len = &r.content_length.?; + + var trash: u64 = std.math.maxInt(u64); + const len = switch (r.transfer_encoding) { + .content_length => |*len| len, + else => &trash, + }; + if (r.elide_body) { len.* -= bytes.len; return bytes.len; @@ -805,7 +833,7 @@ pub const Response = struct { fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); - assert(r.content_length == null); + assert(r.transfer_encoding == .chunked); if (r.elide_body) return bytes.len; @@ -867,15 +895,13 @@ pub const Response = struct { /// This is redundant after calling `end`. /// Respects the value of `elide_body` to omit all data after the headers. pub fn flush(r: *Response) WriteError!void { - if (r.content_length != null) { - return flush_cl(r); - } else { - return flush_chunked(r, null); + switch (r.transfer_encoding) { + .none, .content_length => return flush_cl(r), + .chunked => return flush_chunked(r, null), } } fn flush_cl(r: *Response) WriteError!void { - assert(r.content_length != null); try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); r.send_buffer_start = 0; r.send_buffer_end = 0; @@ -884,7 +910,7 @@ pub const Response = struct { fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { const max_trailers = 25; if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.content_length == null); + assert(r.transfer_encoding == .chunked); const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; @@ -976,7 +1002,10 @@ pub const Response = struct { pub fn writer(r: *Response) std.io.AnyWriter { return .{ - .writeFn = if (r.content_length != null) write_cl else write_chunked, + .writeFn = switch (r.transfer_encoding) { + .none, .content_length => write_cl, + .chunked => write_chunked, + }, .context = r, }; } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 5a2189f7f4..a1f82dc892 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -222,6 +222,72 @@ test "echo content server" { } } +test "Server.Request.respondStreaming non-chunked, unknown content-length" { + // In this case, the response is expected to stream until the connection is + // closed, indicating the end of the body. + const test_server = try createTestServer(struct { + fn run(net_server: *std.net.Server) anyerror!void { + var header_buffer: [1000]u8 = undefined; + var remaining: usize = 1; + while (remaining != 0) : (remaining -= 1) { + const conn = try net_server.accept(); + defer conn.stream.close(); + + var server = std.http.Server.init(conn, &header_buffer); + + try expectEqual(.ready, server.state); + var request = try server.receiveHead(); + try expectEqualStrings(request.head.target, "/foo"); + var send_buffer: [500]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .respond_options = .{ + .transfer_encoding = .none, + }, + }); + var total: usize = 0; + for (0..500) |i| { + var buf: [30]u8 = undefined; + const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); + try response.writeAll(line); + total += line.len; + } + try expectEqual(7390, total); + try response.end(); + try expectEqual(.closing, server.state); + } + } + }); + defer test_server.destroy(); + + const request_bytes = "GET /foo HTTP/1.1\r\n\r\n"; + const gpa = std.testing.allocator; + const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + defer stream.close(); + try stream.writeAll(request_bytes); + + const response = try stream.reader().readAllAlloc(gpa, 8192); + defer gpa.free(response); + + var expected_response = std.ArrayList(u8).init(gpa); + defer expected_response.deinit(); + + try expected_response.appendSlice("HTTP/1.1 200 OK\r\n\r\n"); + + { + var total: usize = 0; + for (0..500) |i| { + var buf: [30]u8 = undefined; + const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); + try expected_response.appendSlice(line); + total += line.len; + } + try expectEqual(7390, total); + } + + try expectEqualStrings(expected_response.items, response); +} + fn echoTests(client: *std.http.Client, port: u16) !void { const gpa = std.testing.allocator; var location_buffer: [100]u8 = undefined;