From 4689d93cb204a4143770105200eb65dcdca5d7a0 Mon Sep 17 00:00:00 2001 From: Nameless Date: Sun, 27 Aug 2023 16:36:24 -0500 Subject: [PATCH] std.http: allow for arbitrary http methods --- lib/std/http.zig | 49 ++++++++++++++++++++++++++++++---------- lib/std/http/Client.zig | 10 ++++---- lib/std/http/Server.zig | 10 ++++---- test/standalone/http.zig | 2 +- 4 files changed, 49 insertions(+), 22 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index f81032da50..424cbc8bb7 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,3 +1,5 @@ +const std = @import("std.zig"); + pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); @@ -14,16 +16,36 @@ pub const Version = enum { /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition /// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum { - GET, - HEAD, - POST, - PUT, - DELETE, - CONNECT, - OPTIONS, - TRACE, - PATCH, +pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is supported by the C backend, and therefore cannot pass CI + GET = parse("GET"), + HEAD = parse("HEAD"), + POST = parse("POST"), + PUT = parse("PUT"), + DELETE = parse("DELETE"), + CONNECT = parse("CONNECT"), + OPTIONS = parse("OPTIONS"), + TRACE = parse("TRACE"), + PATCH = parse("PATCH"), + + _, + + /// Converts `s` into a type that may be used as a `Method` field. + /// Asserts that `s` is 24 or fewer bytes. + pub fn parse(s: []const u8) u64 { + var x: u64 = 0; + @memcpy(std.mem.asBytes(&x)[0..s.len], s); + return x; + } + + pub fn write(self: Method, w: anytype) !void { + const bytes = std.mem.asBytes(&@intFromEnum(self)); + const str = std.mem.sliceTo(bytes, 0); + try w.writeAll(str); + } + + pub fn format(value: Method, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) @TypeOf(writer).Error!void { + return try value.write(writer); + } /// Returns true if a request of this method is allowed to have a body /// Actual behavior from servers may vary and should still be checked @@ -31,6 +53,7 @@ pub const Method = enum { return switch (self) { .POST, .PUT, .PATCH => true, .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, + else => true, }; } @@ -40,6 +63,7 @@ pub const Method = enum { return switch (self) { .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, .HEAD, .PUT, .TRACE => false, + else => true, }; } @@ -50,6 +74,7 @@ pub const Method = enum { return switch (self) { .GET, .HEAD, .OPTIONS, .TRACE => true, .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, + else => false, }; } @@ -60,6 +85,7 @@ pub const Method = enum { return switch (self) { .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, .CONNECT, .POST, .PATCH => false, + else => false, }; } @@ -70,6 +96,7 @@ pub const Method = enum { return switch (self) { .GET, .HEAD => true, .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, + else => false, }; } }; @@ -269,8 +296,6 @@ pub const Connection = enum { close, }; -const std = @import("std.zig"); - test { _ = Client; _ = Method; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index ec2e875d59..8df4525430 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -545,7 +545,7 @@ pub const Request = struct { var buffered = std.io.bufferedWriter(req.connection.?.data.writer()); const w = buffered.writer(); - try w.writeAll(@tagName(req.method)); + try req.method.write(w); try w.writeByte(' '); if (req.method == .CONNECT) { @@ -627,15 +627,15 @@ pub const Request = struct { try buffered.flush(); } - pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); - pub fn transferReader(req: *Request) TransferReader { + fn transferReader(req: *Request) TransferReader { return .{ .context = req }; } - pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { if (req.response.parser.done) return 0; var index: usize = 0; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index a9803ce47f..c493cc1bab 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -185,8 +185,10 @@ pub const Request = struct { return error.HttpHeadersInvalid; const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; + const method_str = first_line[0..method_end]; - const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; if (version_start == method_end) return error.HttpHeadersInvalid; @@ -467,11 +469,11 @@ pub const Response = struct { try buffered.flush(); } - pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); + const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); - pub fn transferReader(res: *Response) TransferReader { + fn transferReader(res: *Response) TransferReader { return .{ .context = res }; } diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 6a503c9338..f0dc0e6b2a 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -20,7 +20,7 @@ var server: Server = undefined; fn handleRequest(res: *Server.Response) !void { const log = std.log.scoped(.server); - log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target }); + log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target }); if (res.request.headers.contains("expect")) { if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) {