From b4b9f6aa4a5bfd6a54b59444f3e1a3706358eb76 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Feb 2024 00:16:03 -0700 Subject: [PATCH] std.http.Server: reimplement chunked uploading * Uncouple std.http.ChunkParser from protocol.zig * Fix receiveHead not passing leftover buffer through the header parser. * Fix content-length read streaming This implementation handles the final chunk length correctly rather than "hoping" that the buffer already contains \r\n. --- lib/std/http.zig | 2 + lib/std/http/ChunkParser.zig | 131 ++++++++++++++++++++++++ lib/std/http/HeadParser.zig | 15 +-- lib/std/http/Server.zig | 193 ++++++++++++++++++++++++----------- lib/std/http/protocol.zig | 136 +++++------------------- lib/std/http/test.zig | 2 +- 6 files changed, 299 insertions(+), 180 deletions(-) create mode 100644 lib/std/http/ChunkParser.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index 613abc66b2..5898b39da8 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -4,6 +4,7 @@ pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); pub const HeadParser = @import("http/HeadParser.zig"); +pub const ChunkParser = @import("http/ChunkParser.zig"); pub const Version = enum { @"HTTP/1.0", @@ -313,5 +314,6 @@ test { _ = Server; _ = Status; _ = HeadParser; + _ = ChunkParser; _ = @import("http/test.zig"); } diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig new file mode 100644 index 0000000000..adcdc74bc7 --- /dev/null +++ b/lib/std/http/ChunkParser.zig @@ -0,0 +1,131 @@ +//! Parser for transfer-encoding: chunked. + +state: State, +chunk_len: u64, + +pub const init: ChunkParser = .{ + .state = .head_size, + .chunk_len = 0, +}; + +pub const State = enum { + head_size, + head_ext, + head_r, + data, + data_suffix, + data_suffix_r, + invalid, +}; + +/// Returns the number of bytes consumed by the chunk size. This is always +/// less than or equal to `bytes.len`. +/// +/// After this function returns, `chunk_len` will contain the parsed chunk size +/// in bytes when `state` is `data`. Alternately, `state` may become `invalid`, +/// indicating a syntax error in the input stream. +/// +/// If the amount returned is less than `bytes.len`, the parser is in the +/// `chunk_data` state and the first byte of the chunk is at `bytes[result]`. +/// +/// Asserts `state` is neither `data` nor `invalid`. +pub fn feed(p: *ChunkParser, bytes: []const u8) usize { + for (bytes, 0..) |c, i| switch (p.state) { + .data_suffix => switch (c) { + '\r' => p.state = .data_suffix_r, + '\n' => p.state = .head_size, + else => { + p.state = .invalid; + return i; + }, + }, + .data_suffix_r => switch (c) { + '\n' => p.state = .head_size, + else => { + p.state = .invalid; + return i; + }, + }, + .head_size => { + const digit = switch (c) { + '0'...'9' => |b| b - '0', + 'A'...'Z' => |b| b - 'A' + 10, + 'a'...'z' => |b| b - 'a' + 10, + '\r' => { + p.state = .head_r; + continue; + }, + '\n' => { + p.state = .data; + return i + 1; + }, + else => { + p.state = .head_ext; + continue; + }, + }; + + const new_len = p.chunk_len *% 16 +% digit; + if (new_len <= p.chunk_len and p.chunk_len != 0) { + p.state = .invalid; + return i; + } + + p.chunk_len = new_len; + }, + .head_ext => switch (c) { + '\r' => p.state = .head_r, + '\n' => { + p.state = .data; + return i + 1; + }, + else => continue, + }, + .head_r => switch (c) { + '\n' => { + p.state = .data; + return i + 1; + }, + else => { + p.state = .invalid; + return i; + }, + }, + .data => unreachable, + .invalid => unreachable, + }; + return bytes.len; +} + +const ChunkParser = @This(); +const std = @import("std"); + +test feed { + const testing = std.testing; + + const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; + + var p = init; + const first = p.feed(data[0..]); + try testing.expectEqual(@as(u32, 4), first); + try testing.expectEqual(@as(u64, 0xff), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const second = p.feed(data[first..]); + try testing.expectEqual(@as(u32, 13), second); + try testing.expectEqual(@as(u64, 0xf0f000), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const third = p.feed(data[first + second ..]); + try testing.expectEqual(@as(u32, 3), third); + try testing.expectEqual(@as(u64, 0), p.chunk_len); + try testing.expectEqual(.data, p.state); + + p = init; + const fourth = p.feed(data[first + second + third ..]); + try testing.expectEqual(@as(u32, 16), fourth); + try testing.expectEqual(@as(u64, 0xffffffffffffffff), p.chunk_len); + try testing.expectEqual(.invalid, p.state); +} diff --git a/lib/std/http/HeadParser.zig b/lib/std/http/HeadParser.zig index 07c357731a..bb49faa14b 100644 --- a/lib/std/http/HeadParser.zig +++ b/lib/std/http/HeadParser.zig @@ -1,3 +1,5 @@ +//! Finds the end of an HTTP head in a stream. + state: State = .start, pub const State = enum { @@ -17,13 +19,12 @@ pub const State = enum { /// `bytes[result]`. pub fn feed(p: *HeadParser, bytes: []const u8) usize { const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8); - const len: u32 = @intCast(bytes.len); - var index: u32 = 0; + var index: usize = 0; while (true) { switch (p.state) { .finished => return index, - .start => switch (len - index) { + .start => switch (bytes.len - index) { 0 => return index, 1 => { switch (bytes[index]) { @@ -218,7 +219,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize { continue; }, }, - .seen_n => switch (len - index) { + .seen_n => switch (bytes.len - index) { 0 => return index, else => { switch (bytes[index]) { @@ -230,7 +231,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize { continue; }, }, - .seen_r => switch (len - index) { + .seen_r => switch (bytes.len - index) { 0 => return index, 1 => { switch (bytes[index]) { @@ -286,7 +287,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize { continue; }, }, - .seen_rn => switch (len - index) { + .seen_rn => switch (bytes.len - index) { 0 => return index, 1 => { switch (bytes[index]) { @@ -317,7 +318,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize { continue; }, }, - .seen_rnr => switch (len - index) { + .seen_rnr => switch (bytes.len - index) { 0 => return index, else => { switch (bytes[index]) { diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 73289d713f..e19f6302de 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,4 +1,5 @@ //! Blocking HTTP server implementation. +//! Handles a single connection's lifecycle. connection: net.Server.Connection, /// Keeps track of whether the Server is ready to accept a new request on the @@ -62,20 +63,19 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { // In case of a reused connection, move the next request's bytes to the // beginning of the buffer. if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[0..leftover.len]; - if (leftover.len <= s.next_request_start) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = leftover.len; - } + if (s.read_buffer_len > s.next_request_start) rebase(s, 0); s.next_request_start = 0; } var hp: http.HeadParser = .{}; + + if (s.read_buffer_len > 0) { + const bytes = s.read_buffer[0..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, end); + } + while (true) { const buf = s.read_buffer[s.read_buffer_len..]; if (buf.len == 0) @@ -85,16 +85,21 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { s.read_buffer_len += read_n; const bytes = buf[0..read_n]; const end = hp.feed(bytes); - if (hp.state == .finished) return .{ - .server = s, - .head_end = end, - .head = Request.Head.parse(s.read_buffer[0..end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, - }; + if (hp.state == .finished) + return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); } } +fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(s.read_buffer[0..head_end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, + }; +} + pub const Request = struct { server: *Server, /// Index into Server's read_buffer. @@ -102,6 +107,7 @@ pub const Request = struct { head: Head, reader_state: union { remaining_content_length: u64, + chunk_parser: http.ChunkParser, }, pub const Compression = union(enum) { @@ -416,51 +422,130 @@ pub const Request = struct { }; } - pub const ReadError = net.Stream.ReadError; + pub const ReadError = net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize }; fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; assert(s.state == .receiving_body); - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { s.state = .ready; return 0; } - - const available_bytes = s.read_buffer_len - request.head_end; - if (available_bytes == 0) - s.read_buffer_len += try s.connection.stream.read(s.read_buffer[request.head_end..]); - - const available_buf = s.read_buffer[request.head_end..s.read_buffer_len]; - const len = @min(remaining_content_length.*, available_buf.len, buffer.len); - @memcpy(buffer[0..len], available_buf[0..len]); + const available = try fill(s, request.head_end); + const len = @min(remaining_content_length.*, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); remaining_content_length.* -= len; + s.next_request_start += len; if (remaining_content_length.* == 0) s.state = .ready; return len; } + fn fill(s: *Server, head_end: usize) ReadError![]u8 { + const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; + if (available.len > 0) return available; + s.next_request_start = head_end; + s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + return s.read_buffer[head_end..s.read_buffer_len]; + } + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; assert(s.state == .receiving_body); - _ = buffer; - @panic("TODO"); - } - pub const ReadAllError = ReadError || error{HttpBodyOversize}; + const cp = &request.reader_state.chunk_parser; + const head_end = request.head_end; + + // Protect against returning 0 before the end of stream. + var out_end: usize = 0; + while (out_end == 0) { + switch (cp.state) { + .invalid => return 0, + .data => { + const available = try fill(s, head_end); + const len = @min(cp.chunk_len, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += len; + continue; + }, + else => { + const available = try fill(s, head_end); + const n = cp.feed(available); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + if (cp.chunk_len == 0) { + // The next bytes in the stream are trailers, + // or \r\n to indicate end of chunked body. + // + // This function must append the trailers at + // head_end so that headers and trailers are + // together. + // + // Since returning 0 would indicate end of + // stream, this function must read all the + // trailers before returning. + if (s.next_request_start > head_end) rebase(s, head_end); + var hp: http.HeadParser = .{}; + { + const bytes = s.read_buffer[head_end..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = try s.connection.stream.read(buf); + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + } + const data = available[n..]; + const len = @min(cp.chunk_len, data.len, buffer.len); + @memcpy(buffer[0..len], data[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += n + len; + continue; + }, + else => continue, + } + }, + } + } + return out_end; + } pub fn reader(request: *Request) std.io.AnyReader { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; + s.next_request_start = request.head_end; switch (request.head.transfer_encoding) { - .chunked => return .{ - .readFn = read_chunked, - .context = request, + .chunked => { + request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + return .{ + .readFn = read_chunked, + .context = request, + }; }, .none => { request.reader_state = .{ @@ -489,31 +574,8 @@ pub const Request = struct { const s = request.server; if (keep_alive and request.head.keep_alive) switch (s.state) { .received_head => { - s.state = .receiving_body; - switch (request.head.transfer_encoding) { - .none => t: { - const len = request.head.content_length orelse break :t; - const head_end = request.head_end; - var total_body_discarded: usize = 0; - while (true) { - const available_bytes = s.read_buffer_len - head_end; - const remaining_len = len - total_body_discarded; - if (available_bytes >= remaining_len) { - s.next_request_start = head_end + remaining_len; - break :t; - } - total_body_discarded += available_bytes; - // Preserve request header memory until receiveHead is called. - const buf = s.read_buffer[head_end..]; - const read_n = s.connection.stream.read(buf) catch return false; - s.read_buffer_len = head_end + read_n; - } - }, - .chunked => { - @panic("TODO"); - }, - } - s.state = .ready; + _ = request.reader().discard() catch return false; + assert(s.state == .ready); return true; }, .receiving_body, .ready => return true, @@ -799,6 +861,17 @@ pub const Response = struct { } }; +fn rebase(s: *Server, index: usize) void { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[index..][0..leftover.len]; + if (leftover.len <= s.next_request_start - index) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = index + leftover.len; +} + const std = @import("../std.zig"); const http = std.http; const mem = std.mem; diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index d080d3c389..64c87b9287 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -97,85 +97,32 @@ pub const HeadersParser = struct { return @intCast(result); } - /// Returns the number of bytes consumed by the chunk size. This is always - /// less than or equal to `bytes.len`. - /// You should check `r.state == .chunk_data` after this to check if the - /// chunk size has been fully parsed. - /// - /// If the amount returned is less than `bytes.len`, you may assume that - /// the parser is in the `chunk_data` state and that the first byte of the - /// chunk is at `bytes[result]`. pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - const len = @as(u32, @intCast(bytes.len)); - - for (bytes[0..], 0..) |c, i| { - const index = @as(u32, @intCast(i)); - switch (r.state) { - .chunk_data_suffix => switch (c) { - '\r' => r.state = .chunk_data_suffix_r, - '\n' => r.state = .chunk_head_size, - else => { - r.state = .invalid; - return index; - }, - }, - .chunk_data_suffix_r => switch (c) { - '\n' => r.state = .chunk_head_size, - else => { - r.state = .invalid; - return index; - }, - }, - .chunk_head_size => { - const digit = switch (c) { - '0'...'9' => |b| b - '0', - 'A'...'Z' => |b| b - 'A' + 10, - 'a'...'z' => |b| b - 'a' + 10, - '\r' => { - r.state = .chunk_head_r; - continue; - }, - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => { - r.state = .chunk_head_ext; - continue; - }, - }; - - const new_len = r.next_chunk_length *% 16 +% digit; - if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) { - r.state = .invalid; - return index; - } - - r.next_chunk_length = new_len; - }, - .chunk_head_ext => switch (c) { - '\r' => r.state = .chunk_head_r, - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => continue, - }, - .chunk_head_r => switch (c) { - '\n' => { - r.state = .chunk_data; - return index + 1; - }, - else => { - r.state = .invalid; - return index; - }, - }, + var cp: std.http.ChunkParser = .{ + .state = switch (r.state) { + .chunk_head_size => .head_size, + .chunk_head_ext => .head_ext, + .chunk_head_r => .head_r, + .chunk_data => .data, + .chunk_data_suffix => .data_suffix, + .chunk_data_suffix_r => .data_suffix_r, + .invalid => .invalid, else => unreachable, - } - } - - return len; + }, + .chunk_len = r.next_chunk_length, + }; + const result = cp.feed(bytes); + r.state = switch (cp.state) { + .head_size => .chunk_head_size, + .head_ext => .chunk_head_ext, + .head_r => .chunk_head_r, + .data => .chunk_data, + .data_suffix => .chunk_data_suffix, + .data_suffix_r => .chunk_data_suffix_r, + .invalid => .invalid, + }; + r.next_chunk_length = cp.chunk_len; + return @intCast(result); } /// Returns whether or not the parser has finished parsing a complete @@ -464,41 +411,6 @@ const MockBufferedConnection = struct { } }; -test "HeadersParser.findChunkedLen" { - var r: HeadersParser = undefined; - const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; - - r = HeadersParser.init(&.{}); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const first = r.findChunkedLen(data[0..]); - try testing.expectEqual(@as(u32, 4), first); - try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const second = r.findChunkedLen(data[first..]); - try testing.expectEqual(@as(u32, 13), second); - try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const third = r.findChunkedLen(data[first + second ..]); - try testing.expectEqual(@as(u32, 3), third); - try testing.expectEqual(@as(u64, 0), r.next_chunk_length); - try testing.expectEqual(State.chunk_data, r.state); - r.state = .chunk_head_size; - r.next_chunk_length = 0; - - const fourth = r.findChunkedLen(data[first + second + third ..]); - try testing.expectEqual(@as(u32, 16), fourth); - try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length); - try testing.expectEqual(State.invalid, r.state); -} - test "HeadersParser.read length" { // mock BufferedConnection for read var headers_buf: [256]u8 = undefined; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 3dbccbcff6..d803e3cd81 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -164,7 +164,7 @@ test "HTTP server handles a chunked transfer coding request" { const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); defer stream.close(); - _ = try stream.writeAll(request_bytes[0..]); + try stream.writeAll(request_bytes); server_thread.join(); }