std.http.Server: reimplement chunked uploading

* Uncouple std.http.ChunkParser from protocol.zig
* Fix receiveHead not passing leftover buffer through the header parser.
* Fix content-length read streaming

This implementation handles the final chunk length correctly rather than
"hoping" that the buffer already contains \r\n.
This commit is contained in:
Andrew Kelley 2024-02-21 00:16:03 -07:00
parent a8958c99a9
commit b4b9f6aa4a
6 changed files with 299 additions and 180 deletions

View File

@ -4,6 +4,7 @@ pub const Client = @import("http/Client.zig");
pub const Server = @import("http/Server.zig");
pub const protocol = @import("http/protocol.zig");
pub const HeadParser = @import("http/HeadParser.zig");
pub const ChunkParser = @import("http/ChunkParser.zig");
pub const Version = enum {
@"HTTP/1.0",
@ -313,5 +314,6 @@ test {
_ = Server;
_ = Status;
_ = HeadParser;
_ = ChunkParser;
_ = @import("http/test.zig");
}

View File

@ -0,0 +1,131 @@
//! Parser for transfer-encoding: chunked.
state: State,
chunk_len: u64,
pub const init: ChunkParser = .{
.state = .head_size,
.chunk_len = 0,
};
pub const State = enum {
head_size,
head_ext,
head_r,
data,
data_suffix,
data_suffix_r,
invalid,
};
/// Returns the number of bytes consumed by the chunk size. This is always
/// less than or equal to `bytes.len`.
///
/// After this function returns, `chunk_len` will contain the parsed chunk size
/// in bytes when `state` is `data`. Alternately, `state` may become `invalid`,
/// indicating a syntax error in the input stream.
///
/// If the amount returned is less than `bytes.len`, the parser is in the
/// `chunk_data` state and the first byte of the chunk is at `bytes[result]`.
///
/// Asserts `state` is neither `data` nor `invalid`.
pub fn feed(p: *ChunkParser, bytes: []const u8) usize {
for (bytes, 0..) |c, i| switch (p.state) {
.data_suffix => switch (c) {
'\r' => p.state = .data_suffix_r,
'\n' => p.state = .head_size,
else => {
p.state = .invalid;
return i;
},
},
.data_suffix_r => switch (c) {
'\n' => p.state = .head_size,
else => {
p.state = .invalid;
return i;
},
},
.head_size => {
const digit = switch (c) {
'0'...'9' => |b| b - '0',
'A'...'Z' => |b| b - 'A' + 10,
'a'...'z' => |b| b - 'a' + 10,
'\r' => {
p.state = .head_r;
continue;
},
'\n' => {
p.state = .data;
return i + 1;
},
else => {
p.state = .head_ext;
continue;
},
};
const new_len = p.chunk_len *% 16 +% digit;
if (new_len <= p.chunk_len and p.chunk_len != 0) {
p.state = .invalid;
return i;
}
p.chunk_len = new_len;
},
.head_ext => switch (c) {
'\r' => p.state = .head_r,
'\n' => {
p.state = .data;
return i + 1;
},
else => continue,
},
.head_r => switch (c) {
'\n' => {
p.state = .data;
return i + 1;
},
else => {
p.state = .invalid;
return i;
},
},
.data => unreachable,
.invalid => unreachable,
};
return bytes.len;
}
const ChunkParser = @This();
const std = @import("std");
test feed {
const testing = std.testing;
const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
var p = init;
const first = p.feed(data[0..]);
try testing.expectEqual(@as(u32, 4), first);
try testing.expectEqual(@as(u64, 0xff), p.chunk_len);
try testing.expectEqual(.data, p.state);
p = init;
const second = p.feed(data[first..]);
try testing.expectEqual(@as(u32, 13), second);
try testing.expectEqual(@as(u64, 0xf0f000), p.chunk_len);
try testing.expectEqual(.data, p.state);
p = init;
const third = p.feed(data[first + second ..]);
try testing.expectEqual(@as(u32, 3), third);
try testing.expectEqual(@as(u64, 0), p.chunk_len);
try testing.expectEqual(.data, p.state);
p = init;
const fourth = p.feed(data[first + second + third ..]);
try testing.expectEqual(@as(u32, 16), fourth);
try testing.expectEqual(@as(u64, 0xffffffffffffffff), p.chunk_len);
try testing.expectEqual(.invalid, p.state);
}

View File

@ -1,3 +1,5 @@
//! Finds the end of an HTTP head in a stream.
state: State = .start,
pub const State = enum {
@ -17,13 +19,12 @@ pub const State = enum {
/// `bytes[result]`.
pub fn feed(p: *HeadParser, bytes: []const u8) usize {
const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8);
const len: u32 = @intCast(bytes.len);
var index: u32 = 0;
var index: usize = 0;
while (true) {
switch (p.state) {
.finished => return index,
.start => switch (len - index) {
.start => switch (bytes.len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
@ -218,7 +219,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
continue;
},
},
.seen_n => switch (len - index) {
.seen_n => switch (bytes.len - index) {
0 => return index,
else => {
switch (bytes[index]) {
@ -230,7 +231,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
continue;
},
},
.seen_r => switch (len - index) {
.seen_r => switch (bytes.len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
@ -286,7 +287,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
continue;
},
},
.seen_rn => switch (len - index) {
.seen_rn => switch (bytes.len - index) {
0 => return index,
1 => {
switch (bytes[index]) {
@ -317,7 +318,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
continue;
},
},
.seen_rnr => switch (len - index) {
.seen_rnr => switch (bytes.len - index) {
0 => return index,
else => {
switch (bytes[index]) {

View File

@ -1,4 +1,5 @@
//! Blocking HTTP server implementation.
//! Handles a single connection's lifecycle.
connection: net.Server.Connection,
/// Keeps track of whether the Server is ready to accept a new request on the
@ -62,20 +63,19 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
// In case of a reused connection, move the next request's bytes to the
// beginning of the buffer.
if (s.next_request_start > 0) {
if (s.read_buffer_len > s.next_request_start) {
const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len];
const dest = s.read_buffer[0..leftover.len];
if (leftover.len <= s.next_request_start) {
@memcpy(dest, leftover);
} else {
mem.copyBackwards(u8, dest, leftover);
}
s.read_buffer_len = leftover.len;
}
if (s.read_buffer_len > s.next_request_start) rebase(s, 0);
s.next_request_start = 0;
}
var hp: http.HeadParser = .{};
if (s.read_buffer_len > 0) {
const bytes = s.read_buffer[0..s.read_buffer_len];
const end = hp.feed(bytes);
if (hp.state == .finished)
return finishReceivingHead(s, end);
}
while (true) {
const buf = s.read_buffer[s.read_buffer_len..];
if (buf.len == 0)
@ -85,16 +85,21 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
s.read_buffer_len += read_n;
const bytes = buf[0..read_n];
const end = hp.feed(bytes);
if (hp.state == .finished) return .{
.server = s,
.head_end = end,
.head = Request.Head.parse(s.read_buffer[0..end]) catch
return error.HttpHeadersInvalid,
.reader_state = undefined,
};
if (hp.state == .finished)
return finishReceivingHead(s, s.read_buffer_len - bytes.len + end);
}
}
fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request {
return .{
.server = s,
.head_end = head_end,
.head = Request.Head.parse(s.read_buffer[0..head_end]) catch
return error.HttpHeadersInvalid,
.reader_state = undefined,
};
}
pub const Request = struct {
server: *Server,
/// Index into Server's read_buffer.
@ -102,6 +107,7 @@ pub const Request = struct {
head: Head,
reader_state: union {
remaining_content_length: u64,
chunk_parser: http.ChunkParser,
},
pub const Compression = union(enum) {
@ -416,51 +422,130 @@ pub const Request = struct {
};
}
pub const ReadError = net.Stream.ReadError;
pub const ReadError = net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize };
fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize {
const request: *Request = @constCast(@alignCast(@ptrCast(context)));
const s = request.server;
assert(s.state == .receiving_body);
const remaining_content_length = &request.reader_state.remaining_content_length;
if (remaining_content_length.* == 0) {
s.state = .ready;
return 0;
}
const available_bytes = s.read_buffer_len - request.head_end;
if (available_bytes == 0)
s.read_buffer_len += try s.connection.stream.read(s.read_buffer[request.head_end..]);
const available_buf = s.read_buffer[request.head_end..s.read_buffer_len];
const len = @min(remaining_content_length.*, available_buf.len, buffer.len);
@memcpy(buffer[0..len], available_buf[0..len]);
const available = try fill(s, request.head_end);
const len = @min(remaining_content_length.*, available.len, buffer.len);
@memcpy(buffer[0..len], available[0..len]);
remaining_content_length.* -= len;
s.next_request_start += len;
if (remaining_content_length.* == 0)
s.state = .ready;
return len;
}
fn fill(s: *Server, head_end: usize) ReadError![]u8 {
const available = s.read_buffer[s.next_request_start..s.read_buffer_len];
if (available.len > 0) return available;
s.next_request_start = head_end;
s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]);
return s.read_buffer[head_end..s.read_buffer_len];
}
fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize {
const request: *Request = @constCast(@alignCast(@ptrCast(context)));
const s = request.server;
assert(s.state == .receiving_body);
_ = buffer;
@panic("TODO");
}
pub const ReadAllError = ReadError || error{HttpBodyOversize};
const cp = &request.reader_state.chunk_parser;
const head_end = request.head_end;
// Protect against returning 0 before the end of stream.
var out_end: usize = 0;
while (out_end == 0) {
switch (cp.state) {
.invalid => return 0,
.data => {
const available = try fill(s, head_end);
const len = @min(cp.chunk_len, available.len, buffer.len);
@memcpy(buffer[0..len], available[0..len]);
cp.chunk_len -= len;
if (cp.chunk_len == 0)
cp.state = .data_suffix;
out_end += len;
s.next_request_start += len;
continue;
},
else => {
const available = try fill(s, head_end);
const n = cp.feed(available);
switch (cp.state) {
.invalid => return error.HttpChunkInvalid,
.data => {
if (cp.chunk_len == 0) {
// The next bytes in the stream are trailers,
// or \r\n to indicate end of chunked body.
//
// This function must append the trailers at
// head_end so that headers and trailers are
// together.
//
// Since returning 0 would indicate end of
// stream, this function must read all the
// trailers before returning.
if (s.next_request_start > head_end) rebase(s, head_end);
var hp: http.HeadParser = .{};
{
const bytes = s.read_buffer[head_end..s.read_buffer_len];
const end = hp.feed(bytes);
if (hp.state == .finished) {
s.next_request_start = s.read_buffer_len - bytes.len + end;
return out_end;
}
}
while (true) {
const buf = s.read_buffer[s.read_buffer_len..];
if (buf.len == 0)
return error.HttpHeadersOversize;
const read_n = try s.connection.stream.read(buf);
s.read_buffer_len += read_n;
const bytes = buf[0..read_n];
const end = hp.feed(bytes);
if (hp.state == .finished) {
s.next_request_start = s.read_buffer_len - bytes.len + end;
return out_end;
}
}
}
const data = available[n..];
const len = @min(cp.chunk_len, data.len, buffer.len);
@memcpy(buffer[0..len], data[0..len]);
cp.chunk_len -= len;
if (cp.chunk_len == 0)
cp.state = .data_suffix;
out_end += len;
s.next_request_start += n + len;
continue;
},
else => continue,
}
},
}
}
return out_end;
}
pub fn reader(request: *Request) std.io.AnyReader {
const s = request.server;
assert(s.state == .received_head);
s.state = .receiving_body;
s.next_request_start = request.head_end;
switch (request.head.transfer_encoding) {
.chunked => return .{
.readFn = read_chunked,
.context = request,
.chunked => {
request.reader_state = .{ .chunk_parser = http.ChunkParser.init };
return .{
.readFn = read_chunked,
.context = request,
};
},
.none => {
request.reader_state = .{
@ -489,31 +574,8 @@ pub const Request = struct {
const s = request.server;
if (keep_alive and request.head.keep_alive) switch (s.state) {
.received_head => {
s.state = .receiving_body;
switch (request.head.transfer_encoding) {
.none => t: {
const len = request.head.content_length orelse break :t;
const head_end = request.head_end;
var total_body_discarded: usize = 0;
while (true) {
const available_bytes = s.read_buffer_len - head_end;
const remaining_len = len - total_body_discarded;
if (available_bytes >= remaining_len) {
s.next_request_start = head_end + remaining_len;
break :t;
}
total_body_discarded += available_bytes;
// Preserve request header memory until receiveHead is called.
const buf = s.read_buffer[head_end..];
const read_n = s.connection.stream.read(buf) catch return false;
s.read_buffer_len = head_end + read_n;
}
},
.chunked => {
@panic("TODO");
},
}
s.state = .ready;
_ = request.reader().discard() catch return false;
assert(s.state == .ready);
return true;
},
.receiving_body, .ready => return true,
@ -799,6 +861,17 @@ pub const Response = struct {
}
};
fn rebase(s: *Server, index: usize) void {
const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len];
const dest = s.read_buffer[index..][0..leftover.len];
if (leftover.len <= s.next_request_start - index) {
@memcpy(dest, leftover);
} else {
mem.copyBackwards(u8, dest, leftover);
}
s.read_buffer_len = index + leftover.len;
}
const std = @import("../std.zig");
const http = std.http;
const mem = std.mem;

View File

@ -97,85 +97,32 @@ pub const HeadersParser = struct {
return @intCast(result);
}
/// Returns the number of bytes consumed by the chunk size. This is always
/// less than or equal to `bytes.len`.
/// You should check `r.state == .chunk_data` after this to check if the
/// chunk size has been fully parsed.
///
/// If the amount returned is less than `bytes.len`, you may assume that
/// the parser is in the `chunk_data` state and that the first byte of the
/// chunk is at `bytes[result]`.
pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
const len = @as(u32, @intCast(bytes.len));
for (bytes[0..], 0..) |c, i| {
const index = @as(u32, @intCast(i));
switch (r.state) {
.chunk_data_suffix => switch (c) {
'\r' => r.state = .chunk_data_suffix_r,
'\n' => r.state = .chunk_head_size,
else => {
r.state = .invalid;
return index;
},
},
.chunk_data_suffix_r => switch (c) {
'\n' => r.state = .chunk_head_size,
else => {
r.state = .invalid;
return index;
},
},
.chunk_head_size => {
const digit = switch (c) {
'0'...'9' => |b| b - '0',
'A'...'Z' => |b| b - 'A' + 10,
'a'...'z' => |b| b - 'a' + 10,
'\r' => {
r.state = .chunk_head_r;
continue;
},
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => {
r.state = .chunk_head_ext;
continue;
},
};
const new_len = r.next_chunk_length *% 16 +% digit;
if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) {
r.state = .invalid;
return index;
}
r.next_chunk_length = new_len;
},
.chunk_head_ext => switch (c) {
'\r' => r.state = .chunk_head_r,
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => continue,
},
.chunk_head_r => switch (c) {
'\n' => {
r.state = .chunk_data;
return index + 1;
},
else => {
r.state = .invalid;
return index;
},
},
var cp: std.http.ChunkParser = .{
.state = switch (r.state) {
.chunk_head_size => .head_size,
.chunk_head_ext => .head_ext,
.chunk_head_r => .head_r,
.chunk_data => .data,
.chunk_data_suffix => .data_suffix,
.chunk_data_suffix_r => .data_suffix_r,
.invalid => .invalid,
else => unreachable,
}
}
return len;
},
.chunk_len = r.next_chunk_length,
};
const result = cp.feed(bytes);
r.state = switch (cp.state) {
.head_size => .chunk_head_size,
.head_ext => .chunk_head_ext,
.head_r => .chunk_head_r,
.data => .chunk_data,
.data_suffix => .chunk_data_suffix,
.data_suffix_r => .chunk_data_suffix_r,
.invalid => .invalid,
};
r.next_chunk_length = cp.chunk_len;
return @intCast(result);
}
/// Returns whether or not the parser has finished parsing a complete
@ -464,41 +411,6 @@ const MockBufferedConnection = struct {
}
};
test "HeadersParser.findChunkedLen" {
var r: HeadersParser = undefined;
const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
r = HeadersParser.init(&.{});
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const first = r.findChunkedLen(data[0..]);
try testing.expectEqual(@as(u32, 4), first);
try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const second = r.findChunkedLen(data[first..]);
try testing.expectEqual(@as(u32, 13), second);
try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const third = r.findChunkedLen(data[first + second ..]);
try testing.expectEqual(@as(u32, 3), third);
try testing.expectEqual(@as(u64, 0), r.next_chunk_length);
try testing.expectEqual(State.chunk_data, r.state);
r.state = .chunk_head_size;
r.next_chunk_length = 0;
const fourth = r.findChunkedLen(data[first + second + third ..]);
try testing.expectEqual(@as(u32, 16), fourth);
try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length);
try testing.expectEqual(State.invalid, r.state);
}
test "HeadersParser.read length" {
// mock BufferedConnection for read
var headers_buf: [256]u8 = undefined;

View File

@ -164,7 +164,7 @@ test "HTTP server handles a chunked transfer coding request" {
const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
defer stream.close();
_ = try stream.writeAll(request_bytes[0..]);
try stream.writeAll(request_bytes);
server_thread.join();
}