diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 39e9fc2413..82bd93493c 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -42,18 +42,19 @@ pub const Component = union(enum) { pub fn format( component: Component, - comptime fmt_str: []const u8, - _: std.fmt.FormatOptions, + comptime fmt: []const u8, + options: std.fmt.Options, writer: *std.io.BufferedWriter, ) anyerror!void { - if (fmt_str.len == 0) { + _ = options; + if (fmt.len == 0) { try writer.print("std.Uri.Component{{ .{s} = \"{}\" }}", .{ @tagName(component), std.zig.fmtEscapes(switch (component) { .raw, .percent_encoded => |string| string, }), }); - } else if (comptime std.mem.eql(u8, fmt_str, "raw")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "raw")) switch (component) { .raw => |raw| try writer.writeAll(raw), .percent_encoded => |percent_encoded| { var start: usize = 0; @@ -72,28 +73,28 @@ pub const Component = union(enum) { } try writer.writeAll(percent_encoded[start..]); }, - } else if (comptime std.mem.eql(u8, fmt_str, "%")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "%")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isUnreserved), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "user")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "user")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isUserChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "password")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "password")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isPasswordChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "host")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "host")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isHostChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "path")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "path")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isPathChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "query")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "query")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isQueryChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else if (comptime std.mem.eql(u8, fmt_str, "fragment")) switch (component) { + } else if (comptime std.mem.eql(u8, fmt, "fragment")) switch (component) { .raw => |raw| try percentEncode(writer, raw, isFragmentChar), .percent_encoded => |percent_encoded| try writer.writeAll(percent_encoded), - } else @compileError("invalid format string '" ++ fmt_str ++ "'"); + } else @compileError("invalid format string '" ++ fmt ++ "'"); } pub fn percentEncode( @@ -227,31 +228,21 @@ pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { pub const WriteToStreamOptions = struct { /// When true, include the scheme part of the URI. scheme: bool = false, - /// When true, include the user and password part of the URI. Ignored if `authority` is false. authentication: bool = false, - /// When true, include the authority part of the URI. authority: bool = false, - /// When true, include the path part of the URI. path: bool = false, - /// When true, include the query part of the URI. Ignored when `path` is false. query: bool = false, - /// When true, include the fragment part of the URI. Ignored when `path` is false. fragment: bool = false, - /// When true, include the port part of the URI. Ignored when `port` is null. port: bool = true, }; -pub fn writeToStream( - uri: Uri, - options: WriteToStreamOptions, - writer: anytype, -) @TypeOf(writer).Error!void { +pub fn writeToStream(uri: Uri, options: WriteToStreamOptions, writer: *std.io.BufferedWriter) anyerror!void { if (options.scheme) { try writer.print("{s}:", .{uri.scheme}); if (options.authority and uri.host != null) { @@ -261,45 +252,40 @@ pub fn writeToStream( if (options.authority) { if (options.authentication and uri.host != null) { if (uri.user) |user| { - try writer.print("{user}", .{user}); + try writer.print("{fuser}", .{user}); if (uri.password) |password| { - try writer.print(":{password}", .{password}); + try writer.print(":{fpassword}", .{password}); } try writer.writeByte('@'); } } if (uri.host) |host| { - try writer.print("{host}", .{host}); + try writer.print("{fhost}", .{host}); if (options.port) { if (uri.port) |port| try writer.print(":{d}", .{port}); } } } if (options.path) { - try writer.print("{path}", .{ + try writer.print("{fpath}", .{ if (uri.path.isEmpty()) Uri.Component{ .percent_encoded = "/" } else uri.path, }); if (options.query) { - if (uri.query) |query| try writer.print("?{query}", .{query}); + if (uri.query) |query| try writer.print("?{fquery}", .{query}); } if (options.fragment) { - if (uri.fragment) |fragment| try writer.print("#{fragment}", .{fragment}); + if (uri.fragment) |fragment| try writer.print("#{ffragment}", .{fragment}); } } } -pub fn format( - uri: Uri, - comptime fmt_str: []const u8, - _: std.fmt.FormatOptions, - writer: anytype, -) @TypeOf(writer).Error!void { - const scheme = comptime std.mem.indexOfScalar(u8, fmt_str, ';') != null or fmt_str.len == 0; - const authentication = comptime std.mem.indexOfScalar(u8, fmt_str, '@') != null or fmt_str.len == 0; - const authority = comptime std.mem.indexOfScalar(u8, fmt_str, '+') != null or fmt_str.len == 0; - const path = comptime std.mem.indexOfScalar(u8, fmt_str, '/') != null or fmt_str.len == 0; - const query = comptime std.mem.indexOfScalar(u8, fmt_str, '?') != null or fmt_str.len == 0; - const fragment = comptime std.mem.indexOfScalar(u8, fmt_str, '#') != null or fmt_str.len == 0; +pub fn format(uri: Uri, comptime fmt: []const u8, _: std.fmt.Options, writer: *std.io.BufferedWriter) anyerror!void { + const scheme = comptime std.mem.indexOfScalar(u8, fmt, ';') != null or fmt.len == 0; + const authentication = comptime std.mem.indexOfScalar(u8, fmt, '@') != null or fmt.len == 0; + const authority = comptime std.mem.indexOfScalar(u8, fmt, '+') != null or fmt.len == 0; + const path = comptime std.mem.indexOfScalar(u8, fmt, '/') != null or fmt.len == 0; + const query = comptime std.mem.indexOfScalar(u8, fmt, '?') != null or fmt.len == 0; + const fragment = comptime std.mem.indexOfScalar(u8, fmt, '#') != null or fmt.len == 0; return writeToStream(uri, .{ .scheme = scheme, @@ -449,7 +435,7 @@ fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Co aux.initFixed(aux_buf.*); if (!base.isEmpty()) { aux.print("{fpath}", .{base}) catch |err| return @errorCast(err); - aux.pos = std.mem.lastIndexOfScalar(u8, aux.getWritten(), '/') orelse + aux.end = std.mem.lastIndexOfScalar(u8, aux.getWritten(), '/') orelse return remove_dot_segments(new); } aux.print("/{s}", .{new}) catch |err| return @errorCast(err); diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 9763989919..c7015b7fd3 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -135,8 +135,8 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { /// The maximum length of the DER encoding is der_encoded_length_max. /// The function returns a slice, that can be shorter than der_encoded_length_max. pub fn toDer(sig: Signature, buf: *[der_encoded_length_max]u8) []u8 { - var fb = io.fixedBufferStream(buf); - const w = fb.writer(); + var w: std.io.BufferedWriter = undefined; + w.initFixed(buf); const r_len = @as(u8, @intCast(sig.r.len + (sig.r[0] >> 7))); const s_len = @as(u8, @intCast(sig.s.len + (sig.s[0] >> 7))); const seq_len = @as(u8, @intCast(2 + r_len + 2 + s_len)); @@ -151,7 +151,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { w.writeByte(0x00) catch unreachable; } w.writeAll(&sig.s) catch unreachable; - return fb.getWritten(); + return w.getWritten(); } // Read a DER-encoded integer. @@ -176,7 +176,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { /// Returns InvalidEncoding if the DER encoding is invalid. pub fn fromDer(der: []const u8) EncodingError!Signature { var sig: Signature = mem.zeroInit(Signature, .{}); - var fb = io.fixedBufferStream(der); + var fb: std.io.FixedBufferStream = .{ .buffer = der }; const reader = fb.reader(); var buf: [2]u8 = undefined; _ = reader.readNoEof(&buf) catch return error.InvalidEncoding; diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 8b5f2ffc8e..74d2806a35 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1594,122 +1594,124 @@ pub fn reader(file: File) Reader { pub fn writer(file: File) std.io.Writer { return .{ - .context = interface.handleToOpaque(file.handle), + .context = handleToOpaque(file.handle), .vtable = &.{ - .writeSplat = interface.writeSplat, - .writeFile = interface.writeFile, + .writeSplat = writer_writeSplat, + .writeFile = writer_writeFile, }, }; } -const interface = struct { - /// Number of slices to store on the stack, when trying to send as many byte - /// vectors through the underlying write calls as possible. - const max_buffers_len = 16; +/// Number of slices to store on the stack, when trying to send as many byte +/// vectors through the underlying write calls as possible. +const max_buffers_len = 16; - fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { - const file = opaqueToHandle(context); - var splat_buffer: [256]u8 = undefined; - if (is_windows) { - if (data.len == 1 and splat == 0) return 0; - return windows.WriteFile(file, data[0], null); - } - var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; - var len: usize = @min(iovecs.len, data.len); - for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ - .base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length. - .len = d.len, - }; - switch (splat) { - 0 => return std.posix.writev(file, iovecs[0 .. len - 1]), - 1 => return std.posix.writev(file, iovecs[0..len]), - else => { - const pattern = data[data.len - 1]; - if (pattern.len == 1) { - const memset_len = @min(splat_buffer.len, splat); - const buf = splat_buffer[0..memset_len]; - @memset(buf, pattern[0]); - iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; - var remaining_splat = splat - buf.len; - while (remaining_splat > 0 and len < iovecs.len) { - iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; - remaining_splat -= splat_buffer.len; - len += 1; - } - return std.posix.writev(file, iovecs[0..len]); +pub fn writer_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + const file = opaqueToHandle(context); + var splat_buffer: [256]u8 = undefined; + if (is_windows) { + if (data.len == 1 and splat == 0) return 0; + return windows.WriteFile(file, data[0], null); + } + var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; + var len: usize = @min(iovecs.len, data.len); + for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ + .base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length. + .len = d.len, + }; + switch (splat) { + 0 => return std.posix.writev(file, iovecs[0 .. len - 1]), + 1 => return std.posix.writev(file, iovecs[0..len]), + else => { + const pattern = data[data.len - 1]; + if (pattern.len == 1) { + const memset_len = @min(splat_buffer.len, splat); + const buf = splat_buffer[0..memset_len]; + @memset(buf, pattern[0]); + iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; + var remaining_splat = splat - buf.len; + while (remaining_splat > splat_buffer.len and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; + remaining_splat -= splat_buffer.len; + len += 1; } - }, - } - return std.posix.writev(file, iovecs[0..len]); + if (remaining_splat > 0 and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; + len += 1; + } + return std.posix.writev(file, iovecs[0..len]); + } + }, } + return std.posix.writev(file, iovecs[0..len]); +} - fn writeFile( - context: *anyopaque, - in_file: std.fs.File, - in_offset: u64, - in_len: std.io.Writer.VTable.FileLen, - headers_and_trailers: []const []const u8, - headers_len: usize, - ) anyerror!usize { - const out_fd = opaqueToHandle(context); - const in_fd = in_file.handle; - const len_int = switch (in_len) { - .zero => return interface.writeSplat(context, headers_and_trailers, 1), - .entire_file => 0, - else => in_len.int(), - }; - var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined; - const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, headers_and_trailers.len)]; - for (iovecs, headers_and_trailers[0..iovecs.len]) |*v, d| v.* = .{ .base = d.ptr, .len = d.len }; - const headers = iovecs[0..@min(headers_len, iovecs.len)]; - const trailers = iovecs[headers.len..]; - const flags = 0; - return posix.sendfile(out_fd, in_fd, in_offset, len_int, headers, trailers, flags) catch |err| switch (err) { - error.Unseekable, - error.FastOpenAlreadyInProgress, - error.MessageTooBig, - error.FileDescriptorNotASocket, - error.NetworkUnreachable, - error.NetworkSubsystemFailed, - => return writeFileUnseekable(out_fd, in_fd, in_offset, in_len, headers_and_trailers, headers_len), +pub fn writer_writeFile( + context: *anyopaque, + in_file: std.fs.File, + in_offset: u64, + in_len: std.io.Writer.VTable.FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, +) anyerror!usize { + const out_fd = opaqueToHandle(context); + const in_fd = in_file.handle; + const len_int = switch (in_len) { + .zero => return writer_writeSplat(context, headers_and_trailers, 1), + .entire_file => 0, + else => in_len.int(), + }; + var iovecs_buffer: [max_buffers_len]std.posix.iovec_const = undefined; + const iovecs = iovecs_buffer[0..@min(iovecs_buffer.len, headers_and_trailers.len)]; + for (iovecs, headers_and_trailers[0..iovecs.len]) |*v, d| v.* = .{ .base = d.ptr, .len = d.len }; + const headers = iovecs[0..@min(headers_len, iovecs.len)]; + const trailers = iovecs[headers.len..]; + const flags = 0; + return posix.sendfile(out_fd, in_fd, in_offset, len_int, headers, trailers, flags) catch |err| switch (err) { + error.Unseekable, + error.FastOpenAlreadyInProgress, + error.MessageTooBig, + error.FileDescriptorNotASocket, + error.NetworkUnreachable, + error.NetworkSubsystemFailed, + => return writeFileUnseekable(out_fd, in_fd, in_offset, in_len, headers_and_trailers, headers_len), - else => |e| return e, - }; - } + else => |e| return e, + }; +} - fn writeFileUnseekable( - out_fd: Handle, - in_fd: Handle, - in_offset: u64, - in_len: std.io.Writer.VTable.FileLen, - headers_and_trailers: []const []const u8, - headers_len: usize, - ) anyerror!usize { - _ = out_fd; - _ = in_fd; - _ = in_offset; - _ = in_len; - _ = headers_and_trailers; - _ = headers_len; - @panic("TODO writeFileUnseekable"); - } +fn writeFileUnseekable( + out_fd: Handle, + in_fd: Handle, + in_offset: u64, + in_len: std.io.Writer.VTable.FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, +) anyerror!usize { + _ = out_fd; + _ = in_fd; + _ = in_offset; + _ = in_len; + _ = headers_and_trailers; + _ = headers_len; + @panic("TODO writeFileUnseekable"); +} - fn handleToOpaque(handle: File.Handle) *anyopaque { - return switch (@typeInfo(Handle)) { - .pointer => @ptrCast(handle), - .int => @ptrFromInt(@as(u32, @bitCast(handle))), - else => @compileError("unhandled"), - }; - } +fn handleToOpaque(handle: Handle) *anyopaque { + return switch (@typeInfo(Handle)) { + .pointer => @ptrCast(handle), + .int => @ptrFromInt(@as(u32, @bitCast(handle))), + else => @compileError("unhandled"), + }; +} - fn opaqueToHandle(userdata: *anyopaque) Handle { - return switch (@typeInfo(Handle)) { - .pointer => @ptrCast(userdata), - .int => @intCast(@intFromPtr(userdata)), - else => @compileError("unhandled"), - }; - } -}; +fn opaqueToHandle(userdata: *anyopaque) Handle { + return switch (@typeInfo(Handle)) { + .pointer => @ptrCast(userdata), + .int => @intCast(@intFromPtr(userdata)), + else => @compileError("unhandled"), + }; +} pub const SeekableStream = io.SeekableStream( File, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1eea06e9c9..664da5524c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -187,6 +187,9 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { stream: net.Stream, + /// Populated when protocol is TLS; this is the writer given to the TLS + /// client, which writes directly to `stream`, unbuffered. + stream_writer: std.io.BufferedWriter, /// undefined unless protocol is tls. tls_client: if (!disable_tls) *std.crypto.tls.Client else void, @@ -210,9 +213,10 @@ pub const Connection = struct { read_start: BufferSize = 0, read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, + read_buf: [buffer_size]u8, + + write_buffer: [buffer_size]u8, + writer: std.io.BufferedWriter, pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; const BufferSize = std.math.IntFittingRange(0, buffer_size); @@ -314,59 +318,7 @@ pub const Connection = struct { pub const Reader = std.io.Reader(*Connection, ReadError, read); pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } - } - - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); - - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; + return .{ .context = conn }; } pub const WriteError = error{ @@ -374,19 +326,12 @@ pub const Connection = struct { UnexpectedWriteFailure, }; - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - /// Closes the connection. pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + _ = conn.tls_client.writeEnd("", true) catch {}; if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); allocator.destroy(conn.tls_client); } @@ -815,15 +760,12 @@ pub const Request = struct { }; } - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; + pub fn send(req: *Request) anyerror!void { + assert(req.transfer_encoding == .none or req.method.requestHasBody()); const connection = req.connection.?; - const w = connection.writer(); + const w = &connection.writer; try req.method.write(w); try w.writeByte(' '); @@ -852,10 +794,7 @@ pub const Request = struct { if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { if (req.uri.user != null or req.uri.password != null) { try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try basic_authorization.write(req.uri, w); try w.writeAll("\r\n"); } } @@ -914,7 +853,7 @@ pub const Request = struct { try w.writeAll("\r\n"); - try connection.flush(); + try connection.writer.flush(); } /// Returns true if the default behavior is required, otherwise handles @@ -953,7 +892,9 @@ pub const Request = struct { return index; } - pub const WaitError = RequestError || SendError || TransferReadError || + /// TODO collapse each error set into its own meta error code, and store + /// the underlying error code as a field on Request + pub const WaitError = RequestError || anyerror || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || error{ TooManyHttpRedirects, @@ -1132,59 +1073,112 @@ pub const Request = struct { return index; } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.Writer(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; + /// Resulting `std.io.Writer` must used after `send` and before `finish`. + pub fn writer(req: *Request) std.io.Writer { + return .{ + .context = req, + .vtable = switch (req.transfer_encoding) { + .chunked => &.{ + .writeSplat = chunked_writeSplat, + .writeFile = chunked_writeFile, + }, + .content_length => &.{ + .writeSplat = cl_writeSplat, + .writeFile = cl_writeFile, + }, + .none => unreachable, }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } + }; } - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } + fn chunked_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + const req: *Request = @ptrCast(@alignCast(context)); + var total: usize = 0; + for (data) |bytes| total += bytes.len; + if (total == 0) return 0; + var iovecs: [max_buffers_len][]const u8 = undefined; + var header_buffer: [30]u8 = undefined; + var header_buffer_writer: std.io.BufferedWriter = undefined; + header_buffer_writer.initFixed(&header_buffer); + header_buffer_writer.print("{x}\r\n", .{total}) catch unreachable; + iovecs[0] = header_buffer_writer.getWritten(); + @memcpy(iovecs[1..][0..data.len], data); + iovecs[data.len + 1] = "\r\n"; + // TODO: only 1 underlying write call + // TODO: don't rely on max_buffers_len exceeding the caller + // TODO: handle splat + _ = splat; + const w = &req.connection.?.writer; + try w.writevAll(iovecs[0 .. data.len + 2]); + return total; } - pub const FinishError = WriteError || error{MessageNotCompleted}; + const max_buffers_len = 16; + + pub fn chunked_writeFile( + context: *anyopaque, + file: std.fs.File, + offset: u64, + len: std.io.Writer.VTable.FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, + ) anyerror!usize { + if (len == .entire_file) return error.Unimplemented; + const req: *Request = @ptrCast(@alignCast(context)); + var total: usize = len.int(); + for (headers_and_trailers) |bytes| total += bytes.len; + if (total == 0) return 0; + var iovecs: [max_buffers_len][]const u8 = undefined; + var header_buffer: [30]u8 = undefined; + var header_buffer_writer: std.io.BufferedWriter = undefined; + header_buffer_writer.initFixed(&header_buffer); + header_buffer_writer.print("{x}\r\n", .{total}) catch unreachable; + iovecs[0] = header_buffer_writer.getWritten(); + @memcpy(iovecs[1..][0..headers_and_trailers.len], headers_and_trailers); + iovecs[headers_and_trailers.len + 1] = "\r\n"; + // TODO: only 1 underlying write call + // TODO: don't rely on max_buffers_len exceeding the caller + const w = &req.connection.?.writer; + try w.writeFileAll(file, .{ + .offset = offset, + .len = len, + .headers_and_trailers = iovecs[0 .. headers_and_trailers.len + 2], + .headers_len = headers_len + 1, + }); + return total; + } + + fn cl_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + const req: *Request = @ptrCast(@alignCast(context)); + const n = try req.connection.?.writer.writeSplat(data, splat); + req.transfer_encoding.content_length -= n; + return n; + } + + pub fn cl_writeFile( + context: *anyopaque, + file: std.fs.File, + offset: u64, + len: std.io.Writer.VTable.FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, + ) anyerror!usize { + const req: *Request = @ptrCast(@alignCast(context)); + const n = try req.connection.?.writer.writeFile(file, offset, len, headers_and_trailers, headers_len); + req.transfer_encoding.content_length -= n; + return n; + } /// Finish the body of a request. This notifies the server that you have no more data to send. /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { + pub fn finish(req: *Request) anyerror!void { switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .chunked => try req.connection.?.writer.writeAll("0\r\n\r\n"), + .content_length => |len| assert(len == 0), .none => {}, } - try req.connection.?.flush(); + try req.connection.?.writer.flush(); } }; @@ -1276,10 +1270,8 @@ pub const basic_authorization = struct { pub const max_password_len = 255; pub const max_value_len = valueLength(max_user_len, max_password_len); - const prefix = "Basic "; - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + return "Basic ".len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); } pub fn valueLengthFromUri(uri: Uri) usize { @@ -1290,6 +1282,13 @@ pub const basic_authorization = struct { } pub fn value(uri: Uri, out: []u8) []u8 { + var bw: std.io.BufferedWriter = undefined; + bw.initFixed(out); + write(uri, &bw) catch unreachable; + return bw.getWritten(); + } + + pub fn write(uri: Uri, out: *std.io.BufferedWriter) anyerror!void { var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; var bw: std.io.BufferedWriter = undefined; bw.initFixed(&buf); @@ -1297,9 +1296,7 @@ pub const basic_authorization = struct { uri.user orelse Uri.Component.empty, uri.password orelse Uri.Component.empty, }) catch unreachable; - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], bw.getWritten()); - return out[0 .. prefix.len + base64.len]; + try out.print("Basic {b64}", .{bw.getWritten()}); } }; @@ -1313,7 +1310,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec .host = host, .port = port, .protocol = protocol, - })) |node| return node; + })) |conn| return conn; if (disable_tls and protocol == .tls) return error.TlsInitializationFailed; @@ -1336,7 +1333,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec conn.* = .{ .stream = stream, + .stream_writer = undefined, .tls_client = undefined, + .read_buf = undefined, + + .write_buffer = undefined, + .writer = undefined, // populated below .protocol = protocol, .host = try client.allocator.dupe(u8, host), @@ -1346,36 +1348,55 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer client.allocator.free(conn.host); - if (protocol == .tls) { - if (disable_tls) unreachable; + switch (protocol) { + .tls => { + if (disable_tls) unreachable; - conn.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.tls_client); + const tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(tls_client); - const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { - const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { - error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, - error.OutOfMemory => return error.OutOfMemory, + const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { + const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { + error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, + error.OutOfMemory => return error.OutOfMemory, + }; + defer client.allocator.free(ssl_key_log_path); + break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ + .truncate = false, + .mode = switch (builtin.os.tag) { + .windows, .wasi => 0, + else => 0o600, + }, + }) catch null; + } else null; + errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); + + conn.stream_writer = .{ + .unbuffered_writer = stream.writer(), + .buffer = &.{}, }; - defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ - .truncate = false, - .mode = switch (builtin.os.tag) { - .windows, .wasi => 0, - else => 0o600, - }, - }) catch null; - } else null; - errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); - conn.tls_client.* = std.crypto.tls.Client.init(stream, .{ - .host = .{ .explicit = host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log_file = ssl_key_log_file, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + tls_client.* = std.crypto.tls.Client.init(stream, &conn.stream_writer, .{ + .host = .{ .explicit = host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log_file = ssl_key_log_file, + }) catch return error.TlsInitializationFailed; + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + tls_client.allow_truncation_attacks = true; + + conn.writer = .{ + .unbuffered_writer = tls_client.writer(), + .buffer = &conn.write_buffer, + }; + conn.tls_client = tls_client; + }, + .plain => { + conn.writer = .{ + .unbuffered_writer = stream.writer(), + .buffer = &conn.write_buffer, + }; + }, } client.connection_pool.addUsed(conn); @@ -1534,14 +1555,14 @@ pub fn connect( return conn; } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || +/// TODO collapse each error set into its own meta error code, and store +/// the underlying error code as a field on Request +pub const RequestError = ConnectTcpError || ConnectErrorPartial || anyerror || std.fmt.ParseIntError || Connection.WriteError || error{ UnsupportedUriScheme, UriMissingHost, - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, }; pub const RequestOptions = struct { @@ -1754,7 +1775,10 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { try req.send(); - if (options.payload) |payload| try req.writeAll(payload); + if (options.payload) |payload| { + var w = req.writer().unbuffered(); + try w.writeAll(payload); + } try req.finish(); try req.wait(); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 06fb9e8cbf..bec647739d 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -113,6 +113,7 @@ fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { .head = Request.Head.parse(s.read_buffer[0..head_end]) catch return error.HttpHeadersInvalid, .reader_state = undefined, + .write_error = undefined, }; } @@ -125,6 +126,8 @@ pub const Request = struct { remaining_content_length: u64, chunk_parser: http.ChunkParser, }, + /// Populated when `error.HttpContinueWriteFailed` is received. + write_error: anyerror, pub const Compression = union(enum) { pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); @@ -327,6 +330,7 @@ pub const Request = struct { .head_end = request_bytes.len, .head = undefined, .reader_state = undefined, + .write_error = undefined, }; var it = request.iterateHeaders(); @@ -391,7 +395,7 @@ pub const Request = struct { request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { + ) anyerror!void { const max_extra_headers = 25; assert(options.status != .@"continue"); assert(options.extra_headers.len <= max_extra_headers); @@ -418,7 +422,8 @@ pub const Request = struct { h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); + var w = request.server.connection.stream.writer().unbuffered(); + try w.writeAll(h.items); return; } h.printAssumeCapacity("{s} {d} {s}\r\n", .{ @@ -438,47 +443,29 @@ pub const Request = struct { } var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs: [max_extra_headers * 4 + 3][]const u8 = undefined; var iovecs_len: usize = 0; - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; + iovecs[iovecs_len] = h.items; iovecs_len += 1; for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; + iovecs[iovecs_len] = header.name; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; + iovecs[iovecs_len] = ": "; iovecs_len += 1; if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; + iovecs[iovecs_len] = header.value; iovecs_len += 1; } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; if (request.head.method != .HEAD) { @@ -491,40 +478,26 @@ pub const Request = struct { .{content.len}, ) catch unreachable; - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; + iovecs[iovecs_len] = chunk_header; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; + iovecs[iovecs_len] = content; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; } - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; + iovecs[iovecs_len] = "0\r\n\r\n"; iovecs_len += 1; } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; + iovecs[iovecs_len] = content; iovecs_len += 1; } } - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + var w = request.server.connection.stream.writer().unbuffered(); + try w.writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { @@ -740,7 +713,10 @@ pub const Request = struct { return out_end; } - pub const ReaderError = Response.WriteError || error{ + pub const ReaderError = error{ + /// Failed to write "100-continue" to the stream. Error value is + /// stored in `Request.write_error`. + HttpContinueWriteFailed, /// The client sent an expect HTTP header value other than /// "100-continue". HttpExpectationFailed, @@ -760,7 +736,11 @@ pub const Request = struct { if (request.head.expect) |expect| { if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + var w = request.server.connection.stream.writer().unbuffered(); + w.writeAll("HTTP/1.1 100 Continue\r\n\r\n") catch |err| { + request.write_error = err; + return error.HttpContinueWriteFailed; + }; request.head.expect = null; } else { return error.HttpExpectationFailed; @@ -845,14 +825,12 @@ pub const Response = struct { chunked, }; - pub const WriteError = net.Stream.WriteError; - /// When using content-length, asserts that the amount of data sent matches /// the value sent in the header, then calls `flush`. /// Otherwise, transfer-encoding: chunked is being used, and it writes the /// end-of-stream message, then flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { + pub fn end(r: *Response) anyerror!void { switch (r.transfer_encoding) { .content_length => |len| { assert(len == 0); // Trips when end() called before all bytes written. @@ -877,7 +855,7 @@ pub const Response = struct { /// flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + pub fn endChunked(r: *Response, options: EndChunkedOptions) anyerror!void { assert(r.transfer_encoding == .chunked); try flush_chunked(r, options.trailers); r.* = undefined; @@ -887,7 +865,7 @@ pub const Response = struct { /// would not exceed the content-length value sent in the HTTP header. /// May return 0, which does not indicate end of stream. The caller decides /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + pub fn write(r: *Response, bytes: []const u8) anyerror!usize { switch (r.transfer_encoding) { .content_length, .none => return @errorCast(cl_writeSplat(r, &.{bytes}, 1)), .chunked => return @errorCast(chunked_writeSplat(r, &.{bytes}, 1)), @@ -916,7 +894,7 @@ pub const Response = struct { return error.Unimplemented; } - fn cl_write(context: *anyopaque, bytes: []const u8) WriteError!usize { + fn cl_write(context: *anyopaque, bytes: []const u8) anyerror!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); var trash: u64 = std.math.maxInt(u64); @@ -932,17 +910,12 @@ pub const Response = struct { if (bytes.len + r.send_buffer_end > r.send_buffer.len) { const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, + var iovecs: [2][]const u8 = .{ + r.send_buffer[r.send_buffer_start..][0..send_buffer_len], + bytes, }; - const n = try r.stream.writev(&iovecs); + var w = r.stream.writer().unbuffered(); + const n = try w.writev(&iovecs); if (n >= send_buffer_len) { // It was enough to reset the buffer. @@ -985,10 +958,10 @@ pub const Response = struct { _ = len; _ = headers_and_trailers; _ = headers_len; - return error.Unimplemented; + return error.Unimplemented; // TODO lower to a call to writeFile on the output } - fn chunked_write(context: *anyopaque, bytes: []const u8) WriteError!usize { + fn chunked_write(context: *anyopaque, bytes: []const u8) anyerror!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); assert(r.transfer_encoding == .chunked); @@ -1001,31 +974,17 @@ pub const Response = struct { var header_buf: [18]u8 = undefined; const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, + var iovecs: [5][]const u8 = .{ + r.send_buffer[r.send_buffer_start .. send_buffer_len - r.chunk_len], + chunk_header, + r.send_buffer[r.send_buffer_end - r.chunk_len ..][0..r.chunk_len], + bytes, + "\r\n", }; // TODO make this writev instead of writevAll, which involves // complicating the logic of this function. - try r.stream.writevAll(&iovecs); + var w = r.stream.writer().unbuffered(); + try w.writevAll(&iovecs); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1041,7 +1000,7 @@ pub const Response = struct { /// If using content-length, asserts that writing these bytes to the client /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { + pub fn writeAll(r: *Response, bytes: []const u8) anyerror!void { var index: usize = 0; while (index < bytes.len) { index += try write(r, bytes[index..]); @@ -1051,20 +1010,21 @@ pub const Response = struct { /// Sends all buffered data to the client. /// This is redundant after calling `end`. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { + pub fn flush(r: *Response) anyerror!void { switch (r.transfer_encoding) { .none, .content_length => return flush_cl(r), .chunked => return flush_chunked(r, null), } } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + fn flush_cl(r: *Response) anyerror!void { + var w = r.stream.writer().unbuffered(); + try w.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); r.send_buffer_start = 0; r.send_buffer_end = 0; } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) anyerror!void { const max_trailers = 25; if (end_trailers) |trailers| assert(trailers.len <= max_trailers); assert(r.transfer_encoding == .chunked); @@ -1072,7 +1032,8 @@ pub const Response = struct { const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; if (r.elide_body) { - try r.stream.writeAll(http_headers); + var w = r.stream.writer().unbuffered(); + try w.writeAll(http_headers); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1082,78 +1043,49 @@ pub const Response = struct { var header_buf: [18]u8 = undefined; const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs: [max_trailers * 4 + 5][]const u8 = undefined; var iovecs_len: usize = 0; - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, - }; + iovecs[iovecs_len] = http_headers; iovecs_len += 1; if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; + iovecs[iovecs_len] = chunk_header; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; + iovecs[iovecs_len] = r.send_buffer[r.send_buffer_end - r.chunk_len ..][0..r.chunk_len]; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; } if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; + iovecs[iovecs_len] = "0\r\n"; iovecs_len += 1; for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; + iovecs[iovecs_len] = trailer.name; iovecs_len += 1; - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; + iovecs[iovecs_len] = ": "; iovecs_len += 1; if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; + iovecs[iovecs_len] = trailer.value; iovecs_len += 1; } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; + iovecs[iovecs_len] = "\r\n"; iovecs_len += 1; } - try r.stream.writevAll(iovecs[0..iovecs_len]); + var w = r.stream.writer().unbuffered(); + try w.writevAll(iovecs[0..iovecs_len]); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index fc00a68ec3..a7b0cbc5d6 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -290,7 +290,7 @@ inline fn intShift(comptime T: type, x: anytype) T { const MockBufferedConnection = struct { pub const buffer_size = 0x2000; - conn: std.io.FixedBufferStream([]const u8), + conn: std.io.FixedBufferStream, buf: [buffer_size]u8 = undefined, start: u16 = 0, end: u16 = 0, @@ -343,27 +343,12 @@ const MockBufferedConnection = struct { return conn.readAtLeast(buffer, 1); } - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; + pub const ReadError = std.io.FixedBufferStream.ReadError || error{EndOfStream}; pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); pub fn reader(conn: *MockBufferedConnection) Reader { return Reader{ .context = conn }; } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } }; test "HeadersParser.read length" { @@ -374,7 +359,7 @@ test "HeadersParser.read length" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), + .conn = .{ .buffer = data }, }; while (true) { // read headers @@ -404,7 +389,7 @@ test "HeadersParser.read chunked" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), + .conn = .{ .buffer = data }, }; while (true) { // read headers @@ -433,7 +418,7 @@ test "HeadersParser.read chunked trailer" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), + .conn = .{ .buffer = data }, }; while (true) { // read headers diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index dc944fbabb..b9861eb612 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -135,7 +135,8 @@ test "HTTP server handles a chunked transfer coding request" { const gpa = std.testing.allocator; const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); defer stream.close(); - try stream.writeAll(request_bytes); + var writer = stream.writer().unbuffered(); + try writer.writeAll(request_bytes); const expected_response = "HTTP/1.1 200 OK\r\n" ++ @@ -276,7 +277,8 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { const gpa = std.testing.allocator; const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); defer stream.close(); - try stream.writeAll(request_bytes); + var writer = stream.writer().unbuffered(); + try writer.writeAll(request_bytes); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -339,7 +341,8 @@ test "receiving arbitrary http headers from the client" { const gpa = std.testing.allocator; const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); defer stream.close(); - try stream.writeAll(request_bytes); + var writer = stream.writer().unbuffered(); + try writer.writeAll(request_bytes); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -960,8 +963,9 @@ test "Server streams both reading and writing" { try req.send(); try req.wait(); - try req.writeAll("one "); - try req.writeAll("fish"); + var w = req.writer().unbuffered(); + try w.writeAll("one "); + try w.writeAll("fish"); try req.finish(); @@ -992,8 +996,9 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + var w = req.writer().unbuffered(); + try w.writeAll("Hello, "); + try w.writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1026,8 +1031,9 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + var w = req.writer().unbuffered(); + try w.writeAll("Hello, "); + try w.writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1080,8 +1086,9 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + var w = req.writer().unbuffered(); + try w.writeAll("Hello, "); + try w.writeAll("World!\n"); try req.finish(); try req.wait(); diff --git a/lib/std/net.zig b/lib/std/net.zig index 9d821c4399..e7d1704749 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -849,9 +849,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { return Stream{ .handle = sockfd }; } -const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ - // TODO: break this up into error sets from the various underlying functions - +// TODO: Instead of having a massive error set, make the error set have categories, and then +// store the sub-error as a diagnostic anyerror value. +const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || anyerror || posix.SocketError || posix.BindError || posix.SetSockOptError || error{ TemporaryNameServerFailure, NameServerFailure, AddressFamilyNotSupported, @@ -1358,16 +1358,23 @@ fn linuxLookupNameFromHosts( var buffered_reader = std.io.bufferedReader(file.reader()); const reader = buffered_reader.reader(); + // TODO: rework buffered reader so that we can use its buffer directly when searching for delimiters var line_buf: [512]u8 = undefined; - while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the reader, to fix parsing - try reader.skipUntilDelimiterOrEof('\n'); - // Use the truncated line. A truncated comment or hostname will be handled correctly. - break :blk &line_buf; - }, - else => |e| return e, - }) |line| { + var line_buf_writer: std.io.BufferedWriter = undefined; + line_buf_writer.initFixed(&line_buf); + while (true) { + const line = if (reader.streamUntilDelimiter(&line_buf_writer, '\n', line_buf.len)) |_| l: { + break :l line_buf_writer.getWritten(); + } else |err| switch (err) { + error.EndOfStream => l: { + if (line_buf_writer.getWritten().len == 0) break; + // Skip to the delimiter in the reader, to fix parsing + try reader.skipUntilDelimiterOrEof('\n'); + // Use the truncated line. A truncated comment or hostname will be handled correctly. + break :l &line_buf; + }, + else => |e| return e, + }; var split_it = mem.splitScalar(u8, line, '#'); const no_comment_line = split_it.first(); @@ -1559,16 +1566,23 @@ fn getResolvConf(allocator: mem.Allocator, rc: *ResolvConf) !void { var buf_reader = std.io.bufferedReader(file.reader()); const stream = buf_reader.reader(); + // TODO: rework buffered reader so that we can use its buffer directly when searching for delimiters var line_buf: [512]u8 = undefined; - while (stream.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { - error.StreamTooLong => blk: { - // Skip to the delimiter in the stream, to fix parsing - try stream.skipUntilDelimiterOrEof('\n'); - // Give an empty line to the while loop, which will be skipped. - break :blk line_buf[0..0]; - }, - else => |e| return e, - }) |line| { + var line_buf_writer: std.io.BufferedWriter = undefined; + line_buf_writer.initFixed(&line_buf); + while (true) { + const line = if (stream.streamUntilDelimiter(&line_buf_writer, '\n', line_buf.len)) |_| l: { + break :l line_buf_writer.getWritten(); + } else |err| switch (err) { + error.EndOfStream => l: { + if (line_buf_writer.getWritten().len == 0) break; + // Skip to the delimiter in the reader, to fix parsing + try stream.skipUntilDelimiterOrEof('\n'); + // Give an empty line to the while loop, which will be skipped. + break :l line_buf[0..0]; + }, + else => |e| return e, + }; const no_comment_line = no_comment_line: { var split = mem.splitScalar(u8, line, '#'); break :no_comment_line split.first(); @@ -1833,7 +1847,9 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) pub const Stream = struct { /// Underlying platform-defined type which may or may not be /// interchangeable with a file system file descriptor. - handle: posix.socket_t, + handle: Handle, + + pub const Handle = if (native_os == .windows) windows.ws2_32.SOCKET else posix.fd_t; pub fn close(s: Stream) void { switch (native_os) { @@ -1843,17 +1859,36 @@ pub const Stream = struct { } pub const ReadError = posix.ReadError; - pub const WriteError = posix.WriteError; + pub const WriteError = posix.SendMsgError || error{ + ConnectionResetByPeer, + SocketNotBound, + MessageTooBig, + NetworkSubsystemFailed, + SystemResources, + SocketNotConnected, + Unexpected, + }; pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); pub fn reader(self: Stream) Reader { return .{ .context = self }; } - pub fn writer(self: Stream) Writer { - return .{ .context = self }; + pub fn writer(stream: Stream) std.io.Writer { + return .{ + .context = handleToOpaque(stream.handle), + .vtable = switch (native_os) { + .windows => &.{ + .writeSplat = windows_writeSplat, + .writeFile = windows_writeFile, + }, + else => &.{ + .writeSplat = posix_writeSplat, + .writeFile = std.fs.File.writer_writeFile, + }, + }, + }; } pub fn read(self: Stream, buffer: []u8) ReadError!usize { @@ -1898,48 +1933,148 @@ pub const Stream = struct { return index; } - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - if (native_os == .windows) { - return windows.WriteFile(self.handle, buffer, null); + fn windows_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + comptime assert(native_os == .windows); + if (data.len == 1 and splat == 0) return 0; + var splat_buffer: [256]u8 = undefined; + var iovecs: [max_buffers_len]windows.WSABUF = undefined; + var len: u32 = @min(iovecs.len, data.len); + for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ + .buf = if (d.len == 0) "" else d.ptr, // TODO: does Windows allow ptr=undefined len=0 ? + .len = d.len, + }; + switch (splat) { + 0 => len -= 1, + 1 => {}, + else => { + const pattern = data[data.len - 1]; + if (pattern.len == 1) { + const memset_len = @min(splat_buffer.len, splat); + const buf = splat_buffer[0..memset_len]; + @memset(buf, pattern[0]); + iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; + var remaining_splat = splat - buf.len; + while (remaining_splat > splat_buffer.len and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; + remaining_splat -= splat_buffer.len; + len += 1; + } + if (remaining_splat > 0 and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; + len += 1; + } + } + }, } - - return posix.write(self.handle, buffer); + const handle = opaqueToHandle(context); + var n: u32 = undefined; + const rc = windows.ws2_32.WSASend(handle, &iovecs, len, &n, 0, null, null); + if (rc == windows.ws2_32.SOCKET_ERROR) switch (windows.ws2_32.WSAGetLastError()) { + .WSAECONNABORTED => return error.ConnectionResetByPeer, + .WSAECONNRESET => return error.ConnectionResetByPeer, + .WSAEFAULT => unreachable, // a pointer is not completely contained in user address space. + .WSAEINPROGRESS, .WSAEINTR => unreachable, // deprecated and removed in WSA 2.2 + .WSAEINVAL => return error.SocketNotBound, + .WSAEMSGSIZE => return error.MessageTooBig, + .WSAENETDOWN => return error.NetworkSubsystemFailed, + .WSAENETRESET => return error.ConnectionResetByPeer, + .WSAENOBUFS => return error.SystemResources, + .WSAENOTCONN => return error.SocketNotConnected, + .WSAENOTSOCK => unreachable, // not a socket + .WSAEOPNOTSUPP => unreachable, // only for message-oriented sockets + .WSAESHUTDOWN => unreachable, // cannot send on a socket after write shutdown + .WSAEWOULDBLOCK => return error.WouldBlock, + .WSANOTINITIALISED => unreachable, // WSAStartup must be called before this function + .WSA_IO_PENDING => unreachable, // not using overlapped I/O + .WSA_OPERATION_ABORTED => unreachable, // not using overlapped I/O + else => |err| return windows.unexpectedWSAError(err), + }; + return n; } - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); + fn posix_writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) anyerror!usize { + const sock_fd = opaqueToHandle(context); + comptime assert(native_os != .windows); + var splat_buffer: [256]u8 = undefined; + var iovecs: [max_buffers_len]std.posix.iovec_const = undefined; + var len: usize = @min(iovecs.len, data.len); + for (iovecs[0..len], data[0..len]) |*v, d| v.* = .{ + .base = if (d.len == 0) "" else d.ptr, // OS sadly checks ptr addr before length. + .len = d.len, + }; + var msg: posix.msghdr_const = .{ + .name = null, + .namelen = 0, + .iov = &iovecs, + .iovlen = len, + .control = null, + .controllen = 0, + .flags = 0, + }; + switch (splat) { + 0 => msg.iovlen = len - 1, + 1 => {}, + else => { + const pattern = data[data.len - 1]; + if (pattern.len == 1) { + const memset_len = @min(splat_buffer.len, splat); + const buf = splat_buffer[0..memset_len]; + @memset(buf, pattern[0]); + iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len }; + var remaining_splat = splat - buf.len; + while (remaining_splat > splat_buffer.len and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len }; + remaining_splat -= splat_buffer.len; + len += 1; + } + if (remaining_splat > 0 and len < iovecs.len) { + iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat }; + len += 1; + } + msg.iovlen = len; + } + }, } + const flags = posix.MSG.NOSIGNAL; + return std.posix.sendmsg(sock_fd, &msg, flags); } - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const posix.iovec_const) WriteError!usize { - return posix.writev(self.handle, iovecs); + fn windows_writeFile( + context: *anyopaque, + in_file: std.fs.File, + in_offset: u64, + in_len: std.io.Writer.VTable.FileLen, + headers_and_trailers: []const []const u8, + headers_len: usize, + ) anyerror!usize { + const len_int = switch (in_len) { + .zero => return windows_writeSplat(context, headers_and_trailers, 1), + .entire_file => std.math.maxInt(usize), + else => in_len.int(), + }; + if (headers_len > 0) return windows_writeSplat(context, headers_and_trailers[0..headers_len], 1); + var file_contents_buffer: [4096]u8 = undefined; + const read_buffer = file_contents_buffer[0..@min(file_contents_buffer.len, len_int)]; + const n = try windows.ReadFile(in_file.handle, read_buffer, in_offset); + return windows_writeSplat(context, &.{read_buffer[0..n]}, 1); } - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []posix.iovec_const) WriteError!void { - if (iovecs.len == 0) return; + const max_buffers_len = 8; - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].len) { - amt -= iovecs[i].len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].base += amt; - iovecs[i].len -= amt; - } + fn handleToOpaque(handle: Handle) *anyopaque { + return switch (@typeInfo(Handle)) { + .pointer => @ptrCast(handle), + .int => @ptrFromInt(@as(u32, @bitCast(handle))), + else => @compileError("unhandled"), + }; + } + + fn opaqueToHandle(userdata: *anyopaque) Handle { + return switch (@typeInfo(Handle)) { + .pointer => @ptrCast(userdata), + .int => @intCast(@intFromPtr(userdata)), + else => @compileError("unhandled"), + }; } }; diff --git a/lib/std/os/windows.zig b/lib/std/os/windows.zig index 3aaff1d60a..33c2e7a548 100644 --- a/lib/std/os/windows.zig +++ b/lib/std/os/windows.zig @@ -1690,40 +1690,6 @@ pub fn getpeername(s: ws2_32.SOCKET, name: *ws2_32.sockaddr, namelen: *ws2_32.so return ws2_32.getpeername(s, name, @as(*i32, @ptrCast(namelen))); } -pub fn sendmsg( - s: ws2_32.SOCKET, - msg: *ws2_32.WSAMSG_const, - flags: u32, -) i32 { - var bytes_send: DWORD = undefined; - if (ws2_32.WSASendMsg(s, msg, flags, &bytes_send, null, null) == ws2_32.SOCKET_ERROR) { - return ws2_32.SOCKET_ERROR; - } else { - return @as(i32, @as(u31, @intCast(bytes_send))); - } -} - -pub fn sendto(s: ws2_32.SOCKET, buf: [*]const u8, len: usize, flags: u32, to: ?*const ws2_32.sockaddr, to_len: ws2_32.socklen_t) i32 { - var buffer = ws2_32.WSABUF{ .len = @as(u31, @truncate(len)), .buf = @constCast(buf) }; - var bytes_send: DWORD = undefined; - if (ws2_32.WSASendTo(s, @as([*]ws2_32.WSABUF, @ptrCast(&buffer)), 1, &bytes_send, flags, to, @as(i32, @intCast(to_len)), null, null) == ws2_32.SOCKET_ERROR) { - return ws2_32.SOCKET_ERROR; - } else { - return @as(i32, @as(u31, @intCast(bytes_send))); - } -} - -pub fn recvfrom(s: ws2_32.SOCKET, buf: [*]u8, len: usize, flags: u32, from: ?*ws2_32.sockaddr, from_len: ?*ws2_32.socklen_t) i32 { - var buffer = ws2_32.WSABUF{ .len = @as(u31, @truncate(len)), .buf = buf }; - var bytes_received: DWORD = undefined; - var flags_inout = flags; - if (ws2_32.WSARecvFrom(s, @as([*]ws2_32.WSABUF, @ptrCast(&buffer)), 1, &bytes_received, &flags_inout, from, @as(?*i32, @ptrCast(from_len)), null, null) == ws2_32.SOCKET_ERROR) { - return ws2_32.SOCKET_ERROR; - } else { - return @as(i32, @as(u31, @intCast(bytes_received))); - } -} - pub fn poll(fds: [*]ws2_32.pollfd, n: c_ulong, timeout: i32) i32 { return ws2_32.WSAPoll(fds, n, timeout); } diff --git a/lib/std/os/windows/ws2_32.zig b/lib/std/os/windows/ws2_32.zig index e8375dc2c1..83194425fa 100644 --- a/lib/std/os/windows/ws2_32.zig +++ b/lib/std/os/windows/ws2_32.zig @@ -1829,7 +1829,7 @@ pub extern "ws2_32" fn sendto( buf: [*]const u8, len: i32, flags: i32, - to: *const sockaddr, + to: ?*const sockaddr, tolen: i32, ) callconv(.winapi) i32; @@ -2116,14 +2116,6 @@ pub extern "ws2_32" fn WSASendMsg( lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, ) callconv(.winapi) i32; -pub extern "ws2_32" fn WSARecvMsg( - s: SOCKET, - lpMsg: *WSAMSG, - lpdwNumberOfBytesRecv: ?*u32, - lpOverlapped: ?*OVERLAPPED, - lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, -) callconv(.winapi) i32; - pub extern "ws2_32" fn WSASendDisconnect( s: SOCKET, lpOutboundDisconnectData: ?*WSABUF,