From afa66f6111c88cda6bf576367f980f2a84c33116 Mon Sep 17 00:00:00 2001 From: lumanetic <396151+h57624paen@users.noreply.github.com> Date: Sat, 26 Jul 2025 01:34:19 -0400 Subject: [PATCH 01/70] std.process.Child: fix double path normalization in spawnWindows besides simply being redundant work, the now removed normalize call would cause spawn to errantly fail (BadPath) when passing a relative path which traversed 'above' the current working directory. This case is already handled by leaving normalization to the windows.wToPrefixedFileW call in windowsCreateProcessPathExt --- lib/std/process/Child.zig | 9 ----- test/standalone/windows_spawn/main.zig | 46 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/lib/std/process/Child.zig b/lib/std/process/Child.zig index 21cc545f12..ce228176f4 100644 --- a/lib/std/process/Child.zig +++ b/lib/std/process/Child.zig @@ -901,11 +901,6 @@ fn spawnWindows(self: *ChildProcess) SpawnError!void { if (dir_buf.items.len > 0) try dir_buf.append(self.allocator, fs.path.sep); try dir_buf.appendSlice(self.allocator, app_dir); } - if (dir_buf.items.len > 0) { - // Need to normalize the path, openDirW can't handle things like double backslashes - const normalized_len = windows.normalizePath(u16, dir_buf.items) catch return error.BadPathName; - dir_buf.shrinkRetainingCapacity(normalized_len); - } windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, flags, &siStartInfo, &piProcInfo) catch |no_path_err| { const original_err = switch (no_path_err) { @@ -930,10 +925,6 @@ fn spawnWindows(self: *ChildProcess) SpawnError!void { while (it.next()) |search_path| { dir_buf.clearRetainingCapacity(); try dir_buf.appendSlice(self.allocator, search_path); - // Need to normalize the path, some PATH values can contain things like double - // backslashes which openDirW can't handle - const normalized_len = windows.normalizePath(u16, dir_buf.items) catch continue; - dir_buf.shrinkRetainingCapacity(normalized_len); if (windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, flags, &siStartInfo, &piProcInfo)) { break :run; diff --git a/test/standalone/windows_spawn/main.zig b/test/standalone/windows_spawn/main.zig index 3b0d0efe75..4cacf14c7c 100644 --- a/test/standalone/windows_spawn/main.zig +++ b/test/standalone/windows_spawn/main.zig @@ -1,4 +1,5 @@ const std = @import("std"); + const windows = std.os.windows; const utf16Literal = std.unicode.utf8ToUtf16LeStringLiteral; @@ -39,6 +40,9 @@ pub fn main() anyerror!void { // No PATH, so it should fail to find anything not in the cwd try testExecError(error.FileNotFound, allocator, "something_missing"); + // make sure we don't get error.BadPath traversing out of cwd with a relative path + try testExecError(error.FileNotFound, allocator, "..\\.\\.\\.\\\\..\\more_missing"); + std.debug.assert(windows.kernel32.SetEnvironmentVariableW( utf16Literal("PATH"), tmp_absolute_path_w, @@ -149,6 +153,48 @@ pub fn main() anyerror!void { // If we try to exec but provide a cwd that is an absolute path, the PATH // should still be searched and the goodbye.exe in something should be found. try testExecWithCwd(allocator, "goodbye", tmp_absolute_path, "hello from exe\n"); + + // introduce some extra path separators into the path which is dealt with inside the spawn call. + const denormed_something_subdir_size = std.mem.replacementSize(u16, something_subdir_abs_path, utf16Literal("\\"), utf16Literal("\\\\\\\\")); + + const denormed_something_subdir_abs_path = try allocator.allocSentinel(u16, denormed_something_subdir_size, 0); + defer allocator.free(denormed_something_subdir_abs_path); + + _ = std.mem.replace(u16, something_subdir_abs_path, utf16Literal("\\"), utf16Literal("\\\\\\\\"), denormed_something_subdir_abs_path); + + const denormed_something_subdir_wtf8 = try std.unicode.wtf16LeToWtf8Alloc(allocator, denormed_something_subdir_abs_path); + defer allocator.free(denormed_something_subdir_wtf8); + + // clear the path to ensure that the match comes from the cwd + std.debug.assert(windows.kernel32.SetEnvironmentVariableW( + utf16Literal("PATH"), + null, + ) == windows.TRUE); + + try testExecWithCwd(allocator, "goodbye", denormed_something_subdir_wtf8, "hello from exe\n"); + + // normalization should also work if the non-normalized path is found in the PATH var. + std.debug.assert(windows.kernel32.SetEnvironmentVariableW( + utf16Literal("PATH"), + denormed_something_subdir_abs_path, + ) == windows.TRUE); + try testExec(allocator, "goodbye", "hello from exe\n"); + + // now make sure we can launch executables "outside" of the cwd + var subdir_cwd = try tmp.dir.openDir(denormed_something_subdir_wtf8, .{}); + defer subdir_cwd.close(); + + try tmp.dir.rename("something/goodbye.exe", "hello.exe"); + try subdir_cwd.setAsCwd(); + + // clear the PATH again + std.debug.assert(windows.kernel32.SetEnvironmentVariableW( + utf16Literal("PATH"), + null, + ) == windows.TRUE); + + // while we're at it make sure non-windows separators work fine + try testExec(allocator, "../hello", "hello from exe\n"); } fn testExecError(err: anyerror, allocator: std.mem.Allocator, command: []const u8) !void { From 0f4106356e298f894fe31704af350da6209019bc Mon Sep 17 00:00:00 2001 From: Ryan Liptak Date: Fri, 25 Jul 2025 00:24:09 -0700 Subject: [PATCH 02/70] child_process standalone test: Test spawning a path with leading .. Also check that FileNotFound is consistently returned when the path is missing. The new `run_relative` step will test spawning paths like: child_path: ../84385e7e669db0967d7a42765011dbe0/child missing_child_path: ../84385e7e669db0967d7a42765011dbe0/child_intentionally_missing --- test/standalone/child_process/build.zig | 11 +++++++++++ test/standalone/child_process/main.zig | 16 ++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/test/standalone/child_process/build.zig b/test/standalone/child_process/build.zig index 81b40835f3..9e96aaa608 100644 --- a/test/standalone/child_process/build.zig +++ b/test/standalone/child_process/build.zig @@ -31,5 +31,16 @@ pub fn build(b: *std.Build) void { run.addArtifactArg(child); run.expectExitCode(0); + // Use a temporary directory within the cache as the CWD to test + // spawning the child using a path that contains a leading `..` component. + const run_relative = b.addRunArtifact(main); + run_relative.addArtifactArg(child); + const write_tmp_dir = b.addWriteFiles(); + const tmp_cwd = write_tmp_dir.getDirectory(); + run_relative.addDirectoryArg(tmp_cwd); + run_relative.setCwd(tmp_cwd); + run_relative.expectExitCode(0); + test_step.dependOn(&run.step); + test_step.dependOn(&run_relative.step); } diff --git a/test/standalone/child_process/main.zig b/test/standalone/child_process/main.zig index 6537f90acf..9ded383d96 100644 --- a/test/standalone/child_process/main.zig +++ b/test/standalone/child_process/main.zig @@ -11,7 +11,14 @@ pub fn main() !void { var it = try std.process.argsWithAllocator(gpa); defer it.deinit(); _ = it.next() orelse unreachable; // skip binary name - const child_path = it.next() orelse unreachable; + const child_path, const needs_free = child_path: { + const child_path = it.next() orelse unreachable; + const cwd_path = it.next() orelse break :child_path .{ child_path, false }; + // If there is a third argument, it is the current CWD somewhere within the cache directory. + // In that case, modify the child path in order to test spawning a path with a leading `..` component. + break :child_path .{ try std.fs.path.relative(gpa, cwd_path, child_path), true }; + }; + defer if (needs_free) gpa.free(child_path); var child = std.process.Child.init(&.{ child_path, "hello arg" }, gpa); child.stdin_behavior = .Pipe; @@ -39,7 +46,12 @@ pub fn main() !void { }, else => |term| testError("abnormal child exit: {}", .{term}), } - return if (parent_test_error) error.ParentTestError else {}; + if (parent_test_error) return error.ParentTestError; + + // Check that FileNotFound is consistent across platforms when trying to spawn an executable that doesn't exist + const missing_child_path = try std.mem.concat(gpa, u8, &.{ child_path, "_intentionally_missing" }); + defer gpa.free(missing_child_path); + try std.testing.expectError(error.FileNotFound, std.process.Child.run(.{ .allocator = gpa, .argv = &.{missing_child_path} })); } var parent_test_error = false; From 259b7c3f3fc3f8902f8aa4e6e7539df271073a47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 30 Jul 2025 18:58:47 +0200 Subject: [PATCH 03/70] std.Target: pull Os.requiresLibC() up to Target --- lib/std/Target.zig | 101 ++++++++++++++++++++++----------------------- src/target.zig | 2 +- 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index c75e2b51fb..ccc8ab79d8 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -697,57 +697,6 @@ pub const Os = struct { => |field| @field(os.version_range, @tagName(field)).isAtLeast(ver), }; } - - /// On Darwin, we always link libSystem which contains libc. - /// Similarly on FreeBSD and NetBSD we always link system libc - /// since this is the stable syscall interface. - pub fn requiresLibC(os: Os) bool { - return switch (os.tag) { - .aix, - .driverkit, - .macos, - .ios, - .tvos, - .watchos, - .visionos, - .dragonfly, - .openbsd, - .haiku, - .solaris, - .illumos, - .serenity, - => true, - - .linux, - .windows, - .freebsd, - .netbsd, - .freestanding, - .fuchsia, - .ps3, - .zos, - .rtems, - .cuda, - .nvcl, - .amdhsa, - .ps4, - .ps5, - .mesa3d, - .contiki, - .amdpal, - .hermit, - .hurd, - .wasi, - .emscripten, - .uefi, - .opencl, - .opengl, - .vulkan, - .plan9, - .other, - => false, - }; - } }; pub const aarch64 = @import("Target/aarch64.zig"); @@ -2055,6 +2004,56 @@ pub inline fn isWasiLibC(target: *const Target) bool { return target.os.tag == .wasi and target.abi.isMusl(); } +/// Does this target require linking libc? This may be the case if the target has an unstable +/// syscall interface, for example. +pub fn requiresLibC(target: *const Target) bool { + return switch (target.os.tag) { + .aix, + .driverkit, + .macos, + .ios, + .tvos, + .watchos, + .visionos, + .dragonfly, + .openbsd, + .haiku, + .solaris, + .illumos, + .serenity, + => true, + + .linux + .windows, + .freebsd, + .netbsd, + .freestanding, + .fuchsia, + .ps3, + .zos, + .rtems, + .cuda, + .nvcl, + .amdhsa, + .ps4, + .ps5, + .mesa3d, + .contiki, + .amdpal, + .hermit, + .hurd, + .wasi, + .emscripten, + .uefi, + .opencl, + .opengl, + .vulkan, + .plan9, + .other, + => false, + }; +} + pub const DynamicLinker = struct { /// Contains the memory used to store the dynamic linker path. This field /// should not be used directly. See `get` and `set`. This field exists so diff --git a/src/target.zig b/src/target.zig index ad83414c23..dbcfe341a4 100644 --- a/src/target.zig +++ b/src/target.zig @@ -20,7 +20,7 @@ pub fn cannotDynamicLink(target: *const std.Target) bool { /// Similarly on FreeBSD and NetBSD we always link system libc /// since this is the stable syscall interface. pub fn osRequiresLibC(target: *const std.Target) bool { - return target.os.requiresLibC(); + return target.requiresLibC(); } pub fn libCNeedsLibUnwind(target: *const std.Target, link_mode: std.builtin.LinkMode) bool { From 03facba4963abd142985c9617ca38de46b1615f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 30 Jul 2025 18:59:08 +0200 Subject: [PATCH 04/70] std.Target: require libc for Android API levels prior to 29 Emulated TLS depends on libc pthread functions. Closes #24589. --- lib/std/Target.zig | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index ccc8ab79d8..fb9149c759 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -2023,7 +2023,12 @@ pub fn requiresLibC(target: *const Target) bool { .serenity, => true, - .linux + // Android API levels prior to 29 did not have native TLS support. For these API levels, TLS + // is implemented through calls to `__emutls_get_address`. We provide this function in + // compiler-rt, but it's implemented by way of `pthread_key_create` et al, so linking libc + // is required. + .linux => target.abi.isAndroid() and target.os.version_range.linux.android < 29, + .windows, .freebsd, .netbsd, From 17330867eb8f9e24ec8aadf23592e36f90f9dcb3 Mon Sep 17 00:00:00 2001 From: David Rubin Date: Wed, 19 Feb 2025 09:46:30 -0800 Subject: [PATCH 05/70] Sema: compile error on reifying align(0) struct fields --- src/Sema.zig | 63 ++++++---------- test/cases/compile_errors/align_zero.zig | 74 +++++++++++++++---- .../compile_errors/bad_alignment_type.zig | 4 +- ...eify_type_with_invalid_field_alignment.zig | 6 +- 4 files changed, 87 insertions(+), 60 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index d7add0724d..b855e4cc9d 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2649,7 +2649,13 @@ pub fn analyzeAsAlign( src: LazySrcLoc, air_ref: Air.Inst.Ref, ) !Alignment { - const alignment_big = try sema.analyzeAsInt(block, src, air_ref, align_ty, .{ .simple = .@"align" }); + const alignment_big = try sema.analyzeAsInt( + block, + src, + air_ref, + align_ty, + .{ .simple = .@"align" }, + ); return sema.validateAlign(block, src, alignment_big); } @@ -18807,7 +18813,7 @@ fn zirPtrType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air const abi_align: Alignment = if (inst_data.flags.has_align) blk: { const ref: Zir.Inst.Ref = @enumFromInt(sema.code.extra[extra_i]); extra_i += 1; - const coerced = try sema.coerce(block, .u32, try sema.resolveInst(ref), align_src); + const coerced = try sema.coerce(block, align_ty, try sema.resolveInst(ref), align_src); const val = try sema.resolveConstDefinedValue(block, align_src, coerced, .{ .simple = .@"align" }); // Check if this happens to be the lazy alignment of our element type, in // which case we can make this 0 without resolving it. @@ -20325,15 +20331,11 @@ fn zirReify( try ip.getOrPutString(gpa, pt.tid, "sentinel_ptr", .no_embedded_nulls), ).?); - if (!try sema.intFitsInType(alignment_val, .u32, null)) { - return sema.fail(block, src, "alignment must fit in 'u32'", .{}); + if (!try sema.intFitsInType(alignment_val, align_ty, null)) { + return sema.fail(block, src, "alignment must fit in '{}'", .{align_ty.fmt(pt)}); } - const alignment_val_int = try alignment_val.toUnsignedIntSema(pt); - if (alignment_val_int > 0 and !math.isPowerOfTwo(alignment_val_int)) { - return sema.fail(block, src, "alignment value '{d}' is not a power of two or zero", .{alignment_val_int}); - } - const abi_align = Alignment.fromByteUnits(alignment_val_int); + const abi_align = try sema.validateAlign(block, src, alignment_val_int); const elem_ty = child_val.toType(); if (abi_align != .none) { @@ -21017,11 +21019,7 @@ fn reifyUnion( field_ty.* = field_type_val.toIntern(); if (any_aligns) { const byte_align = try (try field_info.fieldValue(pt, 2)).toUnsignedIntSema(pt); - if (byte_align > 0 and !math.isPowerOfTwo(byte_align)) { - // TODO: better source location - return sema.fail(block, src, "alignment value '{d}' is not a power of two or zero", .{byte_align}); - } - field_aligns[field_idx] = Alignment.fromByteUnits(byte_align); + field_aligns[field_idx] = try sema.validateAlign(block, src, byte_align); } } @@ -21062,11 +21060,7 @@ fn reifyUnion( field_ty.* = field_type_val.toIntern(); if (any_aligns) { const byte_align = try (try field_info.fieldValue(pt, 2)).toUnsignedIntSema(pt); - if (byte_align > 0 and !math.isPowerOfTwo(byte_align)) { - // TODO: better source location - return sema.fail(block, src, "alignment value '{d}' is not a power of two or zero", .{byte_align}); - } - field_aligns[field_idx] = Alignment.fromByteUnits(byte_align); + field_aligns[field_idx] = try sema.validateAlign(block, src, byte_align); } } @@ -21266,7 +21260,6 @@ fn reifyStruct( var any_comptime_fields = false; var any_default_inits = false; - var any_aligned_fields = false; for (0..fields_len) |field_idx| { const field_info = try fields_val.elemValue(pt, field_idx); @@ -21301,11 +21294,6 @@ fn reifyStruct( if (field_is_comptime) any_comptime_fields = true; if (field_default_value != .none) any_default_inits = true; - switch (try field_alignment_val.orderAgainstZeroSema(pt)) { - .eq => {}, - .gt => any_aligned_fields = true, - .lt => unreachable, - } } const tracked_inst = try block.trackZir(inst); @@ -21317,7 +21305,7 @@ fn reifyStruct( .requires_comptime = .unknown, .any_comptime_fields = any_comptime_fields, .any_default_inits = any_default_inits, - .any_aligned_fields = any_aligned_fields, + .any_aligned_fields = true, .inits_resolved = true, .key = .{ .reified = .{ .zir_index = tracked_inst, @@ -21361,21 +21349,14 @@ fn reifyStruct( return sema.fail(block, src, "duplicate struct field name {f}", .{field_name.fmt(ip)}); } - if (any_aligned_fields) { - if (!try sema.intFitsInType(field_alignment_val, .u32, null)) { - return sema.fail(block, src, "alignment must fit in 'u32'", .{}); - } - - const byte_align = try field_alignment_val.toUnsignedIntSema(pt); - if (byte_align == 0) { - if (layout != .@"packed") { - struct_type.field_aligns.get(ip)[field_idx] = .none; - } - } else { - if (layout == .@"packed") return sema.fail(block, src, "alignment in a packed struct field must be set to 0", .{}); - if (!math.isPowerOfTwo(byte_align)) return sema.fail(block, src, "alignment value '{d}' is not a power of two or zero", .{byte_align}); - struct_type.field_aligns.get(ip)[field_idx] = Alignment.fromNonzeroByteUnits(byte_align); - } + if (!try sema.intFitsInType(field_alignment_val, align_ty, null)) { + return sema.fail(block, src, "alignment must fit in '{f}'", .{align_ty.fmt(pt)}); + } + const byte_align = try field_alignment_val.toUnsignedIntSema(pt); + if (layout == .@"packed") { + if (byte_align != 0) return sema.fail(block, src, "alignment in a packed struct field must be set to 0", .{}); + } else { + struct_type.field_aligns.get(ip)[field_idx] = try sema.validateAlign(block, src, byte_align); } const field_is_comptime = field_is_comptime_val.toBool(); diff --git a/test/cases/compile_errors/align_zero.zig b/test/cases/compile_errors/align_zero.zig index a63523b853..e6d1a993d4 100644 --- a/test/cases/compile_errors/align_zero.zig +++ b/test/cases/compile_errors/align_zero.zig @@ -1,52 +1,98 @@ -pub var global_var: i32 align(0) = undefined; +var global_var: i32 align(0) = undefined; -pub export fn a() void { +export fn a() void { _ = &global_var; } -pub extern var extern_var: i32 align(0); +extern var extern_var: i32 align(0); -pub export fn b() void { +export fn b() void { _ = &extern_var; } -pub export fn c() align(0) void {} +export fn c() align(0) void {} -pub export fn d() void { +export fn d() void { _ = *align(0) fn () i32; } -pub export fn e() void { +export fn e() void { var local_var: i32 align(0) = undefined; _ = &local_var; } -pub export fn f() void { +export fn f() void { _ = *align(0) i32; } -pub export fn g() void { +export fn g() void { _ = []align(0) i32; } -pub export fn h() void { +export fn h() void { _ = struct { field: i32 align(0) }; } -pub export fn i() void { +export fn i() void { _ = union { field: i32 align(0) }; } +export fn j() void { + _ = @Type(.{ .@"struct" = .{ + .layout = .auto, + .fields = &.{.{ + .name = "test", + .type = u32, + .default_value_ptr = null, + .is_comptime = false, + .alignment = 0, + }}, + .decls = &.{}, + .is_tuple = false, + } }); +} + +export fn k() void { + _ = @Type(.{ .pointer = .{ + .size = .one, + .is_const = false, + .is_volatile = false, + .alignment = 0, + .address_space = .generic, + .child = u32, + .is_allowzero = false, + .sentinel_ptr = null, + } }); +} + +export fn l() void { + _ = @Type(.{ .@"struct" = .{ + .layout = .@"packed", + .fields = &.{.{ + .name = "test", + .type = u32, + .default_value_ptr = null, + .is_comptime = false, + .alignment = 8, + }}, + .decls = &.{}, + .is_tuple = false, + } }); +} + // error // backend=stage2 // target=native // -// :1:31: error: alignment must be >= 1 -// :7:38: error: alignment must be >= 1 -// :13:25: error: alignment must be >= 1 +// :1:27: error: alignment must be >= 1 +// :7:34: error: alignment must be >= 1 +// :13:21: error: alignment must be >= 1 // :16:16: error: alignment must be >= 1 // :20:30: error: alignment must be >= 1 // :25:16: error: alignment must be >= 1 // :29:17: error: alignment must be >= 1 // :33:35: error: alignment must be >= 1 // :37:34: error: alignment must be >= 1 +// :41:9: error: alignment can only be 0 on packed struct fields +// :56:9: error: alignment must be >= 1 +// :69:9: error: alignment in a packed struct field must be set to 0 diff --git a/test/cases/compile_errors/bad_alignment_type.zig b/test/cases/compile_errors/bad_alignment_type.zig index c85eb8427d..c03f05d0ad 100644 --- a/test/cases/compile_errors/bad_alignment_type.zig +++ b/test/cases/compile_errors/bad_alignment_type.zig @@ -11,5 +11,5 @@ export fn entry2() void { // backend=stage2 // target=native // -// :2:22: error: expected type 'u32', found 'bool' -// :6:21: error: fractional component prevents float value '12.34' from coercion to type 'u32' +// :2:22: error: expected type 'u29', found 'bool' +// :6:21: error: fractional component prevents float value '12.34' from coercion to type 'u29' diff --git a/test/cases/compile_errors/reify_type_with_invalid_field_alignment.zig b/test/cases/compile_errors/reify_type_with_invalid_field_alignment.zig index 0fcb4ba7fc..a04f3e957c 100644 --- a/test/cases/compile_errors/reify_type_with_invalid_field_alignment.zig +++ b/test/cases/compile_errors/reify_type_with_invalid_field_alignment.zig @@ -43,6 +43,6 @@ comptime { // error // -// :2:9: error: alignment value '3' is not a power of two or zero -// :14:9: error: alignment value '5' is not a power of two or zero -// :30:9: error: alignment value '7' is not a power of two or zero +// :2:9: error: alignment value '3' is not a power of two +// :14:9: error: alignment value '5' is not a power of two +// :30:9: error: alignment value '7' is not a power of two From d6c74a95fdaba4ed373f80baa725dcc53b27e402 Mon Sep 17 00:00:00 2001 From: David Rubin Date: Mon, 24 Feb 2025 04:02:06 -0800 Subject: [PATCH 06/70] remove usages of `.alignment = 0` --- lib/compiler/aro/aro/Attribute.zig | 2 +- lib/std/meta.zig | 2 +- lib/std/zig/llvm/Builder.zig | 7 +++-- src/InternPool.zig | 38 ++++++++++++++---------- src/Sema.zig | 2 +- src/codegen/aarch64/Assemble.zig | 17 ++++++----- test/behavior/tuple.zig | 6 ++-- test/behavior/type.zig | 2 +- test/cases/compile_errors/align_zero.zig | 20 +------------ 9 files changed, 45 insertions(+), 51 deletions(-) diff --git a/lib/compiler/aro/aro/Attribute.zig b/lib/compiler/aro/aro/Attribute.zig index a5b78b8463..4db287b65c 100644 --- a/lib/compiler/aro/aro/Attribute.zig +++ b/lib/compiler/aro/aro/Attribute.zig @@ -708,7 +708,7 @@ pub const Arguments = blk: { field.* = .{ .name = decl.name, .type = @field(attributes, decl.name), - .alignment = 0, + .alignment = @alignOf(@field(attributes, decl.name)), }; } diff --git a/lib/std/meta.zig b/lib/std/meta.zig index 0cee23cfa8..65b7d60c18 100644 --- a/lib/std/meta.zig +++ b/lib/std/meta.zig @@ -939,7 +939,7 @@ fn CreateUniqueTuple(comptime N: comptime_int, comptime types: [N]type) type { .type = T, .default_value_ptr = null, .is_comptime = false, - .alignment = 0, + .alignment = @alignOf(T), }; } diff --git a/lib/std/zig/llvm/Builder.zig b/lib/std/zig/llvm/Builder.zig index f3ff63ec33..ba6faaec2c 100644 --- a/lib/std/zig/llvm/Builder.zig +++ b/lib/std/zig/llvm/Builder.zig @@ -8533,18 +8533,19 @@ pub const Metadata = enum(u32) { .type = []const u8, .default_value_ptr = null, .is_comptime = false, - .alignment = 0, + .alignment = @alignOf([]const u8), }; } fmt_str = fmt_str ++ "("; inline for (fields[2..], names) |*field, name| { fmt_str = fmt_str ++ "{[" ++ name ++ "]f}"; + const T = std.fmt.Formatter(FormatData, format); field.* = .{ .name = name, - .type = std.fmt.Formatter(FormatData, format), + .type = T, .default_value_ptr = null, .is_comptime = false, - .alignment = 0, + .alignment = @alignOf(T), }; } fmt_str = fmt_str ++ ")\n"; diff --git a/src/InternPool.zig b/src/InternPool.zig index 15d895aed0..ed922db742 100644 --- a/src/InternPool.zig +++ b/src/InternPool.zig @@ -1137,13 +1137,16 @@ const Local = struct { const elem_info = @typeInfo(Elem).@"struct"; const elem_fields = elem_info.fields; var new_fields: [elem_fields.len]std.builtin.Type.StructField = undefined; - for (&new_fields, elem_fields) |*new_field, elem_field| new_field.* = .{ - .name = elem_field.name, - .type = *[len]elem_field.type, - .default_value_ptr = null, - .is_comptime = false, - .alignment = 0, - }; + for (&new_fields, elem_fields) |*new_field, elem_field| { + const T = *[len]elem_field.type; + new_field.* = .{ + .name = elem_field.name, + .type = T, + .default_value_ptr = null, + .is_comptime = false, + .alignment = @alignOf(T), + }; + } return @Type(.{ .@"struct" = .{ .layout = .auto, .fields = &new_fields, @@ -1158,22 +1161,25 @@ const Local = struct { const elem_info = @typeInfo(Elem).@"struct"; const elem_fields = elem_info.fields; var new_fields: [elem_fields.len]std.builtin.Type.StructField = undefined; - for (&new_fields, elem_fields) |*new_field, elem_field| new_field.* = .{ - .name = elem_field.name, - .type = @Type(.{ .pointer = .{ + for (&new_fields, elem_fields) |*new_field, elem_field| { + const T = @Type(.{ .pointer = .{ .size = opts.size, .is_const = opts.is_const, .is_volatile = false, - .alignment = 0, + .alignment = @alignOf(elem_field.type), .address_space = .generic, .child = elem_field.type, .is_allowzero = false, .sentinel_ptr = null, - } }), - .default_value_ptr = null, - .is_comptime = false, - .alignment = 0, - }; + } }); + new_field.* = .{ + .name = elem_field.name, + .type = T, + .default_value_ptr = null, + .is_comptime = false, + .alignment = @alignOf(T), + }; + } return @Type(.{ .@"struct" = .{ .layout = .auto, .fields = &new_fields, diff --git a/src/Sema.zig b/src/Sema.zig index b855e4cc9d..3bee1130af 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -20332,7 +20332,7 @@ fn zirReify( ).?); if (!try sema.intFitsInType(alignment_val, align_ty, null)) { - return sema.fail(block, src, "alignment must fit in '{}'", .{align_ty.fmt(pt)}); + return sema.fail(block, src, "alignment must fit in '{f}'", .{align_ty.fmt(pt)}); } const alignment_val_int = try alignment_val.toUnsignedIntSema(pt); const abi_align = try sema.validateAlign(block, src, alignment_val_int); diff --git a/src/codegen/aarch64/Assemble.zig b/src/codegen/aarch64/Assemble.zig index 494e012d80..098065bb16 100644 --- a/src/codegen/aarch64/Assemble.zig +++ b/src/codegen/aarch64/Assemble.zig @@ -33,13 +33,16 @@ pub fn nextInstruction(as: *Assemble) !?Instruction { var symbols: Symbols: { const symbols = @typeInfo(@TypeOf(instruction.symbols)).@"struct".fields; var symbol_fields: [symbols.len]std.builtin.Type.StructField = undefined; - for (&symbol_fields, symbols) |*symbol_field, symbol| symbol_field.* = .{ - .name = symbol.name, - .type = zonCast(SymbolSpec, @field(instruction.symbols, symbol.name), .{}).Storage(), - .default_value_ptr = null, - .is_comptime = false, - .alignment = 0, - }; + for (&symbol_fields, symbols) |*symbol_field, symbol| { + const Storage = zonCast(SymbolSpec, @field(instruction.symbols, symbol.name), .{}).Storage(); + symbol_field.* = .{ + .name = symbol.name, + .type = Storage, + .default_value_ptr = null, + .is_comptime = false, + .alignment = @alignOf(Storage), + }; + } break :Symbols @Type(.{ .@"struct" = .{ .layout = .auto, .fields = &symbol_fields, diff --git a/test/behavior/tuple.zig b/test/behavior/tuple.zig index e760455e09..9e17d61856 100644 --- a/test/behavior/tuple.zig +++ b/test/behavior/tuple.zig @@ -318,6 +318,8 @@ test "tuple type with void field" { test "zero sized struct in tuple handled correctly" { const State = struct { const Self = @This(); + const Inner = struct {}; + data: @Type(.{ .@"struct" = .{ .is_tuple = true, @@ -325,10 +327,10 @@ test "zero sized struct in tuple handled correctly" { .decls = &.{}, .fields = &.{.{ .name = "0", - .type = struct {}, + .type = Inner, .default_value_ptr = null, .is_comptime = false, - .alignment = 0, + .alignment = @alignOf(Inner), }}, }, }), diff --git a/test/behavior/type.zig b/test/behavior/type.zig index 58e049d896..b5ac2d95f7 100644 --- a/test/behavior/type.zig +++ b/test/behavior/type.zig @@ -735,7 +735,7 @@ test "struct field names sliced at comptime from larger string" { var it = std.mem.tokenizeScalar(u8, text, '\n'); while (it.next()) |name| { fields = fields ++ &[_]Type.StructField{.{ - .alignment = 0, + .alignment = @alignOf(usize), .name = name ++ "", .type = usize, .default_value_ptr = null, diff --git a/test/cases/compile_errors/align_zero.zig b/test/cases/compile_errors/align_zero.zig index e6d1a993d4..e54a32ce31 100644 --- a/test/cases/compile_errors/align_zero.zig +++ b/test/cases/compile_errors/align_zero.zig @@ -65,24 +65,7 @@ export fn k() void { } }); } -export fn l() void { - _ = @Type(.{ .@"struct" = .{ - .layout = .@"packed", - .fields = &.{.{ - .name = "test", - .type = u32, - .default_value_ptr = null, - .is_comptime = false, - .alignment = 8, - }}, - .decls = &.{}, - .is_tuple = false, - } }); -} - // error -// backend=stage2 -// target=native // // :1:27: error: alignment must be >= 1 // :7:34: error: alignment must be >= 1 @@ -93,6 +76,5 @@ export fn l() void { // :29:17: error: alignment must be >= 1 // :33:35: error: alignment must be >= 1 // :37:34: error: alignment must be >= 1 -// :41:9: error: alignment can only be 0 on packed struct fields +// :41:9: error: alignment must be >= 1 // :56:9: error: alignment must be >= 1 -// :69:9: error: alignment in a packed struct field must be set to 0 From 5678a600ffb2ec27f88ab6a308b4eaf56c60daed Mon Sep 17 00:00:00 2001 From: David Rubin Date: Fri, 1 Aug 2025 12:00:34 -0700 Subject: [PATCH 07/70] refactor `reifyUnion` alignment handling --- src/Sema.zig | 49 +++++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 3bee1130af..8ea0a19fe8 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -20912,8 +20912,6 @@ fn reifyUnion( std.hash.autoHash(&hasher, opt_tag_type_val.toIntern()); std.hash.autoHash(&hasher, fields_len); - var any_aligns = false; - for (0..fields_len) |field_idx| { const field_info = try fields_val.elemValue(pt, field_idx); @@ -20922,16 +20920,11 @@ fn reifyUnion( const field_align_val = try sema.resolveLazyValue(try field_info.fieldValue(pt, 2)); const field_name = try sema.sliceToIpString(block, src, field_name_val, .{ .simple = .union_field_name }); - std.hash.autoHash(&hasher, .{ field_name, field_type_val.toIntern(), field_align_val.toIntern(), }); - - if (field_align_val.toUnsignedInt(zcu) != 0) { - any_aligns = true; - } } const tracked_inst = try block.trackZir(inst); @@ -20948,7 +20941,7 @@ fn reifyUnion( true => .safety, false => .none, }, - .any_aligned_fields = any_aligns, + .any_aligned_fields = layout != .@"packed", .requires_comptime = .unknown, .assumed_runtime_bits = false, .assumed_pointer_aligned = false, @@ -20981,8 +20974,7 @@ fn reifyUnion( ); wip_ty.setName(ip, type_name.name, type_name.nav); - const field_types = try sema.arena.alloc(InternPool.Index, fields_len); - const field_aligns = if (any_aligns) try sema.arena.alloc(InternPool.Alignment, fields_len) else undefined; + const loaded_union = ip.loadUnionType(wip_ty.index); const enum_tag_ty, const has_explicit_tag = if (opt_tag_type_val.optionalValue(zcu)) |tag_type_val| tag_ty: { switch (ip.indexToKey(tag_type_val.toIntern())) { @@ -20995,11 +20987,12 @@ fn reifyUnion( const tag_ty_fields_len = enum_tag_ty.enumFieldCount(zcu); var seen_tags = try std.DynamicBitSetUnmanaged.initEmpty(sema.arena, tag_ty_fields_len); - for (field_types, 0..) |*field_ty, field_idx| { + for (0..fields_len) |field_idx| { const field_info = try fields_val.elemValue(pt, field_idx); const field_name_val = try field_info.fieldValue(pt, 0); const field_type_val = try field_info.fieldValue(pt, 1); + const field_alignment_val = try field_info.fieldValue(pt, 2); // Don't pass a reason; first loop acts as an assertion that this is valid. const field_name = try sema.sliceToIpString(block, src, field_name_val, undefined); @@ -21016,10 +21009,12 @@ fn reifyUnion( } seen_tags.set(enum_index); - field_ty.* = field_type_val.toIntern(); - if (any_aligns) { - const byte_align = try (try field_info.fieldValue(pt, 2)).toUnsignedIntSema(pt); - field_aligns[field_idx] = try sema.validateAlign(block, src, byte_align); + loaded_union.field_types.get(ip)[field_idx] = field_type_val.toIntern(); + const byte_align = try field_alignment_val.toUnsignedIntSema(pt); + if (layout == .@"packed") { + if (byte_align != 0) return sema.fail(block, src, "alignment of a packed union field must be set to 0", .{}); + } else { + loaded_union.field_aligns.get(ip)[field_idx] = try sema.validateAlign(block, src, byte_align); } } @@ -21043,11 +21038,12 @@ fn reifyUnion( var field_names: std.AutoArrayHashMapUnmanaged(InternPool.NullTerminatedString, void) = .empty; try field_names.ensureTotalCapacity(sema.arena, fields_len); - for (field_types, 0..) |*field_ty, field_idx| { + for (0..fields_len) |field_idx| { const field_info = try fields_val.elemValue(pt, field_idx); const field_name_val = try field_info.fieldValue(pt, 0); const field_type_val = try field_info.fieldValue(pt, 1); + const field_alignment_val = try field_info.fieldValue(pt, 2); // Don't pass a reason; first loop acts as an assertion that this is valid. const field_name = try sema.sliceToIpString(block, src, field_name_val, undefined); @@ -21057,10 +21053,12 @@ fn reifyUnion( return sema.fail(block, src, "duplicate union field {f}", .{field_name.fmt(ip)}); } - field_ty.* = field_type_val.toIntern(); - if (any_aligns) { - const byte_align = try (try field_info.fieldValue(pt, 2)).toUnsignedIntSema(pt); - field_aligns[field_idx] = try sema.validateAlign(block, src, byte_align); + loaded_union.field_types.get(ip)[field_idx] = field_type_val.toIntern(); + const byte_align = try field_alignment_val.toUnsignedIntSema(pt); + if (layout == .@"packed") { + if (byte_align != 0) return sema.fail(block, src, "alignment of a packed union field must be set to 0", .{}); + } else { + loaded_union.field_aligns.get(ip)[field_idx] = try sema.validateAlign(block, src, byte_align); } } @@ -21069,7 +21067,7 @@ fn reifyUnion( }; errdefer if (!has_explicit_tag) ip.remove(pt.tid, enum_tag_ty); // remove generated tag type on error - for (field_types) |field_ty_ip| { + for (loaded_union.field_types.get(ip)) |field_ty_ip| { const field_ty: Type = .fromInterned(field_ty_ip); if (field_ty.zigTypeTag(zcu) == .@"opaque") { return sema.failWithOwnedErrorMsg(block, msg: { @@ -21103,11 +21101,6 @@ fn reifyUnion( } } - const loaded_union = ip.loadUnionType(wip_ty.index); - loaded_union.setFieldTypes(ip, field_types); - if (any_aligns) { - loaded_union.setFieldAligns(ip, field_aligns); - } loaded_union.setTagType(ip, enum_tag_ty); loaded_union.setStatus(ip, .have_field_types); @@ -21305,7 +21298,7 @@ fn reifyStruct( .requires_comptime = .unknown, .any_comptime_fields = any_comptime_fields, .any_default_inits = any_default_inits, - .any_aligned_fields = true, + .any_aligned_fields = layout != .@"packed", .inits_resolved = true, .key = .{ .reified = .{ .zir_index = tracked_inst, @@ -21354,7 +21347,7 @@ fn reifyStruct( } const byte_align = try field_alignment_val.toUnsignedIntSema(pt); if (layout == .@"packed") { - if (byte_align != 0) return sema.fail(block, src, "alignment in a packed struct field must be set to 0", .{}); + if (byte_align != 0) return sema.fail(block, src, "alignment of a packed struct field must be set to 0", .{}); } else { struct_type.field_aligns.get(ip)[field_idx] = try sema.validateAlign(block, src, byte_align); } From abf1795337b3781ab65da04e09e8e3bb85139e33 Mon Sep 17 00:00:00 2001 From: mlugg Date: Wed, 30 Jul 2025 00:07:47 +0100 Subject: [PATCH 08/70] std.Build.Watch: add macOS implementation based on FSEventStream Resolves: #21905 --- lib/compiler/build_runner.zig | 2 +- lib/std/Build/Watch.zig | 39 ++- lib/std/Build/Watch/FsEvents.zig | 493 +++++++++++++++++++++++++++++++ 3 files changed, 530 insertions(+), 4 deletions(-) create mode 100644 lib/std/Build/Watch/FsEvents.zig diff --git a/lib/compiler/build_runner.zig b/lib/compiler/build_runner.zig index 2e84309ccd..e97b7aa313 100644 --- a/lib/compiler/build_runner.zig +++ b/lib/compiler/build_runner.zig @@ -511,7 +511,7 @@ pub fn main() !void { // recursive dependants. var caption_buf: [std.Progress.Node.max_name_len]u8 = undefined; const caption = std.fmt.bufPrint(&caption_buf, "watching {d} directories, {d} processes", .{ - w.dir_table.entries.len, countSubProcesses(run.step_stack.keys()), + w.dir_count, countSubProcesses(run.step_stack.keys()), }) catch &caption_buf; var debouncing_node = main_progress_node.start(caption, 0); var in_debounce = false; diff --git a/lib/std/Build/Watch.zig b/lib/std/Build/Watch.zig index cc8bd60c9a..4bd301e01e 100644 --- a/lib/std/Build/Watch.zig +++ b/lib/std/Build/Watch.zig @@ -1,13 +1,18 @@ const builtin = @import("builtin"); const std = @import("../std.zig"); -const Watch = @This(); const Step = std.Build.Step; const Allocator = std.mem.Allocator; const assert = std.debug.assert; const fatal = std.process.fatal; +const Watch = @This(); +const FsEvents = @import("Watch/FsEvents.zig"); -dir_table: DirTable, os: Os, +/// The number to show as the number of directories being watched. +dir_count: usize, +// These fields are common to most implementations so are kept here for simplicity. +// They are `undefined` on implementations which do not utilize then. +dir_table: DirTable, generation: Generation, pub const have_impl = Os != void; @@ -97,6 +102,7 @@ const Os = switch (builtin.os.tag) { fn init() !Watch { return .{ .dir_table = .{}, + .dir_count = 0, .os = switch (builtin.os.tag) { .linux => .{ .handle_table = .{}, @@ -273,6 +279,7 @@ const Os = switch (builtin.os.tag) { } w.generation +%= 1; } + w.dir_count = w.dir_table.count(); } fn wait(w: *Watch, gpa: Allocator, timeout: Timeout) !WaitResult { @@ -408,6 +415,7 @@ const Os = switch (builtin.os.tag) { fn init() !Watch { return .{ .dir_table = .{}, + .dir_count = 0, .os = switch (builtin.os.tag) { .windows => .{ .handle_table = .{}, @@ -572,6 +580,7 @@ const Os = switch (builtin.os.tag) { } w.generation +%= 1; } + w.dir_count = w.dir_table.count(); } fn wait(w: *Watch, gpa: Allocator, timeout: Timeout) !WaitResult { @@ -605,7 +614,7 @@ const Os = switch (builtin.os.tag) { }; } }, - .dragonfly, .freebsd, .netbsd, .openbsd, .ios, .macos, .tvos, .visionos, .watchos => struct { + .dragonfly, .freebsd, .netbsd, .openbsd, .ios, .tvos, .visionos, .watchos => struct { const posix = std.posix; kq_fd: i32, @@ -639,6 +648,7 @@ const Os = switch (builtin.os.tag) { errdefer posix.close(kq_fd); return .{ .dir_table = .{}, + .dir_count = 0, .os = .{ .kq_fd = kq_fd, .handles = .empty, @@ -769,6 +779,7 @@ const Os = switch (builtin.os.tag) { } w.generation +%= 1; } + w.dir_count = w.dir_table.count(); } fn wait(w: *Watch, gpa: Allocator, timeout: Timeout) !WaitResult { @@ -812,6 +823,28 @@ const Os = switch (builtin.os.tag) { return any_dirty; } }, + .macos => struct { + fse: FsEvents, + + fn init() !Watch { + return .{ + .os = .{ .fse = try .init() }, + .dir_count = 0, + .dir_table = undefined, + .generation = undefined, + }; + } + fn update(w: *Watch, gpa: Allocator, steps: []const *Step) !void { + try w.os.fse.setPaths(gpa, steps); + w.dir_count = w.os.fse.watch_roots.len; + } + fn wait(w: *Watch, gpa: Allocator, timeout: Timeout) !WaitResult { + return w.os.fse.wait(gpa, switch (timeout) { + .none => null, + .ms => |ms| @as(u64, ms) * std.time.ns_per_ms, + }); + } + }, else => void, }; diff --git a/lib/std/Build/Watch/FsEvents.zig b/lib/std/Build/Watch/FsEvents.zig new file mode 100644 index 0000000000..c8c53813c2 --- /dev/null +++ b/lib/std/Build/Watch/FsEvents.zig @@ -0,0 +1,493 @@ +//! An implementation of file-system watching based on the `FSEventStream` API in macOS. +//! While macOS supports kqueue, it does not allow detecting changes to files without +//! placing watches on each individual file, meaning FD limits are reached incredibly +//! quickly. The File System Events API works differently: it implements *recursive* +//! directory watches, managed by a system service. Rather than being in libc, the API is +//! exposed by the CoreServices framework. To avoid a compile dependency on the framework +//! bundle, we dynamically load CoreServices with `std.DynLib`. +//! +//! While the logic in this file *is* specialized to `std.Build.Watch`, efforts have been +//! made to keep that specialization to a minimum. Other use cases could be served with +//! relatively minimal modifications to the `watch_paths` field and its usages (in +//! particular the `setPaths` function). We avoid using the global GCD dispatch queue in +//! favour of creating our own and synchronizing with an explicit semaphore, meaning this +//! logic is thread-safe and does not affect process-global state. +//! +//! In theory, this API is quite good at avoiding filesystem race conditions. In practice, +//! the logic that would avoid them is currently disabled, because the build system kind +//! of relies on them at the time of writing to avoid redundant work -- see the comment at +//! the top of `wait` for details. + +const enable_debug_logs = false; + +core_services: std.DynLib, +resolved_symbols: ResolvedSymbols, + +paths_arena: std.heap.ArenaAllocator.State, +/// The roots of the recursive watches. FSEvents has relatively small limits on the number +/// of watched paths, so this slice must not be too long. The paths themselves are allocated +/// into `paths_arena`, but this slice is allocated into the GPA. +watch_roots: [][:0]const u8, +/// All of the paths being watched. Value is the set of steps which depend on the file/directory. +/// Keys and values are in `paths_arena`, but this map is allocated into the GPA. +watch_paths: std.StringArrayHashMapUnmanaged([]const *std.Build.Step), + +/// The semaphore we use to block the thread calling `wait` until the callback determines a relevant +/// event has occurred. This is retained across `wait` calls for simplicity and efficiency. +waiting_semaphore: dispatch_semaphore_t, +/// This dispatch queue is created by us and executes serially. It exists exclusively to trigger the +/// callbacks of the FSEventStream we create. This is not in use outside of `wait`, but is retained +/// across `wait` calls for simplicity and efficiency. +dispatch_queue: dispatch_queue_t, +/// In theory, this field avoids race conditions. In practice, it is essentially unused at the time +/// of writing. See the comment at the start of `wait` for details. +since_event: FSEventStreamEventId, + +/// All of the symbols we pull from the `dlopen`ed CoreServices framework. If any of these symbols +/// is not present, `init` will close the framework and return an error. +const ResolvedSymbols = struct { + FSEventStreamCreate: *const fn ( + allocator: CFAllocatorRef, + callback: FSEventStreamCallback, + ctx: ?*const FSEventStreamContext, + paths_to_watch: CFArrayRef, + since_when: FSEventStreamEventId, + latency: CFTimeInterval, + flags: FSEventStreamCreateFlags, + ) callconv(.c) FSEventStreamRef, + FSEventStreamSetDispatchQueue: *const fn (stream: FSEventStreamRef, queue: dispatch_queue_t) callconv(.c) void, + FSEventStreamStart: *const fn (stream: FSEventStreamRef) callconv(.c) bool, + FSEventStreamStop: *const fn (stream: FSEventStreamRef) callconv(.c) void, + FSEventStreamInvalidate: *const fn (stream: FSEventStreamRef) callconv(.c) void, + FSEventStreamRelease: *const fn (stream: FSEventStreamRef) callconv(.c) void, + FSEventStreamGetLatestEventId: *const fn (stream: ConstFSEventStreamRef) callconv(.c) FSEventStreamEventId, + FSEventsGetCurrentEventId: *const fn () callconv(.c) FSEventStreamEventId, + CFRelease: *const fn (cf: *const anyopaque) callconv(.c) void, + CFArrayCreate: *const fn ( + allocator: CFAllocatorRef, + values: [*]const usize, + num_values: CFIndex, + call_backs: ?*const CFArrayCallBacks, + ) callconv(.c) CFArrayRef, + CFStringCreateWithCString: *const fn ( + alloc: CFAllocatorRef, + c_str: [*:0]const u8, + encoding: CFStringEncoding, + ) callconv(.c) CFStringRef, + CFAllocatorCreate: *const fn (allocator: CFAllocatorRef, context: *const CFAllocatorContext) callconv(.c) CFAllocatorRef, + kCFAllocatorUseContext: *const CFAllocatorRef, +}; + +pub fn init() error{ OpenFrameworkFailed, MissingCoreServicesSymbol }!FsEvents { + var core_services = std.DynLib.open("/System/Library/Frameworks/CoreServices.framework/CoreServices") catch + return error.OpenFrameworkFailed; + errdefer core_services.close(); + + var resolved_symbols: ResolvedSymbols = undefined; + inline for (@typeInfo(ResolvedSymbols).@"struct".fields) |f| { + @field(resolved_symbols, f.name) = core_services.lookup(f.type, f.name) orelse return error.MissingCoreServicesSymbol; + } + + return .{ + .core_services = core_services, + .resolved_symbols = resolved_symbols, + .paths_arena = .{}, + .watch_roots = &.{}, + .watch_paths = .empty, + .waiting_semaphore = dispatch_semaphore_create(0), + .dispatch_queue = dispatch_queue_create("zig-watch", .SERIAL), + // Not `.since_now`, because this means we can init `FsEvents` *before* we do work in order + // to notice any changes which happened during said work. + .since_event = resolved_symbols.FSEventsGetCurrentEventId(), + }; +} + +pub fn deinit(fse: *FsEvents, gpa: Allocator) void { + dispatch_release(fse.waiting_semaphore); + dispatch_release(fse.dispatch_queue); + fse.core_services.close(); + + gpa.free(fse.watch_roots); + fse.watch_paths.deinit(gpa); + { + var paths_arena = fse.paths_arena.promote(gpa); + paths_arena.deinit(); + } +} + +pub fn setPaths(fse: *FsEvents, gpa: Allocator, steps: []const *std.Build.Step) !void { + var paths_arena_instance = fse.paths_arena.promote(gpa); + defer fse.paths_arena = paths_arena_instance.state; + const paths_arena = paths_arena_instance.allocator(); + + const cwd_path = try std.process.getCwdAlloc(gpa); + defer gpa.free(cwd_path); + + var need_dirs: std.StringArrayHashMapUnmanaged(void) = .empty; + defer need_dirs.deinit(gpa); + + fse.watch_paths.clearRetainingCapacity(); + + // We take `step` by pointer for a slight memory optimization in a moment. + for (steps) |*step| { + for (step.*.inputs.table.keys(), step.*.inputs.table.values()) |path, *files| { + const resolved_dir = try std.fs.path.resolvePosix(paths_arena, &.{ cwd_path, path.root_dir.path orelse ".", path.sub_path }); + try need_dirs.put(gpa, resolved_dir, {}); + for (files.items) |file_name| { + const watch_path = if (std.mem.eql(u8, file_name, ".")) + resolved_dir + else + try std.fs.path.join(paths_arena, &.{ resolved_dir, file_name }); + const gop = try fse.watch_paths.getOrPut(gpa, watch_path); + if (gop.found_existing) { + const old_steps = gop.value_ptr.*; + const new_steps = try paths_arena.alloc(*std.Build.Step, old_steps.len + 1); + @memcpy(new_steps[0..old_steps.len], old_steps); + new_steps[old_steps.len] = step.*; + gop.value_ptr.* = new_steps; + } else { + // This is why we captured `step` by pointer! We can avoid allocating a slice of one + // step in the arena in the common case where a file is referenced by only one step. + gop.value_ptr.* = step[0..1]; + } + } + } + } + + { + // There's no point looking at directories inside other ones (e.g. "/foo" and "/foo/bar"). + // To eliminate these, we'll re-add directories in order of path length with a redundancy check. + const old_dirs = try gpa.dupe([]const u8, need_dirs.keys()); + defer gpa.free(old_dirs); + std.mem.sort([]const u8, old_dirs, {}, struct { + fn lessThan(ctx: void, a: []const u8, b: []const u8) bool { + ctx; + return std.mem.lessThan(u8, a, b); + } + }.lessThan); + need_dirs.clearRetainingCapacity(); + for (old_dirs) |dir_path| { + var it: std.fs.path.ComponentIterator(.posix, u8) = try .init(dir_path); + while (it.next()) |component| { + if (need_dirs.contains(component.path)) { + // this path is '/foo/bar/qux', but '/foo' or '/foo/bar' was already added + break; + } + } else { + need_dirs.putAssumeCapacityNoClobber(dir_path, {}); + } + } + } + + // `need_dirs` is now a set of directories to watch with no redundancy. In practice, this is very + // likely to have reduced it to a quite small set (e.g. it'll typically coalesce a full `src/` + // directory into one entry). However, the FSEventStream API has a fairly low undocumented limit + // on total watches (supposedly 4096), so we should handle the case where we exceed it. To be + // safe, because this API can be a little unpredictable, we'll cap ourselves a little *below* + // that known limit. + if (need_dirs.count() > 2048) { + // Fallback: watch the whole filesystem. This is excessive, but... it *works* :P + if (enable_debug_logs) watch_log.debug("too many dirs; recursively watching root", .{}); + fse.watch_roots = try gpa.realloc(fse.watch_roots, 1); + fse.watch_roots[0] = "/"; + } else { + fse.watch_roots = try gpa.realloc(fse.watch_roots, need_dirs.count()); + for (fse.watch_roots, need_dirs.keys()) |*out, in| { + out.* = try paths_arena.dupeZ(u8, in); + } + } + if (enable_debug_logs) { + watch_log.debug("watching {d} paths using {d} recursive watches:", .{ fse.watch_paths.count(), fse.watch_roots.len }); + for (fse.watch_roots) |dir_path| { + watch_log.debug("- '{s}'", .{dir_path}); + } + } +} + +pub fn wait(fse: *FsEvents, gpa: Allocator, timeout_ns: ?u64) error{ OutOfMemory, StartFailed }!std.Build.Watch.WaitResult { + if (fse.watch_roots.len == 0) @panic("nothing to watch"); + + const rs = fse.resolved_symbols; + + // At the time of writing, using `since_event` in the obvious way causes redundant rebuilds + // to occur, because one step modifies a file which is an input to another step. The solution + // to this problem will probably be either: + // + // a) Don't include the output of one step as a watch input of another; only mark external + // files as watch inputs. Or... + // + // b) Note the current event ID when a step begins, and disregard events preceding that ID + // when considering whether to dirty that step in `eventCallback`. + // + // For now, to avoid the redundant rebuilds, we bypass this `since_event` mechanism. This does + // introduce race conditions, but the other `std.Build.Watch` implementations suffer from those + // too at the time of writing, so this is kind of expected. + fse.since_event = .since_now; + + const cf_allocator = rs.CFAllocatorCreate(rs.kCFAllocatorUseContext.*, &.{ + .version = 0, + .info = @constCast(&gpa), + .retain = null, + .release = null, + .copy_description = null, + .allocate = &cf_alloc_callbacks.allocate, + .reallocate = &cf_alloc_callbacks.reallocate, + .deallocate = &cf_alloc_callbacks.deallocate, + .preferred_size = null, + }) orelse return error.OutOfMemory; + defer rs.CFRelease(cf_allocator); + + const cf_paths = try gpa.alloc(?CFStringRef, fse.watch_roots.len); + @memset(cf_paths, null); + defer { + for (cf_paths) |o| if (o) |p| rs.CFRelease(p); + gpa.free(cf_paths); + } + for (fse.watch_roots, cf_paths) |raw_path, *cf_path| { + cf_path.* = rs.CFStringCreateWithCString(cf_allocator, raw_path, .utf8); + } + const cf_paths_array = rs.CFArrayCreate(cf_allocator, @ptrCast(cf_paths), @intCast(cf_paths.len), null); + defer rs.CFRelease(cf_paths_array); + + const callback_ctx: EventCallbackCtx = .{ + .fse = fse, + .gpa = gpa, + }; + const event_stream = rs.FSEventStreamCreate( + null, + &eventCallback, + &.{ + .version = 0, + .info = @constCast(&callback_ctx), + .retain = null, + .release = null, + .copy_description = null, + }, + cf_paths_array, + fse.since_event, + 0.05, // 0.05s latency; higher values increase efficiency by coalescing more events + .{ .watch_root = true, .file_events = true }, + ); + defer rs.FSEventStreamRelease(event_stream); + rs.FSEventStreamSetDispatchQueue(event_stream, fse.dispatch_queue); + defer rs.FSEventStreamInvalidate(event_stream); + if (!rs.FSEventStreamStart(event_stream)) return error.StartFailed; + defer rs.FSEventStreamStop(event_stream); + const result = dispatch_semaphore_wait(fse.waiting_semaphore, timeout: { + const ns = timeout_ns orelse break :timeout .forever; + break :timeout dispatch_time(.now, @intCast(ns)); + }); + return switch (result) { + 0 => .dirty, + else => .timeout, + }; +} + +const cf_alloc_callbacks = struct { + const log = std.log.scoped(.cf_alloc); + fn allocate(size: CFIndex, hint: CFOptionFlags, info: ?*const anyopaque) callconv(.c) ?*const anyopaque { + if (enable_debug_logs) log.debug("allocate {d}", .{size}); + _ = hint; + const gpa: *const Allocator = @ptrCast(@alignCast(info)); + const mem = gpa.alignedAlloc(u8, .of(usize), @intCast(size + @sizeOf(usize))) catch return null; + const metadata: *usize = @ptrCast(mem); + metadata.* = @intCast(size); + return mem[@sizeOf(usize)..].ptr; + } + fn reallocate(ptr: ?*anyopaque, new_size: CFIndex, hint: CFOptionFlags, info: ?*const anyopaque) callconv(.c) ?*const anyopaque { + if (enable_debug_logs) log.debug("reallocate @{*} {d}", .{ ptr, new_size }); + _ = hint; + if (ptr == null or new_size == 0) return null; // not a bug: documentation explicitly states that realloc on NULL should return NULL + const gpa: *const Allocator = @ptrCast(@alignCast(info)); + const old_base: [*]align(@alignOf(usize)) u8 = @alignCast(@as([*]u8, @ptrCast(ptr)) - @sizeOf(usize)); + const old_size = @as(*const usize, @ptrCast(old_base)).*; + const old_mem = old_base[0 .. old_size + @sizeOf(usize)]; + const new_mem = gpa.realloc(old_mem, @intCast(new_size + @sizeOf(usize))) catch return null; + const metadata: *usize = @ptrCast(new_mem); + metadata.* = @intCast(new_size); + return new_mem[@sizeOf(usize)..].ptr; + } + fn deallocate(ptr: *anyopaque, info: ?*const anyopaque) callconv(.c) void { + if (enable_debug_logs) log.debug("deallocate @{*}", .{ptr}); + const gpa: *const Allocator = @ptrCast(@alignCast(info)); + const old_base: [*]align(@alignOf(usize)) u8 = @alignCast(@as([*]u8, @ptrCast(ptr)) - @sizeOf(usize)); + const old_size = @as(*const usize, @ptrCast(old_base)).*; + const old_mem = old_base[0 .. old_size + @sizeOf(usize)]; + gpa.free(old_mem); + } +}; + +const EventCallbackCtx = struct { + fse: *FsEvents, + gpa: Allocator, +}; + +fn eventCallback( + stream: ConstFSEventStreamRef, + client_callback_info: ?*anyopaque, + num_events: usize, + events_paths_ptr: *anyopaque, + events_flags_ptr: [*]const FSEventStreamEventFlags, + events_ids_ptr: [*]const FSEventStreamEventId, +) callconv(.c) void { + const ctx: *const EventCallbackCtx = @ptrCast(@alignCast(client_callback_info)); + const fse = ctx.fse; + const gpa = ctx.gpa; + const rs = fse.resolved_symbols; + const events_paths_ptr_casted: [*]const [*:0]const u8 = @ptrCast(@alignCast(events_paths_ptr)); + const events_paths = events_paths_ptr_casted[0..num_events]; + const events_ids = events_ids_ptr[0..num_events]; + const events_flags = events_flags_ptr[0..num_events]; + var any_dirty = false; + for (events_paths, events_ids, events_flags) |event_path_nts, event_id, event_flags| { + _ = event_id; + if (event_flags.history_done) continue; // sentinel + const event_path = std.mem.span(event_path_nts); + switch (event_flags.must_scan_sub_dirs) { + false => { + if (fse.watch_paths.get(event_path)) |steps| { + assert(steps.len > 0); + for (steps) |s| dirtyStep(s, gpa, &any_dirty); + } + if (std.fs.path.dirname(event_path)) |event_dirname| { + // Modifying '/foo/bar' triggers the watch on '/foo'. + if (fse.watch_paths.get(event_dirname)) |steps| { + assert(steps.len > 0); + for (steps) |s| dirtyStep(s, gpa, &any_dirty); + } + } + }, + true => { + // This is unlikely, but can occasionally happen when bottlenecked: events have been + // coalesced into one. We want to see if any of these events are actually relevant + // to us. The only way we can reasonably do that in this rare edge case is iterate + // the watch paths and see if any is under this directory. That's acceptable because + // we would otherwise kick off a rebuild which would be clearing those paths anyway. + const changed_path = std.fs.path.dirname(event_path) orelse event_path; + for (fse.watch_paths.keys(), fse.watch_paths.values()) |watching_path, steps| { + if (dirStartsWith(watching_path, changed_path)) { + for (steps) |s| dirtyStep(s, gpa, &any_dirty); + } + } + }, + } + } + if (any_dirty) { + fse.since_event = rs.FSEventStreamGetLatestEventId(stream); + _ = dispatch_semaphore_signal(fse.waiting_semaphore); + } +} +fn dirtyStep(s: *std.Build.Step, gpa: Allocator, any_dirty: *bool) void { + if (s.state == .precheck_done) return; + s.recursiveReset(gpa); + any_dirty.* = true; +} +fn dirStartsWith(path: []const u8, prefix: []const u8) bool { + if (std.mem.eql(u8, path, prefix)) return true; + if (!std.mem.startsWith(u8, path, prefix)) return false; + if (path[prefix.len] != '/') return false; // `path` is `/foo/barx`, `prefix` is `/foo/bar` + return true; // `path` is `/foo/bar/...`, `prefix` is `/foo/bar` +} + +const dispatch_time_t = enum(u64) { + now = 0, + forever = std.math.maxInt(u64), + _, +}; +extern fn dispatch_time(base: dispatch_time_t, delta_ns: i64) dispatch_time_t; + +const dispatch_semaphore_t = *opaque {}; +extern fn dispatch_semaphore_create(value: isize) dispatch_semaphore_t; +extern fn dispatch_semaphore_wait(dsema: dispatch_semaphore_t, timeout: dispatch_time_t) isize; +extern fn dispatch_semaphore_signal(dsema: dispatch_semaphore_t) isize; + +const dispatch_queue_t = *opaque {}; +const dispatch_queue_attr_t = ?*opaque { + const SERIAL: dispatch_queue_attr_t = null; +}; +extern fn dispatch_queue_create(label: [*:0]const u8, attr: dispatch_queue_attr_t) dispatch_queue_t; +extern fn dispatch_release(object: *anyopaque) void; + +const CFAllocatorRef = ?*const opaque {}; +const CFArrayRef = *const opaque {}; +const CFStringRef = *const opaque {}; +const CFTimeInterval = f64; +const CFIndex = i32; +const CFOptionFlags = enum(u32) { _ }; +const CFAllocatorRetainCallBack = *const fn (info: ?*const anyopaque) callconv(.c) *const anyopaque; +const CFAllocatorReleaseCallBack = *const fn (info: ?*const anyopaque) callconv(.c) void; +const CFAllocatorCopyDescriptionCallBack = *const fn (info: ?*const anyopaque) callconv(.c) CFStringRef; +const CFAllocatorAllocateCallBack = *const fn (alloc_size: CFIndex, hint: CFOptionFlags, info: ?*const anyopaque) callconv(.c) ?*const anyopaque; +const CFAllocatorReallocateCallBack = *const fn (ptr: ?*anyopaque, new_size: CFIndex, hint: CFOptionFlags, info: ?*const anyopaque) callconv(.c) ?*const anyopaque; +const CFAllocatorDeallocateCallBack = *const fn (ptr: *anyopaque, info: ?*const anyopaque) callconv(.c) void; +const CFAllocatorPreferredSizeCallBack = *const fn (size: CFIndex, hint: CFOptionFlags, info: ?*const anyopaque) callconv(.c) CFIndex; +const CFAllocatorContext = extern struct { + version: CFIndex, + info: ?*anyopaque, + retain: ?CFAllocatorRetainCallBack, + release: ?CFAllocatorReleaseCallBack, + copy_description: ?CFAllocatorCopyDescriptionCallBack, + allocate: CFAllocatorAllocateCallBack, + reallocate: ?CFAllocatorReallocateCallBack, + deallocate: ?CFAllocatorDeallocateCallBack, + preferred_size: ?CFAllocatorPreferredSizeCallBack, +}; +const CFArrayCallBacks = opaque {}; +const CFStringEncoding = enum(u32) { + invalid_id = std.math.maxInt(u32), + mac_roman = 0, + windows_latin_1 = 0x500, + iso_latin_1 = 0x201, + next_step_latin = 0xB01, + ascii = 0x600, + unicode = 0x100, + utf8 = 0x8000100, + non_lossy_ascii = 0xBFF, +}; + +const FSEventStreamRef = *opaque {}; +const ConstFSEventStreamRef = *const @typeInfo(FSEventStreamRef).pointer.child; +const FSEventStreamCallback = *const fn ( + stream: ConstFSEventStreamRef, + client_callback_info: ?*anyopaque, + num_events: usize, + event_paths: *anyopaque, + event_flags: [*]const FSEventStreamEventFlags, + event_ids: [*]const FSEventStreamEventId, +) callconv(.c) void; +const FSEventStreamContext = extern struct { + version: CFIndex, + info: ?*anyopaque, + retain: ?CFAllocatorRetainCallBack, + release: ?CFAllocatorReleaseCallBack, + copy_description: ?CFAllocatorCopyDescriptionCallBack, +}; +const FSEventStreamEventId = enum(u64) { + since_now = std.math.maxInt(u64), + _, +}; +const FSEventStreamCreateFlags = packed struct(u32) { + use_cf_types: bool = false, + no_defer: bool = false, + watch_root: bool = false, + ignore_self: bool = false, + file_events: bool = false, + _: u27 = 0, +}; +const FSEventStreamEventFlags = packed struct(u32) { + must_scan_sub_dirs: bool, + user_dropped: bool, + kernel_dropped: bool, + event_ids_wrapped: bool, + history_done: bool, + root_changed: bool, + mount: bool, + unmount: bool, + _: u24 = 0, +}; + +const std = @import("std"); +const assert = std.debug.assert; +const Allocator = std.mem.Allocator; +const watch_log = std.log.scoped(.watch); +const FsEvents = @This(); From 761857e3f9dae52b474b38bd2800b72e181745ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 20 Jul 2025 11:21:34 +0200 Subject: [PATCH 09/70] ci: temporarily disable test-link on riscv64-linux https://github.com/ziglang/zig/issues/24663 --- ci/riscv64-linux-debug.sh | 2 +- ci/riscv64-linux-release.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/riscv64-linux-debug.sh b/ci/riscv64-linux-debug.sh index 4076ec2268..14982d1451 100755 --- a/ci/riscv64-linux-debug.sh +++ b/ci/riscv64-linux-debug.sh @@ -49,7 +49,7 @@ unset CXX ninja install # No -fqemu and -fwasmtime here as they're covered by the x86_64-linux scripts. -stage3-debug/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-link test-stack-traces test-asm-link test-llvm-ir \ +stage3-debug/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-stack-traces test-asm-link test-llvm-ir \ --maxrss 68719476736 \ -Dstatic-llvm \ -Dskip-non-native \ diff --git a/ci/riscv64-linux-release.sh b/ci/riscv64-linux-release.sh index 78c398cab4..b36e9a3407 100755 --- a/ci/riscv64-linux-release.sh +++ b/ci/riscv64-linux-release.sh @@ -49,7 +49,7 @@ unset CXX ninja install # No -fqemu and -fwasmtime here as they're covered by the x86_64-linux scripts. -stage3-release/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-link test-stack-traces test-asm-link test-llvm-ir \ +stage3-release/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-stack-traces test-asm-link test-llvm-ir \ --maxrss 68719476736 \ -Dstatic-llvm \ -Dskip-non-native \ From bcaae562d6f62b8cff829eab19417b8dfc4a8dcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 23 Jul 2025 00:23:31 +0200 Subject: [PATCH 10/70] build: add -Dskip-compile-errors option Skips tests in test/cases/compile_errors. --- build.zig | 2 ++ test/src/Cases.zig | 3 +++ 2 files changed, 5 insertions(+) diff --git a/build.zig b/build.zig index 60c4e8f0e4..95c97f29b6 100644 --- a/build.zig +++ b/build.zig @@ -90,6 +90,7 @@ pub fn build(b: *std.Build) !void { const skip_non_native = b.option(bool, "skip-non-native", "Main test suite skips non-native builds") orelse false; const skip_libc = b.option(bool, "skip-libc", "Main test suite skips tests that link libc") orelse false; const skip_single_threaded = b.option(bool, "skip-single-threaded", "Main test suite skips tests that are single-threaded") orelse false; + const skip_compile_errors = b.option(bool, "skip-compile-errors", "Main test suite skips compile error tests") orelse false; const skip_translate_c = b.option(bool, "skip-translate-c", "Main test suite skips translate-c tests") orelse false; const skip_run_translated_c = b.option(bool, "skip-run-translated-c", "Main test suite skips run-translated-c tests") orelse false; const skip_freebsd = b.option(bool, "skip-freebsd", "Main test suite skips targets with freebsd OS") orelse false; @@ -418,6 +419,7 @@ pub fn build(b: *std.Build) !void { try tests.addCases(b, test_cases_step, target, .{ .test_filters = test_filters, .test_target_filters = test_target_filters, + .skip_compile_errors = skip_compile_errors, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, .skip_netbsd = skip_netbsd, diff --git a/test/src/Cases.zig b/test/src/Cases.zig index 60a564bc16..7632a15232 100644 --- a/test/src/Cases.zig +++ b/test/src/Cases.zig @@ -593,6 +593,7 @@ pub fn lowerToTranslateCSteps( pub const CaseTestOptions = struct { test_filters: []const []const u8, test_target_filters: []const []const u8, + skip_compile_errors: bool, skip_non_native: bool, skip_freebsd: bool, skip_netbsd: bool, @@ -618,6 +619,8 @@ pub fn lowerToBuildSteps( if (std.mem.indexOf(u8, case.name, test_filter)) |_| break; } else if (options.test_filters.len > 0) continue; + if (case.case.? == .Error and options.skip_compile_errors) continue; + if (options.skip_non_native and !case.target.query.isNative()) continue; From a0e58501affec246ec12b42f6ffee078d6e33b04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 23 Jul 2025 00:25:03 +0200 Subject: [PATCH 11/70] ci: use -Dskip-compile-errors on riscv64-linux --- ci/riscv64-linux-debug.sh | 1 + ci/riscv64-linux-release.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/ci/riscv64-linux-debug.sh b/ci/riscv64-linux-debug.sh index 14982d1451..512b33dc79 100755 --- a/ci/riscv64-linux-debug.sh +++ b/ci/riscv64-linux-debug.sh @@ -54,6 +54,7 @@ stage3-debug/bin/zig build test-cases test-modules test-unit test-standalone tes -Dstatic-llvm \ -Dskip-non-native \ -Dskip-single-threaded \ + -Dskip-compile-errors \ -Dskip-translate-c \ -Dskip-run-translated-c \ -Dtarget=native-native-musl \ diff --git a/ci/riscv64-linux-release.sh b/ci/riscv64-linux-release.sh index b36e9a3407..f8a9bb0939 100755 --- a/ci/riscv64-linux-release.sh +++ b/ci/riscv64-linux-release.sh @@ -54,6 +54,7 @@ stage3-release/bin/zig build test-cases test-modules test-unit test-standalone t -Dstatic-llvm \ -Dskip-non-native \ -Dskip-single-threaded \ + -Dskip-compile-errors \ -Dskip-translate-c \ -Dskip-run-translated-c \ -Dtarget=native-native-musl \ From 930c6ca49d8a556d256c5064a63af928e7833ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 23 Jul 2025 00:25:26 +0200 Subject: [PATCH 12/70] ci: don't run test-standalone on riscv64-linux --- ci/riscv64-linux-debug.sh | 2 +- ci/riscv64-linux-release.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/riscv64-linux-debug.sh b/ci/riscv64-linux-debug.sh index 512b33dc79..c68952356c 100755 --- a/ci/riscv64-linux-debug.sh +++ b/ci/riscv64-linux-debug.sh @@ -49,7 +49,7 @@ unset CXX ninja install # No -fqemu and -fwasmtime here as they're covered by the x86_64-linux scripts. -stage3-debug/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-stack-traces test-asm-link test-llvm-ir \ +stage3-debug/bin/zig build test-cases test-modules test-unit test-c-abi test-stack-traces test-asm-link test-llvm-ir \ --maxrss 68719476736 \ -Dstatic-llvm \ -Dskip-non-native \ diff --git a/ci/riscv64-linux-release.sh b/ci/riscv64-linux-release.sh index f8a9bb0939..757dace271 100755 --- a/ci/riscv64-linux-release.sh +++ b/ci/riscv64-linux-release.sh @@ -49,7 +49,7 @@ unset CXX ninja install # No -fqemu and -fwasmtime here as they're covered by the x86_64-linux scripts. -stage3-release/bin/zig build test-cases test-modules test-unit test-standalone test-c-abi test-stack-traces test-asm-link test-llvm-ir \ +stage3-release/bin/zig build test-cases test-modules test-unit test-c-abi test-stack-traces test-asm-link test-llvm-ir \ --maxrss 68719476736 \ -Dstatic-llvm \ -Dskip-non-native \ From 4ec232a3460b5daeefbf89ad197a78c6856fb1b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sat, 2 Aug 2025 09:37:53 +0200 Subject: [PATCH 13/70] ci: set riscv64-linux timeouts to 6 hours --- .github/workflows/riscv.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/riscv.yaml b/.github/workflows/riscv.yaml index b5baf43a6d..5d4cc91c55 100644 --- a/.github/workflows/riscv.yaml +++ b/.github/workflows/riscv.yaml @@ -5,7 +5,7 @@ permissions: contents: read jobs: riscv64-linux-debug: - timeout-minutes: 1020 + timeout-minutes: 360 runs-on: [self-hosted, Linux, riscv64] steps: - name: Checkout @@ -13,7 +13,7 @@ jobs: - name: Build and Test run: sh ci/riscv64-linux-debug.sh riscv64-linux-release: - timeout-minutes: 900 + timeout-minutes: 360 runs-on: [self-hosted, Linux, riscv64] steps: - name: Checkout From bab0de92b72b24a3d14d69ec633edeb91b9d0ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sat, 2 Aug 2025 09:39:56 +0200 Subject: [PATCH 14/70] ci: re-renable riscv64-linux-debug and riscv64-linux-release for master --- .github/workflows/ci.yaml | 18 ++++++++++++++++++ .github/workflows/riscv.yaml | 22 ---------------------- 2 files changed, 18 insertions(+), 22 deletions(-) delete mode 100644 .github/workflows/riscv.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fd07b9add4..667316d8b1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -50,6 +50,24 @@ jobs: uses: actions/checkout@v4 - name: Build and Test run: sh ci/aarch64-linux-release.sh + riscv64-linux-debug: + if: github.event_name == 'push' + timeout-minutes: 360 + runs-on: [self-hosted, Linux, riscv64] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Build and Test + run: sh ci/riscv64-linux-debug.sh + riscv64-linux-release: + if: github.event_name == 'push' + timeout-minutes: 360 + runs-on: [self-hosted, Linux, riscv64] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Build and Test + run: sh ci/riscv64-linux-release.sh x86_64-macos-release: runs-on: "macos-13" env: diff --git a/.github/workflows/riscv.yaml b/.github/workflows/riscv.yaml deleted file mode 100644 index 5d4cc91c55..0000000000 --- a/.github/workflows/riscv.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: riscv -on: - workflow_dispatch: -permissions: - contents: read -jobs: - riscv64-linux-debug: - timeout-minutes: 360 - runs-on: [self-hosted, Linux, riscv64] - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Build and Test - run: sh ci/riscv64-linux-debug.sh - riscv64-linux-release: - timeout-minutes: 360 - runs-on: [self-hosted, Linux, riscv64] - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Build and Test - run: sh ci/riscv64-linux-release.sh From e98aeeb73fba942c8e061bb8158ed073a7b19e1d Mon Sep 17 00:00:00 2001 From: mlugg Date: Sat, 2 Aug 2025 02:01:31 +0100 Subject: [PATCH 15/70] std.Build: keep compiler alive under `-fincremental --webui` Previously, this only applied when using `-fincremental --watch`, but `--webui` makes the build runner stay alive just like `--watch` does, so the same logic applies here. Without this, attempting to perform incremental updates with `--webui` performs full rebuilds. (I did test that before merging the PR, but at that time I was passing `--watch` too -- which has since been disallowed -- so I missed that it doesn't work as expected without that option!) --- lib/std/Build/Step/Compile.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Build/Step/Compile.zig b/lib/std/Build/Step/Compile.zig index 5e14796792..59ccb87dad 100644 --- a/lib/std/Build/Step/Compile.zig +++ b/lib/std/Build/Step/Compile.zig @@ -1851,7 +1851,7 @@ fn make(step: *Step, options: Step.MakeOptions) !void { const maybe_output_dir = step.evalZigProcess( zig_args, options.progress_node, - (b.graph.incremental == true) and options.watch, + (b.graph.incremental == true) and (options.watch or options.web_server != null), options.web_server, options.gpa, ) catch |err| switch (err) { From e82d67233b615e04fb5e4f31cd93216a2c1c2899 Mon Sep 17 00:00:00 2001 From: David Rubin Date: Fri, 1 Aug 2025 14:57:06 -0700 Subject: [PATCH 16/70] disallow alignment on packed union fields --- lib/std/zig/AstGen.zig | 3 +++ test/behavior/type.zig | 4 ++-- ...struct_field_alignment_unavailable_for_reify_type.zig | 9 --------- .../compile_errors/packed_union_alignment_override.zig | 9 +++++++++ test/cases/compile_errors/reify_struct.zig | 3 ++- 5 files changed, 16 insertions(+), 12 deletions(-) delete mode 100644 test/cases/compile_errors/packed_struct_field_alignment_unavailable_for_reify_type.zig create mode 100644 test/cases/compile_errors/packed_union_alignment_override.zig diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index 52fc41a2c9..ab81f343bd 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -5386,6 +5386,9 @@ fn unionDeclInner( return astgen.failNode(member_node, "union field missing type", .{}); } if (member.ast.align_expr.unwrap()) |align_expr| { + if (layout == .@"packed") { + return astgen.failNode(align_expr, "unable to override alignment of packed union fields", .{}); + } const align_inst = try expr(&block_scope, &block_scope.base, coerced_align_ri, align_expr); wip_members.appendToField(@intFromEnum(align_inst)); any_aligned_fields = true; diff --git a/test/behavior/type.zig b/test/behavior/type.zig index b5ac2d95f7..78295d3a7d 100644 --- a/test/behavior/type.zig +++ b/test/behavior/type.zig @@ -433,8 +433,8 @@ test "Type.Union" { .layout = .@"packed", .tag_type = null, .fields = &.{ - .{ .name = "signed", .type = i32, .alignment = @alignOf(i32) }, - .{ .name = "unsigned", .type = u32, .alignment = @alignOf(u32) }, + .{ .name = "signed", .type = i32, .alignment = 0 }, + .{ .name = "unsigned", .type = u32, .alignment = 0 }, }, .decls = &.{}, }, diff --git a/test/cases/compile_errors/packed_struct_field_alignment_unavailable_for_reify_type.zig b/test/cases/compile_errors/packed_struct_field_alignment_unavailable_for_reify_type.zig deleted file mode 100644 index 20fb548d83..0000000000 --- a/test/cases/compile_errors/packed_struct_field_alignment_unavailable_for_reify_type.zig +++ /dev/null @@ -1,9 +0,0 @@ -export fn entry() void { - _ = @Type(.{ .@"struct" = .{ .layout = .@"packed", .fields = &.{ - .{ .name = "one", .type = u4, .default_value_ptr = null, .is_comptime = false, .alignment = 2 }, - }, .decls = &.{}, .is_tuple = false } }); -} - -// error -// -// :2:9: error: alignment in a packed struct field must be set to 0 diff --git a/test/cases/compile_errors/packed_union_alignment_override.zig b/test/cases/compile_errors/packed_union_alignment_override.zig new file mode 100644 index 0000000000..e27db3ea4b --- /dev/null +++ b/test/cases/compile_errors/packed_union_alignment_override.zig @@ -0,0 +1,9 @@ +const U = packed union { + x: f32, + y: u8 align(10), + z: u32, +}; + +// error +// +// :3:17: error: unable to override alignment of packed union fields diff --git a/test/cases/compile_errors/reify_struct.zig b/test/cases/compile_errors/reify_struct.zig index 12d445082c..60228061dd 100644 --- a/test/cases/compile_errors/reify_struct.zig +++ b/test/cases/compile_errors/reify_struct.zig @@ -75,4 +75,5 @@ comptime { // :16:5: error: tuple field name '3' does not match field index 0 // :30:5: error: comptime field without default initialization value // :44:5: error: extern struct fields cannot be marked comptime -// :58:5: error: alignment in a packed struct field must be set to 0 +// :58:5: error: alignment of a packed struct field must be set to 0 + From 4d1010d36c69a88a94494f155ba22490d562e3e8 Mon Sep 17 00:00:00 2001 From: David Rubin Date: Sat, 2 Aug 2025 17:28:33 -0700 Subject: [PATCH 17/70] llvm: correctly lower `double_integer` for rv32 --- src/codegen/llvm.zig | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 6bec900aec..2e1c390c09 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -12210,11 +12210,14 @@ fn lowerFnRetTy(o: *Object, pt: Zcu.PerThread, fn_info: InternPool.Key.FuncType) }, .riscv64_lp64, .riscv32_ilp32 => switch (riscv_c_abi.classifyType(return_type, zcu)) { .memory => return .void, - .integer => { - return o.builder.intType(@intCast(return_type.bitSize(zcu))); - }, + .integer => return o.builder.intType(@intCast(return_type.bitSize(zcu))), .double_integer => { - return o.builder.structType(.normal, &.{ .i64, .i64 }); + const integer: Builder.Type = switch (zcu.getTarget().cpu.arch) { + .riscv64 => .i64, + .riscv32 => .i32, + else => unreachable, + }; + return o.builder.structType(.normal, &.{ integer, integer }); }, .byval => return o.lowerType(pt, return_type), .fields => { From 616e69c80745cdc872b1db5ad63b919d82d9e9cb Mon Sep 17 00:00:00 2001 From: DialecticalMaterialist <170803884+DialecticalMaterialist@users.noreply.github.com> Date: Sat, 2 Aug 2025 23:43:17 +0200 Subject: [PATCH 18/70] OpenGL SPIR-V support The support was already there but somebody forgot to allow to use the calling conventions spirv_fragment and spirv_vertex when having opengl as os tag. --- src/Zcu.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Zcu.zig b/src/Zcu.zig index 90b3edd001..c444d57bf6 100644 --- a/src/Zcu.zig +++ b/src/Zcu.zig @@ -4510,7 +4510,7 @@ pub fn callconvSupported(zcu: *Zcu, cc: std.builtin.CallingConvention) union(enu }, .stage2_spirv => switch (cc) { .spirv_device, .spirv_kernel => true, - .spirv_fragment, .spirv_vertex => target.os.tag == .vulkan, + .spirv_fragment, .spirv_vertex => target.os.tag == .vulkan or target.os.tag == .opengl, else => false, }, }; From fa445d86a110f1171b75824fe5ec139089fa4733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 11:05:04 +0200 Subject: [PATCH 19/70] ci: target baseline instead of spacemit_x60 on riscv64-linux Doesn't seem to make much of a difference anyway, and LLVM 20 appears to still have some miscompilations with vector and bitmanip extensions enabled. --- ci/riscv64-linux-debug.sh | 2 +- ci/riscv64-linux-release.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/riscv64-linux-debug.sh b/ci/riscv64-linux-debug.sh index c68952356c..d5edbac282 100755 --- a/ci/riscv64-linux-debug.sh +++ b/ci/riscv64-linux-debug.sh @@ -7,7 +7,7 @@ set -e ARCH="$(uname -m)" TARGET="$ARCH-linux-musl" -MCPU="spacemit_x60" +MCPU="baseline" CACHE_BASENAME="zig+llvm+lld+clang-riscv64-linux-musl-0.15.0-dev.929+31e46be74" PREFIX="$HOME/deps/$CACHE_BASENAME" ZIG="$PREFIX/bin/zig" diff --git a/ci/riscv64-linux-release.sh b/ci/riscv64-linux-release.sh index 757dace271..6cb8fb2891 100755 --- a/ci/riscv64-linux-release.sh +++ b/ci/riscv64-linux-release.sh @@ -7,7 +7,7 @@ set -e ARCH="$(uname -m)" TARGET="$ARCH-linux-musl" -MCPU="spacemit_x60" +MCPU="baseline" CACHE_BASENAME="zig+llvm+lld+clang-riscv64-linux-musl-0.15.0-dev.929+31e46be74" PREFIX="$HOME/deps/$CACHE_BASENAME" ZIG="$PREFIX/bin/zig" From 765825b80224f491e54fac83479cb31014352256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 16:44:58 +0200 Subject: [PATCH 20/70] ci: bump riscv64-linux timeout from 6 hours to 7 hours GitHub is apparently very bad at arithmetic and so will cancel jobs that pass the 5 hours mark, even if they're nowhere near the 6 hours timeout. So add an hour to assist GitHub in this very difficult task. --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 667316d8b1..32c6702939 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -52,7 +52,7 @@ jobs: run: sh ci/aarch64-linux-release.sh riscv64-linux-debug: if: github.event_name == 'push' - timeout-minutes: 360 + timeout-minutes: 420 runs-on: [self-hosted, Linux, riscv64] steps: - name: Checkout @@ -61,7 +61,7 @@ jobs: run: sh ci/riscv64-linux-debug.sh riscv64-linux-release: if: github.event_name == 'push' - timeout-minutes: 360 + timeout-minutes: 420 runs-on: [self-hosted, Linux, riscv64] steps: - name: Checkout From 1808ecfa049d62b1f202e9316fd59aa47f7d4f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 19:52:16 +0200 Subject: [PATCH 21/70] std.Target: bump fuchsia max version to 27.0.0 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index c75e2b51fb..9aa7725e14 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -405,7 +405,7 @@ pub const Os = struct { .fuchsia => .{ .semver = .{ .min = .{ .major = 1, .minor = 0, .patch = 0 }, - .max = .{ .major = 26, .minor = 0, .patch = 0 }, + .max = .{ .major = 27, .minor = 0, .patch = 0 }, }, }, .hermit => .{ From af3baee5ca2cf5a9a965cdd2d965306d487a4bd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 19:54:19 +0200 Subject: [PATCH 22/70] std.Target: bump linux max version to 6.16.0 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 9aa7725e14..7949908ab8 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -446,7 +446,7 @@ pub const Os = struct { break :blk default_min; }, - .max = .{ .major = 6, .minor = 13, .patch = 4 }, + .max = .{ .major = 6, .minor = 16, .patch = 0 }, }, .glibc = blk: { // For 32-bit targets that traditionally used 32-bit time, we require From 39b653c5e776ff87247e8079db3c9c11981a1597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 19:55:57 +0200 Subject: [PATCH 23/70] std.Target: bump freebsd max version to 14.3.0 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 7949908ab8..24ca114fa0 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -519,7 +519,7 @@ pub const Os = struct { break :blk default_min; }, - .max = .{ .major = 14, .minor = 2, .patch = 0 }, + .max = .{ .major = 14, .minor = 3, .patch = 0 }, }, }, .netbsd => .{ From 7f2140710fa1f97062d8f632a26656a5b1c249e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 19:58:17 +0200 Subject: [PATCH 24/70] std.Target: bump cuda max version to 12.9.1 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 24ca114fa0..9530e6bf90 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -626,7 +626,7 @@ pub const Os = struct { .cuda => .{ .semver = .{ .min = .{ .major = 11, .minor = 0, .patch = 1 }, - .max = .{ .major = 12, .minor = 9, .patch = 0 }, + .max = .{ .major = 12, .minor = 9, .patch = 1 }, }, }, .nvcl, From afe458e9b64ec5c1fd9d577ec90a4f0a015600ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 19:59:14 +0200 Subject: [PATCH 25/70] std.Target: bump vulkan max version to 1.4.321 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 9530e6bf90..923a815a06 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -646,7 +646,7 @@ pub const Os = struct { .vulkan => .{ .semver = .{ .min = .{ .major = 1, .minor = 2, .patch = 0 }, - .max = .{ .major = 1, .minor = 4, .patch = 313 }, + .max = .{ .major = 1, .minor = 4, .patch = 321 }, }, }, }; From 5b74d3347191c97fb75e5f449eaa2a94003bbf02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 20:00:32 +0200 Subject: [PATCH 26/70] std.Target: bump amdhsa max version to 6.4.2 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 923a815a06..8f7930adea 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -614,7 +614,7 @@ pub const Os = struct { .amdhsa => .{ .semver = .{ .min = .{ .major = 5, .minor = 0, .patch = 0 }, - .max = .{ .major = 6, .minor = 4, .patch = 0 }, + .max = .{ .major = 6, .minor = 4, .patch = 2 }, }, }, .amdpal => .{ From e9093b8d1801a3e5c54a131cd1ed48e2339d4ee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 20:05:21 +0200 Subject: [PATCH 27/70] std.Target: bump max versions for Apple targets --- lib/std/Target.zig | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 8f7930adea..067f8e45b6 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -550,37 +550,37 @@ pub const Os = struct { .driverkit => .{ .semver = .{ .min = .{ .major = 19, .minor = 0, .patch = 0 }, - .max = .{ .major = 24, .minor = 4, .patch = 0 }, + .max = .{ .major = 25, .minor = 0, .patch = 0 }, }, }, .macos => .{ .semver = .{ .min = .{ .major = 13, .minor = 0, .patch = 0 }, - .max = .{ .major = 15, .minor = 4, .patch = 1 }, + .max = .{ .major = 15, .minor = 6, .patch = 0 }, }, }, .ios => .{ .semver = .{ .min = .{ .major = 15, .minor = 0, .patch = 0 }, - .max = .{ .major = 18, .minor = 4, .patch = 1 }, + .max = .{ .major = 18, .minor = 6, .patch = 0 }, }, }, .tvos => .{ .semver = .{ .min = .{ .major = 15, .minor = 0, .patch = 0 }, - .max = .{ .major = 18, .minor = 4, .patch = 1 }, + .max = .{ .major = 18, .minor = 5, .patch = 0 }, }, }, .visionos => .{ .semver = .{ .min = .{ .major = 1, .minor = 0, .patch = 0 }, - .max = .{ .major = 2, .minor = 4, .patch = 1 }, + .max = .{ .major = 2, .minor = 5, .patch = 0 }, }, }, .watchos => .{ .semver = .{ .min = .{ .major = 7, .minor = 0, .patch = 0 }, - .max = .{ .major = 11, .minor = 4, .patch = 0 }, + .max = .{ .major = 11, .minor = 6, .patch = 0 }, }, }, From 71722df4ab41e61ea641ba5a83166d60208607fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 20:08:07 +0200 Subject: [PATCH 28/70] std.Target: bump driverkit min version to 20.0.0 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 067f8e45b6..62ee94d543 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -549,7 +549,7 @@ pub const Os = struct { .driverkit => .{ .semver = .{ - .min = .{ .major = 19, .minor = 0, .patch = 0 }, + .min = .{ .major = 20, .minor = 0, .patch = 0 }, .max = .{ .major = 25, .minor = 0, .patch = 0 }, }, }, From ba7cc72c47a96fc4dad1ae35c58642debe1e2820 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 20:08:19 +0200 Subject: [PATCH 29/70] std.Target: bump watchos min version to 8.0.0 --- lib/std/Target.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/Target.zig b/lib/std/Target.zig index 62ee94d543..59cf7bda01 100644 --- a/lib/std/Target.zig +++ b/lib/std/Target.zig @@ -579,7 +579,7 @@ pub const Os = struct { }, .watchos => .{ .semver = .{ - .min = .{ .major = 7, .minor = 0, .patch = 0 }, + .min = .{ .major = 8, .minor = 0, .patch = 0 }, .max = .{ .major = 11, .minor = 6, .patch = 0 }, }, }, From 493265486c939c9da1985fc7d0add0242a64d16a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sun, 3 Aug 2025 22:38:39 +0200 Subject: [PATCH 30/70] Revert "ci: target baseline instead of spacemit_x60 on riscv64-linux" This reverts commit fa445d86a110f1171b75824fe5ec139089fa4733. Narrator: It did, in fact, make a difference. For whatever reason, building LLVM against spacemit_x60 or baseline makes no noticeable difference in terms of performance, but building the Zig compiler against spacemit_x60 does. Also, the miscompilation that was causing riscv64-linux-debug to fail was in the LLVM libraries, not in the Zig compiler, so we may as well take the win here. --- ci/riscv64-linux-debug.sh | 2 +- ci/riscv64-linux-release.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/riscv64-linux-debug.sh b/ci/riscv64-linux-debug.sh index d5edbac282..c68952356c 100755 --- a/ci/riscv64-linux-debug.sh +++ b/ci/riscv64-linux-debug.sh @@ -7,7 +7,7 @@ set -e ARCH="$(uname -m)" TARGET="$ARCH-linux-musl" -MCPU="baseline" +MCPU="spacemit_x60" CACHE_BASENAME="zig+llvm+lld+clang-riscv64-linux-musl-0.15.0-dev.929+31e46be74" PREFIX="$HOME/deps/$CACHE_BASENAME" ZIG="$PREFIX/bin/zig" diff --git a/ci/riscv64-linux-release.sh b/ci/riscv64-linux-release.sh index 6cb8fb2891..757dace271 100755 --- a/ci/riscv64-linux-release.sh +++ b/ci/riscv64-linux-release.sh @@ -7,7 +7,7 @@ set -e ARCH="$(uname -m)" TARGET="$ARCH-linux-musl" -MCPU="baseline" +MCPU="spacemit_x60" CACHE_BASENAME="zig+llvm+lld+clang-riscv64-linux-musl-0.15.0-dev.929+31e46be74" PREFIX="$HOME/deps/$CACHE_BASENAME" ZIG="$PREFIX/bin/zig" From dabae3f9dc868af92e7608d897befc39e5db5c33 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 21:44:56 -0700 Subject: [PATCH 31/70] linker: remove dependency on std.fifo --- src/link/MachO/dyld_info/Trie.zig | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/link/MachO/dyld_info/Trie.zig b/src/link/MachO/dyld_info/Trie.zig index 49e1866880..ce56101e54 100644 --- a/src/link/MachO/dyld_info/Trie.zig +++ b/src/link/MachO/dyld_info/Trie.zig @@ -138,18 +138,23 @@ fn finalize(self: *Trie, allocator: Allocator) !void { defer ordered_nodes.deinit(); try ordered_nodes.ensureTotalCapacityPrecise(self.nodes.items(.is_terminal).len); - var fifo = std.fifo.LinearFifo(Node.Index, .Dynamic).init(allocator); - defer fifo.deinit(); + { + var fifo: std.ArrayListUnmanaged(Node.Index) = .empty; + defer fifo.deinit(allocator); - try fifo.writeItem(self.root.?); + try fifo.append(allocator, self.root.?); - while (fifo.readItem()) |next_index| { - const edges = &self.nodes.items(.edges)[next_index]; - for (edges.items) |edge_index| { - const edge = self.edges.items[edge_index]; - try fifo.writeItem(edge.node); + var i: usize = 0; + while (i < fifo.items.len) { + const next_index = fifo.items[i]; + i += 1; + const edges = &self.nodes.items(.edges)[next_index]; + for (edges.items) |edge_index| { + const edge = self.edges.items[edge_index]; + try fifo.append(allocator, edge.node); + } + ordered_nodes.appendAssumeCapacity(next_index); } - ordered_nodes.appendAssumeCapacity(next_index); } var more: bool = true; From 32a069f909a34b22a2049dca39ef5a1965cba21b Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 4 Aug 2025 09:00:41 +0100 Subject: [PATCH 32/70] cli: add `--debug-libc` to `zig build` This option is similar to `--debug-target` in letting us override details of the build runner target when debugging the build system. While `--debug-target` lets us override the target query, this option lets us override the libc installation. This option is only usable in a compiler built with debug extensions. I am using this to (try to) test the build runner targeting SerenityOS. --- src/main.zig | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/main.zig b/src/main.zig index 2a143c30ae..e757b58be5 100644 --- a/src/main.zig +++ b/src/main.zig @@ -4891,6 +4891,7 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { var fetch_mode: Package.Fetch.JobQueue.Mode = .needed; var system_pkg_dir_path: ?[]const u8 = null; var debug_target: ?[]const u8 = null; + var debug_libc_paths_file: ?[]const u8 = null; const argv_index_exe = child_argv.items.len; _ = try child_argv.addOne(); @@ -5014,6 +5015,14 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { } else { warn("Zig was compiled without debug extensions. --debug-target has no effect.", .{}); } + } else if (mem.eql(u8, arg, "--debug-libc")) { + if (i + 1 >= args.len) fatal("expected argument after '{s}'", .{arg}); + i += 1; + if (build_options.enable_debug_extensions) { + debug_libc_paths_file = args[i]; + } else { + warn("Zig was compiled without debug extensions. --debug-libc has no effect.", .{}); + } } else if (mem.eql(u8, arg, "--verbose-link")) { verbose_link = true; } else if (mem.eql(u8, arg, "--verbose-cc")) { @@ -5101,6 +5110,14 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { .is_explicit_dynamic_linker = false, }; }; + // Likewise, `--debug-libc` allows overriding the libc installation. + const libc_installation: ?*const LibCInstallation = lci: { + const paths_file = debug_libc_paths_file orelse break :lci null; + if (!build_options.enable_debug_extensions) unreachable; + const lci = try arena.create(LibCInstallation); + lci.* = try .parse(arena, paths_file, &resolved_target.result); + break :lci lci; + }; process.raiseFileDescriptorLimit(); @@ -5365,6 +5382,7 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { } const comp = Compilation.create(gpa, arena, .{ + .libc_installation = libc_installation, .dirs = dirs, .root_name = "build", .config = config, From 422e8d476c4f407abe2915932fc01012682051fe Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 4 Aug 2025 09:37:18 +0100 Subject: [PATCH 33/70] build runner: fix FTBFS on targets without `--watch` implementation This was a regression in #24588. I have verified that this patch works by confirming that with the downstream patches SerenityOS apply to the Zig source tree (sans the one working around this regression), I can build the build runner for SerenityOS. Resolves: #24682 --- lib/compiler/build_runner.zig | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/compiler/build_runner.zig b/lib/compiler/build_runner.zig index e97b7aa313..f1a0caf47c 100644 --- a/lib/compiler/build_runner.zig +++ b/lib/compiler/build_runner.zig @@ -502,6 +502,9 @@ pub fn main() !void { }; } + // Comptime-known guard to prevent including the logic below when `!Watch.have_impl`. + if (!Watch.have_impl) unreachable; + try w.update(gpa, run.step_stack.keys()); // Wait until a file system notification arrives. Read all such events From 70c6a9fba67a10838cfbbc7554cfa164e9b06462 Mon Sep 17 00:00:00 2001 From: Loris Cro Date: Mon, 4 Aug 2025 14:25:08 +0200 Subject: [PATCH 34/70] init: small fix to zig init template it was placing the current zig version in the wrong field --- src/main.zig | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main.zig b/src/main.zig index 2a143c30ae..3316639fee 100644 --- a/src/main.zig +++ b/src/main.zig @@ -4796,7 +4796,8 @@ fn cmdInit(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { writeSimpleTemplateFile(Package.Manifest.basename, \\.{{ \\ .name = .{s}, - \\ .version = "{s}", + \\ .version = "0.0.1", + \\ .minimum_zig_version = "{s}", \\ .paths = .{{""}}, \\ .fingerprint = 0x{x}, \\}} @@ -4811,6 +4812,7 @@ fn cmdInit(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { }; writeSimpleTemplateFile(Package.build_zig_basename, \\const std = @import("std"); + \\ \\pub fn build(b: *std.Build) void {{ \\ _ = b; // stub \\}} From 96be6f6566b19c242220f0f8571dc81ecadf6a79 Mon Sep 17 00:00:00 2001 From: Ian Johnson Date: Mon, 4 Aug 2025 21:30:21 -0400 Subject: [PATCH 35/70] std.compress.flate.Decompress: return correct size for unbuffered decompression Closes #24686 As a bonus, this commit also makes the `git.zig` "testing `main`" compile again. --- lib/std/compress/flate/Decompress.zig | 5 +++-- src/Package/Fetch/git.zig | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig index 90f89ea6c9..da57d56ab8 100644 --- a/lib/std/compress/flate/Decompress.zig +++ b/lib/std/compress/flate/Decompress.zig @@ -373,7 +373,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader d.state = .{ .stored_block = @intCast(remaining_len - n) }; } w.advance(n); - return n; + return @intFromEnum(limit) - remaining + n; }, .fixed_block => { while (remaining > 0) { @@ -1265,6 +1265,7 @@ fn testDecompress(container: Container, compressed: []const u8, expected_plain: defer aw.deinit(); var decompress: Decompress = .init(&in, container, &.{}); - _ = try decompress.reader.streamRemaining(&aw.writer); + const decompressed_len = try decompress.reader.streamRemaining(&aw.writer); + try testing.expectEqual(expected_plain.len, decompressed_len); try testing.expectEqualSlices(u8, expected_plain, aw.getWritten()); } diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 8f5bff2522..6ff951014b 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -1678,7 +1678,7 @@ test "SHA-256 packfile indexing and checkout" { /// Checks out a commit of a packfile. Intended for experimenting with and /// benchmarking possible optimizations to the indexing and checkout behavior. pub fn main() !void { - const allocator = std.heap.c_allocator; + const allocator = std.heap.smp_allocator; const args = try std.process.argsAlloc(allocator); defer std.process.argsFree(allocator, args); @@ -1703,12 +1703,14 @@ pub fn main() !void { std.debug.print("Starting index...\n", .{}); var index_file = try git_dir.createFile("idx", .{ .read = true }); defer index_file.close(); - var index_buffered_writer = std.io.bufferedWriter(index_file.deprecatedWriter()); - try indexPack(allocator, format, &pack_file_reader, index_buffered_writer.writer()); - try index_buffered_writer.flush(); + var index_file_buffer: [4096]u8 = undefined; + var index_file_writer = index_file.writer(&index_file_buffer); + try indexPack(allocator, format, &pack_file_reader, &index_file_writer); std.debug.print("Starting checkout...\n", .{}); - var repository = try Repository.init(allocator, format, &pack_file_reader, index_file); + var index_file_reader = index_file.reader(&index_file_buffer); + var repository: Repository = undefined; + try repository.init(allocator, format, &pack_file_reader, &index_file_reader); defer repository.deinit(); var diagnostics: Diagnostics = .{ .allocator = allocator }; defer diagnostics.deinit(); From fcb088cb6abf8a0a20de2650491c2d9a4b1ba399 Mon Sep 17 00:00:00 2001 From: KNnut Date: Mon, 4 Aug 2025 23:07:49 +0800 Subject: [PATCH 36/70] std.Target.Query: fix `WindowsVersion` format in `zigTriple()` --- lib/std/Target/Query.zig | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/lib/std/Target/Query.zig b/lib/std/Target/Query.zig index 2d3b0f4436..4b0579d34d 100644 --- a/lib/std/Target/Query.zig +++ b/lib/std/Target/Query.zig @@ -423,7 +423,7 @@ pub fn zigTriple(self: Query, gpa: Allocator) Allocator.Error![]u8 { try formatVersion(v, gpa, &result); }, .windows => |v| { - try result.print(gpa, "{d}", .{v}); + try result.print(gpa, "{f}", .{v}); }, } } @@ -437,7 +437,7 @@ pub fn zigTriple(self: Query, gpa: Allocator) Allocator.Error![]u8 { .windows => |v| { // This is counting on a custom format() function defined on `WindowsVersion` // to add a prefix '.' and make there be a total of three dots. - try result.print(gpa, "..{d}", .{v}); + try result.print(gpa, "..{f}", .{v}); }, } } @@ -729,4 +729,20 @@ test parse { defer std.testing.allocator.free(text); try std.testing.expectEqualSlices(u8, "aarch64-linux.3.10...4.4.1-android.30", text); } + { + const query = try Query.parse(.{ + .arch_os_abi = "x86-windows.xp...win8-msvc", + }); + const target = try std.zig.system.resolveTargetQuery(query); + + try std.testing.expect(target.cpu.arch == .x86); + try std.testing.expect(target.os.tag == .windows); + try std.testing.expect(target.os.version_range.windows.min == .xp); + try std.testing.expect(target.os.version_range.windows.max == .win8); + try std.testing.expect(target.abi == .msvc); + + const text = try query.zigTriple(std.testing.allocator); + defer std.testing.allocator.free(text); + try std.testing.expectEqualSlices(u8, "x86-windows.xp...win8-msvc", text); + } } From def25b918914405ced7f0c7eb66f8f086e878eb0 Mon Sep 17 00:00:00 2001 From: David Rubin Date: Mon, 4 Aug 2025 15:37:57 -0700 Subject: [PATCH 37/70] crypto: fix typo in ecdsa comment --- lib/std/crypto/ecdsa.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 9763989919..831c101970 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -57,7 +57,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { pub const PublicKey = struct { /// Length (in bytes) of a compressed sec1-encoded key. pub const compressed_sec1_encoded_length = 1 + Curve.Fe.encoded_length; - /// Length (in bytes) of a compressed sec1-encoded key. + /// Length (in bytes) of an uncompressed sec1-encoded key. pub const uncompressed_sec1_encoded_length = 1 + 2 * Curve.Fe.encoded_length; p: Curve, From 82961a8c9f66c03b7f1a802d2d47d26bf0af7dca Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Mon, 4 Aug 2025 12:02:20 +0300 Subject: [PATCH 38/70] std.c: fix utsname array sizes --- lib/std/c.zig | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/std/c.zig b/lib/std/c.zig index 6a0289f5f0..8756782956 100644 --- a/lib/std/c.zig +++ b/lib/std/c.zig @@ -6970,11 +6970,11 @@ pub const utsname = switch (native_os) { domainname: [256:0]u8, }, .macos => extern struct { - sysname: [256:0]u8, - nodename: [256:0]u8, - release: [256:0]u8, - version: [256:0]u8, - machine: [256:0]u8, + sysname: [255:0]u8, + nodename: [255:0]u8, + release: [255:0]u8, + version: [255:0]u8, + machine: [255:0]u8, }, // https://github.com/SerenityOS/serenity/blob/d794ed1de7a46482272683f8dc4c858806390f29/Kernel/API/POSIX/sys/utsname.h#L17-L23 .serenity => extern struct { @@ -6984,7 +6984,7 @@ pub const utsname = switch (native_os) { version: [UTSNAME_ENTRY_LEN:0]u8, machine: [UTSNAME_ENTRY_LEN:0]u8, - const UTSNAME_ENTRY_LEN = 65; + const UTSNAME_ENTRY_LEN = 64; }, else => void, }; From 6f545683f3d53917df816ca6cd5c9c20474071f3 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 23:14:54 -0700 Subject: [PATCH 39/70] std: replace various mem copies with `@memmove` --- lib/std/Progress.zig | 2 +- lib/std/array_list.zig | 15 ++++++--------- lib/std/math/big/int.zig | 8 +++----- lib/std/os/windows.zig | 11 +++-------- src/Sema.zig | 6 +++--- 5 files changed, 16 insertions(+), 26 deletions(-) diff --git a/lib/std/Progress.zig b/lib/std/Progress.zig index 8b741187e7..58d409ac07 100644 --- a/lib/std/Progress.zig +++ b/lib/std/Progress.zig @@ -1006,7 +1006,7 @@ fn serializeIpc(start_serialized_len: usize, serialized_buffer: *Serialized.Buff continue; } const src = pipe_buf[m.remaining_read_trash_bytes..n]; - std.mem.copyForwards(u8, &pipe_buf, src); + @memmove(pipe_buf[0..src.len], src); m.remaining_read_trash_bytes = 0; bytes_read = src.len; continue; diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index c3fade794f..02e4608399 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -158,7 +158,7 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?mem.Alignment) ty assert(self.items.len < self.capacity); self.items.len += 1; - mem.copyBackwards(T, self.items[i + 1 .. self.items.len], self.items[i .. self.items.len - 1]); + @memmove(self.items[i + 1 .. self.items.len], self.items[i .. self.items.len - 1]); self.items[i] = item; } @@ -216,7 +216,7 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?mem.Alignment) ty assert(self.capacity >= new_len); const to_move = self.items[index..]; self.items.len = new_len; - mem.copyBackwards(T, self.items[index + count ..], to_move); + @memmove(self.items[index + count ..][0..to_move.len], to_move); const result = self.items[index..][0..count]; @memset(result, undefined); return result; @@ -746,7 +746,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig assert(self.items.len < self.capacity); self.items.len += 1; - mem.copyBackwards(T, self.items[i + 1 .. self.items.len], self.items[i .. self.items.len - 1]); + @memmove(self.items[i + 1 .. self.items.len], self.items[i .. self.items.len - 1]); self.items[i] = item; } @@ -782,7 +782,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig assert(self.capacity >= new_len); const to_move = self.items[index..]; self.items.len = new_len; - mem.copyBackwards(T, self.items[index + count ..], to_move); + @memmove(self.items[index + count ..][0..to_move.len], to_move); const result = self.items[index..][0..count]; @memset(result, undefined); return result; @@ -848,11 +848,8 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } else { const extra = range.len - new_items.len; @memcpy(range[0..new_items.len], new_items); - std.mem.copyForwards( - T, - self.items[after_range - extra ..], - self.items[after_range..], - ); + const src = self.items[after_range..]; + @memmove(self.items[after_range - extra ..][0..src.len], src); @memset(self.items[self.items.len - extra ..], undefined); self.items.len -= extra; } diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index f07f5e422d..685dbead52 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -1710,7 +1710,7 @@ pub const Mutable = struct { if (xy_trailing != 0 and r.limbs[r.len - 1] != 0) { // Manually shift here since we know its limb aligned. - mem.copyBackwards(Limb, r.limbs[xy_trailing..], r.limbs[0..r.len]); + @memmove(r.limbs[xy_trailing..][0..r.len], r.limbs[0..r.len]); @memset(r.limbs[0..xy_trailing], 0); r.len += xy_trailing; } @@ -3836,8 +3836,7 @@ fn llshl(r: []Limb, a: []const Limb, shift: usize) usize { std.debug.assert(@intFromPtr(r.ptr) >= @intFromPtr(a.ptr)); if (shift == 0) { - if (a.ptr != r.ptr) - std.mem.copyBackwards(Limb, r[0..a.len], a); + if (a.ptr != r.ptr) @memmove(r[0..a.len], a); return a.len; } if (shift >= limb_bits) { @@ -3891,8 +3890,7 @@ fn llshr(r: []Limb, a: []const Limb, shift: usize) usize { if (shift == 0) { std.debug.assert(r.len >= a.len); - if (a.ptr != r.ptr) - std.mem.copyForwards(Limb, r[0..a.len], a); + if (a.ptr != r.ptr) @memmove(r[0..a.len], a); return a.len; } if (shift >= limb_bits) { diff --git a/lib/std/os/windows.zig b/lib/std/os/windows.zig index 829a8c37cc..df325e1785 100644 --- a/lib/std/os/windows.zig +++ b/lib/std/os/windows.zig @@ -1332,7 +1332,7 @@ pub fn GetFinalPathNameByHandle( // dropping the \Device\Mup\ and making sure the path begins with \\ if (mem.eql(u16, device_name_u16, std.unicode.utf8ToUtf16LeStringLiteral("Mup"))) { out_buffer[0] = '\\'; - mem.copyForwards(u16, out_buffer[1..][0..file_name_u16.len], file_name_u16); + @memmove(out_buffer[1..][0..file_name_u16.len], file_name_u16); return out_buffer[0 .. 1 + file_name_u16.len]; } @@ -1400,7 +1400,7 @@ pub fn GetFinalPathNameByHandle( if (out_buffer.len < drive_letter.len + file_name_u16.len) return error.NameTooLong; @memcpy(out_buffer[0..drive_letter.len], drive_letter); - mem.copyForwards(u16, out_buffer[drive_letter.len..][0..file_name_u16.len], file_name_u16); + @memmove(out_buffer[drive_letter.len..][0..file_name_u16.len], file_name_u16); const total_len = drive_letter.len + file_name_u16.len; // Validate that DOS does not contain any spurious nul bytes. @@ -1449,12 +1449,7 @@ pub fn GetFinalPathNameByHandle( // to copy backwards. We also need to do this before copying the volume path because // it could overwrite the file_name_u16 memory. const file_name_dest = out_buffer[volume_path.len..][0..file_name_u16.len]; - const file_name_byte_offset = @intFromPtr(file_name_u16.ptr) - @intFromPtr(out_buffer.ptr); - const file_name_index = file_name_byte_offset / @sizeOf(u16); - if (volume_path.len > file_name_index) - mem.copyBackwards(u16, file_name_dest, file_name_u16) - else - mem.copyForwards(u16, file_name_dest, file_name_u16); + @memmove(file_name_dest, file_name_u16); @memcpy(out_buffer[0..volume_path.len], volume_path); const total_len = volume_path.len + file_name_u16.len; diff --git a/src/Sema.zig b/src/Sema.zig index 29c8daa8b3..5816990eb2 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2631,7 +2631,7 @@ fn reparentOwnedErrorMsg( const orig_notes = msg.notes.len; msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1); - std.mem.copyBackwards(Zcu.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]); + @memmove(msg.notes[1..][0..orig_notes], msg.notes[0..orig_notes]); msg.notes[0] = .{ .src_loc = msg.src_loc, .msg = msg.msg, @@ -14464,8 +14464,8 @@ fn analyzeTupleMul( } } for (0..factor) |i| { - mem.copyForwards(InternPool.Index, types[tuple_len * i ..], types[0..tuple_len]); - mem.copyForwards(InternPool.Index, values[tuple_len * i ..], values[0..tuple_len]); + @memmove(types[tuple_len * i ..][0..tuple_len], types[0..tuple_len]); + @memmove(values[tuple_len * i ..][0..tuple_len], values[0..tuple_len]); } break :rs runtime_src; }; From c47ec4f3d7a6bf79be3adcffa33aa51bcc26ed0b Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 23:37:54 -0700 Subject: [PATCH 40/70] std.array_list: add bounded methods --- lib/std/array_list.zig | 174 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 167 insertions(+), 7 deletions(-) diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index 02e4608399..c866a34e03 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -738,9 +738,13 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Insert `item` at index `i`. Moves `list[i .. list.len]` to higher indices to make room. - /// If in` is equal to the length of the list this operation is equivalent to append. + /// + /// If `i` is equal to the length of the list this operation is equivalent to append. + /// /// This operation is O(N). + /// /// Asserts that the list has capacity for one additional item. + /// /// Asserts that the index is in bounds or equal to the length. pub fn insertAssumeCapacity(self: *Self, i: usize, item: T) void { assert(self.items.len < self.capacity); @@ -750,6 +754,21 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig self.items[i] = item; } + /// Insert `item` at index `i`, moving `list[i .. list.len]` to higher indices to make room. + /// + /// If `i` is equal to the length of the list this operation is equivalent to append. + /// + /// This operation is O(N). + /// + /// If the list lacks unused capacity for the additional item, returns + /// `error.OutOfMemory`. + /// + /// Asserts that the index is in bounds or equal to the length. + pub fn insertBounded(self: *Self, i: usize, item: T) error{OutOfMemory}!void { + if (self.capacity - self.items.len == 0) return error.OutOfMemory; + return insertAssumeCapacity(self, i, item); + } + /// Add `count` new elements at position `index`, which have /// `undefined` values. Returns a slice pointing to the newly allocated /// elements, which becomes invalid after various `ArrayList` @@ -788,6 +807,23 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig return result; } + /// Add `count` new elements at position `index`, which have + /// `undefined` values, returning a slice pointing to the newly + /// allocated elements, which becomes invalid after various `ArrayList` + /// operations. + /// + /// Invalidates pre-existing pointers to elements at and after `index`, but + /// does not invalidate any before that. + /// + /// If the list lacks unused capacity for the additional items, returns + /// `error.OutOfMemory`. + /// + /// Asserts that the index is in bounds or equal to the length. + pub fn addManyAtBounded(self: *Self, index: usize, count: usize) error{OutOfMemory}![]T { + if (self.capacity - self.items.len < count) return error.OutOfMemory; + return addManyAtAssumeCapacity(self, index, count); + } + /// Insert slice `items` at index `i` by moving `list[i .. list.len]` to make room. /// This operation is O(N). /// Invalidates pre-existing pointers to elements at and after `index`. @@ -831,7 +867,9 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Grows or shrinks the list as necessary. + /// /// Never invalidates element pointers. + /// /// Asserts the capacity is enough for additional items. pub fn replaceRangeAssumeCapacity(self: *Self, start: usize, len: usize, new_items: []const T) void { const after_range = start + len; @@ -855,6 +893,17 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } } + /// Grows or shrinks the list as necessary. + /// + /// Never invalidates element pointers. + /// + /// If the unused capacity is insufficient for additional items, + /// returns `error.OutOfMemory`. + pub fn replaceRangeBounded(self: *Self, start: usize, len: usize, new_items: []const T) error{OutOfMemory}!void { + if (self.capacity - self.items.len < new_items.len -| len) return error.OutOfMemory; + return replaceRangeAssumeCapacity(self, start, len, new_items); + } + /// Extend the list by 1 element. Allocates more memory as necessary. /// Invalidates element pointers if additional memory is needed. pub fn append(self: *Self, gpa: Allocator, item: T) Allocator.Error!void { @@ -863,12 +912,25 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Extend the list by 1 element. + /// /// Never invalidates element pointers. + /// /// Asserts that the list can hold one additional item. pub fn appendAssumeCapacity(self: *Self, item: T) void { self.addOneAssumeCapacity().* = item; } + /// Extend the list by 1 element. + /// + /// Never invalidates element pointers. + /// + /// If the list lacks unused capacity for the additional item, returns + /// `error.OutOfMemory`. + pub fn appendBounded(self: *Self, item: T) error{OutOfMemory}!void { + if (self.capacity - self.items.len == 0) return error.OutOfMemory; + return appendAssumeCapacity(self, item); + } + /// Remove the element at index `i` from the list and return its value. /// Invalidates pointers to the last element. /// This operation is O(N). @@ -903,6 +965,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Append the slice of items to the list. + /// /// Asserts that the list can hold the additional items. pub fn appendSliceAssumeCapacity(self: *Self, items: []const T) void { const old_len = self.items.len; @@ -912,6 +975,14 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig @memcpy(self.items[old_len..][0..items.len], items); } + /// Append the slice of items to the list. + /// + /// If the list lacks unused capacity for the additional items, returns `error.OutOfMemory`. + pub fn appendSliceBounded(self: *Self, items: []const T) error{OutOfMemory}!void { + if (self.capacity - self.items.len < items.len) return error.OutOfMemory; + return appendSliceAssumeCapacity(self, items); + } + /// Append the slice of items to the list. Allocates more /// memory as necessary. Only call this function if a call to `appendSlice` instead would /// be a compile error. @@ -922,8 +993,10 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Append an unaligned slice of items to the list. - /// Only call this function if a call to `appendSliceAssumeCapacity` - /// instead would be a compile error. + /// + /// Intended to be used only when `appendSliceAssumeCapacity` would be + /// a compile error. + /// /// Asserts that the list can hold the additional items. pub fn appendUnalignedSliceAssumeCapacity(self: *Self, items: []align(1) const T) void { const old_len = self.items.len; @@ -933,6 +1006,18 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig @memcpy(self.items[old_len..][0..items.len], items); } + /// Append an unaligned slice of items to the list. + /// + /// Intended to be used only when `appendSliceAssumeCapacity` would be + /// a compile error. + /// + /// If the list lacks unused capacity for the additional items, returns + /// `error.OutOfMemory`. + pub fn appendUnalignedSliceBounded(self: *Self, items: []align(1) const T) error{OutOfMemory}!void { + if (self.capacity - self.items.len < items.len) return error.OutOfMemory; + return appendUnalignedSliceAssumeCapacity(self, items); + } + pub fn print(self: *Self, gpa: Allocator, comptime fmt: []const u8, args: anytype) error{OutOfMemory}!void { comptime assert(T == u8); try self.ensureUnusedCapacity(gpa, fmt.len); @@ -950,6 +1035,13 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig self.items.len += w.end; } + pub fn printBounded(self: *Self, comptime fmt: []const u8, args: anytype) error{OutOfMemory}!void { + comptime assert(T == u8); + var w: std.io.Writer = .fixed(self.unusedCapacitySlice()); + w.print(fmt, args) catch return error.OutOfMemory; + self.items.len += w.end; + } + /// Deprecated in favor of `print` or `std.io.Writer.Allocating`. pub const WriterContext = struct { self: *Self, @@ -1004,9 +1096,12 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Append a value to the list `n` times. + /// /// Never invalidates element pointers. + /// /// The function is inline so that a comptime-known `value` parameter will /// have better memset codegen in case it has a repeated byte pattern. + /// /// Asserts that the list can hold the additional items. pub inline fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void { const new_len = self.items.len + n; @@ -1015,6 +1110,22 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig self.items.len = new_len; } + /// Append a value to the list `n` times. + /// + /// Never invalidates element pointers. + /// + /// The function is inline so that a comptime-known `value` parameter will + /// have better memset codegen in case it has a repeated byte pattern. + /// + /// If the list lacks unused capacity for the additional items, returns + /// `error.OutOfMemory`. + pub inline fn appendNTimesBounded(self: *Self, value: T, n: usize) error{OutOfMemory}!void { + const new_len = self.items.len + n; + if (self.capacity < new_len) return error.OutOfMemory; + @memset(self.items.ptr[self.items.len..new_len], value); + self.items.len = new_len; + } + /// Adjust the list length to `new_len`. /// Additional elements contain the value `undefined`. /// Invalidates element pointers if additional memory is needed. @@ -1140,8 +1251,11 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Increase length by 1, returning pointer to the new item. + /// /// Never invalidates element pointers. + /// /// The returned element pointer becomes invalid when the list is resized. + /// /// Asserts that the list can hold one additional item. pub fn addOneAssumeCapacity(self: *Self) *T { assert(self.items.len < self.capacity); @@ -1150,6 +1264,18 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig return &self.items[self.items.len - 1]; } + /// Increase length by 1, returning pointer to the new item. + /// + /// Never invalidates element pointers. + /// + /// The returned element pointer becomes invalid when the list is resized. + /// + /// If the list lacks unused capacity for the additional item, returns `error.OutOfMemory`. + pub fn addOneBounded(self: *Self) error{OutOfMemory}!*T { + if (self.capacity - self.items.len < 1) return error.OutOfMemory; + return addOneAssumeCapacity(self); + } + /// Resize the array, adding `n` new elements, which have `undefined` values. /// The return value is an array pointing to the newly allocated elements. /// The returned pointer becomes invalid when the list is resized. @@ -1160,9 +1286,13 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig } /// Resize the array, adding `n` new elements, which have `undefined` values. + /// /// The return value is an array pointing to the newly allocated elements. + /// /// Never invalidates element pointers. + /// /// The returned pointer becomes invalid when the list is resized. + /// /// Asserts that the list can hold the additional items. pub fn addManyAsArrayAssumeCapacity(self: *Self, comptime n: usize) *[n]T { assert(self.items.len + n <= self.capacity); @@ -1171,6 +1301,21 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig return self.items[prev_len..][0..n]; } + /// Resize the array, adding `n` new elements, which have `undefined` values. + /// + /// The return value is an array pointing to the newly allocated elements. + /// + /// Never invalidates element pointers. + /// + /// The returned pointer becomes invalid when the list is resized. + /// + /// If the list lacks unused capacity for the additional items, returns + /// `error.OutOfMemory`. + pub fn addManyAsArrayBounded(self: *Self, comptime n: usize) error{OutOfMemory}!*[n]T { + if (self.capacity - self.items.len < n) return error.OutOfMemory; + return addManyAsArrayAssumeCapacity(self, n); + } + /// Resize the array, adding `n` new elements, which have `undefined` values. /// The return value is a slice pointing to the newly allocated elements. /// The returned pointer becomes invalid when the list is resized. @@ -1181,10 +1326,12 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig return self.items[prev_len..][0..n]; } - /// Resize the array, adding `n` new elements, which have `undefined` values. - /// The return value is a slice pointing to the newly allocated elements. - /// Never invalidates element pointers. - /// The returned pointer becomes invalid when the list is resized. + /// Resizes the array, adding `n` new elements, which have `undefined` + /// values, returning a slice pointing to the newly allocated elements. + /// + /// Never invalidates element pointers. The returned pointer becomes + /// invalid when the list is resized. + /// /// Asserts that the list can hold the additional items. pub fn addManyAsSliceAssumeCapacity(self: *Self, n: usize) []T { assert(self.items.len + n <= self.capacity); @@ -1193,6 +1340,19 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig return self.items[prev_len..][0..n]; } + /// Resizes the array, adding `n` new elements, which have `undefined` + /// values, returning a slice pointing to the newly allocated elements. + /// + /// Never invalidates element pointers. The returned pointer becomes + /// invalid when the list is resized. + /// + /// If the list lacks unused capacity for the additional items, returns + /// `error.OutOfMemory`. + pub fn addManyAsSliceBounded(self: *Self, n: usize) error{OutOfMemory}![]T { + if (self.capacity - self.items.len < n) return error.OutOfMemory; + return addManyAsSliceAssumeCapacity(self, n); + } + /// Remove and return the last element from the list. /// If the list is empty, returns `null`. /// Invalidates pointers to last element. From 196e36bbb27b0f0ebd7cd7a866b85f477b3662fb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 23:51:45 -0700 Subject: [PATCH 41/70] std: remove BoundedArray This use case is handled by ArrayListUnmanaged via the "...Bounded" method variants, and it's more optimal to share machine code, versus generating multiple versions of each function for differing array lengths. --- doc/langref/test_switch_dispatch_loop.zig | 8 +- lib/docs/wasm/markdown/Parser.zig | 66 ++-- lib/std/Io.zig | 15 - lib/std/Io/DeprecatedReader.zig | 27 -- lib/std/Io/Reader/test.zig | 21 -- lib/std/array_list.zig | 1 + lib/std/base64.zig | 22 +- lib/std/bounded_array.zig | 412 ---------------------- lib/std/std.zig | 2 - src/Compilation.zig | 18 +- src/Zcu.zig | 11 +- src/Zcu/PerThread.zig | 5 +- 12 files changed, 82 insertions(+), 526 deletions(-) delete mode 100644 lib/std/bounded_array.zig diff --git a/doc/langref/test_switch_dispatch_loop.zig b/doc/langref/test_switch_dispatch_loop.zig index e3b891595e..1756549094 100644 --- a/doc/langref/test_switch_dispatch_loop.zig +++ b/doc/langref/test_switch_dispatch_loop.zig @@ -8,20 +8,22 @@ const Instruction = enum { }; fn evaluate(initial_stack: []const i32, code: []const Instruction) !i32 { - var stack = try std.BoundedArray(i32, 8).fromSlice(initial_stack); + var buffer: [8]i32 = undefined; + var stack = std.ArrayListUnmanaged(i32).initBuffer(&buffer); + try stack.appendSliceBounded(initial_stack); var ip: usize = 0; return vm: switch (code[ip]) { // Because all code after `continue` is unreachable, this branch does // not provide a result. .add => { - try stack.append(stack.pop().? + stack.pop().?); + try stack.appendBounded(stack.pop().? + stack.pop().?); ip += 1; continue :vm code[ip]; }, .mul => { - try stack.append(stack.pop().? * stack.pop().?); + try stack.appendBounded(stack.pop().? * stack.pop().?); ip += 1; continue :vm code[ip]; diff --git a/lib/docs/wasm/markdown/Parser.zig b/lib/docs/wasm/markdown/Parser.zig index ce8db08d96..2c9fff2ae8 100644 --- a/lib/docs/wasm/markdown/Parser.zig +++ b/lib/docs/wasm/markdown/Parser.zig @@ -29,13 +29,14 @@ const Node = Document.Node; const ExtraIndex = Document.ExtraIndex; const ExtraData = Document.ExtraData; const StringIndex = Document.StringIndex; +const ArrayList = std.ArrayListUnmanaged; nodes: Node.List = .{}, -extra: std.ArrayListUnmanaged(u32) = .empty, -scratch_extra: std.ArrayListUnmanaged(u32) = .empty, -string_bytes: std.ArrayListUnmanaged(u8) = .empty, -scratch_string: std.ArrayListUnmanaged(u8) = .empty, -pending_blocks: std.ArrayListUnmanaged(Block) = .empty, +extra: ArrayList(u32) = .empty, +scratch_extra: ArrayList(u32) = .empty, +string_bytes: ArrayList(u8) = .empty, +scratch_string: ArrayList(u8) = .empty, +pending_blocks: ArrayList(Block) = .empty, allocator: Allocator, const Parser = @This(); @@ -86,7 +87,8 @@ const Block = struct { continuation_indent: usize, }, table: struct { - column_alignments: std.BoundedArray(Node.TableCellAlignment, max_table_columns) = .{}, + column_alignments_buffer: [max_table_columns]Node.TableCellAlignment, + column_alignments_len: usize, }, heading: struct { /// Between 1 and 6, inclusive. @@ -354,7 +356,8 @@ const BlockStart = struct { continuation_indent: usize, }, table_row: struct { - cells: std.BoundedArray([]const u8, max_table_columns), + cells_buffer: [max_table_columns][]const u8, + cells_len: usize, }, heading: struct { /// Between 1 and 6, inclusive. @@ -422,7 +425,8 @@ fn appendBlockStart(p: *Parser, block_start: BlockStart) !void { try p.pending_blocks.append(p.allocator, .{ .tag = .table, .data = .{ .table = .{ - .column_alignments = .{}, + .column_alignments_buffer = undefined, + .column_alignments_len = 0, } }, .string_start = p.scratch_string.items.len, .extra_start = p.scratch_extra.items.len, @@ -431,15 +435,19 @@ fn appendBlockStart(p: *Parser, block_start: BlockStart) !void { const current_row = p.scratch_extra.items.len - p.pending_blocks.getLast().extra_start; if (current_row <= 1) { - if (parseTableHeaderDelimiter(block_start.data.table_row.cells)) |alignments| { - p.pending_blocks.items[p.pending_blocks.items.len - 1].data.table.column_alignments = alignments; + var buffer: [max_table_columns]Node.TableCellAlignment = undefined; + const table_row = &block_start.data.table_row; + if (parseTableHeaderDelimiter(table_row.cells_buffer[0..table_row.cells_len], &buffer)) |alignments| { + const table = &p.pending_blocks.items[p.pending_blocks.items.len - 1].data.table; + @memcpy(table.column_alignments_buffer[0..alignments.len], alignments); + table.column_alignments_len = alignments.len; if (current_row == 1) { // We need to go back and mark the header row and its column // alignments. const datas = p.nodes.items(.data); const header_data = datas[p.scratch_extra.getLast()]; for (p.extraChildren(header_data.container.children), 0..) |header_cell, i| { - const alignment = if (i < alignments.len) alignments.buffer[i] else .unset; + const alignment = if (i < alignments.len) alignments[i] else .unset; const cell_data = &datas[@intFromEnum(header_cell)].table_cell; cell_data.info.alignment = alignment; cell_data.info.header = true; @@ -480,8 +488,10 @@ fn appendBlockStart(p: *Parser, block_start: BlockStart) !void { // available in the BlockStart. We can immediately parse and append // these children now. const containing_table = p.pending_blocks.items[p.pending_blocks.items.len - 2]; - const column_alignments = containing_table.data.table.column_alignments.slice(); - for (block_start.data.table_row.cells.slice(), 0..) |cell_content, i| { + const table = &containing_table.data.table; + const column_alignments = table.column_alignments_buffer[0..table.column_alignments_len]; + const table_row = &block_start.data.table_row; + for (table_row.cells_buffer[0..table_row.cells_len], 0..) |cell_content, i| { const cell_children = try p.parseInlines(cell_content); const alignment = if (i < column_alignments.len) column_alignments[i] else .unset; const cell = try p.addNode(.{ @@ -523,7 +533,8 @@ fn startBlock(p: *Parser, line: []const u8) !?BlockStart { return .{ .tag = .table_row, .data = .{ .table_row = .{ - .cells = table_row.cells, + .cells_buffer = table_row.cells_buffer, + .cells_len = table_row.cells_len, } }, .rest = "", }; @@ -606,7 +617,8 @@ fn startListItem(unindented_line: []const u8) ?ListItemStart { } const TableRowStart = struct { - cells: std.BoundedArray([]const u8, max_table_columns), + cells_buffer: [max_table_columns][]const u8, + cells_len: usize, }; fn startTableRow(unindented_line: []const u8) ?TableRowStart { @@ -615,7 +627,8 @@ fn startTableRow(unindented_line: []const u8) ?TableRowStart { mem.endsWith(u8, unindented_line, "\\|") or !mem.endsWith(u8, unindented_line, "|")) return null; - var cells: std.BoundedArray([]const u8, max_table_columns) = .{}; + var cells_buffer: [max_table_columns][]const u8 = undefined; + var cells: ArrayList([]const u8) = .initBuffer(&cells_buffer); const table_row_content = unindented_line[1 .. unindented_line.len - 1]; var cell_start: usize = 0; var i: usize = 0; @@ -623,7 +636,7 @@ fn startTableRow(unindented_line: []const u8) ?TableRowStart { switch (table_row_content[i]) { '\\' => i += 1, '|' => { - cells.append(table_row_content[cell_start..i]) catch return null; + cells.appendBounded(table_row_content[cell_start..i]) catch return null; cell_start = i + 1; }, '`' => { @@ -641,20 +654,21 @@ fn startTableRow(unindented_line: []const u8) ?TableRowStart { else => {}, } } - cells.append(table_row_content[cell_start..]) catch return null; + cells.appendBounded(table_row_content[cell_start..]) catch return null; - return .{ .cells = cells }; + return .{ .cells_buffer = cells_buffer, .cells_len = cells.items.len }; } fn parseTableHeaderDelimiter( - row_cells: std.BoundedArray([]const u8, max_table_columns), -) ?std.BoundedArray(Node.TableCellAlignment, max_table_columns) { - var alignments: std.BoundedArray(Node.TableCellAlignment, max_table_columns) = .{}; - for (row_cells.slice()) |content| { + row_cells: []const []const u8, + buffer: []Node.TableCellAlignment, +) ?[]Node.TableCellAlignment { + var alignments: ArrayList(Node.TableCellAlignment) = .initBuffer(buffer); + for (row_cells) |content| { const alignment = parseTableHeaderDelimiterCell(content) orelse return null; alignments.appendAssumeCapacity(alignment); } - return alignments; + return alignments.items; } fn parseTableHeaderDelimiterCell(content: []const u8) ?Node.TableCellAlignment { @@ -928,8 +942,8 @@ const InlineParser = struct { parent: *Parser, content: []const u8, pos: usize = 0, - pending_inlines: std.ArrayListUnmanaged(PendingInline) = .empty, - completed_inlines: std.ArrayListUnmanaged(CompletedInline) = .empty, + pending_inlines: ArrayList(PendingInline) = .empty, + completed_inlines: ArrayList(CompletedInline) = .empty, const PendingInline = struct { tag: Tag, diff --git a/lib/std/Io.zig b/lib/std/Io.zig index c55c28f177..688120c08b 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -231,21 +231,6 @@ pub fn GenericReader( return @errorCast(self.any().readBytesNoEof(num_bytes)); } - pub inline fn readIntoBoundedBytes( - self: Self, - comptime num_bytes: usize, - bounded: *std.BoundedArray(u8, num_bytes), - ) Error!void { - return @errorCast(self.any().readIntoBoundedBytes(num_bytes, bounded)); - } - - pub inline fn readBoundedBytes( - self: Self, - comptime num_bytes: usize, - ) Error!std.BoundedArray(u8, num_bytes) { - return @errorCast(self.any().readBoundedBytes(num_bytes)); - } - pub inline fn readInt(self: Self, comptime T: type, endian: std.builtin.Endian) NoEofError!T { return @errorCast(self.any().readInt(T, endian)); } diff --git a/lib/std/Io/DeprecatedReader.zig b/lib/std/Io/DeprecatedReader.zig index 59f163b39c..2e99328e42 100644 --- a/lib/std/Io/DeprecatedReader.zig +++ b/lib/std/Io/DeprecatedReader.zig @@ -249,33 +249,6 @@ pub fn readBytesNoEof(self: Self, comptime num_bytes: usize) anyerror![num_bytes return bytes; } -/// Reads bytes until `bounded.len` is equal to `num_bytes`, -/// or the stream ends. -/// -/// * it is assumed that `num_bytes` will not exceed `bounded.capacity()` -pub fn readIntoBoundedBytes( - self: Self, - comptime num_bytes: usize, - bounded: *std.BoundedArray(u8, num_bytes), -) anyerror!void { - while (bounded.len < num_bytes) { - // get at most the number of bytes free in the bounded array - const bytes_read = try self.read(bounded.unusedCapacitySlice()); - if (bytes_read == 0) return; - - // bytes_read will never be larger than @TypeOf(bounded.len) - // due to `self.read` being bounded by `bounded.unusedCapacitySlice()` - bounded.len += @as(@TypeOf(bounded.len), @intCast(bytes_read)); - } -} - -/// Reads at most `num_bytes` and returns as a bounded array. -pub fn readBoundedBytes(self: Self, comptime num_bytes: usize) anyerror!std.BoundedArray(u8, num_bytes) { - var result = std.BoundedArray(u8, num_bytes){}; - try self.readIntoBoundedBytes(num_bytes, &result); - return result; -} - pub inline fn readInt(self: Self, comptime T: type, endian: std.builtin.Endian) anyerror!T { const bytes = try self.readBytesNoEof(@divExact(@typeInfo(T).int.bits, 8)); return mem.readInt(T, &bytes, endian); diff --git a/lib/std/Io/Reader/test.zig b/lib/std/Io/Reader/test.zig index 30f0e1269c..69b7bcdbda 100644 --- a/lib/std/Io/Reader/test.zig +++ b/lib/std/Io/Reader/test.zig @@ -349,24 +349,3 @@ test "streamUntilDelimiter writes all bytes without delimiter to the output" { try std.testing.expectError(error.StreamTooLong, reader.streamUntilDelimiter(writer, '!', 5)); } - -test "readBoundedBytes correctly reads into a new bounded array" { - const test_string = "abcdefg"; - var fis = std.io.fixedBufferStream(test_string); - const reader = fis.reader(); - - var array = try reader.readBoundedBytes(10000); - try testing.expectEqualStrings(array.slice(), test_string); -} - -test "readIntoBoundedBytes correctly reads into a provided bounded array" { - const test_string = "abcdefg"; - var fis = std.io.fixedBufferStream(test_string); - const reader = fis.reader(); - - var bounded_array = std.BoundedArray(u8, 10000){}; - - // compile time error if the size is not the same at the provided `bounded.capacity()` - try reader.readIntoBoundedBytes(10000, &bounded_array); - try testing.expectEqualStrings(bounded_array.slice(), test_string); -} diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index c866a34e03..8af36a4a8e 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -657,6 +657,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?mem.Alig /// Initialize with externally-managed memory. The buffer determines the /// capacity, and the length is set to zero. + /// /// When initialized this way, all functions that accept an Allocator /// argument cause illegal behavior. pub fn initBuffer(buffer: Slice) Self { diff --git a/lib/std/base64.zig b/lib/std/base64.zig index a84f4a0b4f..37d3b70c20 100644 --- a/lib/std/base64.zig +++ b/lib/std/base64.zig @@ -517,17 +517,21 @@ fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: [ var buffer: [0x100]u8 = undefined; const encoded = codecs.Encoder.encode(&buffer, expected_decoded); try testing.expectEqualSlices(u8, expected_encoded, encoded); - + } + { // stream encode - var list = try std.BoundedArray(u8, 0x100).init(0); - try codecs.Encoder.encodeWriter(list.writer(), expected_decoded); - try testing.expectEqualSlices(u8, expected_encoded, list.slice()); - + var buffer: [0x100]u8 = undefined; + var writer: std.Io.Writer = .fixed(&buffer); + try codecs.Encoder.encodeWriter(&writer, expected_decoded); + try testing.expectEqualSlices(u8, expected_encoded, writer.buffered()); + } + { // reader to writer encode - var stream = std.io.fixedBufferStream(expected_decoded); - list = try std.BoundedArray(u8, 0x100).init(0); - try codecs.Encoder.encodeFromReaderToWriter(list.writer(), stream.reader()); - try testing.expectEqualSlices(u8, expected_encoded, list.slice()); + var stream: std.Io.Reader = .fixed(expected_decoded); + var buffer: [0x100]u8 = undefined; + var writer: std.Io.Writer = .fixed(&buffer); + try codecs.Encoder.encodeFromReaderToWriter(&writer, &stream); + try testing.expectEqualSlices(u8, expected_encoded, writer.buffered()); } // Base64Decoder diff --git a/lib/std/bounded_array.zig b/lib/std/bounded_array.zig deleted file mode 100644 index 7864dfb775..0000000000 --- a/lib/std/bounded_array.zig +++ /dev/null @@ -1,412 +0,0 @@ -const std = @import("std.zig"); -const assert = std.debug.assert; -const mem = std.mem; -const testing = std.testing; -const Alignment = std.mem.Alignment; - -/// A structure with an array and a length, that can be used as a slice. -/// -/// Useful to pass around small arrays whose exact size is only known at -/// runtime, but whose maximum size is known at comptime, without requiring -/// an `Allocator`. -/// -/// ```zig -/// var actual_size = 32; -/// var a = try BoundedArray(u8, 64).init(actual_size); -/// var slice = a.slice(); // a slice of the 64-byte array -/// var a_clone = a; // creates a copy - the structure doesn't use any internal pointers -/// ``` -pub fn BoundedArray(comptime T: type, comptime buffer_capacity: usize) type { - return BoundedArrayAligned(T, .of(T), buffer_capacity); -} - -/// A structure with an array, length and alignment, that can be used as a -/// slice. -/// -/// Useful to pass around small explicitly-aligned arrays whose exact size is -/// only known at runtime, but whose maximum size is known at comptime, without -/// requiring an `Allocator`. -/// ```zig -// var a = try BoundedArrayAligned(u8, 16, 2).init(0); -// try a.append(255); -// try a.append(255); -// const b = @ptrCast(*const [1]u16, a.constSlice().ptr); -// try testing.expectEqual(@as(u16, 65535), b[0]); -/// ``` -pub fn BoundedArrayAligned( - comptime T: type, - comptime alignment: Alignment, - comptime buffer_capacity: usize, -) type { - return struct { - const Self = @This(); - buffer: [buffer_capacity]T align(alignment.toByteUnits()) = undefined, - len: usize = 0, - - /// Set the actual length of the slice. - /// Returns error.Overflow if it exceeds the length of the backing array. - pub fn init(len: usize) error{Overflow}!Self { - if (len > buffer_capacity) return error.Overflow; - return Self{ .len = len }; - } - - /// View the internal array as a slice whose size was previously set. - pub fn slice(self: anytype) switch (@TypeOf(&self.buffer)) { - *align(alignment.toByteUnits()) [buffer_capacity]T => []align(alignment.toByteUnits()) T, - *align(alignment.toByteUnits()) const [buffer_capacity]T => []align(alignment.toByteUnits()) const T, - else => unreachable, - } { - return self.buffer[0..self.len]; - } - - /// View the internal array as a constant slice whose size was previously set. - pub fn constSlice(self: *const Self) []align(alignment.toByteUnits()) const T { - return self.slice(); - } - - /// Adjust the slice's length to `len`. - /// Does not initialize added items if any. - pub fn resize(self: *Self, len: usize) error{Overflow}!void { - if (len > buffer_capacity) return error.Overflow; - self.len = len; - } - - /// Remove all elements from the slice. - pub fn clear(self: *Self) void { - self.len = 0; - } - - /// Copy the content of an existing slice. - pub fn fromSlice(m: []const T) error{Overflow}!Self { - var list = try init(m.len); - @memcpy(list.slice(), m); - return list; - } - - /// Return the element at index `i` of the slice. - pub fn get(self: Self, i: usize) T { - return self.constSlice()[i]; - } - - /// Set the value of the element at index `i` of the slice. - pub fn set(self: *Self, i: usize, item: T) void { - self.slice()[i] = item; - } - - /// Return the maximum length of a slice. - pub fn capacity(self: Self) usize { - return self.buffer.len; - } - - /// Check that the slice can hold at least `additional_count` items. - pub fn ensureUnusedCapacity(self: Self, additional_count: usize) error{Overflow}!void { - if (self.len + additional_count > buffer_capacity) { - return error.Overflow; - } - } - - /// Increase length by 1, returning a pointer to the new item. - pub fn addOne(self: *Self) error{Overflow}!*T { - try self.ensureUnusedCapacity(1); - return self.addOneAssumeCapacity(); - } - - /// Increase length by 1, returning pointer to the new item. - /// Asserts that there is space for the new item. - pub fn addOneAssumeCapacity(self: *Self) *T { - assert(self.len < buffer_capacity); - self.len += 1; - return &self.slice()[self.len - 1]; - } - - /// Resize the slice, adding `n` new elements, which have `undefined` values. - /// The return value is a pointer to the array of uninitialized elements. - pub fn addManyAsArray(self: *Self, comptime n: usize) error{Overflow}!*align(alignment.toByteUnits()) [n]T { - const prev_len = self.len; - try self.resize(self.len + n); - return self.slice()[prev_len..][0..n]; - } - - /// Resize the slice, adding `n` new elements, which have `undefined` values. - /// The return value is a slice pointing to the uninitialized elements. - pub fn addManyAsSlice(self: *Self, n: usize) error{Overflow}![]align(alignment.toByteUnits()) T { - const prev_len = self.len; - try self.resize(self.len + n); - return self.slice()[prev_len..][0..n]; - } - - /// Remove and return the last element from the slice, or return `null` if the slice is empty. - pub fn pop(self: *Self) ?T { - if (self.len == 0) return null; - const item = self.get(self.len - 1); - self.len -= 1; - return item; - } - - /// Return a slice of only the extra capacity after items. - /// This can be useful for writing directly into it. - /// Note that such an operation must be followed up with a - /// call to `resize()` - pub fn unusedCapacitySlice(self: *Self) []align(alignment.toByteUnits()) T { - return self.buffer[self.len..]; - } - - /// Insert `item` at index `i` by moving `slice[n .. slice.len]` to make room. - /// This operation is O(N). - pub fn insert( - self: *Self, - i: usize, - item: T, - ) error{Overflow}!void { - if (i > self.len) { - return error.Overflow; - } - _ = try self.addOne(); - var s = self.slice(); - mem.copyBackwards(T, s[i + 1 .. s.len], s[i .. s.len - 1]); - self.buffer[i] = item; - } - - /// Insert slice `items` at index `i` by moving `slice[i .. slice.len]` to make room. - /// This operation is O(N). - pub fn insertSlice(self: *Self, i: usize, items: []const T) error{Overflow}!void { - try self.ensureUnusedCapacity(items.len); - self.len += items.len; - mem.copyBackwards(T, self.slice()[i + items.len .. self.len], self.constSlice()[i .. self.len - items.len]); - @memcpy(self.slice()[i..][0..items.len], items); - } - - /// Replace range of elements `slice[start..][0..len]` with `new_items`. - /// Grows slice if `len < new_items.len`. - /// Shrinks slice if `len > new_items.len`. - pub fn replaceRange( - self: *Self, - start: usize, - len: usize, - new_items: []const T, - ) error{Overflow}!void { - const after_range = start + len; - var range = self.slice()[start..after_range]; - - if (range.len == new_items.len) { - @memcpy(range[0..new_items.len], new_items); - } else if (range.len < new_items.len) { - const first = new_items[0..range.len]; - const rest = new_items[range.len..]; - @memcpy(range[0..first.len], first); - try self.insertSlice(after_range, rest); - } else { - @memcpy(range[0..new_items.len], new_items); - const after_subrange = start + new_items.len; - for (self.constSlice()[after_range..], 0..) |item, i| { - self.slice()[after_subrange..][i] = item; - } - self.len -= len - new_items.len; - } - } - - /// Extend the slice by 1 element. - pub fn append(self: *Self, item: T) error{Overflow}!void { - const new_item_ptr = try self.addOne(); - new_item_ptr.* = item; - } - - /// Extend the slice by 1 element, asserting the capacity is already - /// enough to store the new item. - pub fn appendAssumeCapacity(self: *Self, item: T) void { - const new_item_ptr = self.addOneAssumeCapacity(); - new_item_ptr.* = item; - } - - /// Remove the element at index `i`, shift elements after index - /// `i` forward, and return the removed element. - /// Asserts the slice has at least one item. - /// This operation is O(N). - pub fn orderedRemove(self: *Self, i: usize) T { - const newlen = self.len - 1; - if (newlen == i) return self.pop().?; - const old_item = self.get(i); - for (self.slice()[i..newlen], 0..) |*b, j| b.* = self.get(i + 1 + j); - self.set(newlen, undefined); - self.len = newlen; - return old_item; - } - - /// Remove the element at the specified index and return it. - /// The empty slot is filled from the end of the slice. - /// This operation is O(1). - pub fn swapRemove(self: *Self, i: usize) T { - if (self.len - 1 == i) return self.pop().?; - const old_item = self.get(i); - self.set(i, self.pop().?); - return old_item; - } - - /// Append the slice of items to the slice. - pub fn appendSlice(self: *Self, items: []const T) error{Overflow}!void { - try self.ensureUnusedCapacity(items.len); - self.appendSliceAssumeCapacity(items); - } - - /// Append the slice of items to the slice, asserting the capacity is already - /// enough to store the new items. - pub fn appendSliceAssumeCapacity(self: *Self, items: []const T) void { - const old_len = self.len; - self.len += items.len; - @memcpy(self.slice()[old_len..][0..items.len], items); - } - - /// Append a value to the slice `n` times. - /// Allocates more memory as necessary. - pub fn appendNTimes(self: *Self, value: T, n: usize) error{Overflow}!void { - const old_len = self.len; - try self.resize(old_len + n); - @memset(self.slice()[old_len..self.len], value); - } - - /// Append a value to the slice `n` times. - /// Asserts the capacity is enough. - pub fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void { - const old_len = self.len; - self.len += n; - assert(self.len <= buffer_capacity); - @memset(self.slice()[old_len..self.len], value); - } - - pub const Writer = if (T != u8) - @compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++ - "but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)") - else - std.io.GenericWriter(*Self, error{Overflow}, appendWrite); - - /// Initializes a writer which will write into the array. - pub fn writer(self: *Self) Writer { - return .{ .context = self }; - } - - /// Same as `appendSlice` except it returns the number of bytes written, which is always the same - /// as `m.len`. The purpose of this function existing is to match `std.io.GenericWriter` API. - fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize { - try self.appendSlice(m); - return m.len; - } - }; -} - -test BoundedArray { - var a = try BoundedArray(u8, 64).init(32); - - try testing.expectEqual(a.capacity(), 64); - try testing.expectEqual(a.slice().len, 32); - try testing.expectEqual(a.constSlice().len, 32); - - try a.resize(48); - try testing.expectEqual(a.len, 48); - - const x = [_]u8{1} ** 10; - a = try BoundedArray(u8, 64).fromSlice(&x); - try testing.expectEqualSlices(u8, &x, a.constSlice()); - - var a2 = a; - try testing.expectEqualSlices(u8, a.constSlice(), a2.constSlice()); - a2.set(0, 0); - try testing.expect(a.get(0) != a2.get(0)); - - try testing.expectError(error.Overflow, a.resize(100)); - try testing.expectError(error.Overflow, BoundedArray(u8, x.len - 1).fromSlice(&x)); - - try a.resize(0); - try a.ensureUnusedCapacity(a.capacity()); - (try a.addOne()).* = 0; - try a.ensureUnusedCapacity(a.capacity() - 1); - try testing.expectEqual(a.len, 1); - - const uninitialized = try a.addManyAsArray(4); - try testing.expectEqual(uninitialized.len, 4); - try testing.expectEqual(a.len, 5); - - try a.append(0xff); - try testing.expectEqual(a.len, 6); - try testing.expectEqual(a.pop(), 0xff); - - a.appendAssumeCapacity(0xff); - try testing.expectEqual(a.len, 6); - try testing.expectEqual(a.pop(), 0xff); - - try a.resize(1); - try testing.expectEqual(a.pop(), 0); - try testing.expectEqual(a.pop(), null); - var unused = a.unusedCapacitySlice(); - @memset(unused[0..8], 2); - unused[8] = 3; - unused[9] = 4; - try testing.expectEqual(unused.len, a.capacity()); - try a.resize(10); - - try a.insert(5, 0xaa); - try testing.expectEqual(a.len, 11); - try testing.expectEqual(a.get(5), 0xaa); - try testing.expectEqual(a.get(9), 3); - try testing.expectEqual(a.get(10), 4); - - try a.insert(11, 0xbb); - try testing.expectEqual(a.len, 12); - try testing.expectEqual(a.pop(), 0xbb); - - try a.appendSlice(&x); - try testing.expectEqual(a.len, 11 + x.len); - - try a.appendNTimes(0xbb, 5); - try testing.expectEqual(a.len, 11 + x.len + 5); - try testing.expectEqual(a.pop(), 0xbb); - - a.appendNTimesAssumeCapacity(0xcc, 5); - try testing.expectEqual(a.len, 11 + x.len + 5 - 1 + 5); - try testing.expectEqual(a.pop(), 0xcc); - - try testing.expectEqual(a.len, 29); - try a.replaceRange(1, 20, &x); - try testing.expectEqual(a.len, 29 + x.len - 20); - - try a.insertSlice(0, &x); - try testing.expectEqual(a.len, 29 + x.len - 20 + x.len); - - try a.replaceRange(1, 5, &x); - try testing.expectEqual(a.len, 29 + x.len - 20 + x.len + x.len - 5); - - try a.append(10); - try testing.expectEqual(a.pop(), 10); - - try a.append(20); - const removed = a.orderedRemove(5); - try testing.expectEqual(removed, 1); - try testing.expectEqual(a.len, 34); - - a.set(0, 0xdd); - a.set(a.len - 1, 0xee); - const swapped = a.swapRemove(0); - try testing.expectEqual(swapped, 0xdd); - try testing.expectEqual(a.get(0), 0xee); - - const added_slice = try a.addManyAsSlice(3); - try testing.expectEqual(added_slice.len, 3); - try testing.expectEqual(a.len, 36); - - while (a.pop()) |_| {} - const w = a.writer(); - const s = "hello, this is a test string"; - try w.writeAll(s); - try testing.expectEqualStrings(s, a.constSlice()); -} - -test "BoundedArrayAligned" { - var a = try BoundedArrayAligned(u8, .@"16", 4).init(0); - try a.append(0); - try a.append(0); - try a.append(255); - try a.append(255); - - const b = @as(*const [2]u16, @ptrCast(a.constSlice().ptr)); - try testing.expectEqual(@as(u16, 0), b[0]); - try testing.expectEqual(@as(u16, 65535), b[1]); -} diff --git a/lib/std/std.zig b/lib/std/std.zig index 564b04c609..891c0bc256 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -9,8 +9,6 @@ pub const AutoArrayHashMapUnmanaged = array_hash_map.AutoArrayHashMapUnmanaged; pub const AutoHashMap = hash_map.AutoHashMap; pub const AutoHashMapUnmanaged = hash_map.AutoHashMapUnmanaged; pub const BitStack = @import("BitStack.zig"); -pub const BoundedArray = @import("bounded_array.zig").BoundedArray; -pub const BoundedArrayAligned = @import("bounded_array.zig").BoundedArrayAligned; pub const Build = @import("Build.zig"); pub const BufMap = @import("buf_map.zig").BufMap; pub const BufSet = @import("buf_set.zig").BufSet; diff --git a/src/Compilation.zig b/src/Compilation.zig index 419c56019f..4024e0a49e 100644 --- a/src/Compilation.zig +++ b/src/Compilation.zig @@ -2103,6 +2103,8 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil .local_zir_cache = local_zir_cache, .error_limit = error_limit, .llvm_object = null, + .analysis_roots_buffer = undefined, + .analysis_roots_len = 0, }; try zcu.init(options.thread_pool.getIdCount()); break :blk zcu; @@ -2933,22 +2935,26 @@ pub fn update(comp: *Compilation, main_progress_node: std.Progress.Node) !void { try comp.appendFileSystemInput(embed_file.path); } - zcu.analysis_roots.clear(); + zcu.analysis_roots_len = 0; - zcu.analysis_roots.appendAssumeCapacity(zcu.std_mod); + zcu.analysis_roots_buffer[zcu.analysis_roots_len] = zcu.std_mod; + zcu.analysis_roots_len += 1; // Normally we rely on importing std to in turn import the root source file in the start code. // However, the main module is distinct from the root module in tests, so that won't happen there. if (comp.config.is_test and zcu.main_mod != zcu.std_mod) { - zcu.analysis_roots.appendAssumeCapacity(zcu.main_mod); + zcu.analysis_roots_buffer[zcu.analysis_roots_len] = zcu.main_mod; + zcu.analysis_roots_len += 1; } if (zcu.root_mod.deps.get("compiler_rt")) |compiler_rt_mod| { - zcu.analysis_roots.appendAssumeCapacity(compiler_rt_mod); + zcu.analysis_roots_buffer[zcu.analysis_roots_len] = compiler_rt_mod; + zcu.analysis_roots_len += 1; } if (zcu.root_mod.deps.get("ubsan_rt")) |ubsan_rt_mod| { - zcu.analysis_roots.appendAssumeCapacity(ubsan_rt_mod); + zcu.analysis_roots_buffer[zcu.analysis_roots_len] = ubsan_rt_mod; + zcu.analysis_roots_len += 1; } } @@ -4745,7 +4751,7 @@ fn performAllTheWork( try zcu.flushRetryableFailures(); // It's analysis time! Queue up our initial analysis. - for (zcu.analysis_roots.slice()) |mod| { + for (zcu.analysisRoots()) |mod| { try comp.queueJob(.{ .analyze_mod = mod }); } diff --git a/src/Zcu.zig b/src/Zcu.zig index c444d57bf6..d77a6edc31 100644 --- a/src/Zcu.zig +++ b/src/Zcu.zig @@ -268,7 +268,8 @@ nav_val_analysis_queued: std.AutoArrayHashMapUnmanaged(InternPool.Nav.Index, voi /// These are the modules which we initially queue for analysis in `Compilation.update`. /// `resolveReferences` will use these as the root of its reachability traversal. -analysis_roots: std.BoundedArray(*Package.Module, 4) = .{}, +analysis_roots_buffer: [4]*Package.Module, +analysis_roots_len: usize = 0, /// This is the cached result of `Zcu.resolveReferences`. It is computed on-demand, and /// reset to `null` when any semantic analysis occurs (since this invalidates the data). /// Allocated into `gpa`. @@ -4013,8 +4014,8 @@ fn resolveReferencesInner(zcu: *Zcu) !std.AutoHashMapUnmanaged(AnalUnit, ?Resolv // This is not a sufficient size, but a lower bound. try result.ensureTotalCapacity(gpa, @intCast(zcu.reference_table.count())); - try type_queue.ensureTotalCapacity(gpa, zcu.analysis_roots.len); - for (zcu.analysis_roots.slice()) |mod| { + try type_queue.ensureTotalCapacity(gpa, zcu.analysis_roots_len); + for (zcu.analysisRoots()) |mod| { const file = zcu.module_roots.get(mod).?.unwrap() orelse continue; const root_ty = zcu.fileRootType(file); if (root_ty == .none) continue; @@ -4202,6 +4203,10 @@ fn resolveReferencesInner(zcu: *Zcu) !std.AutoHashMapUnmanaged(AnalUnit, ?Resolv return result; } +pub fn analysisRoots(zcu: *Zcu) []*Package.Module { + return zcu.analysis_roots_buffer[0..zcu.analysis_roots_len]; +} + pub fn fileByIndex(zcu: *const Zcu, file_index: File.Index) *File { return zcu.intern_pool.filePtr(file_index); } diff --git a/src/Zcu/PerThread.zig b/src/Zcu/PerThread.zig index de4be438f5..79ad9f14e9 100644 --- a/src/Zcu/PerThread.zig +++ b/src/Zcu/PerThread.zig @@ -2116,8 +2116,9 @@ pub fn computeAliveFiles(pt: Zcu.PerThread) Allocator.Error!bool { // multi-threaded environment (where things like file indices could differ between compiler runs). // The roots of our file liveness analysis will be the analysis roots. - try zcu.alive_files.ensureTotalCapacity(gpa, zcu.analysis_roots.len); - for (zcu.analysis_roots.slice()) |mod| { + const analysis_roots = zcu.analysisRoots(); + try zcu.alive_files.ensureTotalCapacity(gpa, analysis_roots.len); + for (analysis_roots) |mod| { const file_index = zcu.module_roots.get(mod).?.unwrap() orelse continue; const file = zcu.fileByIndex(file_index); From b6f84c47c4144827ebf96617dbe40aeacd8cc34f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 23:57:51 -0700 Subject: [PATCH 42/70] std.base64: delete encodeFromReaderToWriter this function is wacky, should not have been merged --- lib/std/base64.zig | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/lib/std/base64.zig b/lib/std/base64.zig index 37d3b70c20..15e48b5c51 100644 --- a/lib/std/base64.zig +++ b/lib/std/base64.zig @@ -118,22 +118,6 @@ pub const Base64Encoder = struct { } } - // destWriter must be compatible with std.io.GenericWriter's writeAll interface - // sourceReader must be compatible with `std.io.GenericReader` read interface - pub fn encodeFromReaderToWriter(encoder: *const Base64Encoder, destWriter: anytype, sourceReader: anytype) !void { - while (true) { - var tempSource: [3]u8 = undefined; - const bytesRead = try sourceReader.read(&tempSource); - if (bytesRead == 0) { - break; - } - - var temp: [5]u8 = undefined; - const s = encoder.encode(&temp, tempSource[0..bytesRead]); - try destWriter.writeAll(s); - } - } - /// dest.len must at least be what you get from ::calcSize. pub fn encode(encoder: *const Base64Encoder, dest: []u8, source: []const u8) []const u8 { const out_len = encoder.calcSize(source.len); @@ -525,14 +509,6 @@ fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: [ try codecs.Encoder.encodeWriter(&writer, expected_decoded); try testing.expectEqualSlices(u8, expected_encoded, writer.buffered()); } - { - // reader to writer encode - var stream: std.Io.Reader = .fixed(expected_decoded); - var buffer: [0x100]u8 = undefined; - var writer: std.Io.Writer = .fixed(&buffer); - try codecs.Encoder.encodeFromReaderToWriter(&writer, &stream); - try testing.expectEqualSlices(u8, expected_encoded, writer.buffered()); - } // Base64Decoder { From 8c11ada66caa011523e5c1019f9bb23c2db89231 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 23:59:48 -0700 Subject: [PATCH 43/70] std: delete RingBuffer Progress towards #19231 --- lib/std/RingBuffer.zig | 230 ----------------------------------------- lib/std/std.zig | 1 - 2 files changed, 231 deletions(-) delete mode 100644 lib/std/RingBuffer.zig diff --git a/lib/std/RingBuffer.zig b/lib/std/RingBuffer.zig deleted file mode 100644 index 5bdfb3e070..0000000000 --- a/lib/std/RingBuffer.zig +++ /dev/null @@ -1,230 +0,0 @@ -//! This ring buffer stores read and write indices while being able to utilise -//! the full backing slice by incrementing the indices modulo twice the slice's -//! length and reducing indices modulo the slice's length on slice access. This -//! means that whether the ring buffer is full or empty can be distinguished by -//! looking at the difference between the read and write indices without adding -//! an extra boolean flag or having to reserve a slot in the buffer. -//! -//! This ring buffer has not been implemented with thread safety in mind, and -//! therefore should not be assumed to be suitable for use cases involving -//! separate reader and writer threads. - -const Allocator = @import("std").mem.Allocator; -const assert = @import("std").debug.assert; -const copyForwards = @import("std").mem.copyForwards; - -const RingBuffer = @This(); - -data: []u8, -read_index: usize, -write_index: usize, - -pub const Error = error{ Full, ReadLengthInvalid }; - -/// Allocate a new `RingBuffer`; `deinit()` should be called to free the buffer. -pub fn init(allocator: Allocator, capacity: usize) Allocator.Error!RingBuffer { - const bytes = try allocator.alloc(u8, capacity); - return RingBuffer{ - .data = bytes, - .write_index = 0, - .read_index = 0, - }; -} - -/// Free the data backing a `RingBuffer`; must be passed the same `Allocator` as -/// `init()`. -pub fn deinit(self: *RingBuffer, allocator: Allocator) void { - allocator.free(self.data); - self.* = undefined; -} - -/// Returns `index` modulo the length of the backing slice. -pub fn mask(self: RingBuffer, index: usize) usize { - return index % self.data.len; -} - -/// Returns `index` modulo twice the length of the backing slice. -pub fn mask2(self: RingBuffer, index: usize) usize { - return index % (2 * self.data.len); -} - -/// Write `byte` into the ring buffer. Returns `error.Full` if the ring -/// buffer is full. -pub fn write(self: *RingBuffer, byte: u8) Error!void { - if (self.isFull()) return error.Full; - self.writeAssumeCapacity(byte); -} - -/// Write `byte` into the ring buffer. If the ring buffer is full, the -/// oldest byte is overwritten. -pub fn writeAssumeCapacity(self: *RingBuffer, byte: u8) void { - self.data[self.mask(self.write_index)] = byte; - self.write_index = self.mask2(self.write_index + 1); -} - -/// Write `bytes` into the ring buffer. Returns `error.Full` if the ring -/// buffer does not have enough space, without writing any data. -/// Uses memcpy and so `bytes` must not overlap ring buffer data. -pub fn writeSlice(self: *RingBuffer, bytes: []const u8) Error!void { - if (self.len() + bytes.len > self.data.len) return error.Full; - self.writeSliceAssumeCapacity(bytes); -} - -/// Write `bytes` into the ring buffer. If there is not enough space, older -/// bytes will be overwritten. -/// Uses memcpy and so `bytes` must not overlap ring buffer data. -pub fn writeSliceAssumeCapacity(self: *RingBuffer, bytes: []const u8) void { - assert(bytes.len <= self.data.len); - const data_start = self.mask(self.write_index); - const part1_data_end = @min(data_start + bytes.len, self.data.len); - const part1_len = part1_data_end - data_start; - @memcpy(self.data[data_start..part1_data_end], bytes[0..part1_len]); - - const remaining = bytes.len - part1_len; - const to_write = @min(remaining, remaining % self.data.len + self.data.len); - const part2_bytes_start = bytes.len - to_write; - const part2_bytes_end = @min(part2_bytes_start + self.data.len, bytes.len); - const part2_len = part2_bytes_end - part2_bytes_start; - @memcpy(self.data[0..part2_len], bytes[part2_bytes_start..part2_bytes_end]); - if (part2_bytes_end != bytes.len) { - const part3_len = bytes.len - part2_bytes_end; - @memcpy(self.data[0..part3_len], bytes[part2_bytes_end..bytes.len]); - } - self.write_index = self.mask2(self.write_index + bytes.len); -} - -/// Write `bytes` into the ring buffer. Returns `error.Full` if the ring -/// buffer does not have enough space, without writing any data. -/// Uses copyForwards and can write slices from this RingBuffer into itself. -pub fn writeSliceForwards(self: *RingBuffer, bytes: []const u8) Error!void { - if (self.len() + bytes.len > self.data.len) return error.Full; - self.writeSliceForwardsAssumeCapacity(bytes); -} - -/// Write `bytes` into the ring buffer. If there is not enough space, older -/// bytes will be overwritten. -/// Uses copyForwards and can write slices from this RingBuffer into itself. -pub fn writeSliceForwardsAssumeCapacity(self: *RingBuffer, bytes: []const u8) void { - assert(bytes.len <= self.data.len); - const data_start = self.mask(self.write_index); - const part1_data_end = @min(data_start + bytes.len, self.data.len); - const part1_len = part1_data_end - data_start; - copyForwards(u8, self.data[data_start..], bytes[0..part1_len]); - - const remaining = bytes.len - part1_len; - const to_write = @min(remaining, remaining % self.data.len + self.data.len); - const part2_bytes_start = bytes.len - to_write; - const part2_bytes_end = @min(part2_bytes_start + self.data.len, bytes.len); - copyForwards(u8, self.data[0..], bytes[part2_bytes_start..part2_bytes_end]); - if (part2_bytes_end != bytes.len) - copyForwards(u8, self.data[0..], bytes[part2_bytes_end..bytes.len]); - self.write_index = self.mask2(self.write_index + bytes.len); -} - -/// Consume a byte from the ring buffer and return it. Returns `null` if the -/// ring buffer is empty. -pub fn read(self: *RingBuffer) ?u8 { - if (self.isEmpty()) return null; - return self.readAssumeLength(); -} - -/// Consume a byte from the ring buffer and return it; asserts that the buffer -/// is not empty. -pub fn readAssumeLength(self: *RingBuffer) u8 { - assert(!self.isEmpty()); - const byte = self.data[self.mask(self.read_index)]; - self.read_index = self.mask2(self.read_index + 1); - return byte; -} - -/// Reads first `length` bytes written to the ring buffer into `dest`; Returns -/// Error.ReadLengthInvalid if length greater than ring or dest length -/// Uses memcpy and so `dest` must not overlap ring buffer data. -pub fn readFirst(self: *RingBuffer, dest: []u8, length: usize) Error!void { - if (length > self.len() or length > dest.len) return error.ReadLengthInvalid; - self.readFirstAssumeLength(dest, length); -} - -/// Reads first `length` bytes written to the ring buffer into `dest`; -/// Asserts that length not greater than ring buffer or dest length -/// Uses memcpy and so `dest` must not overlap ring buffer data. -pub fn readFirstAssumeLength(self: *RingBuffer, dest: []u8, length: usize) void { - assert(length <= self.len() and length <= dest.len); - const slice = self.sliceAt(self.read_index, length); - slice.copyTo(dest); - self.read_index = self.mask2(self.read_index + length); -} - -/// Reads last `length` bytes written to the ring buffer into `dest`; Returns -/// Error.ReadLengthInvalid if length greater than ring or dest length -/// Uses memcpy and so `dest` must not overlap ring buffer data. -/// Reduces write index by `length`. -pub fn readLast(self: *RingBuffer, dest: []u8, length: usize) Error!void { - if (length > self.len() or length > dest.len) return error.ReadLengthInvalid; - self.readLastAssumeLength(dest, length); -} - -/// Reads last `length` bytes written to the ring buffer into `dest`; -/// Asserts that length not greater than ring buffer or dest length -/// Uses memcpy and so `dest` must not overlap ring buffer data. -/// Reduces write index by `length`. -pub fn readLastAssumeLength(self: *RingBuffer, dest: []u8, length: usize) void { - assert(length <= self.len() and length <= dest.len); - const slice = self.sliceLast(length); - slice.copyTo(dest); - self.write_index = if (self.write_index >= self.data.len) - self.write_index - length - else - self.mask(self.write_index + self.data.len - length); -} - -/// Returns `true` if the ring buffer is empty and `false` otherwise. -pub fn isEmpty(self: RingBuffer) bool { - return self.write_index == self.read_index; -} - -/// Returns `true` if the ring buffer is full and `false` otherwise. -pub fn isFull(self: RingBuffer) bool { - return self.mask2(self.write_index + self.data.len) == self.read_index; -} - -/// Returns the length of data available for reading -pub fn len(self: RingBuffer) usize { - const wrap_offset = 2 * self.data.len * @intFromBool(self.write_index < self.read_index); - const adjusted_write_index = self.write_index + wrap_offset; - return adjusted_write_index - self.read_index; -} - -/// A `Slice` represents a region of a ring buffer. The region is split into two -/// sections as the ring buffer data will not be contiguous if the desired -/// region wraps to the start of the backing slice. -pub const Slice = struct { - first: []u8, - second: []u8, - - /// Copy data from `self` into `dest` - pub fn copyTo(self: Slice, dest: []u8) void { - @memcpy(dest[0..self.first.len], self.first); - @memcpy(dest[self.first.len..][0..self.second.len], self.second); - } -}; - -/// Returns a `Slice` for the region of the ring buffer starting at -/// `self.mask(start_unmasked)` with the specified length. -pub fn sliceAt(self: RingBuffer, start_unmasked: usize, length: usize) Slice { - assert(length <= self.data.len); - const slice1_start = self.mask(start_unmasked); - const slice1_end = @min(self.data.len, slice1_start + length); - const slice1 = self.data[slice1_start..slice1_end]; - const slice2 = self.data[0 .. length - slice1.len]; - return Slice{ - .first = slice1, - .second = slice2, - }; -} - -/// Returns a `Slice` for the last `length` bytes written to the ring buffer. -/// Does not check that any bytes have been written into the region. -pub fn sliceLast(self: RingBuffer, length: usize) Slice { - return self.sliceAt(self.write_index + self.data.len - length, length); -} diff --git a/lib/std/std.zig b/lib/std/std.zig index 891c0bc256..aaae4c2eba 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -29,7 +29,6 @@ pub const PriorityQueue = @import("priority_queue.zig").PriorityQueue; pub const PriorityDequeue = @import("priority_dequeue.zig").PriorityDequeue; pub const Progress = @import("Progress.zig"); pub const Random = @import("Random.zig"); -pub const RingBuffer = @import("RingBuffer.zig"); pub const SegmentedList = @import("segmented_list.zig").SegmentedList; pub const SemanticVersion = @import("SemanticVersion.zig"); pub const SinglyLinkedList = @import("SinglyLinkedList.zig"); From 3914eaf3571949718bcd986ab8129b3c9f39b1d0 Mon Sep 17 00:00:00 2001 From: Giuseppe Cesarano Date: Tue, 5 Aug 2025 19:00:33 +0200 Subject: [PATCH 44/70] std.elf: buffer header iterator API (#24691) Closes #24666. --- lib/std/elf.zig | 130 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 38 deletions(-) diff --git a/lib/std/elf.zig b/lib/std/elf.zig index 2583e83d19..0be6e6d0a2 100644 --- a/lib/std/elf.zig +++ b/lib/std/elf.zig @@ -502,6 +502,13 @@ pub const Header = struct { }; } + pub fn iterateProgramHeadersBuffer(h: Header, buf: []const u8) ProgramHeaderBufferIterator { + return .{ + .elf_header = h, + .buf = buf, + }; + } + pub fn iterateSectionHeaders(h: Header, file_reader: *std.fs.File.Reader) SectionHeaderIterator { return .{ .elf_header = h, @@ -509,6 +516,13 @@ pub const Header = struct { }; } + pub fn iterateSectionHeadersBuffer(h: Header, buf: []const u8) SectionHeaderBufferIterator { + return .{ + .elf_header = h, + .buf = buf, + }; + } + pub const ReadError = std.Io.Reader.Error || error{ InvalidElfMagic, InvalidElfVersion, @@ -570,29 +584,48 @@ pub const ProgramHeaderIterator = struct { if (it.index >= it.elf_header.phnum) return null; defer it.index += 1; - if (it.elf_header.is_64) { - const offset = it.elf_header.phoff + @sizeOf(Elf64_Phdr) * it.index; - try it.file_reader.seekTo(offset); - const phdr = try it.file_reader.interface.takeStruct(Elf64_Phdr, it.elf_header.endian); - return phdr; - } - - const offset = it.elf_header.phoff + @sizeOf(Elf32_Phdr) * it.index; + const offset = it.elf_header.phoff + if (it.elf_header.is_64) @sizeOf(Elf64_Phdr) else @sizeOf(Elf32_Phdr) * it.index; try it.file_reader.seekTo(offset); - const phdr = try it.file_reader.interface.takeStruct(Elf32_Phdr, it.elf_header.endian); - return .{ - .p_type = phdr.p_type, - .p_offset = phdr.p_offset, - .p_vaddr = phdr.p_vaddr, - .p_paddr = phdr.p_paddr, - .p_filesz = phdr.p_filesz, - .p_memsz = phdr.p_memsz, - .p_flags = phdr.p_flags, - .p_align = phdr.p_align, - }; + + return takePhdr(&it.file_reader.interface, it.elf_header); } }; +pub const ProgramHeaderBufferIterator = struct { + elf_header: Header, + buf: []const u8, + index: usize = 0, + + pub fn next(it: *ProgramHeaderBufferIterator) !?Elf64_Phdr { + if (it.index >= it.elf_header.phnum) return null; + defer it.index += 1; + + const offset = it.elf_header.phoff + if (it.elf_header.is_64) @sizeOf(Elf64_Phdr) else @sizeOf(Elf32_Phdr) * it.index; + var reader = std.Io.Reader.fixed(it.buf[offset..]); + + return takePhdr(&reader, it.elf_header); + } +}; + +fn takePhdr(reader: *std.io.Reader, elf_header: Header) !?Elf64_Phdr { + if (elf_header.is_64) { + const phdr = try reader.takeStruct(Elf64_Phdr, elf_header.endian); + return phdr; + } + + const phdr = try reader.takeStruct(Elf32_Phdr, elf_header.endian); + return .{ + .p_type = phdr.p_type, + .p_offset = phdr.p_offset, + .p_vaddr = phdr.p_vaddr, + .p_paddr = phdr.p_paddr, + .p_filesz = phdr.p_filesz, + .p_memsz = phdr.p_memsz, + .p_flags = phdr.p_flags, + .p_align = phdr.p_align, + }; +} + pub const SectionHeaderIterator = struct { elf_header: Header, file_reader: *std.fs.File.Reader, @@ -602,29 +635,50 @@ pub const SectionHeaderIterator = struct { if (it.index >= it.elf_header.shnum) return null; defer it.index += 1; - if (it.elf_header.is_64) { - try it.file_reader.seekTo(it.elf_header.shoff + @sizeOf(Elf64_Shdr) * it.index); - const shdr = try it.file_reader.interface.takeStruct(Elf64_Shdr, it.elf_header.endian); - return shdr; - } + const offset = it.elf_header.shoff + if (it.elf_header.is_64) @sizeOf(Elf64_Shdr) else @sizeOf(Elf32_Shdr) * it.index; + try it.file_reader.seekTo(offset); - try it.file_reader.seekTo(it.elf_header.shoff + @sizeOf(Elf32_Shdr) * it.index); - const shdr = try it.file_reader.interface.takeStruct(Elf32_Shdr, it.elf_header.endian); - return .{ - .sh_name = shdr.sh_name, - .sh_type = shdr.sh_type, - .sh_flags = shdr.sh_flags, - .sh_addr = shdr.sh_addr, - .sh_offset = shdr.sh_offset, - .sh_size = shdr.sh_size, - .sh_link = shdr.sh_link, - .sh_info = shdr.sh_info, - .sh_addralign = shdr.sh_addralign, - .sh_entsize = shdr.sh_entsize, - }; + return takeShdr(&it.file_reader.interface, it.elf_header); } }; +pub const SectionHeaderBufferIterator = struct { + elf_header: Header, + buf: []const u8, + index: usize = 0, + + pub fn next(it: *SectionHeaderBufferIterator) !?Elf64_Shdr { + if (it.index >= it.elf_header.shnum) return null; + defer it.index += 1; + + const offset = it.elf_header.shoff + if (it.elf_header.is_64) @sizeOf(Elf64_Shdr) else @sizeOf(Elf32_Shdr) * it.index; + var reader = std.Io.Reader.fixed(it.buf[offset..]); + + return takeShdr(&reader, it.elf_header); + } +}; + +fn takeShdr(reader: *std.Io.Reader, elf_header: Header) !?Elf64_Shdr { + if (elf_header.is_64) { + const shdr = try reader.takeStruct(Elf64_Shdr, elf_header.endian); + return shdr; + } + + const shdr = try reader.takeStruct(Elf32_Shdr, elf_header.endian); + return .{ + .sh_name = shdr.sh_name, + .sh_type = shdr.sh_type, + .sh_flags = shdr.sh_flags, + .sh_addr = shdr.sh_addr, + .sh_offset = shdr.sh_offset, + .sh_size = shdr.sh_size, + .sh_link = shdr.sh_link, + .sh_info = shdr.sh_info, + .sh_addralign = shdr.sh_addralign, + .sh_entsize = shdr.sh_entsize, + }; +} + pub const ELFCLASSNONE = 0; pub const ELFCLASS32 = 1; pub const ELFCLASS64 = 2; From d2149106a6c03301d0467283814db86be030b335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Igor=20Anic=CC=81?= Date: Tue, 5 Aug 2025 18:47:22 +0200 Subject: [PATCH 45/70] flate zlib fix end of block reading `n` is wanted number of bits to toss `buffered_n` is actual number of bytes in `next_int` --- lib/std/compress/flate/Decompress.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig index da57d56ab8..7c58905ee1 100644 --- a/lib/std/compress/flate/Decompress.zig +++ b/lib/std/compress/flate/Decompress.zig @@ -603,7 +603,7 @@ fn tossBitsEnding(d: *Decompress, n: u4) !void { error.EndOfStream => unreachable, }; d.next_bits = next_int >> needed_bits; - d.remaining_bits = @intCast(@as(usize, n) * 8 -| @as(usize, needed_bits)); + d.remaining_bits = @intCast(@as(usize, buffered_n) * 8 -| @as(usize, needed_bits)); } fn takeBitsRuntime(d: *Decompress, n: u4) !u16 { From 9a158c1dae531f2a4e5667569bed38c27cbd4d57 Mon Sep 17 00:00:00 2001 From: massi Date: Tue, 5 Aug 2025 20:23:24 -0700 Subject: [PATCH 46/70] autodoc: Use the search input's value on load (#24467) Co-authored-by: massi --- lib/docs/main.js | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/lib/docs/main.js b/lib/docs/main.js index 330b51ba49..71a14516cb 100644 --- a/lib/docs/main.js +++ b/lib/docs/main.js @@ -129,6 +129,11 @@ domSearch.addEventListener('input', onSearchChange, false); window.addEventListener('keydown', onWindowKeyDown, false); onHashChange(null); + if (domSearch.value) { + // user started typing a search query while the page was loading + curSearchIndex = -1; + startAsyncSearch(); + } }); }); @@ -643,6 +648,7 @@ } function onHashChange(state) { + // Use a non-null state value to prevent the window scrolling if the user goes back to this history entry. history.replaceState({}, ""); navigate(location.hash); if (state == null) window.scrollTo({top: 0}); @@ -650,13 +656,11 @@ function onPopState(ev) { onHashChange(ev.state); + syncDomSearch(); } function navigate(location_hash) { updateCurNav(location_hash); - if (domSearch.value !== curNavSearch) { - domSearch.value = curNavSearch; - } render(); if (imFeelingLucky) { imFeelingLucky = false; @@ -664,6 +668,12 @@ } } + function syncDomSearch() { + if (domSearch.value !== curNavSearch) { + domSearch.value = curNavSearch; + } + } + function activateSelectedResult() { if (domSectSearchResults.classList.contains("hidden")) { return; From 7ee6dab39fac7aa12fa9fd952bb2bdc28d5eabe8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Tue, 5 Aug 2025 22:32:35 +0200 Subject: [PATCH 47/70] Revert "Sema: Stop adding Windows implib link inputs for `extern "..."` syntax." This reverts commit b461d07a5464aec86c533434dab0b58edfffb331. After some discussion in the team, we've decided that this is too disruptive, especially because the linker errors are less than helpful. That's a fixable problem, so we might reconsider this in the future, but revert it for now. --- build.zig | 20 ----------- src/Compilation.zig | 39 ++++++++++++++++------ src/Sema.zig | 13 ++++++++ src/libs/mingw.zig | 3 +- src/main.zig | 24 ++----------- test/standalone/simple/build.zig | 8 ----- test/standalone/windows_argv/build.zig | 2 -- test/standalone/windows_bat_args/build.zig | 4 --- test/standalone/windows_spawn/build.zig | 2 -- test/tests.zig | 9 ----- 10 files changed, 46 insertions(+), 78 deletions(-) diff --git a/build.zig b/build.zig index 95c97f29b6..ace08e740b 100644 --- a/build.zig +++ b/build.zig @@ -452,7 +452,6 @@ pub fn build(b: *std.Build) !void { .desc = "Run the behavior tests", .optimize_modes = optimization_modes, .include_paths = &.{}, - .windows_libs = &.{}, .skip_single_threaded = skip_single_threaded, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, @@ -475,7 +474,6 @@ pub fn build(b: *std.Build) !void { .desc = "Run the @cImport tests", .optimize_modes = optimization_modes, .include_paths = &.{"test/c_import"}, - .windows_libs = &.{}, .skip_single_threaded = true, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, @@ -496,7 +494,6 @@ pub fn build(b: *std.Build) !void { .desc = "Run the compiler_rt tests", .optimize_modes = optimization_modes, .include_paths = &.{}, - .windows_libs = &.{}, .skip_single_threaded = true, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, @@ -518,7 +515,6 @@ pub fn build(b: *std.Build) !void { .desc = "Run the zigc tests", .optimize_modes = optimization_modes, .include_paths = &.{}, - .windows_libs = &.{}, .skip_single_threaded = true, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, @@ -540,12 +536,6 @@ pub fn build(b: *std.Build) !void { .desc = "Run the standard library tests", .optimize_modes = optimization_modes, .include_paths = &.{}, - .windows_libs = &.{ - "advapi32", - "crypt32", - "iphlpapi", - "ws2_32", - }, .skip_single_threaded = skip_single_threaded, .skip_non_native = skip_non_native, .skip_freebsd = skip_freebsd, @@ -743,12 +733,6 @@ fn addCompilerMod(b: *std.Build, options: AddCompilerModOptions) *std.Build.Modu compiler_mod.addImport("aro", aro_mod); compiler_mod.addImport("aro_translate_c", aro_translate_c_mod); - if (options.target.result.os.tag == .windows) { - compiler_mod.linkSystemLibrary("advapi32", .{}); - compiler_mod.linkSystemLibrary("crypt32", .{}); - compiler_mod.linkSystemLibrary("ws2_32", .{}); - } - return compiler_mod; } @@ -1446,10 +1430,6 @@ fn generateLangRef(b: *std.Build) std.Build.LazyPath { }), }); - if (b.graph.host.result.os.tag == .windows) { - doctest_exe.root_module.linkSystemLibrary("advapi32", .{}); - } - var dir = b.build_root.handle.openDir("doc/langref", .{ .iterate = true }) catch |err| { std.debug.panic("unable to open '{f}doc/langref' directory: {s}", .{ b.build_root, @errorName(err), diff --git a/src/Compilation.zig b/src/Compilation.zig index 4024e0a49e..ae3ab118b7 100644 --- a/src/Compilation.zig +++ b/src/Compilation.zig @@ -2185,8 +2185,12 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil .emit_docs = try options.emit_docs.resolve(arena, &options, .docs), }; - comp.windows_libs = try std.StringArrayHashMapUnmanaged(void).init(gpa, options.windows_lib_names, &.{}); - errdefer comp.windows_libs.deinit(gpa); + errdefer { + for (comp.windows_libs.keys()) |windows_lib| gpa.free(windows_lib); + comp.windows_libs.deinit(gpa); + } + try comp.windows_libs.ensureUnusedCapacity(gpa, options.windows_lib_names.len); + for (options.windows_lib_names) |windows_lib| comp.windows_libs.putAssumeCapacity(try gpa.dupe(u8, windows_lib), {}); // Prevent some footguns by making the "any" fields of config reflect // the default Module settings. @@ -2417,13 +2421,6 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil if (comp.emit_bin != null and target.ofmt != .c) { if (!comp.skip_linker_dependencies) { - // These DLLs are always loaded into every Windows process. - if (target.os.tag == .windows and is_exe_or_dyn_lib) { - try comp.windows_libs.ensureUnusedCapacity(gpa, 2); - comp.windows_libs.putAssumeCapacity("kernel32", {}); - comp.windows_libs.putAssumeCapacity("ntdll", {}); - } - // If we need to build libc for the target, add work items for it. // We go through the work queue so that building can be done in parallel. // If linking against host libc installation, instead queue up jobs @@ -2512,7 +2509,7 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil // When linking mingw-w64 there are some import libs we always need. try comp.windows_libs.ensureUnusedCapacity(gpa, mingw.always_link_libs.len); - for (mingw.always_link_libs) |name| comp.windows_libs.putAssumeCapacity(name, {}); + for (mingw.always_link_libs) |name| comp.windows_libs.putAssumeCapacity(try gpa.dupe(u8, name), {}); } else { return error.LibCUnavailable; } @@ -2610,6 +2607,7 @@ pub fn destroy(comp: *Compilation) void { comp.c_object_work_queue.deinit(); comp.win32_resource_work_queue.deinit(); + for (comp.windows_libs.keys()) |windows_lib| gpa.free(windows_lib); comp.windows_libs.deinit(gpa); { @@ -7795,6 +7793,27 @@ fn getCrtPathsInner( }; } +pub fn addLinkLib(comp: *Compilation, lib_name: []const u8) !void { + // Avoid deadlocking on building import libs such as kernel32.lib + // This can happen when the user uses `build-exe foo.obj -lkernel32` and + // then when we create a sub-Compilation for zig libc, it also tries to + // build kernel32.lib. + if (comp.skip_linker_dependencies) return; + const target = &comp.root_mod.resolved_target.result; + if (target.os.tag != .windows or target.ofmt == .c) return; + + // This happens when an `extern "foo"` function is referenced. + // If we haven't seen this library yet and we're targeting Windows, we need + // to queue up a work item to produce the DLL import library for this. + const gop = try comp.windows_libs.getOrPut(comp.gpa, lib_name); + if (gop.found_existing) return; + { + errdefer _ = comp.windows_libs.pop(); + gop.key_ptr.* = try comp.gpa.dupe(u8, lib_name); + } + try comp.queueJob(.{ .windows_import_lib = gop.index }); +} + /// This decides the optimization mode for all zig-provided libraries, including /// compiler-rt, libcxx, libc, libunwind, etc. pub fn compilerRtOptMode(comp: Compilation) std.builtin.OptimizeMode { diff --git a/src/Sema.zig b/src/Sema.zig index 5816990eb2..41e1444420 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -8906,6 +8906,14 @@ fn resolveGenericBody( return sema.resolveConstDefinedValue(block, src, result, reason); } +/// Given a library name, examines if the library name should end up in +/// `link.File.Options.windows_libs` table (for example, libc is always +/// specified via dedicated flag `link_libc` instead), +/// and puts it there if it doesn't exist. +/// It also dupes the library name which can then be saved as part of the +/// respective `Decl` (either `ExternFn` or `Var`). +/// The liveness of the duped library name is tied to liveness of `Zcu`. +/// To deallocate, call `deinit` on the respective `Decl` (`ExternFn` or `Var`). pub fn handleExternLibName( sema: *Sema, block: *Block, @@ -8955,6 +8963,11 @@ pub fn handleExternLibName( .{ lib_name, lib_name }, ); } + comp.addLinkLib(lib_name) catch |err| { + return sema.fail(block, src_loc, "unable to add link lib '{s}': {s}", .{ + lib_name, @errorName(err), + }); + }; } } diff --git a/src/libs/mingw.zig b/src/libs/mingw.zig index 1c2927eba0..5603a15a57 100644 --- a/src/libs/mingw.zig +++ b/src/libs/mingw.zig @@ -1012,7 +1012,6 @@ const mingw32_winpthreads_src = [_][]const u8{ "winpthreads" ++ path.sep_str ++ "thread.c", }; -// Note: kernel32 and ntdll are always linked even without targeting MinGW-w64. pub const always_link_libs = [_][]const u8{ "api-ms-win-crt-conio-l1-1-0", "api-ms-win-crt-convert-l1-1-0", @@ -1030,6 +1029,8 @@ pub const always_link_libs = [_][]const u8{ "api-ms-win-crt-time-l1-1-0", "api-ms-win-crt-utility-l1-1-0", "advapi32", + "kernel32", + "ntdll", "shell32", "user32", }; diff --git a/src/main.zig b/src/main.zig index 0efa54ac1e..5bc6c2e923 100644 --- a/src/main.zig +++ b/src/main.zig @@ -312,7 +312,6 @@ fn mainArgs(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { return jitCmd(gpa, arena, cmd_args, .{ .cmd_name = "resinator", .root_src_path = "resinator/main.zig", - .windows_libs = &.{"advapi32"}, .depend_on_aro = true, .prepend_zig_lib_dir_path = true, .server = use_server, @@ -337,7 +336,6 @@ fn mainArgs(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { return jitCmd(gpa, arena, cmd_args, .{ .cmd_name = "std", .root_src_path = "std-docs.zig", - .windows_libs = &.{"ws2_32"}, .prepend_zig_lib_dir_path = true, .prepend_zig_exe_path = true, .prepend_global_cache_path = true, @@ -3659,6 +3657,7 @@ fn buildOutputType( } else if (target.os.tag == .windows) { try test_exec_args.appendSlice(arena, &.{ "--subsystem", "console", + "-lkernel32", "-lntdll", }); } @@ -3862,8 +3861,7 @@ fn createModule( .only_compiler_rt => continue, } - // We currently prefer import libraries provided by MinGW-w64 even for MSVC. - if (target.os.tag == .windows) { + if (target.isMinGW()) { const exists = mingw.libExists(arena, target, create_module.dirs.zig_lib, lib_name) catch |err| { fatal("failed to check zig installation for DLL import libs: {s}", .{ @errorName(err), @@ -5375,14 +5373,6 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { try root_mod.deps.put(arena, "@build", build_mod); - var windows_libs: std.StringArrayHashMapUnmanaged(void) = .empty; - - if (resolved_target.result.os.tag == .windows) { - try windows_libs.ensureUnusedCapacity(arena, 2); - windows_libs.putAssumeCapacity("advapi32", {}); - windows_libs.putAssumeCapacity("ws2_32", {}); // for `--listen` (web interface) - } - const comp = Compilation.create(gpa, arena, .{ .libc_installation = libc_installation, .dirs = dirs, @@ -5405,7 +5395,6 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { .cache_mode = .whole, .reference_trace = reference_trace, .debug_compile_errors = debug_compile_errors, - .windows_lib_names = windows_libs.keys(), }) catch |err| { fatal("unable to create compilation: {s}", .{@errorName(err)}); }; @@ -5509,7 +5498,6 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void { const JitCmdOptions = struct { cmd_name: []const u8, root_src_path: []const u8, - windows_libs: []const []const u8 = &.{}, prepend_zig_lib_dir_path: bool = false, prepend_global_cache_path: bool = false, prepend_zig_exe_path: bool = false, @@ -5626,13 +5614,6 @@ fn jitCmd( try root_mod.deps.put(arena, "aro", aro_mod); } - var windows_libs: std.StringArrayHashMapUnmanaged(void) = .empty; - - if (resolved_target.result.os.tag == .windows) { - try windows_libs.ensureUnusedCapacity(arena, options.windows_libs.len); - for (options.windows_libs) |lib| windows_libs.putAssumeCapacity(lib, {}); - } - const comp = Compilation.create(gpa, arena, .{ .dirs = dirs, .root_name = options.cmd_name, @@ -5643,7 +5624,6 @@ fn jitCmd( .self_exe_path = self_exe_path, .thread_pool = &thread_pool, .cache_mode = .whole, - .windows_lib_names = windows_libs.keys(), }) catch |err| { fatal("unable to create compilation: {s}", .{@errorName(err)}); }; diff --git a/test/standalone/simple/build.zig b/test/standalone/simple/build.zig index 51d9b3a9b1..5bd76a3540 100644 --- a/test/standalone/simple/build.zig +++ b/test/standalone/simple/build.zig @@ -50,10 +50,6 @@ pub fn build(b: *std.Build) void { }); if (case.link_libc) exe.root_module.link_libc = true; - if (resolved_target.result.os.tag == .windows) { - exe.root_module.linkSystemLibrary("advapi32", .{}); - } - _ = exe.getEmittedBin(); step.dependOn(&exe.step); @@ -70,10 +66,6 @@ pub fn build(b: *std.Build) void { }); if (case.link_libc) exe.root_module.link_libc = true; - if (resolved_target.result.os.tag == .windows) { - exe.root_module.linkSystemLibrary("advapi32", .{}); - } - const run = b.addRunArtifact(exe); step.dependOn(&run.step); } diff --git a/test/standalone/windows_argv/build.zig b/test/standalone/windows_argv/build.zig index 9ace58088a..df988d2371 100644 --- a/test/standalone/windows_argv/build.zig +++ b/test/standalone/windows_argv/build.zig @@ -47,8 +47,6 @@ pub fn build(b: *std.Build) !void { }), }); - fuzz.root_module.linkSystemLibrary("advapi32", .{}); - const fuzz_max_iterations = b.option(u64, "iterations", "The max fuzz iterations (default: 100)") orelse 100; const fuzz_iterations_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_max_iterations}) catch @panic("oom"); diff --git a/test/standalone/windows_bat_args/build.zig b/test/standalone/windows_bat_args/build.zig index e30f5b36c1..b91b4fb3bb 100644 --- a/test/standalone/windows_bat_args/build.zig +++ b/test/standalone/windows_bat_args/build.zig @@ -28,8 +28,6 @@ pub fn build(b: *std.Build) !void { }), }); - test_exe.root_module.linkSystemLibrary("advapi32", .{}); - const run = b.addRunArtifact(test_exe); run.addArtifactArg(echo_args); run.expectExitCode(0); @@ -46,8 +44,6 @@ pub fn build(b: *std.Build) !void { }), }); - fuzz.root_module.linkSystemLibrary("advapi32", .{}); - const fuzz_max_iterations = b.option(u64, "iterations", "The max fuzz iterations (default: 100)") orelse 100; const fuzz_iterations_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_max_iterations}) catch @panic("oom"); diff --git a/test/standalone/windows_spawn/build.zig b/test/standalone/windows_spawn/build.zig index 628b1a6900..2b967b16d5 100644 --- a/test/standalone/windows_spawn/build.zig +++ b/test/standalone/windows_spawn/build.zig @@ -28,8 +28,6 @@ pub fn build(b: *std.Build) void { }), }); - main.root_module.linkSystemLibrary("advapi32", .{}); - const run = b.addRunArtifact(main); run.addArtifactArg(hello); run.expectExitCode(0); diff --git a/test/tests.zig b/test/tests.zig index a12312d278..1e5db67f07 100644 --- a/test/tests.zig +++ b/test/tests.zig @@ -2238,7 +2238,6 @@ const ModuleTestOptions = struct { desc: []const u8, optimize_modes: []const OptimizeMode, include_paths: []const []const u8, - windows_libs: []const []const u8, skip_single_threaded: bool, skip_non_native: bool, skip_freebsd: bool, @@ -2373,10 +2372,6 @@ pub fn addModuleTests(b: *std.Build, options: ModuleTestOptions) *Step { for (options.include_paths) |include_path| these_tests.root_module.addIncludePath(b.path(include_path)); - if (target.os.tag == .windows) { - for (options.windows_libs) |lib| these_tests.root_module.linkSystemLibrary(lib, .{}); - } - const qualified_name = b.fmt("{s}-{s}-{s}-{s}{s}{s}{s}{s}{s}{s}", .{ options.name, triple_txt, @@ -2672,10 +2667,6 @@ pub fn addIncrementalTests(b: *std.Build, test_step: *Step) !void { }), }); - if (b.graph.host.result.os.tag == .windows) { - incr_check.root_module.linkSystemLibrary("advapi32", .{}); - } - var dir = try b.build_root.handle.openDir("test/incremental", .{ .iterate = true }); defer dir.close(); From 44ea11d71f6a713081e3dec11a08ec322dfeb787 Mon Sep 17 00:00:00 2001 From: "kj4tmp@gmail.com" Date: Tue, 15 Jul 2025 22:30:58 -0700 Subject: [PATCH 48/70] #24471: add mlock syscalls to std.os.linux --- lib/std/os/linux.zig | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/lib/std/os/linux.zig b/lib/std/os/linux.zig index a02451c0fd..9242c33ece 100644 --- a/lib/std/os/linux.zig +++ b/lib/std/os/linux.zig @@ -1014,6 +1014,44 @@ pub fn munmap(address: [*]const u8, length: usize) usize { return syscall2(.munmap, @intFromPtr(address), length); } +pub fn mlock(address: [*]const u8, length: usize) usize { + return syscall2(.mlock, @intFromPtr(address), length); +} + +pub fn munlock(address: [*]const u8, length: usize) usize { + return syscall2(.munlock, @intFromPtr(address), length); +} + +pub const MLOCK = packed struct(u32) { + ONFAULT: bool = false, + _1: u31 = 0, +}; + +pub fn mlock2(address: [*]const u8, length: usize, flags: MLOCK) usize { + return syscall3(.mlock2, @intFromPtr(address), length, @as(u32, @bitCast(flags))); +} + +pub const MCL = if (native_arch.isSPARC() or native_arch.isPowerPC()) packed struct(u32) { + _0: u13 = 0, + CURRENT: bool = false, + FUTURE: bool = false, + ONFAULT: bool = false, + _4: u16 = 0, +} else packed struct(u32) { + CURRENT: bool = false, + FUTURE: bool = false, + ONFAULT: bool = false, + _3: u29 = 0, +}; + +pub fn mlockall(flags: MCL) usize { + return syscall1(.mlockall, @as(u32, @bitCast(flags))); +} + +pub fn munlockall() usize { + return syscall0(.munlockall); +} + pub fn poll(fds: [*]pollfd, n: nfds_t, timeout: i32) usize { if (@hasField(SYS, "poll")) { return syscall3(.poll, @intFromPtr(fds), n, @as(u32, @bitCast(timeout))); From 3de8bbd3d4e262df11a582fb52401b8077b5f352 Mon Sep 17 00:00:00 2001 From: mlugg Date: Wed, 6 Aug 2025 15:00:58 +0100 Subject: [PATCH 49/70] Sema: fix initializing comptime-known constant with OPV union field Resolves: #24716 --- src/Sema.zig | 10 +++++----- test/behavior/union.zig | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 41e1444420..8cfcadea66 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -3932,11 +3932,12 @@ fn resolveComptimeKnownAllocPtr(sema: *Sema, block: *Block, alloc: Air.Inst.Ref, // Whilst constructing our mapping, we will also initialize optional and error union payloads when // we encounter the corresponding pointers. For this reason, the ordering of `to_map` matters. var to_map = try std.ArrayList(Air.Inst.Index).initCapacity(sema.arena, stores.len); + for (stores) |store_inst_idx| { const store_inst = sema.air_instructions.get(@intFromEnum(store_inst_idx)); const ptr_to_map = switch (store_inst.tag) { .store, .store_safe => store_inst.data.bin_op.lhs.toIndex().?, // Map the pointer being stored to. - .set_union_tag => continue, // Ignore for now; handled after we map pointers + .set_union_tag => store_inst.data.bin_op.lhs.toIndex().?, // Map the union pointer. .optional_payload_ptr_set, .errunion_payload_ptr_set => store_inst_idx, // Map the generated pointer itself. else => unreachable, }; @@ -4053,13 +4054,12 @@ fn resolveComptimeKnownAllocPtr(sema: *Sema, block: *Block, alloc: Air.Inst.Ref, const maybe_union_ty = Value.fromInterned(decl_parent_ptr).typeOf(zcu).childType(zcu); if (zcu.typeToUnion(maybe_union_ty)) |union_obj| { // As this is a union field, we must store to the pointer now to set the tag. - // If the payload is OPV, there will not be a payload store, so we store that value. - // Otherwise, there will be a payload store to process later, so undef will suffice. + // The payload value will be stored later, so undef is a sufficent payload for now. const payload_ty: Type = .fromInterned(union_obj.field_types.get(&zcu.intern_pool)[idx]); - const payload_val = try sema.typeHasOnePossibleValue(payload_ty) orelse try pt.undefValue(payload_ty); + const payload_val = try pt.undefValue(payload_ty); const tag_val = try pt.enumValueFieldIndex(.fromInterned(union_obj.enum_tag_ty), idx); const store_val = try pt.unionValue(maybe_union_ty, tag_val, payload_val); - try sema.storePtrVal(block, LazySrcLoc.unneeded, Value.fromInterned(decl_parent_ptr), store_val, maybe_union_ty); + try sema.storePtrVal(block, .unneeded, .fromInterned(decl_parent_ptr), store_val, maybe_union_ty); } break :ptr (try Value.fromInterned(decl_parent_ptr).ptrField(idx, pt)).toIntern(); }, diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 186ae56593..98da2dc185 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -2311,3 +2311,11 @@ test "set mutable union by switching on same union" { try expect(val == .bar); try expect(val.bar == 2); } + +test "initialize empty field of union inside comptime-known struct constant" { + const Inner = union { none: void, some: u8 }; + const Wrapper = struct { inner: Inner }; + + const val: Wrapper = .{ .inner = .{ .none = {} } }; + comptime assert(val.inner.none == {}); +} From e17a050bc695f7d117b89adb1d258813593ca111 Mon Sep 17 00:00:00 2001 From: mlugg Date: Tue, 5 Aug 2025 22:48:50 +0100 Subject: [PATCH 50/70] link: prevent deadlock when prelink tasks fail If an error occured which prevented a prelink task from being queued, then `pending_prelink_tasks` would never be decremented, which could cause deadlocks in some cases. So, instead of calculating ahead of time the number of prelink tasks to expect, we use a simpler strategy which is much like a wait group: we add 1 to a value when we spawn a worker, and in the worker function, `defer` decrementing the value. The initial value is 1, and there's a decrement after all of the workers are spawned, so once it hits 0, prelink is done (be it with a failure or a success). --- src/Compilation.zig | 81 ++++++++++++++++++++++++-------------------- src/libs/freebsd.zig | 4 --- src/libs/glibc.zig | 12 ------- src/libs/netbsd.zig | 4 --- src/link/Queue.zig | 58 ++++++++++++++++++++++++------- 5 files changed, 90 insertions(+), 69 deletions(-) diff --git a/src/Compilation.zig b/src/Compilation.zig index ae3ab118b7..8382da2804 100644 --- a/src/Compilation.zig +++ b/src/Compilation.zig @@ -2384,7 +2384,6 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil }; comp.c_object_table.putAssumeCapacityNoClobber(c_object, {}); } - comp.link_task_queue.pending_prelink_tasks += @intCast(comp.c_object_table.count()); // Add a `Win32Resource` for each `rc_source_files` and one for `manifest_file`. const win32_resource_count = @@ -2392,10 +2391,6 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil if (win32_resource_count > 0) { dev.check(.win32_resource); try comp.win32_resource_table.ensureTotalCapacity(gpa, win32_resource_count); - // Add this after adding logic to updateWin32Resource to pass the - // result into link.loadInput. loadInput integration is not implemented - // for Windows linking logic yet. - //comp.link_task_queue.pending_prelink_tasks += @intCast(win32_resource_count); for (options.rc_source_files) |rc_source_file| { const win32_resource = try gpa.create(Win32Resource); errdefer gpa.destroy(win32_resource); @@ -2454,58 +2449,47 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil if (musl.needsCrt0(comp.config.output_mode, comp.config.link_mode, comp.config.pie)) |f| { comp.queued_jobs.musl_crt_file[@intFromEnum(f)] = true; - comp.link_task_queue.pending_prelink_tasks += 1; } switch (comp.config.link_mode) { .static => comp.queued_jobs.musl_crt_file[@intFromEnum(musl.CrtFile.libc_a)] = true, .dynamic => comp.queued_jobs.musl_crt_file[@intFromEnum(musl.CrtFile.libc_so)] = true, } - comp.link_task_queue.pending_prelink_tasks += 1; } else if (target.isGnuLibC()) { if (!std.zig.target.canBuildLibC(target)) return error.LibCUnavailable; if (glibc.needsCrt0(comp.config.output_mode)) |f| { comp.queued_jobs.glibc_crt_file[@intFromEnum(f)] = true; - comp.link_task_queue.pending_prelink_tasks += 1; } comp.queued_jobs.glibc_shared_objects = true; - comp.link_task_queue.pending_prelink_tasks += glibc.sharedObjectsCount(target); comp.queued_jobs.glibc_crt_file[@intFromEnum(glibc.CrtFile.libc_nonshared_a)] = true; - comp.link_task_queue.pending_prelink_tasks += 1; } else if (target.isFreeBSDLibC()) { if (!std.zig.target.canBuildLibC(target)) return error.LibCUnavailable; if (freebsd.needsCrt0(comp.config.output_mode)) |f| { comp.queued_jobs.freebsd_crt_file[@intFromEnum(f)] = true; - comp.link_task_queue.pending_prelink_tasks += 1; } comp.queued_jobs.freebsd_shared_objects = true; - comp.link_task_queue.pending_prelink_tasks += freebsd.sharedObjectsCount(); } else if (target.isNetBSDLibC()) { if (!std.zig.target.canBuildLibC(target)) return error.LibCUnavailable; if (netbsd.needsCrt0(comp.config.output_mode)) |f| { comp.queued_jobs.netbsd_crt_file[@intFromEnum(f)] = true; - comp.link_task_queue.pending_prelink_tasks += 1; } comp.queued_jobs.netbsd_shared_objects = true; - comp.link_task_queue.pending_prelink_tasks += netbsd.sharedObjectsCount(); } else if (target.isWasiLibC()) { if (!std.zig.target.canBuildLibC(target)) return error.LibCUnavailable; comp.queued_jobs.wasi_libc_crt_file[@intFromEnum(wasi_libc.execModelCrtFile(comp.config.wasi_exec_model))] = true; comp.queued_jobs.wasi_libc_crt_file[@intFromEnum(wasi_libc.CrtFile.libc_a)] = true; - comp.link_task_queue.pending_prelink_tasks += 2; } else if (target.isMinGW()) { if (!std.zig.target.canBuildLibC(target)) return error.LibCUnavailable; const main_crt_file: mingw.CrtFile = if (is_dyn_lib) .dllcrt2_o else .crt2_o; comp.queued_jobs.mingw_crt_file[@intFromEnum(main_crt_file)] = true; comp.queued_jobs.mingw_crt_file[@intFromEnum(mingw.CrtFile.libmingw32_lib)] = true; - comp.link_task_queue.pending_prelink_tasks += 2; // When linking mingw-w64 there are some import libs we always need. try comp.windows_libs.ensureUnusedCapacity(gpa, mingw.always_link_libs.len); @@ -2519,7 +2503,6 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil target.isMinGW()) { comp.queued_jobs.zigc_lib = true; - comp.link_task_queue.pending_prelink_tasks += 1; } } @@ -2536,50 +2519,41 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil } if (comp.wantBuildLibUnwindFromSource()) { comp.queued_jobs.libunwind = true; - comp.link_task_queue.pending_prelink_tasks += 1; } if (build_options.have_llvm and is_exe_or_dyn_lib and comp.config.link_libcpp) { comp.queued_jobs.libcxx = true; comp.queued_jobs.libcxxabi = true; - comp.link_task_queue.pending_prelink_tasks += 2; } if (build_options.have_llvm and is_exe_or_dyn_lib and comp.config.any_sanitize_thread) { comp.queued_jobs.libtsan = true; - comp.link_task_queue.pending_prelink_tasks += 1; } if (can_build_compiler_rt) { if (comp.compiler_rt_strat == .lib) { log.debug("queuing a job to build compiler_rt_lib", .{}); comp.queued_jobs.compiler_rt_lib = true; - comp.link_task_queue.pending_prelink_tasks += 1; } else if (comp.compiler_rt_strat == .obj) { log.debug("queuing a job to build compiler_rt_obj", .{}); // In this case we are making a static library, so we ask // for a compiler-rt object to put in it. comp.queued_jobs.compiler_rt_obj = true; - comp.link_task_queue.pending_prelink_tasks += 1; } else if (comp.compiler_rt_strat == .dyn_lib) { // hack for stage2_x86_64 + coff log.debug("queuing a job to build compiler_rt_dyn_lib", .{}); comp.queued_jobs.compiler_rt_dyn_lib = true; - comp.link_task_queue.pending_prelink_tasks += 1; } if (comp.ubsan_rt_strat == .lib) { log.debug("queuing a job to build ubsan_rt_lib", .{}); comp.queued_jobs.ubsan_rt_lib = true; - comp.link_task_queue.pending_prelink_tasks += 1; } else if (comp.ubsan_rt_strat == .obj) { log.debug("queuing a job to build ubsan_rt_obj", .{}); comp.queued_jobs.ubsan_rt_obj = true; - comp.link_task_queue.pending_prelink_tasks += 1; } if (is_exe_or_dyn_lib and comp.config.any_fuzz) { log.debug("queuing a job to build libfuzzer", .{}); comp.queued_jobs.fuzzer_lib = true; - comp.link_task_queue.pending_prelink_tasks += 1; } } } @@ -2587,8 +2561,6 @@ pub fn create(gpa: Allocator, arena: Allocator, options: CreateOptions) !*Compil try comp.link_task_queue.queued_prelink.append(gpa, .load_explicitly_provided); } log.debug("queued prelink tasks: {d}", .{comp.link_task_queue.queued_prelink.items.len}); - log.debug("pending prelink tasks: {d}", .{comp.link_task_queue.pending_prelink_tasks}); - return comp; } @@ -4408,10 +4380,8 @@ fn performAllTheWork( comp.link_task_wait_group.reset(); defer comp.link_task_wait_group.wait(); - comp.link_prog_node.increaseEstimatedTotalItems( - comp.link_task_queue.queued_prelink.items.len + // already queued prelink tasks - comp.link_task_queue.pending_prelink_tasks, // prelink tasks which will be queued - ); + // Already-queued prelink tasks + comp.link_prog_node.increaseEstimatedTotalItems(comp.link_task_queue.queued_prelink.items.len); comp.link_task_queue.start(comp); if (comp.emit_docs != null) { @@ -4427,6 +4397,7 @@ fn performAllTheWork( // compiler-rt due to LLD bugs as well, e.g.: // // https://github.com/llvm/llvm-project/issues/43698#issuecomment-2542660611 + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "compiler_rt.zig", @@ -4444,6 +4415,7 @@ fn performAllTheWork( } if (comp.queued_jobs.compiler_rt_obj and comp.compiler_rt_obj == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "compiler_rt.zig", @@ -4462,6 +4434,7 @@ fn performAllTheWork( // hack for stage2_x86_64 + coff if (comp.queued_jobs.compiler_rt_dyn_lib and comp.compiler_rt_dyn_lib == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "compiler_rt.zig", @@ -4479,6 +4452,7 @@ fn performAllTheWork( } if (comp.queued_jobs.fuzzer_lib and comp.fuzzer_lib == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "fuzzer.zig", @@ -4493,6 +4467,7 @@ fn performAllTheWork( } if (comp.queued_jobs.ubsan_rt_lib and comp.ubsan_rt_lib == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "ubsan_rt.zig", @@ -4509,6 +4484,7 @@ fn performAllTheWork( } if (comp.queued_jobs.ubsan_rt_obj and comp.ubsan_rt_obj == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildRt, .{ comp, "ubsan_rt.zig", @@ -4525,40 +4501,49 @@ fn performAllTheWork( } if (comp.queued_jobs.glibc_shared_objects) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildGlibcSharedObjects, .{ comp, main_progress_node }); } if (comp.queued_jobs.freebsd_shared_objects) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildFreeBSDSharedObjects, .{ comp, main_progress_node }); } if (comp.queued_jobs.netbsd_shared_objects) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildNetBSDSharedObjects, .{ comp, main_progress_node }); } if (comp.queued_jobs.libunwind) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildLibUnwind, .{ comp, main_progress_node }); } if (comp.queued_jobs.libcxx) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildLibCxx, .{ comp, main_progress_node }); } if (comp.queued_jobs.libcxxabi) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildLibCxxAbi, .{ comp, main_progress_node }); } if (comp.queued_jobs.libtsan) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildLibTsan, .{ comp, main_progress_node }); } if (comp.queued_jobs.zigc_lib and comp.zigc_static_lib == null) { + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildLibZigC, .{ comp, main_progress_node }); } for (0..@typeInfo(musl.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.musl_crt_file[i]) { const tag: musl.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildMuslCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4566,6 +4551,7 @@ fn performAllTheWork( for (0..@typeInfo(glibc.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.glibc_crt_file[i]) { const tag: glibc.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildGlibcCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4573,6 +4559,7 @@ fn performAllTheWork( for (0..@typeInfo(freebsd.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.freebsd_crt_file[i]) { const tag: freebsd.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildFreeBSDCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4580,6 +4567,7 @@ fn performAllTheWork( for (0..@typeInfo(netbsd.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.netbsd_crt_file[i]) { const tag: netbsd.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildNetBSDCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4587,6 +4575,7 @@ fn performAllTheWork( for (0..@typeInfo(wasi_libc.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.wasi_libc_crt_file[i]) { const tag: wasi_libc.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildWasiLibcCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4594,6 +4583,7 @@ fn performAllTheWork( for (0..@typeInfo(mingw.CrtFile).@"enum".fields.len) |i| { if (comp.queued_jobs.mingw_crt_file[i]) { const tag: mingw.CrtFile = @enumFromInt(i); + comp.link_task_queue.startPrelinkItem(); comp.link_task_wait_group.spawnManager(buildMingwCrtFile, .{ comp, tag, main_progress_node }); } } @@ -4665,12 +4655,14 @@ fn performAllTheWork( } while (comp.c_object_work_queue.readItem()) |c_object| { + comp.link_task_queue.startPrelinkItem(); comp.thread_pool.spawnWg(&comp.link_task_wait_group, workerUpdateCObject, .{ comp, c_object, main_progress_node, }); } while (comp.win32_resource_work_queue.readItem()) |win32_resource| { + comp.link_task_queue.startPrelinkItem(); comp.thread_pool.spawnWg(&comp.link_task_wait_group, workerUpdateWin32Resource, .{ comp, win32_resource, main_progress_node, }); @@ -4773,15 +4765,14 @@ fn performAllTheWork( } }; + // We aren't going to queue any more prelink tasks. + comp.link_task_queue.finishPrelinkItem(comp); + if (!comp.separateCodegenThreadOk()) { // Waits until all input files have been parsed. comp.link_task_wait_group.wait(); comp.link_task_wait_group.reset(); std.log.scoped(.link).debug("finished waiting for link_task_wait_group", .{}); - if (comp.link_task_queue.pending_prelink_tasks > 0) { - // Indicates an error occurred preventing prelink phase from completing. - return; - } } if (comp.zcu != null) { @@ -5568,6 +5559,7 @@ fn workerUpdateCObject( c_object: *CObject, progress_node: std.Progress.Node, ) void { + defer comp.link_task_queue.finishPrelinkItem(comp); comp.updateCObject(c_object, progress_node) catch |err| switch (err) { error.AnalysisFail => return, else => { @@ -5585,6 +5577,7 @@ fn workerUpdateWin32Resource( win32_resource: *Win32Resource, progress_node: std.Progress.Node, ) void { + defer comp.link_task_queue.finishPrelinkItem(comp); comp.updateWin32Resource(win32_resource, progress_node) catch |err| switch (err) { error.AnalysisFail => return, else => { @@ -5628,6 +5621,7 @@ fn buildRt( options: RtOptions, out: *?CrtFile, ) void { + defer comp.link_task_queue.finishPrelinkItem(comp); comp.buildOutputFromZig( root_source_name, root_name, @@ -5646,6 +5640,7 @@ fn buildRt( } fn buildMuslCrtFile(comp: *Compilation, crt_file: musl.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (musl.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.musl_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5657,6 +5652,7 @@ fn buildMuslCrtFile(comp: *Compilation, crt_file: musl.CrtFile, prog_node: std.P } fn buildGlibcCrtFile(comp: *Compilation, crt_file: glibc.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (glibc.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.glibc_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5668,6 +5664,7 @@ fn buildGlibcCrtFile(comp: *Compilation, crt_file: glibc.CrtFile, prog_node: std } fn buildGlibcSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (glibc.buildSharedObjects(comp, prog_node)) |_| { // The job should no longer be queued up since it succeeded. comp.queued_jobs.glibc_shared_objects = false; @@ -5680,6 +5677,7 @@ fn buildGlibcSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) voi } fn buildFreeBSDCrtFile(comp: *Compilation, crt_file: freebsd.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (freebsd.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.freebsd_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5691,6 +5689,7 @@ fn buildFreeBSDCrtFile(comp: *Compilation, crt_file: freebsd.CrtFile, prog_node: } fn buildFreeBSDSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (freebsd.buildSharedObjects(comp, prog_node)) |_| { // The job should no longer be queued up since it succeeded. comp.queued_jobs.freebsd_shared_objects = false; @@ -5703,6 +5702,7 @@ fn buildFreeBSDSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) v } fn buildNetBSDCrtFile(comp: *Compilation, crt_file: netbsd.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (netbsd.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.netbsd_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5714,6 +5714,7 @@ fn buildNetBSDCrtFile(comp: *Compilation, crt_file: netbsd.CrtFile, prog_node: s } fn buildNetBSDSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (netbsd.buildSharedObjects(comp, prog_node)) |_| { // The job should no longer be queued up since it succeeded. comp.queued_jobs.netbsd_shared_objects = false; @@ -5726,6 +5727,7 @@ fn buildNetBSDSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) vo } fn buildMingwCrtFile(comp: *Compilation, crt_file: mingw.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (mingw.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.mingw_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5737,6 +5739,7 @@ fn buildMingwCrtFile(comp: *Compilation, crt_file: mingw.CrtFile, prog_node: std } fn buildWasiLibcCrtFile(comp: *Compilation, crt_file: wasi_libc.CrtFile, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (wasi_libc.buildCrtFile(comp, crt_file, prog_node)) |_| { comp.queued_jobs.wasi_libc_crt_file[@intFromEnum(crt_file)] = false; } else |err| switch (err) { @@ -5748,6 +5751,7 @@ fn buildWasiLibcCrtFile(comp: *Compilation, crt_file: wasi_libc.CrtFile, prog_no } fn buildLibUnwind(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (libunwind.buildStaticLib(comp, prog_node)) |_| { comp.queued_jobs.libunwind = false; } else |err| switch (err) { @@ -5757,6 +5761,7 @@ fn buildLibUnwind(comp: *Compilation, prog_node: std.Progress.Node) void { } fn buildLibCxx(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (libcxx.buildLibCxx(comp, prog_node)) |_| { comp.queued_jobs.libcxx = false; } else |err| switch (err) { @@ -5766,6 +5771,7 @@ fn buildLibCxx(comp: *Compilation, prog_node: std.Progress.Node) void { } fn buildLibCxxAbi(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (libcxx.buildLibCxxAbi(comp, prog_node)) |_| { comp.queued_jobs.libcxxabi = false; } else |err| switch (err) { @@ -5775,6 +5781,7 @@ fn buildLibCxxAbi(comp: *Compilation, prog_node: std.Progress.Node) void { } fn buildLibTsan(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); if (libtsan.buildTsan(comp, prog_node)) |_| { comp.queued_jobs.libtsan = false; } else |err| switch (err) { @@ -5784,6 +5791,7 @@ fn buildLibTsan(comp: *Compilation, prog_node: std.Progress.Node) void { } fn buildLibZigC(comp: *Compilation, prog_node: std.Progress.Node) void { + defer comp.link_task_queue.finishPrelinkItem(comp); comp.buildOutputFromZig( "c.zig", "zigc", @@ -7721,6 +7729,7 @@ pub fn queuePrelinkTaskMode(comp: *Compilation, path: Cache.Path, config: *const /// Only valid to call during `update`. Automatically handles queuing up a /// linker worker task if there is not already one. pub fn queuePrelinkTasks(comp: *Compilation, tasks: []const link.PrelinkTask) void { + comp.link_prog_node.increaseEstimatedTotalItems(tasks.len); comp.link_task_queue.enqueuePrelink(comp, tasks) catch |err| switch (err) { error.OutOfMemory => return comp.setAllocFailure(), }; diff --git a/src/libs/freebsd.zig b/src/libs/freebsd.zig index 6baa899087..0e8e8f850b 100644 --- a/src/libs/freebsd.zig +++ b/src/libs/freebsd.zig @@ -977,10 +977,6 @@ pub fn buildSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) anye }); } -pub fn sharedObjectsCount() u8 { - return libs.len; -} - fn queueSharedObjects(comp: *Compilation, so_files: BuiltSharedObjects) void { assert(comp.freebsd_so_files == null); comp.freebsd_so_files = so_files; diff --git a/src/libs/glibc.zig b/src/libs/glibc.zig index 43baaf38d4..5518dc9321 100644 --- a/src/libs/glibc.zig +++ b/src/libs/glibc.zig @@ -1130,18 +1130,6 @@ pub fn buildSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) anye }); } -pub fn sharedObjectsCount(target: *const std.Target) u8 { - const target_version = target.os.versionRange().gnuLibCVersion() orelse return 0; - var count: u8 = 0; - for (libs) |lib| { - if (lib.removed_in) |rem_in| { - if (target_version.order(rem_in) != .lt) continue; - } - count += 1; - } - return count; -} - fn queueSharedObjects(comp: *Compilation, so_files: BuiltSharedObjects) void { const target_version = comp.getTarget().os.versionRange().gnuLibCVersion().?; diff --git a/src/libs/netbsd.zig b/src/libs/netbsd.zig index 094165b9c5..945c6853a0 100644 --- a/src/libs/netbsd.zig +++ b/src/libs/netbsd.zig @@ -642,10 +642,6 @@ pub fn buildSharedObjects(comp: *Compilation, prog_node: std.Progress.Node) anye }); } -pub fn sharedObjectsCount() u8 { - return libs.len; -} - fn queueSharedObjects(comp: *Compilation, so_files: BuiltSharedObjects) void { assert(comp.netbsd_so_files == null); comp.netbsd_so_files = so_files; diff --git a/src/link/Queue.zig b/src/link/Queue.zig index d1595636ac..9f4535e1fe 100644 --- a/src/link/Queue.zig +++ b/src/link/Queue.zig @@ -16,9 +16,9 @@ mutex: std.Thread.Mutex, /// Validates that only one `flushTaskQueue` thread is running at a time. flush_safety: std.debug.SafetyLock, -/// This is the number of prelink tasks which are expected but have not yet been enqueued. -/// Guarded by `mutex`. -pending_prelink_tasks: u32, +/// This value is positive while there are still prelink tasks yet to be queued. Once they are +/// all queued, this value becomes 0, and ZCU tasks can be run. Guarded by `mutex`. +prelink_wait_count: u32, /// Prelink tasks which have been enqueued and are not yet owned by the worker thread. /// Allocated into `gpa`, guarded by `mutex`. @@ -59,7 +59,7 @@ state: union(enum) { /// The link thread is currently running or queued to run. running, /// The link thread is not running or queued, because it has exhausted all immediately available - /// tasks. It should be spawned when more tasks are enqueued. If `pending_prelink_tasks` is not + /// tasks. It should be spawned when more tasks are enqueued. If `prelink_wait_count` is not /// zero, we are specifically waiting for prelink tasks. finished, /// The link thread is not running or queued, because it is waiting for this MIR to be populated. @@ -73,11 +73,11 @@ state: union(enum) { const max_air_bytes_in_flight = 10 * 1024 * 1024; /// The initial `Queue` state, containing no tasks, expecting no prelink tasks, and with no running worker thread. -/// The `pending_prelink_tasks` and `queued_prelink` fields may be modified as needed before calling `start`. +/// The `queued_prelink` field may be appended to before calling `start`. pub const empty: Queue = .{ .mutex = .{}, .flush_safety = .{}, - .pending_prelink_tasks = 0, + .prelink_wait_count = undefined, // set in `start` .queued_prelink = .empty, .wip_prelink = .empty, .queued_zcu = .empty, @@ -100,17 +100,49 @@ pub fn deinit(q: *Queue, comp: *Compilation) void { } /// This is expected to be called exactly once, after which the caller must not directly access -/// `queued_prelink` or `pending_prelink_tasks` any longer. This will spawn the link thread if -/// necessary. +/// `queued_prelink` any longer. This will spawn the link thread if necessary. pub fn start(q: *Queue, comp: *Compilation) void { assert(q.state == .finished); assert(q.queued_zcu.items.len == 0); + // Reset this to 1. We can't init it to 1 in `empty`, because it would fall to 0 on successive + // incremental updates, but we still need the initial 1. + q.prelink_wait_count = 1; if (q.queued_prelink.items.len != 0) { q.state = .running; comp.thread_pool.spawnWgId(&comp.link_task_wait_group, flushTaskQueue, .{ q, comp }); } } +/// Every call to this must be paired with a call to `finishPrelinkItem`. +pub fn startPrelinkItem(q: *Queue) void { + q.mutex.lock(); + defer q.mutex.unlock(); + assert(q.prelink_wait_count > 0); // must not have finished everything already + q.prelink_wait_count += 1; +} +/// This function must be called exactly one more time than `startPrelinkItem` is. The final call +/// indicates that we have finished calling `startPrelinkItem`, so once all pending items finish, +/// we are ready to move on to ZCU tasks. +pub fn finishPrelinkItem(q: *Queue, comp: *Compilation) void { + { + q.mutex.lock(); + defer q.mutex.unlock(); + q.prelink_wait_count -= 1; + if (q.prelink_wait_count != 0) return; + // The prelink task count dropped to 0; restart the linker thread if necessary. + switch (q.state) { + .wait_for_mir => unreachable, // we've not started zcu tasks yet + .running => return, + .finished => {}, + } + assert(q.queued_prelink.items.len == 0); + // Even if there are no ZCU tasks, we must restart the linker thread to make sure + // that `link.File.prelink()` is called. + q.state = .running; + } + comp.thread_pool.spawnWgId(&comp.link_task_wait_group, flushTaskQueue, .{ q, comp }); +} + /// Called by codegen workers after they have populated a `ZcuTask.LinkFunc.SharedMir`. If the link /// thread was waiting for this MIR, it can resume. pub fn mirReady(q: *Queue, comp: *Compilation, func_index: InternPool.Index, mir: *ZcuTask.LinkFunc.SharedMir) void { @@ -130,14 +162,14 @@ pub fn mirReady(q: *Queue, comp: *Compilation, func_index: InternPool.Index, mir comp.thread_pool.spawnWgId(&comp.link_task_wait_group, flushTaskQueue, .{ q, comp }); } -/// Enqueues all prelink tasks in `tasks`. Asserts that they were expected, i.e. that `tasks.len` is -/// less than or equal to `q.pending_prelink_tasks`. Also asserts that `tasks.len` is not 0. +/// Enqueues all prelink tasks in `tasks`. Asserts that they were expected, i.e. that +/// `prelink_wait_count` is not yet 0. Also asserts that `tasks.len` is not 0. pub fn enqueuePrelink(q: *Queue, comp: *Compilation, tasks: []const PrelinkTask) Allocator.Error!void { { q.mutex.lock(); defer q.mutex.unlock(); + assert(q.prelink_wait_count > 0); try q.queued_prelink.appendSlice(comp.gpa, tasks); - q.pending_prelink_tasks -= @intCast(tasks.len); switch (q.state) { .wait_for_mir => unreachable, // we've not started zcu tasks yet .running => return, @@ -167,7 +199,7 @@ pub fn enqueueZcu(q: *Queue, comp: *Compilation, task: ZcuTask) Allocator.Error! try q.queued_zcu.append(comp.gpa, task); switch (q.state) { .running, .wait_for_mir => return, - .finished => if (q.pending_prelink_tasks != 0) return, + .finished => if (q.prelink_wait_count > 0) return, } // Restart the linker thread, unless it would immediately be blocked if (task == .link_func and task.link_func.mir.status.load(.acquire) == .pending) { @@ -194,7 +226,7 @@ fn flushTaskQueue(tid: usize, q: *Queue, comp: *Compilation) void { defer q.mutex.unlock(); std.mem.swap(std.ArrayListUnmanaged(PrelinkTask), &q.queued_prelink, &q.wip_prelink); if (q.wip_prelink.items.len == 0) { - if (q.pending_prelink_tasks == 0) { + if (q.prelink_wait_count == 0) { break :prelink; // prelink is done } else { // We're expecting more prelink tasks so can't move on to ZCU tasks. From 04fe1bfe3ceabd632183a85101cddeeec11f0745 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 5 Aug 2025 11:26:00 -0700 Subject: [PATCH 51/70] std.Io.Reader: use readVec for fill functions readVec has two updated responsibilities: 1. it must respect any existing already buffered data. 2. it must write to the buffer if data is empty --- lib/std/Io/Reader.zig | 166 ++++++++++++++++++++++++++---------------- lib/std/fs/File.zig | 22 +++++- 2 files changed, 123 insertions(+), 65 deletions(-) diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index 05ab489286..78d3646cbe 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -70,13 +70,14 @@ pub const VTable = struct { /// Returns number of bytes written to `data`. /// - /// `data` may not have nonzero length. + /// `data` must have nonzero length. `data[0]` may have zero length, in + /// which case the implementation must write to `Reader.buffer`. /// /// `data` may not contain an alias to `Reader.buffer`. /// - /// `data` is mutable because the implementation may to temporarily modify - /// the fields in order to handle partial reads. Implementations must - /// restore the original value before returning. + /// `data` is mutable because the implementation may temporarily modify the + /// fields in order to handle partial reads. Implementations must restore + /// the original value before returning. /// /// Implementations may ignore `data`, writing directly to `Reader.buffer`, /// modifying `seek` and `end` accordingly, and returning 0 from this @@ -421,23 +422,29 @@ pub fn readVec(r: *Reader, data: [][]u8) Error!usize { /// Writes to `Reader.buffer` or `data`, whichever has larger capacity. pub fn defaultReadVec(r: *Reader, data: [][]u8) Error!usize { - assert(r.seek == r.end); - r.seek = 0; - r.end = 0; const first = data[0]; - const direct = first.len >= r.buffer.len; + if (r.seek == r.end and first.len >= r.buffer.len) { + var writer: Writer = .{ + .buffer = first, + .end = 0, + .vtable = &.{ .drain = Writer.fixedDrain }, + }; + const limit: Limit = .limited(writer.buffer.len - writer.end); + return r.vtable.stream(r, &writer, limit) catch |err| switch (err) { + error.WriteFailed => unreachable, + else => |e| return e, + }; + } var writer: Writer = .{ - .buffer = if (direct) first else r.buffer, - .end = 0, + .buffer = r.buffer, + .end = r.end, .vtable = &.{ .drain = Writer.fixedDrain }, }; const limit: Limit = .limited(writer.buffer.len - writer.end); - const n = r.vtable.stream(r, &writer, limit) catch |err| switch (err) { + r.end += r.vtable.stream(r, &writer, limit) catch |err| switch (err) { error.WriteFailed => unreachable, else => |e| return e, }; - if (direct) return n; - r.end += n; return 0; } @@ -1059,17 +1066,8 @@ pub fn fill(r: *Reader, n: usize) Error!void { /// increasing by a factor of 5 or more. fn fillUnbuffered(r: *Reader, n: usize) Error!void { try rebase(r, n); - var writer: Writer = .{ - .buffer = r.buffer, - .vtable = &.{ .drain = Writer.fixedDrain }, - }; - while (r.end < r.seek + n) { - writer.end = r.end; - r.end += r.vtable.stream(r, &writer, .limited(r.buffer.len - r.end)) catch |err| switch (err) { - error.WriteFailed => unreachable, - error.ReadFailed, error.EndOfStream => |e| return e, - }; - } + var bufs: [1][]u8 = .{""}; + while (r.end < r.seek + n) _ = try r.vtable.readVec(r, &bufs); } /// Without advancing the seek position, does exactly one underlying read, filling the buffer as @@ -1079,15 +1077,8 @@ fn fillUnbuffered(r: *Reader, n: usize) Error!void { /// Asserts buffer capacity is at least 1. pub fn fillMore(r: *Reader) Error!void { try rebase(r, 1); - var writer: Writer = .{ - .buffer = r.buffer, - .end = r.end, - .vtable = &.{ .drain = Writer.fixedDrain }, - }; - r.end += r.vtable.stream(r, &writer, .limited(r.buffer.len - r.end)) catch |err| switch (err) { - error.WriteFailed => unreachable, - else => |e| return e, - }; + var bufs: [1][]u8 = .{""}; + _ = try r.vtable.readVec(r, &bufs); } /// Returns the next byte from the stream or returns `error.EndOfStream`. @@ -1796,18 +1787,26 @@ pub fn Hashed(comptime Hasher: type) type { fn readVec(r: *Reader, data: [][]u8) Error!usize { const this: *@This() = @alignCast(@fieldParentPtr("reader", r)); - const n = try this.in.readVec(data); + var vecs: [8][]u8 = undefined; // Arbitrarily chosen amount. + const dest_n, const data_size = try r.writableVector(&vecs, data); + const dest = vecs[0..dest_n]; + const n = try this.in.readVec(dest); var remaining: usize = n; - for (data) |slice| { + for (dest) |slice| { if (remaining < slice.len) { this.hasher.update(slice[0..remaining]); - return n; + remaining = 0; + break; } else { remaining -= slice.len; this.hasher.update(slice); } } assert(remaining == 0); + if (n > data_size) { + r.end += n - data_size; + return data_size; + } return n; } @@ -1824,17 +1823,24 @@ pub fn Hashed(comptime Hasher: type) type { pub fn writableVectorPosix(r: *Reader, buffer: []std.posix.iovec, data: []const []u8) Error!struct { usize, usize } { var i: usize = 0; var n: usize = 0; - for (data) |buf| { - if (buffer.len - i == 0) return .{ i, n }; + if (r.seek == r.end) { + for (data) |buf| { + if (buffer.len - i == 0) return .{ i, n }; + if (buf.len != 0) { + buffer[i] = .{ .base = buf.ptr, .len = buf.len }; + i += 1; + n += buf.len; + } + } + const buf = r.buffer; if (buf.len != 0) { + r.seek = 0; + r.end = 0; buffer[i] = .{ .base = buf.ptr, .len = buf.len }; i += 1; - n += buf.len; } - } - assert(r.seek == r.end); - const buf = r.buffer; - if (buf.len != 0) { + } else { + const buf = r.buffer[r.end..]; buffer[i] = .{ .base = buf.ptr, .len = buf.len }; i += 1; } @@ -1848,28 +1854,62 @@ pub fn writableVectorWsa( ) Error!struct { usize, usize } { var i: usize = 0; var n: usize = 0; - for (data) |buf| { - if (buffer.len - i == 0) return .{ i, n }; - if (buf.len == 0) continue; - if (std.math.cast(u32, buf.len)) |len| { - buffer[i] = .{ .buf = buf.ptr, .len = len }; - i += 1; - n += len; - continue; - } - buffer[i] = .{ .buf = buf.ptr, .len = std.math.maxInt(u32) }; - i += 1; - n += std.math.maxInt(u32); - return .{ i, n }; - } - assert(r.seek == r.end); - const buf = r.buffer; - if (buf.len != 0) { - if (std.math.cast(u32, buf.len)) |len| { - buffer[i] = .{ .buf = buf.ptr, .len = len }; - } else { + if (r.seek == r.end) { + for (data) |buf| { + if (buffer.len - i == 0) return .{ i, n }; + if (buf.len == 0) continue; + if (std.math.cast(u32, buf.len)) |len| { + buffer[i] = .{ .buf = buf.ptr, .len = len }; + i += 1; + n += len; + continue; + } buffer[i] = .{ .buf = buf.ptr, .len = std.math.maxInt(u32) }; + i += 1; + n += std.math.maxInt(u32); + return .{ i, n }; } + const buf = r.buffer; + if (buf.len != 0) { + r.seek = 0; + r.end = 0; + if (std.math.cast(u32, buf.len)) |len| { + buffer[i] = .{ .buf = buf.ptr, .len = len }; + } else { + buffer[i] = .{ .buf = buf.ptr, .len = std.math.maxInt(u32) }; + } + i += 1; + } + } else { + buffer[i] = .{ + .buf = r.buffer.ptr + r.end, + .len = @min(std.math.maxInt(u32), r.buffer.len - r.end), + }; + i += 1; + } + return .{ i, n }; +} + +pub fn writableVector(r: *Reader, buffer: [][]u8, data: []const []u8) Error!struct { usize, usize } { + var i: usize = 0; + var n: usize = 0; + if (r.seek == r.end) { + for (data) |buf| { + if (buffer.len - i == 0) return .{ i, n }; + if (buf.len != 0) { + buffer[i] = buf; + i += 1; + n += buf.len; + } + } + if (r.buffer.len != 0) { + r.seek = 0; + r.end = 0; + buffer[i] = r.buffer; + i += 1; + } + } else { + buffer[i] = r.buffer[r.end..]; i += 1; } return .{ i, n }; diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 9776854f18..2791642ac7 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1312,7 +1312,16 @@ pub const Reader = struct { if (is_windows) { // Unfortunately, `ReadFileScatter` cannot be used since it // requires page alignment. - return readPositional(r, data[0]); + assert(io_reader.seek == io_reader.end); + io_reader.seek = 0; + io_reader.end = 0; + const first = data[0]; + if (first.len >= io_reader.buffer.len) { + return readPositional(r, first); + } else { + io_reader.end += try readPositional(r, io_reader.buffer); + return 0; + } } var iovecs_buffer: [max_buffers_len]posix.iovec = undefined; const dest_n, const data_size = try io_reader.writableVectorPosix(&iovecs_buffer, data); @@ -1352,7 +1361,16 @@ pub const Reader = struct { if (is_windows) { // Unfortunately, `ReadFileScatter` cannot be used since it // requires page alignment. - return readStreaming(r, data[0]); + assert(io_reader.seek == io_reader.end); + io_reader.seek = 0; + io_reader.end = 0; + const first = data[0]; + if (first.len >= io_reader.buffer.len) { + return readStreaming(r, first); + } else { + io_reader.end += try readStreaming(r, io_reader.buffer); + return 0; + } } var iovecs_buffer: [max_buffers_len]posix.iovec = undefined; const dest_n, const data_size = try io_reader.writableVectorPosix(&iovecs_buffer, data); From 6e671d4c779dc2087b0687f9e5ed5cd7a3341ea9 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 31 Jul 2025 22:36:08 -0700 Subject: [PATCH 52/70] std.http: rework for new std.Io API --- lib/std/http.zig | 781 ++++++++++++- lib/std/http/ChunkParser.zig | 6 +- lib/std/http/Client.zig | 2052 +++++++++++++++++----------------- lib/std/http/Server.zig | 1140 +++++++------------ lib/std/http/WebSocket.zig | 246 ---- lib/std/http/protocol.zig | 464 -------- lib/std/http/test.zig | 594 +++++----- 7 files changed, 2449 insertions(+), 2834 deletions(-) delete mode 100644 lib/std/http/WebSocket.zig delete mode 100644 lib/std/http/protocol.zig diff --git a/lib/std/http.zig b/lib/std/http.zig index 5bf12a1876..6075a2fe6d 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,14 +1,14 @@ const builtin = @import("builtin"); const std = @import("std.zig"); const assert = std.debug.assert; +const Writer = std.Io.Writer; +const File = std.fs.File; pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); -pub const protocol = @import("http/protocol.zig"); pub const HeadParser = @import("http/HeadParser.zig"); pub const ChunkParser = @import("http/ChunkParser.zig"); pub const HeaderIterator = @import("http/HeaderIterator.zig"); -pub const WebSocket = @import("http/WebSocket.zig"); pub const Version = enum { @"HTTP/1.0", @@ -42,7 +42,7 @@ pub const Method = enum(u64) { return x; } - pub fn format(self: Method, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn format(self: Method, w: *Writer) Writer.Error!void { const bytes: []const u8 = @ptrCast(&@intFromEnum(self)); const str = std.mem.sliceTo(bytes, 0); try w.writeAll(str); @@ -296,13 +296,24 @@ pub const TransferEncoding = enum { }; pub const ContentEncoding = enum { - identity, - compress, - @"x-compress", - deflate, - gzip, - @"x-gzip", zstd, + gzip, + deflate, + compress, + identity, + + pub fn fromString(s: []const u8) ?ContentEncoding { + const map = std.StaticStringMap(ContentEncoding).initComptime(.{ + .{ "zstd", .zstd }, + .{ "gzip", .gzip }, + .{ "x-gzip", .gzip }, + .{ "deflate", .deflate }, + .{ "compress", .compress }, + .{ "x-compress", .compress }, + .{ "identity", .identity }, + }); + return map.get(s); + } }; pub const Connection = enum { @@ -315,15 +326,755 @@ pub const Header = struct { value: []const u8, }; +pub const Reader = struct { + in: *std.Io.Reader, + /// This is preallocated memory that might be used by `bodyReader`. That + /// function might return a pointer to this field, or a different + /// `*std.Io.Reader`. Advisable to not access this field directly. + interface: std.Io.Reader, + /// Keeps track of whether the stream is ready to accept a new request, + /// making invalid API usage cause assertion failures rather than HTTP + /// protocol violations. + state: State, + /// HTTP trailer bytes. These are at the end of a transfer-encoding: + /// chunked message. This data is available only after calling one of the + /// "end" functions and points to data inside the buffer of `in`, and is + /// therefore invalidated on the next call to `receiveHead`, or any other + /// read from `in`. + trailers: []const u8 = &.{}, + body_err: ?BodyError = null, + /// Stolen from `in`. + head_buffer: []u8 = &.{}, + + pub const max_chunk_header_len = 22; + + pub const RemainingChunkLen = enum(u64) { + head = 0, + n = 1, + rn = 2, + _, + + pub fn init(integer: u64) RemainingChunkLen { + return @enumFromInt(integer); + } + + pub fn int(rcl: RemainingChunkLen) u64 { + return @intFromEnum(rcl); + } + }; + + pub const State = union(enum) { + /// The stream is available to be used for the first time, or reused. + ready, + received_head, + /// The stream goes until the connection is closed. + body_none, + body_remaining_content_length: u64, + body_remaining_chunk_len: RemainingChunkLen, + /// The stream would be eligible for another HTTP request, however the + /// client and server did not negotiate a persistent connection. + closing, + }; + + pub const BodyError = error{ + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + }; + + pub const HeadError = error{ + /// Too many bytes of HTTP headers. + /// + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Partial HTTP request was received but the connection was closed + /// before fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. This + /// happens when a keep-alive connection is finally closed. + HttpConnectionClosing, + /// Transitive error occurred reading from `in`. + ReadFailed, + }; + + pub fn restituteHeadBuffer(reader: *Reader) void { + reader.in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; + } + + /// Buffers the entire head into `head_buffer`, invalidating the previous + /// `head_buffer`, if any. + pub fn receiveHead(reader: *Reader) HeadError!void { + reader.trailers = &.{}; + const in = reader.in; + in.restitute(reader.head_buffer.len); + reader.head_buffer.len = 0; + in.rebase(); + var hp: HeadParser = .{}; + var head_end: usize = 0; + while (true) { + if (head_end >= in.buffer.len) return error.HttpHeadersOversize; + in.fillMore() catch |err| switch (err) { + error.EndOfStream => switch (head_end) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }, + error.ReadFailed => return error.ReadFailed, + }; + head_end += hp.feed(in.buffered()[head_end..]); + if (hp.state == .finished) { + reader.head_buffer = in.steal(head_end); + reader.state = .received_head; + return; + } + } + } + + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interfaceDecompressing` + pub fn bodyReader( + reader: *Reader, + buffer: []u8, + transfer_encoding: TransferEncoding, + content_length: ?u64, + ) *std.Io.Reader { + assert(reader.state == .received_head); + switch (transfer_encoding) { + .chunked => { + reader.state = .{ .body_remaining_chunk_len = .head }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = chunkedStream, + .discard = chunkedDiscard, + }, + }; + return &reader.interface; + }, + .none => { + if (content_length) |len| { + reader.state = .{ .body_remaining_content_length = len }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = contentLengthStream, + .discard = contentLengthDiscard, + }, + }; + return &reader.interface; + } else { + reader.state = .body_none; + return reader.in; + } + }, + } + } + + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interface` + pub fn bodyReaderDecompressing( + reader: *Reader, + transfer_encoding: TransferEncoding, + content_length: ?u64, + content_encoding: ContentEncoding, + decompressor: *Decompressor, + decompression_buffer: []u8, + ) *std.Io.Reader { + if (transfer_encoding == .none and content_length == null) { + assert(reader.state == .received_head); + reader.state = .body_none; + switch (content_encoding) { + .identity => { + return reader.in; + }, + .deflate => { + decompressor.* = .{ .flate = .init(reader.in, .raw, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } + const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length); + return decompressor.init(transfer_reader, decompression_buffer, content_encoding); + } + + fn contentLengthStream( + io_r: *std.Io.Reader, + w: *Writer, + limit: std.Io.Limit, + ) std.Io.Reader.StreamError!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.stream(w, limit.min(.limited(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn contentLengthDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.discard(limit.min(.limited(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn chunkedStream(io_r: *std.Io.Reader, w: *Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return error.WriteFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedReadEndless( + reader: *Reader, + w: *Writer, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.StreamError)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.stream(w, limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); + chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); + return n; + }, + } + } + + fn chunkedDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @fieldParentPtr("interface", io_r); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedDiscardEndless(reader, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedDiscardEndless( + reader: *Reader, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.discard(limit.min(.limited(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.discard(limit.min(.limited(remaining_chunk_len.int() - 2))); + chunk_len_ptr.* = .init(remaining_chunk_len.int() - n); + return n; + }, + } + } + + /// Called when next bytes in the stream are trailers, or "\r\n" to indicate + /// end of chunked body. + fn parseTrailers(reader: *Reader, amt_read: usize) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + const rn = try in.peekArray(2); + if (rn[0] == '\r' and rn[1] == '\n') { + in.toss(2); + reader.state = .ready; + assert(reader.trailers.len == 0); + return amt_read; + } + var hp: HeadParser = .{ .state = .seen_rn }; + var trailers_len: usize = 2; + while (true) { + if (in.buffer.len - trailers_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[trailers_len..]; + if (remaining.len == 0) { + try in.fillMore(); + continue; + } + trailers_len += hp.feed(remaining); + if (hp.state == .finished) { + reader.state = .ready; + reader.trailers = in.buffered()[0..trailers_len]; + in.toss(trailers_len); + return amt_read; + } + } + } +}; + +pub const Decompressor = union(enum) { + flate: std.compress.flate.Decompress, + zstd: std.compress.zstd.Decompress, + none: *std.Io.Reader, + + pub fn init( + decompressor: *Decompressor, + transfer_reader: *std.Io.Reader, + buffer: []u8, + content_encoding: ContentEncoding, + ) *std.Io.Reader { + switch (content_encoding) { + .identity => { + decompressor.* = .{ .none = transfer_reader }; + return transfer_reader; + }, + .deflate => { + decompressor.* = .{ .flate = .init(transfer_reader, .raw, buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } +}; + +/// Request or response body. +pub const BodyWriter = struct { + /// Until the lifetime of `BodyWriter` ends, it is illegal to modify the + /// state of this other than via methods of `BodyWriter`. + http_protocol_output: *Writer, + state: State, + writer: Writer, + + pub const Error = Writer.Error; + + /// How many zeroes to reserve for hex-encoded chunk length. + const chunk_len_digits = 8; + const max_chunk_len: usize = std.math.pow(usize, 16, chunk_len_digits) - 1; + const chunk_header_template = ("0" ** chunk_len_digits) ++ "\r\n"; + + comptime { + assert(max_chunk_len == std.math.maxInt(u32)); + } + + pub const State = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked: Chunked, + /// Cleanly finished stream; connection can be reused. + end, + + pub const Chunked = union(enum) { + /// Index to the start of the hex-encoded chunk length in the chunk + /// header within the buffer of `BodyWriter.http_protocol_output`. + /// Buffered chunk data starts here plus length of `chunk_header_template`. + offset: usize, + /// We are in the middle of a chunk and this is how many bytes are + /// left until the next header. This includes +2 for "\r"\n", and + /// is zero for the beginning of the stream. + chunk_len: usize, + + pub const init: Chunked = .{ .chunk_len = 0 }; + }; + }; + + pub fn isEliding(w: *const BodyWriter) bool { + return w.writer.vtable.drain == Writer.discardingDrain; + } + + /// Sends all buffered data across `BodyWriter.http_protocol_output`. + pub fn flush(w: *BodyWriter) Error!void { + const out = w.http_protocol_output; + switch (w.state) { + .end, .none, .content_length => return out.flush(), + .chunked => |*chunked| switch (chunked.*) { + .offset => |offset| { + const chunk_len = out.end - offset - chunk_header_template.len; + if (chunk_len > 0) { + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + chunked.* = .{ .chunk_len = 2 }; + } else { + out.end = offset; + chunked.* = .{ .chunk_len = 0 }; + } + try out.flush(); + }, + .chunk_len => return out.flush(), + }, + } + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then flushes. + /// + /// When using transfer-encoding: chunked, writes the end-of-stream message + /// with empty trailers, then flushes the stream to the system. Asserts any + /// started chunk has been completely finished. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endUnflushed` + /// * `endChunked` + pub fn end(w: *BodyWriter) Error!void { + try endUnflushed(w); + try w.http_protocol_output.flush(); + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header. + /// + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message with empty trailers. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `end` + /// * `endChunked` + pub fn endUnflushed(w: *BodyWriter) Error!void { + switch (w.state) { + .end => unreachable, + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + w.state = .end; + }, + .none => {}, + .chunked => return endChunkedUnflushed(w, .{}), + } + } + + pub const EndChunkedOptions = struct { + trailers: []const Header = &.{}, + }; + + /// Writes the end-of-stream message and any optional trailers, flushing + /// the underlying stream. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunkedUnflushed` + /// * `end` + pub fn endChunked(w: *BodyWriter, options: EndChunkedOptions) Error!void { + try endChunkedUnflushed(w, options); + try w.http_protocol_output.flush(); + } + + /// Writes the end-of-stream message and any optional trailers. + /// + /// Does not flush. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunked` + /// * `endUnflushed` + /// * `end` + pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void { + const chunked = &w.state.chunked; + if (w.isEliding()) { + w.state = .end; + return; + } + const bw = w.http_protocol_output; + switch (chunked.*) { + .offset => |offset| { + const chunk_len = bw.end - offset - chunk_header_template.len; + writeHex(bw.buffer[offset..][0..chunk_len_digits], chunk_len); + try bw.writeAll("\r\n"); + }, + .chunk_len => |chunk_len| switch (chunk_len) { + 0 => {}, + 1 => try bw.writeByte('\n'), + 2 => try bw.writeAll("\r\n"), + else => unreachable, // An earlier write call indicated more data would follow. + }, + } + try bw.writeAll("0\r\n"); + for (options.trailers) |trailer| { + try bw.writeAll(trailer.name); + try bw.writeAll(": "); + try bw.writeAll(trailer.value); + try bw.writeAll("\r\n"); + } + try bw.writeAll("\r\n"); + w.state = .end; + } + + pub fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + return w.consume(n); + } + + /// Returns `null` if size cannot be computed without making any syscalls. + pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + return w.consume(n); + } + + pub fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const data_len = Writer.countSendFileLowerBound(w.end, file_reader, limit) orelse { + // If the file size is unknown, we cannot lower to a `sendFile` since we would + // have to flush the chunk header before knowing the chunk length. + return error.Unimplemented; + }; + const out = bw.http_protocol_output; + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |off| { + // TODO: is it better perf to read small files into the buffer? + const buffered_len = out.end - off - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[off..][0..chunk_len_digits], chunk_len); + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const off = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = off }; + continue :state .{ .offset = off }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const new_limit = limit.min(.limited(chunk_len - 2)); + const n = try out.sendFileHeader(w.buffered(), file_reader, new_limit); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + pub fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const data_len = w.end + Writer.countSplat(data, splat); + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |offset| { + if (out.unusedCapacityLen() >= data_len) { + return w.consume(out.writeSplatHeader(w.buffered(), data, splat) catch unreachable); + } + const buffered_len = out.end - offset - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + const n = try out.writeSplatHeader(w.buffered(), data, splat); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const offset = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = offset }; + continue :state .{ .offset = offset }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const n = try out.writeSplatHeaderLimit(w.buffered(), data, splat, .limited(chunk_len - 2)); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + /// Writes an integer as base 16 to `buf`, right-aligned, assuming the + /// buffer has already been filled with zeroes. + fn writeHex(buf: []u8, x: usize) void { + assert(std.mem.allEqual(u8, buf, '0')); + const base = 16; + var index: usize = buf.len; + var a = x; + while (a > 0) { + const digit = a % base; + index -= 1; + buf[index] = std.fmt.digitToChar(@intCast(digit), .lower); + a /= base; + } + } +}; + test { + _ = Server; + _ = Status; + _ = Method; + _ = ChunkParser; + _ = HeadParser; + if (builtin.os.tag != .wasi) { _ = Client; - _ = Method; - _ = Server; - _ = Status; - _ = HeadParser; - _ = ChunkParser; - _ = WebSocket; _ = @import("http/test.zig"); } } diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig index adcdc74bc7..7c628ec327 100644 --- a/lib/std/http/ChunkParser.zig +++ b/lib/std/http/ChunkParser.zig @@ -1,5 +1,8 @@ //! Parser for transfer-encoding: chunked. +const ChunkParser = @This(); +const std = @import("std"); + state: State, chunk_len: u64, @@ -97,9 +100,6 @@ pub fn feed(p: *ChunkParser, bytes: []const u8) usize { return bytes.len; } -const ChunkParser = @This(); -const std = @import("std"); - test feed { const testing = std.testing; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 20f6018e45..61a9eeb5c3 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,9 +13,10 @@ const net = std.net; const Uri = std.Uri; const Allocator = mem.Allocator; const assert = std.debug.assert; +const Writer = std.io.Writer; +const Reader = std.io.Reader; const Client = @This(); -const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; @@ -24,6 +25,12 @@ allocator: Allocator, ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, +/// Used both for the reader and writer buffers. +tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +/// If non-null, ssl secrets are logged to a stream. Creating such a stream +/// allows other processes with access to that stream to decrypt all +/// traffic over connections created with this `Client`. +ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -31,6 +38,13 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, +/// Each `Connection` allocates this amount for the reader buffer. +/// +/// If the entire HTTP header cannot fit in this amount of bytes, +/// `error.HttpHeadersOversize` will be returned from `Request.wait`. +read_buffer_size: usize = 4096, +/// Each `Connection` allocates this amount for the writer buffer. +write_buffer_size: usize = 1024, /// If populated, all http traffic travels through this third party. /// This field cannot be modified while the client has active connections. @@ -41,7 +55,7 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, -/// A set of linked lists of connections that can be reused. +/// A Least-Recently-Used cache of open connections to be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, /// Open connections that are currently in use. @@ -55,11 +69,13 @@ pub const ConnectionPool = struct { pub const Criteria = struct { host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, }; - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// Finds and acquires a connection from the connection pool matching the criteria. /// If no connection is found, null is returned. + /// + /// Threadsafe. pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -71,7 +87,7 @@ pub const ConnectionPool = struct { if (connection.port != criteria.port) continue; // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(connection.host, criteria.host)) continue; + if (!std.ascii.eqlIgnoreCase(connection.host(), criteria.host)) continue; pool.acquireUnsafe(connection); return connection; @@ -96,28 +112,25 @@ pub const ConnectionPool = struct { return pool.acquireUnsafe(connection); } - /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// Tries to release a connection back to the connection pool. /// If the connection is marked as closing, it will be closed instead. /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + /// `allocator` must be the same one used to create `connection`. + /// + /// Threadsafe. + pub fn release(pool: *ConnectionPool, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(&connection.pool_node); - if (connection.closing or pool.free_size == 0) { - connection.close(allocator); - return allocator.destroy(connection); - } + if (connection.closing or pool.free_size == 0) return connection.destroy(); if (pool.free_len >= pool.free_size) { const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); pool.free_len -= 1; - popped.close(allocator); - allocator.destroy(popped); + popped.destroy(); } if (connection.proxied) { @@ -138,9 +151,11 @@ pub const ConnectionPool = struct { pool.used.append(&connection.pool_node); } - /// Resizes the connection pool. This function is threadsafe. + /// Resizes the connection pool. /// /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + /// + /// Threadsafe. pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -158,538 +173,586 @@ pub const ConnectionPool = struct { pool.free_size = new_size; } - /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// Frees the connection pool and closes all connections within. /// /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + /// + /// Threadsafe. + pub fn deinit(pool: *ConnectionPool) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { const connection: *Connection = @fieldParentPtr("pool_node", node); next = node.next; - connection.close(allocator); - allocator.destroy(connection); + connection.destroy(); } next = pool.used.first; while (next) |node| { const connection: *Connection = @fieldParentPtr("pool_node", node); next = node.next; - connection.close(allocator); - allocator.destroy(node); + connection.destroy(); } pool.* = undefined; } }; -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, +pub const Protocol = enum { + plain, + tls, + fn port(protocol: Protocol) u16 { + return switch (protocol) { + .plain => 80, + .tls => 443, + }; + } + + pub fn fromScheme(scheme: []const u8) ?Protocol { + const protocol_map = std.StaticStringMap(Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, + }); + return protocol_map.get(scheme); + } + + pub fn fromUri(uri: Uri) ?Protocol { + return fromScheme(uri.scheme); + } +}; + +pub const Connection = struct { + client: *Client, + stream_writer: net.Stream.Writer, + stream_reader: net.Stream.Reader, /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. pool_node: std.DoublyLinkedList.Node, - - /// The protocol that this connection is using. + port: u16, + host_len: u8, + proxied: bool, + closing: bool, protocol: Protocol, - /// The host that this connection is connected to. - host: []u8, + const Plain = struct { + connection: Connection, - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } - - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{OutOfMemory}!*Plain { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len]; + const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size]; + const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const plain: *Plain = @ptrCast(base); + plain.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(socket_read_buffer), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .plain, + }, + }; + return plain; } - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); - } - - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } - - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } - - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); - - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; - - return available_read; + fn destroy(plain: *Plain) void { + const c = &plain.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size; } - return nread; - } - - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, + fn host(plain: *Plain) []u8 { + const base: [*]u8 = @ptrCast(plain); + return base[@sizeOf(Plain)..][0..plain.connection.host_len]; + } }; - pub const Reader = std.io.GenericReader(*Connection, ReadError, read); + const Tls = struct { + client: std.crypto.tls.Client, + connection: Connection, - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{ OutOfMemory, TlsInitializationFailed }!*Tls { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; + const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; + const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; + const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const tls: *Tls = @ptrCast(base); + tls.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(&.{}), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .tls, + }, + // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true + .client = std.crypto.tls.Client.init( + tls.connection.stream_reader.interface(), + &tls.connection.stream_writer.interface, + .{ + .host = .{ .explicit = remote_host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log = client.ssl_key_log, + .read_buffer = tls_read_buffer, + .write_buffer = tls_write_buffer, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + }, + ) catch return error.TlsInitializationFailed, + }; + return tls; } - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } + fn destroy(tls: *Tls) void { + const c = &tls.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size; + } - return buffer.len; - } - - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; - } - - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; - } - - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, + fn host(tls: *Tls) []u8 { + const base: [*]u8 = @ptrCast(tls); + return base[@sizeOf(Tls)..][0..tls.connection.host_len]; + } }; - pub const Writer = std.io.GenericWriter(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; + fn getStream(c: *Connection) net.Stream { + return c.stream_reader.getStream(); } - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; + fn host(c: *Connection) []u8 { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return tls.host(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + return plain.host(); + }, + }; + } - // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); - allocator.destroy(conn.tls_client); + /// If this is called without calling `flush` or `end`, data will be + /// dropped unsent. + pub fn destroy(c: *Connection) void { + c.getStream().close(); + switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + tls.destroy(); + }, + .plain => { + const plain: *Plain = @fieldParentPtr("connection", c); + plain.destroy(); + }, } + } - conn.stream.close(); - allocator.free(conn.host); + /// HTTP protocol from client to server. + /// This either goes directly to `stream_writer`, or to a TLS client. + pub fn writer(c: *Connection) *Writer { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.writer; + }, + .plain => &c.stream_writer.interface, + }; + } + + /// HTTP protocol from server to client. + /// This either comes directly from `stream_reader`, or from a TLS client. + pub fn reader(c: *Connection) *Reader { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + return &tls.client.reader; + }, + .plain => c.stream_reader.interface(), + }; + } + + pub fn flush(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.client.writer.flush(); + } + try c.stream_writer.interface.flush(); + } + + /// If the connection is a TLS connection, sends the close_notify alert. + /// + /// Flushes all buffers. + pub fn end(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @fieldParentPtr("connection", c); + try tls.client.end(); + try tls.client.writer.flush(); + } + try c.stream_writer.interface.flush(); } }; -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - //deflate: std.compress.flate.Decompress, - //gzip: std.compress.flate.Decompress, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, + request: *Request, + /// Pointers in this struct are invalidated with the next call to + /// `receiveHead`. + head: Head, - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, + pub const Head = struct { + bytes: []const u8, + version: http.Version, + status: http.Status, + reason: []const u8, + location: ?[]const u8 = null, + content_type: ?[]const u8 = null, + content_disposition: ?[]const u8 = null, - keep_alive: bool, + keep_alive: bool, - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, + transfer_encoding: http.TransferEncoding = .none, + content_encoding: http.ContentEncoding = .identity, - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, - - parser: proto.HeadersParser, - compression: Compression = .none, - - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, - - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; - - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); - - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimStart(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, + pub const ParseError = error{ + HttpConnectionHeaderUnsupported, + HttpContentEncodingUnsupported, + HttpHeaderContinuationsUnsupported, + HttpHeadersInvalid, + HttpTransferEncodingUnsupported, + InvalidContentLength, }; - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, + pub fn parse(bytes: []const u8) ParseError!Head { + var res: Head = .{ + .bytes = bytes, + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + }; + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 12) { + return error.HttpHeadersInvalid; } - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); + while (it.next()) |line| { + if (line.len == 0) return res; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } + + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); + + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.content_encoding = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (http.ContentEncoding.fromString(trimmed)) |ce| { + res.content_encoding = ce; } else { - return error.HttpTransferEncodingUnsupported; + return error.HttpContentEncodingUnsupported; } } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } } + return error.HttpHeadersInvalid; // missing empty line } - return error.HttpHeadersInvalid; // missing empty line + + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", head.version); + try testing.expectEqualStrings("OK", head.reason); + try testing.expectEqual(.ok, head.status); + + try testing.expectEqualStrings("url", head.location.?); + try testing.expectEqualStrings("text/plain", head.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", head.content_disposition.?); + + try testing.expectEqual(true, head.keep_alive); + try testing.expectEqual(10, head.content_length.?); + try testing.expectEqual(.chunked, head.transfer_encoding); + try testing.expectEqual(.deflate, head.content_encoding); + } + + pub fn iterateHeaders(h: Head) http.HeaderIterator { + return .init(h.bytes); + } + + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + var it = head.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } + + fn parseInt3(text: *const [3]u8) u10 { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, (nnn -% zero) *% mmm); + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } + }; + + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `readerDecompressing` + pub fn reader(response: *Response, buffer: []u8) *Reader { + const req = response.request; + if (!req.method.responseHasBody()) return .ending; + const head = &response.head; + return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length); } - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `reader` + pub fn readerDecompressing( + response: *Response, + decompressor: *http.Decompressor, + decompression_buffer: []u8, + ) *Reader { + const head = &response.head; + return response.request.reader.bodyReaderDecompressing( + head.transfer_encoding, + head.content_length, + head.content_encoding, + decompressor, + decompression_buffer, + ); + } - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), + /// After receiving `error.ReadFailed` from the `Reader` returned by + /// `reader` or `readerDecompressing`, this function accesses the + /// more specific error code. + pub fn bodyErr(response: *const Response) ?http.Reader.BodyError { + return response.request.reader.body_err; + } + + pub fn iterateTrailers(response: *const Response) http.HeaderIterator { + const r = &response.request.reader; + assert(r.state == .ready); + return .{ + .bytes = r.trailers, + .index = 0, + .is_trailer = true, }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - try res.parse(response_bytes); - - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); - - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); - - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } - - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } - - fn parseInt3(text: *const [3]u8) u10 { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, (nnn -% zero) *% mmm); - } - - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); - } - - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return .init(r.parser.get()); - } - - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), - }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); } }; -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read pub const Request = struct { + /// This field is provided so that clients can observe redirected URIs. + /// + /// Its backing memory is externally provided by API users when creating a + /// request, and then again provided externally via `redirect_buffer` to + /// `receiveHead`. uri: Uri, client: *Client, /// This is null when the connection is released. connection: ?*Connection, + reader: http.Reader, keep_alive: bool, method: http.Method, version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer, + transfer_encoding: TransferEncoding, redirect_behavior: RedirectBehavior, + accept_encoding: @TypeOf(default_accept_encoding) = default_accept_encoding, /// Whether the request should handle a 100-continue response before sending the request body. handle_continue: bool, - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response, - /// Standard headers that have default, but overridable, behavior. headers: Headers, @@ -703,6 +766,20 @@ pub const Request = struct { /// Externally-owned; must outlive the Request. privileged_headers: []const http.Header, + pub const default_accept_encoding: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = b: { + var result: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = @splat(false); + result[@intFromEnum(http.ContentEncoding.gzip)] = true; + result[@intFromEnum(http.ContentEncoding.deflate)] = true; + result[@intFromEnum(http.ContentEncoding.identity)] = true; + break :b result; + }; + + pub const TransferEncoding = union(enum) { + content_length: u64, + chunked: void, + none: void, + }; + pub const Headers = struct { host: Value = .default, authorization: Value = .default, @@ -742,98 +819,102 @@ pub const Request = struct { } }; - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); + /// Returns the request's `Connection` back to the pool of the `Client`. + pub fn deinit(r: *Request) void { + r.reader.restituteHeadBuffer(); + if (r.connection) |connection| { + connection.closing = connection.closing or switch (r.reader.state) { + .ready => false, + .received_head => r.method.requestHasBody(), + else => true, + }; + r.client.connection_pool.release(connection); } - req.* = undefined; + r.* = undefined; } - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - assert(req.response.parser.done); + /// Sends and flushes a complete request as only HTTP head, no body. + pub fn sendBodiless(r: *Request) Writer.Error!void { + try sendBodilessUnflushed(r); + try r.connection.?.flush(); + } - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; + /// Sends but does not flush a complete request as only HTTP head, no body. + pub fn sendBodilessUnflushed(r: *Request) Writer.Error!void { + assert(r.transfer_encoding == .none); + assert(!r.method.requestHasBody()); + try sendHead(r); + } - var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + /// Transfers the HTTP head over the connection and flushes. + /// + /// See also: + /// * `sendBodyUnflushed` + pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + const result = try sendBodyUnflushed(r, buffer); + try r.connection.?.flush(); + return result; + } - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } - - req.uri = valid_uri; - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, + /// Transfers the HTTP head over the connection, which is not flushed until + /// `BodyWriter.flush` or `BodyWriter.end` is called. + /// + /// See also: + /// * `sendBody` + pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + assert(r.method.requestHasBody()); + try sendHead(r); + const http_protocol_output = r.connection.?.writer(); + return switch (r.transfer_encoding) { + .chunked => .{ + .http_protocol_output = http_protocol_output, + .state = .{ .chunked = .init }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + }, + }, + .content_length => |len| .{ + .http_protocol_output = http_protocol_output, + .state = .{ .content_length = len }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + }, + }, + .none => .{ + .http_protocol_output = http_protocol_output, + .state = .none, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + }, + }, }; } - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + /// Sends HTTP headers without flushing. + fn sendHead(r: *Request) Writer.Error!void { + const uri = r.uri; + const connection = r.connection.?; + const w = connection.writer(); - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; - - const connection = req.connection.?; - var connection_writer_adapter = connection.writer().adaptToNewApi(); - const w = &connection_writer_adapter.new_interface; - sendAdapted(req, connection, w) catch |err| switch (err) { - error.WriteFailed => return connection_writer_adapter.err.?, - else => |e| return e, - }; - } - - fn sendAdapted(req: *Request, connection: *Connection, w: *std.io.Writer) !void { - try req.method.format(w); + try r.method.write(w); try w.writeByte(' '); - if (req.method == .CONNECT) { - try req.uri.writeToStream(w, .{ .authority = true }); + if (r.method == .CONNECT) { + try uri.writeToStream(.{ .authority = true }, w); } else { - try req.uri.writeToStream(w, .{ + try uri.writeToStream(.{ .scheme = connection.proxied, .authentication = connection.proxied, .authority = connection.proxied, @@ -842,58 +923,64 @@ pub const Request = struct { }); } try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); + try w.writeAll(@tagName(r.version)); try w.writeAll("\r\n"); - if (try emitOverridableHeader("host: ", req.headers.host, w)) { + if (try emitOverridableHeader("host: ", r.headers.host, w)) { try w.writeAll("host: "); - try req.uri.writeToStream(w, .{ .authority = true }); + try uri.writeToStream(.{ .authority = true }, w); try w.writeAll("\r\n"); } - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { + if (try emitOverridableHeader("authorization: ", r.headers.authorization, w)) { + if (uri.user != null or uri.password != null) { try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try basic_authorization.write(uri, w); try w.writeAll("\r\n"); } } - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + if (try emitOverridableHeader("user-agent: ", r.headers.user_agent, w)) { try w.writeAll("user-agent: zig/"); try w.writeAll(builtin.zig_version_string); try w.writeAll(" (std.http)\r\n"); } - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { + if (try emitOverridableHeader("connection: ", r.headers.connection, w)) { + if (r.keep_alive) { try w.writeAll("connection: keep-alive\r\n"); } else { try w.writeAll("connection: close\r\n"); } } - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); + if (try emitOverridableHeader("accept-encoding: ", r.headers.accept_encoding, w)) { + try w.writeAll("accept-encoding: "); + for (r.accept_encoding, 0..) |enabled, i| { + if (!enabled) continue; + const tag: http.ContentEncoding = @enumFromInt(i); + if (tag == .identity) continue; + const tag_name = @tagName(tag); + try w.ensureUnusedCapacity(tag_name.len + 2); + try w.writeAll(tag_name); + try w.writeAll(", "); + } + w.undo(2); + try w.writeAll("\r\n"); } - switch (req.transfer_encoding) { + switch (r.transfer_encoding) { .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), .none => {}, } - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + if (try emitOverridableHeader("content-type: ", r.headers.content_type, w)) { // The default is to omit content-type if not provided because // "application/octet-stream" is redundant. } - for (req.extra_headers) |header| { + for (r.extra_headers) |header| { assert(header.name.len != 0); try w.writeAll(header.name); @@ -904,8 +991,8 @@ pub const Request = struct { if (connection.proxied) proxy: { const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, + .plain => r.client.http_proxy, + .tls => r.client.https_proxy, } orelse break :proxy; const authorization = proxy.authorization orelse break :proxy; @@ -915,282 +1002,198 @@ pub const Request = struct { } try w.writeAll("\r\n"); - - try connection.flush(); } - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } - } + pub const ReceiveHeadError = http.Reader.HeadError || ConnectError || error{ + /// Server sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// passed directly to `Request.Head.parse`. + HttpHeadersInvalid, + TooManyHttpRedirects, + /// This can be avoided by calling `receiveHead` before sending the + /// request body. + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationOversize, + HttpRedirectLocationInvalid, + HttpContentEncodingUnsupported, + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + UnsupportedUriScheme, - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + /// Sending the request failed. Error code can be found on the + /// `Connection` object. + WriteFailed, + }; - const TransferReader = std.io.GenericReader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } - - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; - - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. + /// function will automatically follow redirects. /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { + /// If a request payload is present, then this function will error with + /// `error.RedirectRequiresResend`. + /// + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. + /// + /// `redirect_buffer` must outlive accesses to `Request.uri`. If this + /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize` + /// is returned instead. This buffer may be empty if no redirects are to be + /// handled. + pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { + var aux_buf = redirect_buffer; while (true) { + try r.reader.receiveHead(); + const response: Response = .{ + .request = r, + .head = Response.Head.parse(r.reader.head_buffer) catch return error.HttpHeadersInvalid, + }; + const head = &response.head; + + if (head.status == .@"continue") { + if (r.handle_continue) continue; + return response; // we're not handling the 100-continue + } + // This while loop is for handling redirects, which means the request's // connection may be different than the previous iteration. However, it // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; + const connection = r.connection.?; - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); - - if (req.response.parser.state.isContent()) break; - } - - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } - - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { + if (r.method == .CONNECT and head.status.class() == .success) { + // This connection is no longer doing HTTP. connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point + return response; } - connection.closing = !req.response.keep_alive or !req.keep_alive; + connection.closing = !head.keep_alive or !r.keep_alive; // Any response to a HEAD request and any response with a 1xx // (Informational), 204 (No Content), or 304 (Not Modified) status // code is always terminated by the first empty line after the // header fields, regardless of the header fields present in the // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) + if (r.method == .HEAD or head.status.class() == .informational or + head.status == .no_content or head.status == .not_modified) { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. + return response; } - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, - } - - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .deflate => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .gzip, .@"x-gzip" => return error.CompressionUnsupported, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } + if (head.status.class() == .redirect and r.redirect_behavior != .unhandled) { + if (r.redirect_behavior == .not_allowed) { + // Connection can still be reused by skipping the body. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => connection.closing = true, + }; + return error.TooManyHttpRedirects; } - - break; + try r.redirect(head, &aux_buf); + try r.sendBodiless(); + continue; } + + if (!r.accept_encoding[@intFromEnum(head.content_encoding)]) + return error.HttpContentEncodingUnsupported; + + return response; } } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.GenericReader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - // I'm about to upstream my http client rewrite - //.deflate => |*deflate| deflate.readSlice(buffer) catch return error.DecompressionFailure, - //.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. + /// + /// `aux_buf` must outlive accesses to `Request.uri`. + fn redirect(r: *Request, head: *const Response.Head, aux_buf: *[]u8) !void { + const new_location = head.location orelse return error.HttpRedirectLocationMissing; + if (new_location.len > aux_buf.*.len) return error.HttpRedirectLocationOversize; + const location = aux_buf.*[0..new_location.len]; + @memcpy(location, new_location); + { + // Skip the body of the redirect response to leave the connection in + // the correct state. This causes `new_location` to be invalidated. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return r.reader.body_err.?, + }; + r.reader.restituteHeadBuffer(); + } + const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) { + error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid, + error.InvalidFormat => return error.HttpRedirectLocationInvalid, + error.InvalidPort => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpRedirectLocationOversize, }; - if (out_index > 0) return out_index; - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); + const protocol = Protocol.fromUri(new_uri) orelse return error.UnsupportedUriScheme; + const old_connection = r.connection.?; + const old_host = old_connection.host(); + var new_host_name_buffer: [Uri.host_name_max]u8 = undefined; + const new_host = try new_uri.getHost(&new_host_name_buffer); + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(r.uri.scheme, new_uri.scheme) and + sameParentDomain(old_host, new_host); - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); + r.client.connection_pool.release(old_connection); + r.connection = null; + + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + r.privileged_headers = &.{}; } - return 0; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; + if (switch (head.status) { + .see_other => true, + .moved_permanently, .found => r.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + r.method = .GET; + r.transfer_encoding = .none; + r.headers.content_type = .omit; } - return index; + + if (r.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; + } + + const new_connection = try r.client.connect(new_host, uriPort(new_uri, protocol), protocol); + r.uri = new_uri; + r.connection = new_connection; + r.reader = .{ + .in = new_connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }; + r.redirect_behavior.subtractOne(); } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.GenericWriter(*Request, WriteError, write); - - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } - - return bytes.len; + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, bw: *Writer) Writer.Error!bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + var vecs: [3][]const u8 = .{ prefix, x, "\r\n" }; + try bw.writeVecAll(&vecs); + return false; }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, } } - - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - - pub const FinishError = WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } - - try req.connection.?.flush(); - } }; pub const Proxy = struct { - protocol: Connection.Protocol, + protocol: Protocol, host: []const u8, authorization: ?[]const u8, port: u16, @@ -1204,10 +1207,8 @@ pub const Proxy = struct { pub fn deinit(client: *Client) void { assert(client.connection_pool.used.first == null); // There are still active requests. - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); + client.connection_pool.deinit(); + if (!disable_tls) client.ca_bundle.deinit(client.allocator); client.* = undefined; } @@ -1249,24 +1250,21 @@ fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !? } else return null; const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; + const protocol = Protocol.fromUri(uri) orelse return null; + const raw_host = try uri.getHostAlloc(arena); - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri)); + assert(basic_authorization.value(uri, authorization).len == authorization.len); break :a authorization; } else null; const proxy = try arena.create(Proxy); proxy.* = .{ .protocol = protocol, - .host = valid_uri.host.?.raw, + .host = raw_host, .authorization = authorization, - .port = uriPort(valid_uri, protocol), + .port = uriPort(uri, protocol), .supports_connect = true, }; return proxy; @@ -1277,10 +1275,8 @@ pub const basic_authorization = struct { pub const max_password_len = 255; pub const max_value_len = valueLength(max_user_len, max_password_len); - const prefix = "Basic "; - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + return "Basic ".len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); } pub fn valueLengthFromUri(uri: Uri) usize { @@ -1300,37 +1296,69 @@ pub const basic_authorization = struct { } pub fn value(uri: Uri, out: []u8) []u8 { - const user: Uri.Component = uri.user orelse .empty; - const password: Uri.Component = uri.password orelse .empty; + var bw: Writer = .fixed(out); + write(uri, &bw) catch unreachable; + return bw.getWritten(); + } + pub fn write(uri: Uri, out: *Writer) Writer.Error!void { var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var w: std.io.Writer = .fixed(&buf); - user.formatUser(&w) catch unreachable; // fixed - password.formatPassword(&w) catch unreachable; // fixed - - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], w.buffered()); - return out[0 .. prefix.len + base64.len]; + var w: Writer = .fixed(&buf); + w.print("{fuser}:{fpassword}", .{ + uri.user orelse Uri.Component.empty, + uri.password orelse Uri.Component.empty, + }) catch unreachable; + try out.print("Basic {b64}", .{w.buffered()}); } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, + TlsInitializationFailed, +}; -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// Reuses a `Connection` if one matching `host` and `port` is already open. /// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { +/// Threadsafe. +pub fn connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Protocol, +) ConnectTcpError!*Connection { + return connectTcpOptions(client, .{ .host = host, .port = port, .protocol = protocol }); +} + +pub const ConnectTcpOptions = struct { + host: []const u8, + port: u16, + protocol: Protocol, + + proxied_host: ?[]const u8 = null, + proxied_port: ?u16 = null, +}; + +pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcpError!*Connection { + const host = options.host; + const port = options.port; + const protocol = options.protocol; + + const proxied_host = options.proxied_host orelse host; + const proxied_port = options.proxied_port orelse port; + if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, + .host = proxied_host, + .port = proxied_port, .protocol = protocol, - })) |node| return node; - - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - - const conn = try client.allocator.create(Connection); - errdefer client.allocator.destroy(conn); + })) |conn| return conn; const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, @@ -1345,53 +1373,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer stream.close(); - conn.* = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - - .pool_node = .{}, - }; - errdefer client.allocator.free(conn.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.tls_client); - - const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { - const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { - error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, - error.OutOfMemory => return error.OutOfMemory, - }; - defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ - .truncate = false, - .mode = switch (builtin.os.tag) { - .windows, .wasi => 0, - else => 0o600, - }, - }) catch null; - } else null; - errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); - - conn.tls_client.* = std.crypto.tls.Client.init(stream, .{ - .host = .{ .explicit = host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log_file = ssl_key_log_file, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + switch (protocol) { + .tls => { + if (disable_tls) return error.TlsInitializationFailed; + const tc = try Connection.Tls.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&tc.connection); + return &tc.connection; + }, + .plain => { + const pc = try Connection.Plain.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&pc.connection); + return &pc.connection; + }, } - - client.connection_pool.addUsed(conn); - - return conn; } pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; @@ -1429,69 +1423,67 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return &conn.data; } -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// Connect to `proxied_host:proxied_port` using the specified proxy with HTTP /// CONNECT. This will reuse a connection if one is already open. /// /// This function is threadsafe. -pub fn connectTunnel( +pub fn connectProxied( client: *Client, proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, + proxied_host: []const u8, + proxied_port: u16, ) !*Connection { if (!proxy.supports_connect) return error.TunnelNotSupported; if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, + .host = proxied_host, + .port = proxied_port, .protocol = proxy.protocol, - })) |node| - return node; + })) |node| return node; var maybe_valid = false; (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + const connection = try client.connectTcpOptions(.{ + .host = proxy.host, + .port = proxy.port, + .protocol = proxy.protocol, + .proxied_host = proxied_host, + .proxied_port = proxied_port, + }); errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); + connection.closing = true; + client.connection_pool.release(connection); } - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ + var req = client.request(.CONNECT, .{ .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, + .host = .{ .raw = proxied_host }, + .port = proxied_port, }, .{ .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, + .connection = connection, }) catch |err| { - std.log.debug("err {}", .{err}); break :tunnel err; }; defer req.deinit(); - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; + req.sendBodiless() catch |err| break :tunnel err; + const response = req.receiveHead(&.{}) catch |err| break :tunnel err; - if (req.response.status.class() == .server_error) { + if (response.head.status.class() == .server_error) { maybe_valid = true; break :tunnel error.ServerError; } - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + if (response.head.status != .ok) break :tunnel error.ConnectionRefused; - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + // this connection is now a tunnel, so we can't use it for anything + // else, it will only be released when the client is de-initialized. req.connection = null; - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); + connection.closing = false; - conn.port = tunnel_port; - conn.closing = false; - - return conn; + return connection; }) catch { // something went wrong with the tunnel proxy.supports_connect = maybe_valid; @@ -1499,12 +1491,11 @@ pub fn connectTunnel( }; } -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; +pub const ConnectError = ConnectTcpError || RequestError; /// Connect to `host:port` using the specified protocol. This will reuse a /// connection if one is already open. +/// /// If a proxy is configured for the client, then the proxy will be used to /// connect to the host. /// @@ -1513,7 +1504,7 @@ pub fn connect( client: *Client, host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, ) ConnectError!*Connection { const proxy = switch (protocol) { .plain => client.http_proxy, @@ -1528,32 +1519,24 @@ pub fn connect( } if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + return connectProxied(client, proxy, host, port) catch |err| switch (err) { error.TunnelNotSupported => break :tunnel, else => |e| return e, }; } // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; + const connection = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + connection.proxied = true; + return connection; } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, - }; +pub const RequestError = ConnectTcpError || error{ + UnsupportedUriScheme, + UriMissingHost, + UriHostTooLong, + CertificateBundleLoadFailure, +}; pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", @@ -1578,11 +1561,6 @@ pub const RequestOptions = struct { /// payload or the server has acknowledged the payload). redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - /// Must be an already acquired connection. connection: ?*Connection = null, @@ -1598,38 +1576,17 @@ pub const RequestOptions = struct { privileged_headers: []const http.Header = &.{}, }; -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, - }); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; +fn uriPort(uri: Uri, protocol: Protocol) u16 { + return uri.port orelse protocol.port(); } /// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// -/// `uri` must remain alive during the entire request. -/// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. /// /// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( +pub fn request( client: *Client, method: http.Method, uri: Uri, @@ -1649,59 +1606,58 @@ pub fn open( } } - var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + const protocol = Protocol.fromUri(uri) orelse return error.UnsupportedUriScheme; - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (protocol == .tls) { if (disable_tls) unreachable; + if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } } } - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + const connection = options.connection orelse c: { + var host_name_buffer: [Uri.host_name_max]u8 = undefined; + const host_name = try uri.getHost(&host_name_buffer); + break :c try client.connect(host_name, uriPort(uri, protocol), protocol); + }; - var req: Request = .{ - .uri = valid_uri, + return .{ + .uri = uri, .client = client, - .connection = conn, + .connection = connection, + .reader = .{ + .in = connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, .keep_alive = options.keep_alive, .method = method, .version = options.version, .transfer_encoding = .none, .redirect_behavior = options.redirect_behavior, .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = .init(server_header.buffer[server_header.end_index..]), - }, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, }; - errdefer req.deinit(); - - return req; } pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + redirect_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + decompress_buffer: ?[]u8 = null, redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, + /// If the server sends a body, it will be stored here. + response_storage: ?ResponseStorage = null, location: Location, method: ?http.Method = null, @@ -1725,11 +1681,11 @@ pub const FetchOptions = struct { uri: Uri, }; - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), + pub const ResponseStorage = struct { + list: *std.ArrayListUnmanaged(u8), + /// If null then only the existing capacity will be used. + allocator: ?Allocator = null, + append_limit: std.io.Limit = .unlimited, }; }; @@ -1737,23 +1693,28 @@ pub const FetchResult = struct { status: http.Status, }; +pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadError || error{ + StreamTooLong, + /// TODO provide optional diagnostics when this occurs or break into more error codes + WriteFailed, +}; + /// Perform a one-shot HTTP request with the provided options. /// /// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { +pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const uri = switch (options.location) { .url => |u| try Uri.parse(u), .uri => |u| u, }; - var server_header_buffer: [16 * 1024]u8 = undefined; - const method: http.Method = options.method orelse if (options.payload != null) .POST else .GET; - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, + const redirect_behavior: Request.RedirectBehavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled; + + var req = try request(client, method, uri, .{ + .redirect_behavior = redirect_behavior, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, @@ -1761,44 +1722,69 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { }); defer req.deinit(); - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; - - try req.send(); - - if (options.payload) |payload| try req.writeAll(payload); - - try req.finish(); - try req.wait(); - - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, + if (options.payload) |payload| { + req.transfer_encoding = .{ .content_length = payload.len }; + var body = try req.sendBody(&.{}); + try body.writer.writeAll(payload); + try body.end(); + } else { + try req.sendBodiless(); } - return .{ - .status = req.response.status, + const redirect_buffer: []u8 = if (redirect_behavior == .unhandled) &.{} else options.redirect_buffer orelse + try client.allocator.alloc(u8, 8 * 1024); + defer if (options.redirect_buffer == null) client.allocator.free(redirect_buffer); + + var response = try req.receiveHead(redirect_buffer); + + const storage = options.response_storage orelse { + const reader = response.reader(&.{}); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; + return .{ .status = response.head.status }; }; + + const decompress_buffer: []u8 = switch (response.head.content_encoding) { + .identity => &.{}, + .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), + else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), + }; + defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); + + var decompressor: http.Decompressor = undefined; + const reader = response.readerDecompressing(&decompressor, decompress_buffer); + const list = storage.list; + + if (storage.allocator) |allocator| { + reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + else => |e| return e, + }; + } else { + const buf = storage.append_limit.slice(list.unusedCapacitySlice()); + list.items.len += reader.readSliceShort(buf) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; + } + + return .{ .status = response.head.status }; +} + +pub fn sameParentDomain(parent_host: []const u8, child_host: []const u8) bool { + if (!std.ascii.endsWithIgnoreCase(child_host, parent_host)) return false; + if (child_host.len == parent_host.len) return true; + if (parent_host.len > child_host.len) return false; + return child_host[child_host.len - parent_host.len - 1] == '.'; +} + +test sameParentDomain { + try testing.expect(!sameParentDomain("foo.com", "bar.com")); + try testing.expect(sameParentDomain("foo.com", "foo.com")); + try testing.expect(sameParentDomain("foo.com", "bar.foo.com")); + try testing.expect(!sameParentDomain("bar.foo.com", "foo.com")); } test { _ = Response; - _ = &initDefaultProxies; } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 7ec5d5c11f..004741d1ae 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,139 +1,70 @@ -//! Blocking HTTP server implementation. -//! Handles a single connection's lifecycle. +//! Handles a single connection lifecycle. -connection: net.Server.Connection, -/// Keeps track of whether the Server is ready to accept a new request on the -/// same connection, and makes invalid API usage cause assertion failures -/// rather than HTTP protocol violations. -state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; +const Writer = std.io.Writer; -pub const State = enum { - /// The connection is available to be used for the first time, or reused. - ready, - /// An error occurred in `receiveHead`. - receiving_head, - /// A Request object has been obtained and from there a Response can be - /// opened. - received_head, - /// The client is uploading something to this Server. - receiving_body, - /// The connection is eligible for another HTTP request, however the client - /// and server did not negotiate a persistent connection. - closing, -}; +const Server = @This(); + +/// Data from the HTTP server to the HTTP client. +out: *Writer, +reader: http.Reader, /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. +/// +/// The buffer of `in` must be large enough to store the client's entire HTTP +/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. +/// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { +pub fn init(in: *std.io.Reader, out: *Writer) Server { return .{ - .connection = connection, - .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, + .reader = .{ + .in = in, + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, + .out = out, }; } -pub const ReceiveHeadError = error{ - /// Client sent too many bytes of HTTP headers. - /// The HTTP specification suggests to respond with a 431 status code - /// before closing the connection. - HttpHeadersOversize, - /// Client sent headers that did not conform to the HTTP protocol. - HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, - /// Partial HTTP request was received but the connection was closed before - /// fully receiving the headers. - HttpRequestTruncated, - /// The client sent 0 bytes of headers before closing the stream. - /// In other words, a keep-alive connection was finally closed. - HttpConnectionClosing, -}; - -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. -pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - assert(s.state == .ready); - s.state = .received_head; - errdefer s.state = .receiving_head; - - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - - var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } - - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); - } +pub fn deinit(s: *Server) void { + s.reader.restituteHeadBuffer(); } -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { +pub const ReceiveHeadError = http.Reader.HeadError || error{ + /// Client sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// passed directly to `Request.Head.parse`. + HttpHeadersInvalid, +}; + +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + try s.reader.receiveHead(); return .{ .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, + // No need to track the returned error here since users can repeat the + // parse with the header buffer to get detailed diagnostics. + .head = Request.Head.parse(s.reader.head_buffer) catch return error.HttpHeadersInvalid, }; } pub const Request = struct { server: *Server, - /// Index into Server's read_buffer. - head_end: usize, + /// Pointers in this struct are invalidated with the next call to + /// `receiveHead`. head: Head, - reader_state: union { - remaining_content_length: u64, - chunk_parser: http.ChunkParser, - }, + respond_err: ?RespondError = null, - pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); - - deflate: std.compress.flate.Decompress, - gzip: std.compress.flate.Decompress, - zstd: std.compress.zstd.Decompress, - none: void, + pub const RespondError = error{ + /// The request contained an `expect` header with an unrecognized value. + HttpExpectationFailed, }; pub const Head = struct { @@ -146,7 +77,6 @@ pub const Request = struct { transfer_encoding: http.TransferEncoding, transfer_compression: http.ContentEncoding, keep_alive: bool, - compression: Compression, pub const ParseError = error{ UnknownHttpMethod, @@ -200,7 +130,6 @@ pub const Request = struct { .@"HTTP/1.0" => false, .@"HTTP/1.1" => true, }, - .compression = .none, }; while (it.next()) |line| { @@ -230,7 +159,7 @@ pub const Request = struct { const trimmed = mem.trim(u8, header_value, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (http.ContentEncoding.fromString(trimmed)) |ce| { head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; @@ -255,7 +184,7 @@ pub const Request = struct { if (next) |second| { const trimmed_second = mem.trim(u8, second, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported head.transfer_compression = transfer; @@ -299,7 +228,8 @@ pub const Request = struct { }; pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + assert(r.server.reader.state == .received_head); + return http.HeaderIterator.init(r.server.reader.head_buffer); } test iterateHeaders { @@ -310,22 +240,19 @@ pub const Request = struct { "TRansfer-encoding:\tdeflate, chunked \r\n" ++ "connectioN:\t keep-alive \r\n\r\n"; - var read_buffer: [500]u8 = undefined; - @memcpy(read_buffer[0..request_bytes.len], request_bytes); - var server: Server = .{ - .connection = undefined, - .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, + .reader = .{ + .in = undefined, + .state = .received_head, + .head_buffer = @constCast(request_bytes), + .interface = undefined, + }, + .out = undefined, }; var request: Request = .{ .server = &server, - .head_end = request_bytes.len, .head = undefined, - .reader_state = undefined, }; var it = request.iterateHeaders(); @@ -384,16 +311,22 @@ pub const Request = struct { /// no error is surfaced. /// /// Asserts status is not `continue`. - /// Asserts there are at most 25 extra_headers. /// Asserts that "\r\n" does not occur in any header name or value. pub fn respond( request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { - const max_extra_headers = 25; + ) ExpectContinueError!void { + try respondUnflushed(request, content, options); + try request.server.out.flush(); + } + + pub fn respondUnflushed( + request: *Request, + content: []const u8, + options: RespondOptions, + ) ExpectContinueError!void { assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); if (std.debug.runtime_safety) { for (options.extra_headers) |header| { assert(header.name.len != 0); @@ -402,6 +335,7 @@ pub const Request = struct { assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); } } + try writeExpectContinue(request); const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and options.keep_alive; @@ -409,130 +343,42 @@ pub const Request = struct { const phrase = options.reason orelse options.status.phrase() orelse ""; - var first_buffer: [500]u8 = undefined; - var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); - if (request.head.expect != null) { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); - return; - } - h.fixedWriter().print("{s} {d} {s}\r\n", .{ + const out = request.server.out; + try out.print("{s} {d} {s}\r\n", .{ @tagName(options.version), @intFromEnum(options.status), phrase, - }) catch unreachable; + }); switch (options.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), } if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), } else { - h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; + try out.print("content-length: {d}\r\n", .{content.len}); } - var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; - iovecs_len += 1; - for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + var vecs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&vecs); } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + try out.writeAll("\r\n"); if (request.head.method != .HEAD) { const is_chunked = (options.transfer_encoding orelse .none) == .chunked; if (is_chunked) { - if (content.len > 0) { - const chunk_header = std.fmt.bufPrint( - &chunk_header_buffer, - "{x}\r\n", - .{content.len}, - ) catch unreachable; - - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; - iovecs_len += 1; + if (content.len > 0) try out.print("{x}\r\n{s}\r\n", .{ content.len, content }); + try out.writeAll("0\r\n\r\n"); } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; + try out.writeAll(content); } } - - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { - /// An externally managed slice of memory used to batch bytes before - /// sending. `respondStreaming` asserts this is large enough to store - /// the full HTTP response head. - /// - /// Must outlive the returned Response. - send_buffer: []u8, /// If provided, the response will use the content-length header; /// otherwise it will use transfer-encoding: chunked. content_length: ?u64 = null, @@ -540,254 +386,221 @@ pub const Request = struct { respond_options: RespondOptions = .{}, }; - /// The header is buffered but not sent until Response.flush is called. + /// The header is not guaranteed to be sent until `BodyWriter.flush` or + /// `BodyWriter.end` is called. /// /// If the request contains a body and the connection is to be reused, /// discards the request body, leaving the Server in the `ready` state. If /// this discarding fails, the connection is marked as not to be reused and /// no error is surfaced. /// - /// HEAD requests are handled transparently by setting a flag on the - /// returned Response to omit the body. However it may be worth noticing + /// HEAD requests are handled transparently by setting the + /// `BodyWriter.elide` flag on the returned `BodyWriter`, causing + /// the response stream to omit the body. However, it may be worth noticing /// that flag and skipping any expensive work that would otherwise need to /// be done to satisfy the request. /// - /// Asserts `send_buffer` is large enough to store the entire response header. /// Asserts status is not `continue`. - pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + pub fn respondStreaming( + request: *Request, + buffer: []u8, + options: RespondStreamingOptions, + ) ExpectContinueError!http.BodyWriter { + try writeExpectContinue(request); const o = options.respond_options; assert(o.status != .@"continue"); const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and o.keep_alive; const keep_alive = request.discardBody(server_keep_alive); const phrase = o.reason orelse o.status.phrase() orelse ""; + const out = request.server.out; - var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + try out.print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }); - const elide_body = if (request.head.expect != null) eb: { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - break :eb true; - } else eb: { - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; + switch (o.version) { + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), + } - switch (o.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + try out.print("content-length: {d}\r\n", .{len}); + } else { + try out.writeAll("transfer-encoding: chunked\r\n"); + } - if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - .none => {}, - } else if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } + for (o.extra_headers) |header| { + assert(header.name.len != 0); + try out.writeAll(header.name); + try out.writeAll(": "); + try out.writeAll(header.value); + try out.writeAll("\r\n"); + } - for (o.extra_headers) |header| { - assert(header.name.len != 0); - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); - h.appendSliceAssumeCapacity("\r\n"); - } + try out.writeAll("\r\n"); + const elide_body = request.head.method == .HEAD; + const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) { + .chunked => .{ .chunked = .init }, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .{ .chunked = .init }; - h.appendSliceAssumeCapacity("\r\n"); - break :eb request.head.method == .HEAD; + return if (elide_body) .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .discarding(buffer), + } else .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .{ + .buffer = buffer, + .vtable = switch (state) { + .none => &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + .content_length => &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + .chunked => &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + .end => unreachable, + }, + }, }; + } + + pub const UpgradeRequest = union(enum) { + websocket: ?[]const u8, + other: []const u8, + none, + }; + + pub fn upgradeRequested(request: *const Request) UpgradeRequest { + switch (request.head.version) { + .@"HTTP/1.0" => return null, + .@"HTTP/1.1" => if (request.head.method != .GET) return null, + } + + var sec_websocket_key: ?[]const u8 = null; + var upgrade_name: ?[]const u8 = null; + var it = request.iterateHeaders(); + while (it.next()) |header| { + if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { + sec_websocket_key = header.value; + } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { + upgrade_name = header.value; + } + } + + const name = upgrade_name orelse return .none; + if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key }; + return .{ .other = name }; + } + + pub const WebSocketOptions = struct { + /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value). + key: []const u8, + reason: ?[]const u8 = null, + extra_headers: []const http.Header = &.{}, + }; + + /// The header is not guaranteed to be sent until `WebSocket.flush` is + /// called on the returned struct. + pub fn respondWebSocket(request: *Request, options: WebSocketOptions) Writer.Error!WebSocket { + if (request.head.expect != null) return error.HttpExpectationFailed; + + const out = request.server.out; + const version: http.Version = .@"HTTP/1.1"; + const status: http.Status = .switching_protocols; + const phrase = options.reason orelse status.phrase() orelse ""; + + assert(request.head.version == version); + assert(request.head.method == .GET); + + var sha1 = std.crypto.hash.Sha1.init(.{}); + sha1.update(options.key); + sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; + sha1.final(&digest); + try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase }); + try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: "); + const base64_digest = try out.writableArray(28); + assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); + out.advance(base64_digest.len); + try out.writeAll("\r\n"); + + for (options.extra_headers) |header| { + assert(header.name.len != 0); + try out.writeAll(header.name); + try out.writeAll(": "); + try out.writeAll(header.value); + try out.writeAll("\r\n"); + } + + try out.writeAll("\r\n"); return .{ - .stream = request.server.connection.stream, - .send_buffer = options.send_buffer, - .send_buffer_start = 0, - .send_buffer_end = h.items.len, - .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { - .chunked => .chunked, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .chunked, - .elide_body = elide_body, - .chunk_len = 0, + .input = request.server.reader.in, + .output = request.server.out, + .key = options.key, }; } - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, - }; - - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; - } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; - } - - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @constCast(@alignCast(@ptrCast(context))); - const s = request.server; - - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; - - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, - } - }, - } - } - return out_end; - } - - pub const ReaderError = Response.WriteError || error{ - /// The client sent an expect HTTP header value other than - /// "100-continue". - HttpExpectationFailed, - }; - /// In the case that the request contains "expect: 100-continue", this /// function writes the continuation header, which means it can fail with a /// write error. After sending the continuation header, it sets the /// request's expect field to `null`. /// /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { - const s = request.server; - assert(s.state == .received_head); - s.state = .receiving_body; - s.next_request_start = request.head_end; + /// + /// See `readerExpectNone` for an infallible alternative that cannot write + /// to the server output stream. + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.io.Reader { + const flush = request.head.expect != null; + try writeExpectContinue(request); + if (flush) try request.server.out.flush(); + return readerExpectNone(request, buffer); + } - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); - request.head.expect = null; - } else { - return error.HttpExpectationFailed; - } - } + /// Asserts the expect header is `null`. The caller must handle the + /// expectation manually and then set the value to `null` prior to calling + /// this function. + /// + /// Asserts that this function is only called once. + pub fn readerExpectNone(request: *Request, buffer: []u8) *std.io.Reader { + assert(request.server.reader.state == .received_head); + assert(request.head.expect == null); + if (!request.head.method.requestHasBody()) return .ending; + return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length); + } - switch (request.head.transfer_encoding) { - .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; - return .{ - .readFn = read_chunked, - .context = request, - }; - }, - .none => { - request.reader_state = .{ - .remaining_content_length = request.head.content_length orelse 0, - }; - return .{ - .readFn = read_cl, - .context = request, - }; - }, - } + pub const ExpectContinueError = error{ + /// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream. + WriteFailed, + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + pub fn writeExpectContinue(request: *Request) ExpectContinueError!void { + const expect = request.head.expect orelse return; + if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed; + try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; } /// Returns whether the connection should remain persistent. - /// If it would fail, it instead sets the Server state to `receiving_body` + /// + /// If it would fail, it instead sets the Server state to receiving body /// and returns false. fn discardBody(request: *Request, keep_alive: bool) bool { // Prepare to receive another request on the same connection. @@ -798,350 +611,175 @@ pub const Request = struct { // or the request body. // If the connection won't be kept alive, then none of this matters // because the connection will be severed after the response is sent. - const s = request.server; - if (keep_alive and request.head.keep_alive) switch (s.state) { + const r = &request.server.reader; + if (keep_alive and request.head.keep_alive) switch (r.state) { .received_head => { - const r = request.reader() catch return false; - _ = r.discard() catch return false; - assert(s.state == .ready); + if (request.head.method.requestHasBody()) { + assert(request.head.transfer_encoding != .none or request.head.content_length != null); + const reader_interface = request.readerExpectContinue(&.{}) catch return false; + _ = reader_interface.discardRemaining() catch return false; + assert(r.state == .ready); + } else { + r.state = .ready; + } return true; }, - .receiving_body, .ready => return true, + .body_remaining_content_length, .body_remaining_chunk_len, .body_none, .ready => return true, else => unreachable, }; // Avoid clobbering the state in case a reading stream already exists. - switch (s.state) { - .received_head => s.state = .closing, + switch (r.state) { + .received_head => r.state = .closing, else => {}, } return false; } }; -pub const Response = struct { - stream: net.Stream, - send_buffer: []u8, - /// Index of the first byte in `send_buffer`. - /// This is 0 unless a short write happens in `write`. - send_buffer_start: usize, - /// Index of the last byte + 1 in `send_buffer`. - send_buffer_end: usize, - /// `null` means transfer-encoding: chunked. - /// As a debugging utility, counts down to zero as bytes are written. - transfer_encoding: TransferEncoding, - elide_body: bool, - /// Indicates how much of the end of the `send_buffer` corresponds to a - /// chunk. This amount of data will be wrapped by an HTTP chunk header. - chunk_len: usize, +/// See https://tools.ietf.org/html/rfc6455 +pub const WebSocket = struct { + key: []const u8, + input: *std.io.Reader, + output: *Writer, - pub const TransferEncoding = union(enum) { - /// End of connection signals the end of the stream. - none, - /// As a debugging utility, counts down to zero as bytes are written. - content_length: u64, - /// Each chunk is wrapped in a header and trailer. - chunked, + pub const Header0 = packed struct(u8) { + opcode: Opcode, + rsv3: u1 = 0, + rsv2: u1 = 0, + rsv1: u1 = 0, + fin: bool, }; - pub const WriteError = net.Stream.WriteError; - - /// When using content-length, asserts that the amount of data sent matches - /// the value sent in the header, then calls `flush`. - /// Otherwise, transfer-encoding: chunked is being used, and it writes the - /// end-of-stream message, then flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .content_length => |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - }, - .none => { - try flush_cl(r); - }, - .chunked => { - try flush_chunked(r, &.{}); - }, - } - r.* = undefined; - } - - pub const EndChunkedOptions = struct { - trailers: []const http.Header = &.{}, + pub const Header1 = packed struct(u8) { + payload_len: enum(u7) { + len16 = 126, + len64 = 127, + _, + }, + mask: bool, }; - /// Asserts that the Response is using transfer-encoding: chunked. - /// Writes the end-of-stream message and any optional trailers, then - /// flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { - assert(r.transfer_encoding == .chunked); - try flush_chunked(r, options.trailers); - r.* = undefined; - } + pub const Opcode = enum(u4) { + continuation = 0, + text = 1, + binary = 2, + connection_close = 8, + ping = 9, + /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional + /// heartbeat. A response to an unsolicited Pong frame is not expected." + pong = 10, + _, + }; - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } + pub const ReadSmallTextMessageError = error{ + ConnectionClose, + UnexpectedOpCode, + MessageTooBig, + MissingMaskBit, + }; - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); + pub const SmallMessage = struct { + /// Can be text, binary, or ping. + opcode: Opcode, + data: []u8, + }; - var trash: u64 = std.math.maxInt(u64); - const len = switch (r.transfer_encoding) { - .content_length => |*len| len, - else => &trash, - }; + /// Reads the next message from the WebSocket stream, failing if the + /// message does not fit into the input buffer. The returned memory points + /// into the input buffer and is invalidated on the next read. + pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { + const in = ws.input; + while (true) { + const h0 = in.takeStruct(Header0); + const h1 = in.takeStruct(Header1); - if (r.elide_body) { - len.* -= bytes.len; - return bytes.len; - } - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - }; - const n = try r.stream.writev(&iovecs); - - if (n >= send_buffer_len) { - // It was enough to reset the buffer. - r.send_buffer_start = 0; - r.send_buffer_end = 0; - const bytes_n = n - send_buffer_len; - len.* -= bytes_n; - return bytes_n; + switch (h0.opcode) { + .text, .binary, .pong, .ping => {}, + .connection_close => return error.ConnectionClose, + .continuation => return error.UnexpectedOpCode, + _ => return error.UnexpectedOpCode, } - // It didn't even make it through the existing buffer, let - // alone the new bytes provided. - r.send_buffer_start += n; - return 0; - } + if (!h0.fin) return error.MessageTooBig; + if (!h1.mask) return error.MissingMaskBit; - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - len.* -= bytes.len; - return bytes.len; - } - - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @constCast(@alignCast(@ptrCast(context))); - assert(r.transfer_encoding == .chunked); - - if (r.elide_body) - return bytes.len; - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - const chunk_len = r.chunk_len + bytes.len; - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, + const len: usize = switch (h1.payload_len) { + .len16 => try in.takeInt(u16, .big), + .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig, + else => @intFromEnum(h1.payload_len), }; - // TODO make this writev instead of writevAll, which involves - // complicating the logic of this function. - try r.stream.writevAll(&iovecs); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return bytes.len; - } + if (len > in.buffer.len) return error.MessageTooBig; + const mask: u32 = @bitCast((try in.takeArray(4)).*); + const payload = try in.take(len); - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - r.chunk_len += bytes.len; - return bytes.len; - } + // Skip pongs. + if (h0.opcode == .pong) continue; - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); + // The last item may contain a partial word of unused data. + const floored_len = (payload.len / 4) * 4; + const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]); + for (u32_payload) |*elem| elem.* ^= mask; + const mask_bytes: []const u8 = @ptrCast(&mask); + for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m| + leftover.* ^= m; + + return .{ + .opcode = h0.opcode, + .data = payload, + }; } } - /// Sends all buffered data to the client. - /// This is redundant after calling `end`. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .none, .content_length => return flush_cl(r), - .chunked => return flush_chunked(r, null), - } + pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, &.{data}, op); + try ws.output.flush(); } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; + pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, &.{data}, op); } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { - const max_trailers = 25; - if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.transfer_encoding == .chunked); + pub fn writeMessageVec(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, data, op); + try ws.output.flush(); + } - const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; - - if (r.elide_body) { - try r.stream.writeAll(http_headers); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return; - } - - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, + pub fn writeMessageVecUnflushed(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + const total_len = l: { + var total_len: u64 = 0; + for (data) |iovec| total_len += iovec.len; + break :l total_len; }; - iovecs_len += 1; - - if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; - iovecs_len += 1; - - for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - try r.stream.writevAll(iovecs[0..iovecs_len]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - } - - pub fn writer(r: *Response) std.io.AnyWriter { - return .{ - .writeFn = switch (r.transfer_encoding) { - .none, .content_length => write_cl, - .chunked => write_chunked, + const out = ws.output; + try out.writeStruct(@as(Header0, .{ + .opcode = op, + .fin = true, + })); + switch (total_len) { + 0...125 => try out.writeStruct(@as(Header1, .{ + .payload_len = @enumFromInt(total_len), + .mask = false, + })), + 126...0xffff => { + try out.writeStruct(@as(Header1, .{ + .payload_len = .len16, + .mask = false, + })); + try out.writeInt(u16, @intCast(total_len), .big); }, - .context = r, - }; + else => { + try out.writeStruct(@as(Header1, .{ + .payload_len = .len64, + .mask = false, + })); + try out.writeInt(u64, total_len, .big); + }, + } + try out.writeVecAll(data); + } + + pub fn flush(ws: *WebSocket) Writer.Error!void { + try ws.output.flush(); } }; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} - -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/http/WebSocket.zig b/lib/std/http/WebSocket.zig deleted file mode 100644 index b9a66cdbd6..0000000000 --- a/lib/std/http/WebSocket.zig +++ /dev/null @@ -1,246 +0,0 @@ -//! See https://tools.ietf.org/html/rfc6455 - -const builtin = @import("builtin"); -const std = @import("std"); -const WebSocket = @This(); -const assert = std.debug.assert; -const native_endian = builtin.cpu.arch.endian(); - -key: []const u8, -request: *std.http.Server.Request, -recv_fifo: std.fifo.LinearFifo(u8, .Slice), -reader: std.io.AnyReader, -response: std.http.Server.Response, -/// Number of bytes that have been peeked but not discarded yet. -outstanding_len: usize, - -pub const InitError = error{WebSocketUpgradeMissingKey} || - std.http.Server.Request.ReaderError; - -pub fn init( - request: *std.http.Server.Request, - send_buffer: []u8, - recv_buffer: []align(4) u8, -) InitError!?WebSocket { - switch (request.head.version) { - .@"HTTP/1.0" => return null, - .@"HTTP/1.1" => if (request.head.method != .GET) return null, - } - - var sec_websocket_key: ?[]const u8 = null; - var upgrade_websocket: bool = false; - var it = request.iterateHeaders(); - while (it.next()) |header| { - if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { - sec_websocket_key = header.value; - } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { - if (!std.ascii.eqlIgnoreCase(header.value, "websocket")) - return null; - upgrade_websocket = true; - } - } - if (!upgrade_websocket) - return null; - - const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey; - - var sha1 = std.crypto.hash.Sha1.init(.{}); - sha1.update(key); - sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; - sha1.final(&digest); - var base64_digest: [28]u8 = undefined; - assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); - - request.head.content_length = std.math.maxInt(u64); - - return .{ - .key = key, - .recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer), - .reader = try request.reader(), - .response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .respond_options = .{ - .status = .switching_protocols, - .extra_headers = &.{ - .{ .name = "upgrade", .value = "websocket" }, - .{ .name = "connection", .value = "upgrade" }, - .{ .name = "sec-websocket-accept", .value = &base64_digest }, - }, - .transfer_encoding = .none, - }, - }), - .request = request, - .outstanding_len = 0, - }; -} - -pub const Header0 = packed struct(u8) { - opcode: Opcode, - rsv3: u1 = 0, - rsv2: u1 = 0, - rsv1: u1 = 0, - fin: bool, -}; - -pub const Header1 = packed struct(u8) { - payload_len: enum(u7) { - len16 = 126, - len64 = 127, - _, - }, - mask: bool, -}; - -pub const Opcode = enum(u4) { - continuation = 0, - text = 1, - binary = 2, - connection_close = 8, - ping = 9, - /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional - /// heartbeat. A response to an unsolicited Pong frame is not expected." - pong = 10, - _, -}; - -pub const ReadSmallTextMessageError = error{ - ConnectionClose, - UnexpectedOpCode, - MessageTooBig, - MissingMaskBit, -} || RecvError; - -pub const SmallMessage = struct { - /// Can be text, binary, or ping. - opcode: Opcode, - data: []u8, -}; - -/// Reads the next message from the WebSocket stream, failing if the message does not fit -/// into `recv_buffer`. -pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { - while (true) { - const header_bytes = (try recv(ws, 2))[0..2]; - const h0: Header0 = @bitCast(header_bytes[0]); - const h1: Header1 = @bitCast(header_bytes[1]); - - switch (h0.opcode) { - .text, .binary, .pong, .ping => {}, - .connection_close => return error.ConnectionClose, - .continuation => return error.UnexpectedOpCode, - _ => return error.UnexpectedOpCode, - } - - if (!h0.fin) return error.MessageTooBig; - if (!h1.mask) return error.MissingMaskBit; - - const len: usize = switch (h1.payload_len) { - .len16 => try recvReadInt(ws, u16), - .len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig, - else => @intFromEnum(h1.payload_len), - }; - if (len > ws.recv_fifo.buf.len) return error.MessageTooBig; - - const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*); - const payload = try recv(ws, len); - - // Skip pongs. - if (h0.opcode == .pong) continue; - - // The last item may contain a partial word of unused data. - const floored_len = (payload.len / 4) * 4; - const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len])); - for (u32_payload) |*elem| elem.* ^= mask; - const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len]; - for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m; - - return .{ - .opcode = h0.opcode, - .data = payload, - }; - } -} - -const RecvError = std.http.Server.Request.ReadError || error{EndOfStream}; - -fn recv(ws: *WebSocket, len: usize) RecvError![]u8 { - ws.recv_fifo.discard(ws.outstanding_len); - assert(len <= ws.recv_fifo.buf.len); - if (len > ws.recv_fifo.count) { - const small_buf = ws.recv_fifo.writableSlice(0); - const needed = len - ws.recv_fifo.count; - const buf = if (small_buf.len >= needed) small_buf else b: { - ws.recv_fifo.realign(); - break :b ws.recv_fifo.writableSlice(0); - }; - const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed))); - if (n < needed) return error.EndOfStream; - ws.recv_fifo.update(n); - } - ws.outstanding_len = len; - // TODO: improve the std lib API so this cast isn't necessary. - return @constCast(ws.recv_fifo.readableSliceOfLen(len)); -} - -fn recvReadInt(ws: *WebSocket, comptime I: type) !I { - const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*); - return switch (native_endian) { - .little => @byteSwap(unswapped), - .big => unswapped, - }; -} - -pub const WriteError = std.http.Server.Response.WriteError; - -pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void { - const iovecs: [1]std.posix.iovec_const = .{ - .{ .base = message.ptr, .len = message.len }, - }; - return writeMessagev(ws, &iovecs, opcode); -} - -pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void { - const total_len = l: { - var total_len: u64 = 0; - for (message) |iovec| total_len += iovec.len; - break :l total_len; - }; - - var header_buf: [2 + 8]u8 = undefined; - header_buf[0] = @bitCast(@as(Header0, .{ - .opcode = opcode, - .fin = true, - })); - const header = switch (total_len) { - 0...125 => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = @enumFromInt(total_len), - .mask = false, - })); - break :blk header_buf[0..2]; - }, - 126...0xffff => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len16, - .mask = false, - })); - std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big); - break :blk header_buf[0..4]; - }, - else => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len64, - .mask = false, - })); - std.mem.writeInt(u64, header_buf[2..10], total_len, .big); - break :blk header_buf[0..10]; - }, - }; - - const response = &ws.response; - try response.writeAll(header); - for (message) |iovec| - try response.writeAll(iovec.base[0..iovec.len]); - try response.flush(); -} diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig deleted file mode 100644 index 797ed989ad..0000000000 --- a/lib/std/http/protocol.zig +++ /dev/null @@ -1,464 +0,0 @@ -const std = @import("../std.zig"); -const builtin = @import("builtin"); -const testing = std.testing; -const mem = std.mem; - -const assert = std.debug.assert; - -pub const State = enum { - invalid, - - // Begin header and trailer parsing states. - - start, - seen_n, - seen_r, - seen_rn, - seen_rnr, - finished, - - // Begin transfer-encoding: chunked parsing states. - - chunk_head_size, - chunk_head_ext, - chunk_head_r, - chunk_data, - chunk_data_suffix, - chunk_data_suffix_r, - - /// Returns true if the parser is in a content state (ie. not waiting for more headers). - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, - }; - } -}; - -pub const HeadersParser = struct { - state: State = .start, - /// A fixed buffer of len `max_header_bytes`. - /// Pointers into this buffer are not stable until after a message is complete. - header_bytes_buffer: []u8, - header_bytes_len: u32, - next_chunk_length: u64, - /// `false`: headers. `true`: trailers. - done: bool, - - /// Initializes the parser with a provided buffer `buf`. - pub fn init(buf: []u8) HeadersParser { - return .{ - .header_bytes_buffer = buf, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - /// Reinitialize the parser. - /// Asserts the parser is in the "done" state. - pub fn reset(hp: *HeadersParser) void { - assert(hp.done); - hp.* = .{ - .state = .start, - .header_bytes_buffer = hp.header_bytes_buffer, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - pub fn get(hp: HeadersParser) []u8 { - return hp.header_bytes_buffer[0..hp.header_bytes_len]; - } - - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - var hp: std.http.HeadParser = .{ - .state = switch (r.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - else => unreachable, - }, - }; - const result = hp.feed(bytes); - r.state = switch (hp.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - }; - return @intCast(result); - } - - pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - var cp: std.http.ChunkParser = .{ - .state = switch (r.state) { - .chunk_head_size => .head_size, - .chunk_head_ext => .head_ext, - .chunk_head_r => .head_r, - .chunk_data => .data, - .chunk_data_suffix => .data_suffix, - .chunk_data_suffix_r => .data_suffix_r, - .invalid => .invalid, - else => unreachable, - }, - .chunk_len = r.next_chunk_length, - }; - const result = cp.feed(bytes); - r.state = switch (cp.state) { - .head_size => .chunk_head_size, - .head_ext => .chunk_head_ext, - .head_r => .chunk_head_r, - .data => .chunk_data, - .data_suffix => .chunk_data_suffix, - .data_suffix_r => .chunk_data_suffix_r, - .invalid => .invalid, - }; - r.next_chunk_length = cp.chunk_len; - return @intCast(result); - } - - /// Returns whether or not the parser has finished parsing a complete - /// message. A message is only complete after the entire body has been read - /// and any trailing headers have been parsed. - pub fn isComplete(r: *HeadersParser) bool { - return r.done and r.state == .finished; - } - - pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - - /// Pushes `in` into the parser. Returns the number of bytes consumed by - /// the header. Any header bytes are appended to `header_bytes_buffer`. - pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { - if (hp.state.isContent()) return 0; - - const i = hp.findHeadersEnd(in); - const data = in[0..i]; - if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) - return error.HttpHeadersOversize; - - @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); - hp.header_bytes_len += @intCast(data.len); - - return i; - } - - pub const ReadError = error{ - HttpChunkInvalid, - }; - - /// Reads the body of the message into `buffer`. Returns the number of - /// bytes placed in the buffer. - /// - /// If `skip` is true, the buffer will be unused and the body will be skipped. - /// - /// See `std.http.Client.Connection for an example of `conn`. - pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { - assert(r.state.isContent()); - if (r.done) return 0; - - var out_index: usize = 0; - while (true) { - switch (r.state) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - const data_avail = r.next_chunk_length; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return out_index; - } else if (out_index < buffer.len) { - const out_avail = buffer.len - out_index; - - const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); - const nread = try conn.read(buffer[0..can_read]); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return nread; - } else { - return out_index; - } - }, - .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const i = r.findChunkedLen(conn.peek()); - conn.drop(@intCast(i)); - - switch (r.state) { - .invalid => return error.HttpChunkInvalid, - .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, conn.peek(), "\r\n")) { - r.state = .finished; - conn.drop(2); - } else { - // The trailer section is formatted identically - // to the header section. - r.state = .seen_rn; - } - r.done = true; - - return out_index; - }, - else => return out_index, - } - - continue; - }, - .chunk_data => { - const data_avail = r.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - } else if (out_avail > 0) { - const can_read: usize = @intCast(@min(data_avail, out_avail)); - const nread = try conn.read(buffer[out_index..][0..can_read]); - r.next_chunk_length -= nread; - out_index += nread; - } - - if (r.next_chunk_length == 0) { - r.state = .chunk_data_suffix; - continue; - } - - return out_index; - }, - } - } - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @as(u16, @bitCast(array.*)); -} - -inline fn int24(array: *const [3]u8) u24 { - return @as(u24, @bitCast(array.*)); -} - -inline fn int32(array: *const [4]u8) u32 { - return @as(u32, @bitCast(array.*)); -} - -inline fn intShift(comptime T: type, x: anytype) T { - switch (@import("builtin").cpu.arch.endian()) { - .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), - .big => return @as(T, @truncate(x)), - } -} - -/// A buffered (and peekable) Connection. -const MockBufferedConnection = struct { - pub const buffer_size = 0x2000; - - conn: std.io.FixedBufferStream([]const u8), - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, - - pub fn fill(conn: *MockBufferedConnection) ReadError!void { - if (conn.end != conn.start) return; - - const nread = try conn.conn.read(conn.buf[0..]); - if (nread == 0) return error.EndOfStream; - conn.start = 0; - conn.end = @as(u16, @truncate(nread)); - } - - pub fn peek(conn: *MockBufferedConnection) []const u8 { - return conn.buf[conn.start..conn.end]; - } - - pub fn drop(conn: *MockBufferedConnection, num: u16) void { - conn.start += num; - } - - pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = conn.end - conn.start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @as(u16, @truncate(@min(available, left))); - - @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); - out_index += can_read; - conn.start += can_read; - - continue; - } - - if (left > conn.buf.len) { - // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; - pub const Reader = std.io.GenericReader(*MockBufferedConnection, ReadError, read); - - pub fn reader(conn: *MockBufferedConnection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.GenericWriter(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } -}; - -test "HeadersParser.read length" { - // mock BufferedConnection for read - var headers_buf: [256]u8 = undefined; - - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - var buf: [8]u8 = undefined; - - r.next_chunk_length = 5; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked trailer" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); -} diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 33bc2eb191..4c3466d5c9 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -10,32 +10,33 @@ const expectError = std.testing.expectError; test "trailers" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [1024]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try serve(&request); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); } } fn serve(request: *http.Server.Request) !void { try expectEqualStrings(request.head.target, "/trailer"); - var send_buffer: [1024]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, - }); - try response.writeAll("Hello, "); + var response = try request.respondStreaming(&.{}, .{}); + try response.writer.writeAll("Hello, "); try response.flush(); - try response.writeAll("World!\n"); + try response.writer.writeAll("World!\n"); try response.flush(); try response.endChunked(.{ .trailers = &.{ @@ -58,34 +59,33 @@ test "trailers" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&.{}); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - var it = req.response.iterateHeaders(); { + var it = response.head.iterateHeaders(); const header = it.next().?; try expect(!it.is_trailer); try expectEqualStrings("transfer-encoding", header.name); try expectEqualStrings("chunked", header.value); + try expectEqual(null, it.next()); } { + var it = response.iterateTrailers(); const header = it.next().?; try expect(it.is_trailer); try expectEqualStrings("X-Checksum", header.name); try expectEqualStrings("aaaa", header.value); + try expectEqual(null, it.next()); } - try expectEqual(null, it.next()); } // connection has been kept alive @@ -94,19 +94,24 @@ test "trailers" { test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) !void { - var header_buffer: [8192]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [8192]u8 = undefined; + var send_buffer: [500]u8 = undefined; + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try (try request.reader()).readAll(&buf); - try expect(mem.eql(u8, buf[0..n], "ABCD")); + var br = try request.readerExpectContinue(&.{}); + const n = try br.readSliceShort(&buf); + try expectEqualStrings("ABCD", buf[0..n]); try request.respond("message from server!\n", .{ .extra_headers = &.{ @@ -154,16 +159,20 @@ test "HTTP server handles a chunked transfer coding request" { test "echo content server" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; - accept: while (true) { - const conn = try net_server.accept(); - defer conn.stream.close(); + accept: while (!test_server.shutting_down) { + const connection = try net_server.accept(); + defer connection.stream.close(); - var http_server = http.Server.init(conn, &read_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :accept, else => |e| return e, @@ -173,7 +182,7 @@ test "echo content server" { } if (request.head.expect) |expect_header_value| { if (mem.eql(u8, expect_header_value, "garbage")) { - try expectError(error.HttpExpectationFailed, request.reader()); + try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{})); try request.respond("", .{ .keep_alive = false }); continue; } @@ -195,16 +204,14 @@ test "echo content server" { // request.head.target, //}); - const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .limited(8192)); defer std.testing.allocator.free(body); try expect(mem.startsWith(u8, request.head.target, "/echo-content")); try expectEqualStrings("Hello, World!\n", body); try expectEqualStrings("text/plain", request.head.content_type.?); - var send_buffer: [100]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = switch (request.head.transfer_encoding) { .chunked => null, .none => len: { @@ -213,9 +220,8 @@ test "echo content server" { }, }, }); - try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); @@ -241,35 +247,36 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { // In this case, the response is expected to stream until the connection is // closed, indicating the end of the body. const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1000]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1000]u8 = undefined; + var send_buffer: [500]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/foo"); - var send_buffer: [500]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var buf: [30]u8 = undefined; + var response = try request.respondStreaming(&buf, .{ .respond_options = .{ .transfer_encoding = .none, }, }); - var total: usize = 0; + const w = &response.writer; for (0..500) |i| { - var buf: [30]u8 = undefined; - const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); - try response.writeAll(line); - total += line.len; + try w.print("{d}, ah ha ha!\n", .{i}); } - try expectEqual(7390, total); + try expectEqual(7390, w.count); + try w.flush(); try response.end(); - try expectEqual(.closing, server.state); + try expectEqual(.closing, server.reader.state); } } }); @@ -308,15 +315,20 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { test "receiving arbitrary http headers from the client" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [666]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [666]u8 = undefined; + var send_buffer: [777]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &read_buffer); - try expectEqual(.ready, server.state); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings("/bar", request.head.target); var it = request.iterateHeaders(); @@ -368,19 +380,21 @@ test "general client/server API coverage" { return error.SkipZigTest; } - const global = struct { - var handle_new_requests = true; - }; const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var client_header_buffer: [1024]u8 = undefined; - outer: while (global.handle_new_requests) { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; + + outer: while (!test_server.shutting_down) { var connection = try net_server.accept(); defer connection.stream.close(); - var http_server = http.Server.init(connection, &client_header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :outer, else => |e| return e, @@ -399,14 +413,11 @@ test "general client/server API coverage" { }); const gpa = std.testing.allocator; - const body = try (try request.reader()).readAllAlloc(gpa, 8192); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); - var send_buffer: [100]u8 = undefined; - if (mem.startsWith(u8, request.head.target, "/get")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) 14 else @@ -417,20 +428,19 @@ test "general client/server API coverage" { }, }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); // Writing again would cause an assertion failure. } else if (mem.startsWith(u8, request.head.target, "/large")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = 14 * 1024 + 14 * 10, }); try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; var i: u32 = 0; while (i < 5) : (i += 1) { @@ -446,8 +456,7 @@ test "general client/server API coverage" { try response.end(); } else if (mem.eql(u8, request.head.target, "/redirect/1")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .status = .found, .extra_headers = &.{ @@ -456,7 +465,7 @@ test "general client/server API coverage" { }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("Redirected!\n"); try response.end(); @@ -524,17 +533,13 @@ test "general client/server API coverage" { return s.listen_address.in.getPort(); } }); - defer { - global.handle_new_requests = false; - test_server.destroy(); - } + defer test_server.destroy(); const log = std.log.scoped(.client); const gpa = std.testing.allocator; var client: http.Client = .{ .allocator = gpa }; - errdefer client.deinit(); - // defer client.deinit(); handled below + defer client.deinit(); const port = test_server.port(); @@ -544,20 +549,18 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -569,16 +572,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192 * 1024); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192 * 1024)); defer gpa.free(body); try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); @@ -593,21 +594,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expectEqual(14, req.response.content_length.?); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expectEqual(14, response.head.content_length.?); } // connection has been kept alive @@ -619,20 +618,18 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -644,21 +641,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expect(req.response.transfer_encoding == .chunked); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expect(response.head.transfer_encoding == .chunked); } // connection has been kept alive @@ -670,21 +665,20 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .keep_alive = false, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); + try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been closed @@ -696,26 +690,25 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .extra_headers = &.{ .{ .name = "empty", .value = "" }, }, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - try std.testing.expectEqual(.ok, req.response.status); + try std.testing.expectEqual(.ok, response.head.status); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("", body); - var it = req.response.iterateHeaders(); + var it = response.head.iterateHeaders(); { const header = it.next().?; try expect(!it.is_trailer); @@ -740,16 +733,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -764,16 +755,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -788,16 +777,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -812,17 +799,17 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - req.wait() catch |err| switch (err) { + try req.sendBodiless(); + if (req.receiveHead(&redirect_buffer)) |_| { + return error.TestFailed; + } else |err| switch (err) { error.TooManyHttpRedirects => {}, else => return err, - }; + } } { // redirect to encoded url @@ -831,16 +818,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Encoded redirect successful!\n", body); @@ -855,14 +840,12 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - const result = req.wait(); + try req.sendBodiless(); + const result = req.receiveHead(&redirect_buffer); // a proxy without an upstream is likely to return a 5xx status. if (client.http_proxy == null) { @@ -872,77 +855,40 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** - const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); - defer gpa.free(location); - const uri = try std.Uri.parse(location); - - const total_connections = client.connection_pool.free_size + 64; - var requests = try gpa.alloc(http.Client.Request, total_connections); - defer gpa.free(requests); - - var header_bufs = std.ArrayList([]u8).init(gpa); - defer header_bufs.deinit(); - defer for (header_bufs.items) |item| gpa.free(item); - - for (0..total_connections) |i| { - const headers_buf = try gpa.alloc(u8, 1024); - try header_bufs.append(headers_buf); - var req = try client.open(.GET, uri, .{ - .server_header_buffer = headers_buf, - }); - req.response.parser.done = true; - req.connection.?.closing = false; - requests[i] = req; - } - - for (0..total_connections) |i| { - requests[i].deinit(); - } - - // free connections should be full now - try expect(client.connection_pool.free_len == client.connection_pool.free_size); - } - - client.deinit(); - - { - global.handle_new_requests = false; - - const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address); - conn.close(); - } } test "Server streams both reading and writing" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); - - var server = http.Server.init(conn, &header_buffer); - var request = try server.receiveHead(); - const reader = try request.reader(); - + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; var send_buffer: [777]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + + const connection = try net_server.accept(); + defer connection.stream.close(); + + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + var request = try server.receiveHead(); + var read_buffer: [100]u8 = undefined; + var br = try request.readerExpectContinue(&read_buffer); + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .transfer_encoding = .none, // Causes keep_alive=false }, }); - const writer = response.writer(); + const w = &response.writer; while (true) { try response.flush(); - var buf: [100]u8 = undefined; - const n = try reader.read(&buf); - if (n == 0) break; - const sub_buf = buf[0..n]; - for (sub_buf) |*b| b.* = std.ascii.toUpper(b.*); - try writer.writeAll(sub_buf); + const buf = br.peekGreedy(1) catch |err| switch (err) { + error.EndOfStream => break, + error.ReadFailed => return error.ReadFailed, + }; + br.toss(buf.len); + for (buf) |*b| b.* = std.ascii.toUpper(b.*); + try w.writeAll(buf); } try response.end(); } @@ -952,27 +898,24 @@ test "Server streams both reading and writing" { var client: http.Client = .{ .allocator = std.testing.allocator }; defer client.deinit(); - var server_header_buffer: [555]u8 = undefined; - var req = try client.open(.POST, .{ + var redirect_buffer: [555]u8 = undefined; + var req = try client.request(.POST, .{ .scheme = "http", .host = .{ .raw = "127.0.0.1" }, .port = test_server.port(), .path = .{ .percent_encoded = "/" }, - }, .{ - .server_header_buffer = &server_header_buffer, - }); + }, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); + var body_writer = try req.sendBody(&.{}); + var response = try req.receiveHead(&redirect_buffer); - try req.writeAll("one "); - try req.writeAll("fish"); + try body_writer.writer.writeAll("one "); + try body_writer.writer.writeAll("fish"); + try body_writer.end(); - try req.finish(); - - const body = try req.reader().readAllAlloc(std.testing.allocator, 8192); + const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .limited(8192)); defer std.testing.allocator.free(body); try expectEqualStrings("ONE FISH", body); @@ -987,9 +930,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -998,14 +940,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1021,9 +963,8 @@ fn echoTests(client: *http.Client, port: u16) !void { .{port}, )); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -1032,14 +973,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1053,8 +994,8 @@ fn echoTests(client: *http.Client, port: u16) !void { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port}); defer gpa.free(location); - var body = std.ArrayList(u8).init(gpa); - defer body.deinit(); + var body: std.ArrayListUnmanaged(u8) = .empty; + defer body.deinit(gpa); const res = try client.fetch(.{ .location = .{ .url = location }, @@ -1063,7 +1004,7 @@ fn echoTests(client: *http.Client, port: u16) !void { .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, - .response_storage = .{ .dynamic = &body }, + .response_storage = .{ .allocator = gpa, .list = &body }, }); try expectEqual(.ok, res.status); try expectEqualStrings("Hello, World!\n", body.items); @@ -1074,9 +1015,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "expect", .value = "100-continue" }, .{ .name = "content-type", .value = "text/plain" }, @@ -1086,15 +1026,15 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); - try expectEqual(.ok, req.response.status); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.ok, response.head.status); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1105,9 +1045,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, .{ .name = "expect", .value = "garbage" }, @@ -1117,23 +1056,24 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); - try expectEqual(.expectation_failed, req.response.status); + var body_writer = try req.sendBody(&.{}); + try body_writer.flush(); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.expectation_failed, response.head.status); + _ = try response.reader(&.{}).discardRemaining(); } - - _ = try client.fetch(.{ - .location = .{ - .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}), - }, - }); } const TestServer = struct { + shutting_down: bool, server_thread: std.Thread, net_server: std.net.Server, fn destroy(self: *@This()) void { + self.shutting_down = true; + const conn = std.net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure"); + conn.close(); + self.server_thread.join(); self.net_server.deinit(); std.testing.allocator.destroy(self); @@ -1153,20 +1093,27 @@ fn createTestServer(S: type) !*TestServer { const address = try std.net.Address.parseIp("127.0.0.1", 0); const test_server = try std.testing.allocator.create(TestServer); - test_server.net_server = try address.listen(.{ .reuse_address = true }); - test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server}); + test_server.* = .{ + .net_server = try address.listen(.{ .reuse_address = true }), + .server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}), + .shutting_down = false, + }; return test_server; } test "redirect to different connection" { const test_server_new = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [888]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [888]u8 = undefined; + var send_buffer: [777]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/ok"); try request.respond("good job, you pass", .{}); @@ -1180,18 +1127,22 @@ test "redirect to different connection" { global.other_port = test_server_new.port(); const test_server_orig = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [999]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{ + var loc_buf: [50]u8 = undefined; + const new_loc = try std.fmt.bufPrint(&loc_buf, "http://127.0.0.1:{d}/ok", .{ global.other_port.?, }); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/help"); try request.respond("", .{ @@ -1216,16 +1167,15 @@ test "redirect to different connection" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [666]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [666]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + var reader = response.reader(&.{}); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try reader.allocRemaining(gpa, .limited(8192)); defer gpa.free(body); try expectEqualStrings("good job, you pass", body); From 8f06754a06996e1114e5ada9644fa36da4908642 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 00:13:28 -0700 Subject: [PATCH 53/70] std.crypto.tls: rework for new std.Io API --- lib/std/Io/Reader.zig | 25 - lib/std/crypto/tls.zig | 205 +++--- lib/std/crypto/tls/Client.zig | 1202 +++++++++++---------------------- lib/std/http.zig | 19 +- 4 files changed, 502 insertions(+), 949 deletions(-) diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index 78d3646cbe..88c2a20f97 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -1306,31 +1306,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void { r.end = data.len; } -/// Advances the stream and decreases the size of the storage buffer by `n`, -/// returning the range of bytes no longer accessible by `r`. -/// -/// This action can be undone by `restitute`. -/// -/// Asserts there are at least `n` buffered bytes already. -/// -/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state. -pub fn steal(r: *Reader, n: usize) []u8 { - assert(r.seek == 0); - assert(n <= r.end); - const stolen = r.buffer[0..n]; - r.buffer = r.buffer[n..]; - r.end -= n; - return stolen; -} - -/// Expands the storage buffer, undoing the effects of `steal` -/// Assumes that `n` does not exceed the total number of stolen bytes. -pub fn restitute(r: *Reader, n: usize) void { - r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n]; - r.end += n; - r.seek += n; -} - test fixed { var r: Reader = .fixed("a\x02"); try testing.expect((try r.takeByte()) == 'a'); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index da6a431840..e647a7710e 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{ }; pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.Description.close_notify), }; pub const ProtocolVersion = enum(u16) { @@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) { _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; +pub const Alert = struct { + level: Level, + description: Description, -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + pub const Description = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; - pub fn toError(alert: AlertDescription) Error!void { - switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => return error.TlsAlertUnexpectedMessage, - .bad_record_mac => return error.TlsAlertBadRecordMac, - .record_overflow => return error.TlsAlertRecordOverflow, - .handshake_failure => return error.TlsAlertHandshakeFailure, - .bad_certificate => return error.TlsAlertBadCertificate, - .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, - .certificate_revoked => return error.TlsAlertCertificateRevoked, - .certificate_expired => return error.TlsAlertCertificateExpired, - .certificate_unknown => return error.TlsAlertCertificateUnknown, - .illegal_parameter => return error.TlsAlertIllegalParameter, - .unknown_ca => return error.TlsAlertUnknownCa, - .access_denied => return error.TlsAlertAccessDenied, - .decode_error => return error.TlsAlertDecodeError, - .decrypt_error => return error.TlsAlertDecryptError, - .protocol_version => return error.TlsAlertProtocolVersion, - .insufficient_security => return error.TlsAlertInsufficientSecurity, - .internal_error => return error.TlsAlertInternalError, - .inappropriate_fallback => return error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => return error.TlsAlertMissingExtension, - .unsupported_extension => return error.TlsAlertUnsupportedExtension, - .unrecognized_name => return error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, - .certificate_required => return error.TlsAlertCertificateRequired, - .no_application_protocol => return error.TlsAlertNoApplicationProtocol, - _ => return error.TlsAlertUnknown, + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(description: Description) Error!void { + switch (description) { + .close_notify => {}, // not an error + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } - } + }; }; pub const SignatureScheme = enum(u16) { @@ -650,7 +655,7 @@ pub const Decoder = struct { } /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void { assert(!d.disable_reads); const existing_amt = d.cap - d.idx; d.their_end = d.idx + their_amt; @@ -658,14 +663,16 @@ pub const Decoder = struct { const request_amt = their_amt - existing_amt; const dest = d.buf[d.cap..]; if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; + stream.readSlice(dest[0..request_amt]) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + d.cap += request_amt; } /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void { assert(!d.disable_reads); try readAtLeast(d, stream, our_amt); d.our_end = d.idx + our_amt; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3fa7b73d06..082fc9da70 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,11 +1,15 @@ +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; +const Reader = std.io.Reader; +const Writer = std.io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -13,44 +17,58 @@ const hkdfExpandLabel = tls.hkdfExpandLabel; const int = tls.int; const array = tls.array; +/// The encrypted stream from the server to the client. Bytes are pulled from +/// here via `reader`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +input: *Reader, +/// Decrypted stream from the server to the client. +reader: Reader, + +/// The encrypted stream from the client to the server. Bytes are pushed here +/// via `writer`. +output: *Writer, +/// The plaintext stream from the client to the server. +writer: Writer, + +/// Populated when `error.TlsAlert` is returned. +alert: ?tls.Alert = null, +read_err: ?ReadError = null, tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. +/// may be data in the input buffer. received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, -/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other -/// programs with access to that file to decrypt all traffic over this connection. -ssl_key_log: ?struct { + +/// If non-null, ssl secrets are logged to a stream. Creating such a log file +/// allows other programs with access to that file to decrypt all traffic over +/// this connection. +ssl_key_log: ?*SslKeyLog, + +pub const ReadError = error{ + /// The alert description will be stored in `alert`. + TlsAlert, + TlsBadLength, + TlsBadRecordMac, + TlsConnectionTruncated, + TlsDecodeError, + TlsRecordOverflow, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsSequenceOverflow, + /// The buffer provided to the read function was not at least + /// `min_buffer_len`. + OutputBufferUndersize, +}; + +pub const SslKeyLog = struct { client_key_seq: u64, server_key_seq: u64, client_random: [32]u8, - file: std.fs.File, + writer: *Writer, fn clientCounter(key_log: *@This()) u64 { defer key_log.client_key_seq += 1; @@ -61,51 +79,12 @@ ssl_key_log: ?struct { defer key_log.server_key_seq += 1; return key_log.server_key_seq; } -}, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } }; +/// The `Reader` supplied to `init` requires a buffer capacity +/// at least this amount. +pub const min_buffer_len = tls.max_ciphertext_record_len; + pub const Options = struct { /// How to perform host verification of server certificates. host: union(enum) { @@ -127,64 +106,85 @@ pub const Options = struct { /// Verify that the server certificate is authorized by a given ca bundle. bundle: Certificate.Bundle, }, - /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows /// other programs with access to that file to decrypt all traffic over this connection. - ssl_key_log_file: ?std.fs.File = null, + /// + /// Only the `writer` field is observed during the handshake (`init`). + /// After that, the other fields are populated. + ssl_key_log: ?*SslKeyLog = null, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + write_buffer: []u8, + /// Asserted to have capacity at least `min_buffer_len`. + read_buffer: []u8, + /// Populated when `error.TlsAlert` is returned from `init`. + alert: ?*tls.Alert = null, }; -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; -} +const InitError = error{ + WriteFailed, + ReadFailed, + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + /// The alert description will be stored in `alert`. + TlsAlert, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + WeakPublicKey, +}; -/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { +/// +/// `input` is asserted to have buffer capacity at least `min_buffer_len`. +pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { + assert(input.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -276,11 +276,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; { - var iovecs = [_]std.posix.iovec_const{ - .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, - .{ .base = host.ptr, .len = host.len }, - }; - try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); + var iovecs: [2][]const u8 = .{ cleartext_header, host }; + try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); } var tls_version: tls.ProtocolVersion = undefined; @@ -329,20 +326,26 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var cleartext_fragment_start: usize = 0; var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; - var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..tls.record_header_len]; - const record_ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); + // Ensure the input buffer pointer is stable in this scope. + input.rebaseCapacity(tls.max_ciphertext_record_len); + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked + input.toss(2); // legacy_version + const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked + if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow; + const record_buffer = input.take(record_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, .handshake => { - std.debug.assert(tls_version == .tls_1_3); + assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -374,7 +377,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { - std.debug.assert(tls_version == .tls_1_2); + assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -412,14 +415,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client switch (ct) { .alert => { ctd.ensure(2) catch continue :fragment; - const level = ctd.decode(tls.AlertLevel); - const desc = ctd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + if (options.alert) |a| a.* = .{ + .level = ctd.decode(tls.Alert.Level), + .description = ctd.decode(tls.Alert.Description), + }; + return error.TlsAlert; }, .change_cipher_spec => { ctd.ensure(1) catch continue :fragment; @@ -533,7 +533,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, @@ -707,7 +707,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client &client_hello_rand, &server_hello_rand, }, 48); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .CLIENT_RANDOM = &master_secret, @@ -755,11 +755,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client nonce, pv.app_cipher.client_write_key, ); - const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [3][]const u8 = .{ + &client_key_exchange_msg, + &client_change_cipher_spec_msg, + &client_verify_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); }, } write_seq += 1; @@ -820,15 +821,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const nonce = pv.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [2][]const u8 = .{ + &client_change_cipher_spec_msg, + &finished_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .counter = key_seq, .client_random = &client_hello_rand, }, .{ @@ -855,8 +856,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client else => unreachable, }, }; - const leftover = d.rest(); - var client: Client = .{ + if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .writer = ssl_key_log.writer, + }; + return .{ + .input = input, + .reader = .{ + .buffer = options.read_buffer, + .vtable = &.{ .stream = stream }, + .seek = 0, + .end = 0, + }, + .output = output, + .writer = .{ + .buffer = options.write_buffer, + .vtable = &.{ + .drain = drain, + .sendFile = Writer.unimplementedSendFile, + }, + }, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -868,22 +889,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client .tls_1_2 => write_seq, else => unreachable, }, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), .received_close_notify = false, - .allow_truncation_attacks = false, + .allow_truncation_attacks = options.allow_truncation_attacks, .application_cipher = app_cipher, - .partially_read_buffer = undefined, - .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ - .client_key_seq = key_seq, - .server_key_seq = key_seq, - .client_random = client_hello_rand, - .file = key_log_file, - } else null, + .ssl_key_log = options.ssl_key_log, }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; }, else => return error.TlsUnexpectedMessage, } @@ -897,94 +907,48 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); +fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + const c: *Client = @fieldParentPtr("writer", w); + if (true) @panic("update to use the buffer and flush"); + const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + var total_clear: usize = 0; + var ciphertext_end: usize = 0; + for (sliced_data) |buf| { + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (total_clear < buf.len) break; + } + output.advance(ciphertext_end); + return total_clear; } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.posix.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].len) { - const encrypted_amt = iovecs_buf[i].len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i].base += amt; - iovecs_buf[i].len -= amt; - } +/// Sends a `close_notify` alert, which is necessary for the server to +/// distinguish between a properly finished TLS session, or a truncation +/// attack. +pub fn end(c: *Client) Writer.Error!void { + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); + output.advance(prepared.cleartext_len); + return prepared.ciphertext_end; } fn prepareCiphertextRecord( c: *Client, - iovecs: []std.posix.iovec_const, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, ) struct { - iovec_end: usize, ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, + cleartext_len: usize, } { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; var ciphertext_end: usize = 0; - var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { @@ -992,18 +956,15 @@ fn prepareCiphertextRecord( const pv = &p.tls_1_3; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); @@ -1012,7 +973,6 @@ fn prepareCiphertextRecord( const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -1030,38 +990,27 @@ fn prepareCiphertextRecord( }; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, .tls_1_2 => { const pv = &p.tls_1_2; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const message_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (message_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); bytes_i += message_len; const cleartext = cleartext_buf[0..message_len]; - const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; record_header.* = .{@intFromEnum(inner_content_type)} ++ @@ -1083,13 +1032,6 @@ fn prepareCiphertextRecord( ciphertext_end += P.mac_length; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, else => unreachable, @@ -1098,421 +1040,194 @@ fn prepareCiphertextRecord( } pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; + return c.received_close_notify; } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].len) { - amt -= iovecs[vec_i].len; - vec_i += 1; - } - iovecs[vec_i].base += amt; - iovecs[vec_i].len -= amt; - } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } - } - - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer` space. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, +fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { + const c: *Client = @fieldParentPtr("reader", r); + if (c.eof()) return error.EndOfStream; + const input = c.input; + // If at least one full encrypted record is not buffered, read once. + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => { + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + return error.EndOfStream; + } else { + return failRead(c, error.TlsConnectionTruncated); + } }, - .{ - .base = &in_stack_buffer, - .len = in_stack_buffer.len, + error.ReadFailed => return error.ReadFailed, + }; + const ct: tls.ContentType = @enumFromInt(record_header[0]); + const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big); + _ = legacy_version; + const record_len = mem.readInt(u16, record_header[3..][0..2], .big); + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); + const record_end = 5 + record_len; + if (record_end > input.buffered().len) { + input.fillMore() catch |err| switch (err) { + error.EndOfStream => return failRead(c, error.TlsConnectionTruncated), + error.ReadFailed => return error.ReadFailed, + }; + if (record_end > input.buffered().len) return 0; + } + + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = input.take(tls.record_header_len) catch unreachable; // already peeked + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked + const nonce = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return failRead(c, error.TlsBadRecordMac); + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked + const ad = std.mem.toBytes(big(c.read_seq)) ++ + ad_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = input.take(message_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return failRead(c, error.TlsBadRecordMac); + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, }, }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end; - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { - inline else => |*p| switch (c.tls_version) { - .tls_1_3 => { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); - break :nonce @as(V, pv.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch - return error.TlsBadRecordMac; - const msg = mem.trimEnd(u8, cleartext, "\x00"); - break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; - }, - .tls_1_2 => { - const pv = &p.tls_1_2; - const P = @TypeOf(p.*); - const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - const ad = std.mem.toBytes(big(c.read_seq)) ++ - frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); - const record_iv = frag[in..][0..P.record_iv_length].*; - in += P.record_iv_length; - const masked_read_seq = c.read_seq & - comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); - const nonce: [P.AEAD.nonce_length]u8 = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); - break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; - }; - const ciphertext = frag[in..][0..message_len]; - in += message_len; - const auth_tag = frag[in..][0..P.mac_length].*; - in += P.mac_length; - const out_buf = vp.peek(); - const cleartext_buf = if (message_len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch - return error.TlsBadRecordMac; - break :cleartext .{ cleartext, ct }; - }, - else => unreachable, - }, - }; - c.read_seq = try std.math.add(u64, c.read_seq, 1); - switch (inner_ct) { - .alert => { - if (cleartext.len != 2) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { + c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); + switch (inner_ct) { + .alert => { + if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); + const alert: tls.Alert = .{ + .level = @enumFromInt(cleartext[0]), + .description = @enumFromInt(cleartext[1]), + }; + switch (alert.description) { + .close_notify => { c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.serverCounter(), - .client_random = &key_log.client_random, - }, .{ - .SERVER_TRAFFIC_SECRET = &server_secret, - }); - pv.server_secret = server_secret; - pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.clientCounter(), - .client_random = &key_log.client_random, - }, .{ - .CLIENT_TRAFFIC_SECRET = &client_secret, - }); - pv.client_secret = client_secret; - pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], - cleartext, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); - } else { - const amt = vp.put(cleartext); - if (amt < cleartext.len) { - const rest = cleartext[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); + return 0; + }, + .user_canceled => { + // TODO: handle server-side closures + return failRead(c, error.TlsUnexpectedMessage); + }, + else => { + c.alert = alert; + return failRead(c, error.TlsAlert); + }, + } + }, + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len); + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return failRead(c, error.TlsIllegalParameter), + } + }, + else => return failRead(c, error.TlsUnexpectedMessage), } - }, - else => return error.TlsUnexpectedMessage, - } - in = end; + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; + } + return 0; + }, + .application_data => { + if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize); + try w.writeAll(cleartext); + return cleartext.len; + }, + else => return failRead(c, error.TlsUnexpectedMessage), } } -fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { - const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; - defer if (locked) key_log_file.unlock(); - key_log_file.seekFromEnd(0) catch {}; - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.deprecatedWriter().print("{s}" ++ +fn failRead(c: *Client, err: ReadError) error{ReadFailed} { + c.read_err = err; + return error.ReadFailed; +} + +fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void { + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random, @@ -1520,62 +1235,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi }) catch {}; } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} - -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; -} - -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} - -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; - } -} - -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); - fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, @@ -1836,81 +1495,6 @@ const CertificatePublicKey = struct { } }; -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.base[vp.off..v.len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].len) { - vp.off = 0; - vp.idx += 1; - } - } - - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len; - return total; - } -}; - -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.len) { - iovec.len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.len; - } - return iovecs; -} - /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. @@ -1954,7 +1538,3 @@ else .AES_256_GCM_SHA384, .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); - -test { - _ = StreamInterface; -} diff --git a/lib/std/http.zig b/lib/std/http.zig index 6075a2fe6d..c64a946a25 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -343,10 +343,9 @@ pub const Reader = struct { /// read from `in`. trailers: []const u8 = &.{}, body_err: ?BodyError = null, - /// Stolen from `in`. - head_buffer: []u8 = &.{}, - - pub const max_chunk_header_len = 22; + /// Determines at which point `error.HttpHeadersOversize` occurs, as well + /// as the minimum buffer capacity of `in`. + max_head_len: usize, pub const RemainingChunkLen = enum(u64) { head = 0, @@ -398,19 +397,11 @@ pub const Reader = struct { ReadFailed, }; - pub fn restituteHeadBuffer(reader: *Reader) void { - reader.in.restitute(reader.head_buffer.len); - reader.head_buffer.len = 0; - } - - /// Buffers the entire head into `head_buffer`, invalidating the previous - /// `head_buffer`, if any. + /// Buffers the entire head. pub fn receiveHead(reader: *Reader) HeadError!void { reader.trailers = &.{}; const in = reader.in; - in.restitute(reader.head_buffer.len); - reader.head_buffer.len = 0; - in.rebase(); + try in.rebase(reader.max_head_len); var hp: HeadParser = .{}; var head_end: usize = 0; while (true) { From 98ee9360555730328621aa8b9b9170a4e2b0df7b Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 16:38:41 -0700 Subject: [PATCH 54/70] http fixes --- lib/std/Io/Writer.zig | 94 +++++++++++--- lib/std/Uri.zig | 236 +++++++++++++--------------------- lib/std/crypto/tls/Client.zig | 4 +- lib/std/http.zig | 82 +++++++++--- lib/std/http/Client.zig | 29 +++-- lib/std/http/Server.zig | 40 +++--- lib/std/http/test.zig | 56 ++++---- 7 files changed, 297 insertions(+), 244 deletions(-) diff --git a/lib/std/Io/Writer.zig b/lib/std/Io/Writer.zig index a84077f8f3..797a69914c 100644 --- a/lib/std/Io/Writer.zig +++ b/lib/std/Io/Writer.zig @@ -191,29 +191,87 @@ pub fn writeSplatHeader( data: []const []const u8, splat: usize, ) Error!usize { - const new_end = w.end + header.len; - if (new_end <= w.buffer.len) { - @memcpy(w.buffer[w.end..][0..header.len], header); - w.end = new_end; - return header.len + try writeSplat(w, data, splat); + return writeSplatHeaderLimit(w, header, data, splat, .unlimited); +} + +/// Equivalent to `writeSplatHeader` but writes at most `limit` bytes. +pub fn writeSplatHeaderLimit( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: Limit, +) Error!usize { + var remaining = @intFromEnum(limit); + { + const copy_len = @min(header.len, w.buffer.len - w.end, remaining); + if (header.len - copy_len != 0) return writeSplatHeaderLimitFinish(w, header, data, splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], header[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; } - var vecs: [8][]const u8 = undefined; // Arbitrarily chosen size. - var i: usize = 1; - vecs[0] = header; - for (data[0 .. data.len - 1]) |buf| { - if (buf.len == 0) continue; - vecs[i] = buf; - i += 1; - if (vecs.len - i == 0) break; + for (data[0 .. data.len - 1], 0..) |buf, i| { + const copy_len = @min(buf.len, w.buffer.len - w.end, remaining); + if (buf.len - copy_len != 0) return @intFromEnum(limit) - remaining + + try writeSplatHeaderLimitFinish(w, &.{}, data[i..], splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], buf[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; } const pattern = data[data.len - 1]; - const new_splat = s: { - if (pattern.len == 0 or vecs.len - i == 0) break :s 1; + const splat_n = pattern.len * splat; + if (splat_n > @min(w.buffer.len - w.end, remaining)) { + const buffered_n = @intFromEnum(limit) - remaining; + const written = try writeSplatHeaderLimitFinish(w, &.{}, data[data.len - 1 ..][0..1], splat, remaining); + return buffered_n + written; + } + + for (0..splat) |_| { + @memcpy(w.buffer[w.end..][0..pattern.len], pattern); + w.end += pattern.len; + } + + remaining -= splat_n; + return @intFromEnum(limit) - remaining; +} + +fn writeSplatHeaderLimitFinish( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: usize, +) Error!usize { + var remaining = limit; + var vecs: [8][]const u8 = undefined; + var i: usize = 0; + v: { + if (header.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = header[0..copy_len]; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + } + for (data[0 .. data.len - 1]) |buf| if (buf.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = buf; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + if (vecs.len - i == 0) break :v; + }; + const pattern = data[data.len - 1]; + if (splat == 1) { + vecs[i] = pattern[0..@min(remaining, pattern.len)]; + i += 1; + break :v; + } vecs[i] = pattern; i += 1; - break :s splat; - }; - return w.vtable.drain(w, vecs[0..i], new_splat); + return w.vtable.drain(w, (&vecs)[0..i], @min(remaining / pattern.len, splat)); + } + return w.vtable.drain(w, (&vecs)[0..i], 1); } test "writeSplatHeader splatting avoids buffer aliasing temptation" { diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 19af1512c2..5390bee5b5 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -4,6 +4,8 @@ const std = @import("std.zig"); const testing = std.testing; const Uri = @This(); +const Allocator = std.mem.Allocator; +const Writer = std.Io.Writer; scheme: []const u8, user: ?Component = null, @@ -14,6 +16,32 @@ path: Component = Component.empty, query: ?Component = null, fragment: ?Component = null, +pub const host_name_max = 255; + +/// Returned value may point into `buffer` or be the original string. +/// +/// Suggested buffer length: `host_name_max`. +/// +/// See also: +/// * `getHostAlloc` +pub fn getHost(uri: Uri, buffer: []u8) error{ UriMissingHost, UriHostTooLong }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + return component.toRaw(buffer) catch |err| switch (err) { + error.NoSpaceLeft => return error.UriHostTooLong, + }; +} + +/// Returned value may point into `buffer` or be the original string. +/// +/// See also: +/// * `getHost` +pub fn getHostAlloc(uri: Uri, arena: Allocator) error{ UriMissingHost, UriHostTooLong, OutOfMemory }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + const result = try component.toRawMaybeAlloc(arena); + if (result.len > host_name_max) return error.UriHostTooLong; + return result; +} + pub const Component = union(enum) { /// Invalid characters in this component must be percent encoded /// before being printed as part of a URI. @@ -30,11 +58,19 @@ pub const Component = union(enum) { }; } + /// Returned value may point into `buffer` or be the original string. + pub fn toRaw(component: Component, buffer: []u8) error{NoSpaceLeft}![]const u8 { + return switch (component) { + .raw => |raw| raw, + .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| + try std.fmt.bufPrint(buffer, "{f}", .{std.fmt.alt(component, .formatRaw)}) + else + percent_encoded, + }; + } + /// Allocates the result with `arena` only if needed, so the result should not be freed. - pub fn toRawMaybeAlloc( - component: Component, - arena: std.mem.Allocator, - ) std.mem.Allocator.Error![]const u8 { + pub fn toRawMaybeAlloc(component: Component, arena: Allocator) Allocator.Error![]const u8 { return switch (component) { .raw => |raw| raw, .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| @@ -44,7 +80,7 @@ pub const Component = union(enum) { }; } - pub fn formatRaw(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatRaw(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try w.writeAll(raw), .percent_encoded => |percent_encoded| { @@ -67,56 +103,56 @@ pub const Component = union(enum) { } } - pub fn formatEscaped(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatEscaped(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUnreserved), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatUser(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatUser(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUserChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPassword(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPassword(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPasswordChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatHost(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatHost(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isHostChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPath(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPath(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPathChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatQuery(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatQuery(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isQueryChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatFragment(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatFragment(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isFragmentChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn percentEncode(w: *std.io.Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) std.io.Writer.Error!void { + pub fn percentEncode(w: *Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) Writer.Error!void { var start: usize = 0; for (raw, 0..) |char, index| { if (isValidChar(char)) continue; @@ -165,17 +201,15 @@ pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort }; /// The return value will contain strings pointing into the original `text`. /// Each component that is provided, will be non-`null`. pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { - var reader = SliceReader{ .slice = text }; - var uri: Uri = .{ .scheme = scheme, .path = undefined }; + var i: usize = 0; - if (reader.peekPrefix("//")) a: { // authority part - std.debug.assert(reader.get().? == '/'); - std.debug.assert(reader.get().? == '/'); - - const authority = reader.readUntil(isAuthoritySeparator); + if (std.mem.startsWith(u8, text, "//")) a: { + i = std.mem.indexOfAnyPos(u8, text, 2, &authority_sep) orelse text.len; + const authority = text[2..i]; if (authority.len == 0) { - if (reader.peekPrefix("/")) break :a else return error.InvalidFormat; + if (!std.mem.startsWith(u8, text[2..], "/")) return error.InvalidFormat; + break :a; } var start_of_host: usize = 0; @@ -225,26 +259,28 @@ pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { uri.host = .{ .percent_encoded = authority[start_of_host..end_of_host] }; } - uri.path = .{ .percent_encoded = reader.readUntil(isPathSeparator) }; + const path_start = i; + i = std.mem.indexOfAnyPos(u8, text, path_start, &path_sep) orelse text.len; + uri.path = .{ .percent_encoded = text[path_start..i] }; - if ((reader.peek() orelse 0) == '?') { // query part - std.debug.assert(reader.get().? == '?'); - uri.query = .{ .percent_encoded = reader.readUntil(isQuerySeparator) }; + if (std.mem.startsWith(u8, text[i..], "?")) { + const query_start = i + 1; + i = std.mem.indexOfScalarPos(u8, text, query_start, '#') orelse text.len; + uri.query = .{ .percent_encoded = text[query_start..i] }; } - if ((reader.peek() orelse 0) == '#') { // fragment part - std.debug.assert(reader.get().? == '#'); - uri.fragment = .{ .percent_encoded = reader.readUntilEof() }; + if (std.mem.startsWith(u8, text[i..], "#")) { + uri.fragment = .{ .percent_encoded = text[i + 1 ..] }; } return uri; } -pub fn format(uri: *const Uri, writer: *std.io.Writer) std.io.Writer.Error!void { +pub fn format(uri: *const Uri, writer: *Writer) Writer.Error!void { return writeToStream(uri, writer, .all); } -pub fn writeToStream(uri: *const Uri, writer: *std.io.Writer, flags: Format.Flags) std.io.Writer.Error!void { +pub fn writeToStream(uri: *const Uri, writer: *Writer, flags: Format.Flags) Writer.Error!void { if (flags.scheme) { try writer.print("{s}:", .{uri.scheme}); if (flags.authority and uri.host != null) { @@ -318,7 +354,7 @@ pub const Format = struct { }; }; - pub fn default(f: Format, writer: *std.io.Writer) std.io.Writer.Error!void { + pub fn default(f: Format, writer: *Writer) Writer.Error!void { return writeToStream(f.uri, writer, f.flags); } }; @@ -327,41 +363,33 @@ pub fn fmt(uri: *const Uri, flags: Format.Flags) std.fmt.Formatter(Format, Forma return .{ .data = .{ .uri = uri, .flags = flags } }; } -/// Parses the URI or returns an error. -/// The return value will contain strings pointing into the -/// original `text`. Each component that is provided, will be non-`null`. +/// The return value will contain strings pointing into the original `text`. +/// Each component that is provided will be non-`null`. pub fn parse(text: []const u8) ParseError!Uri { - var reader: SliceReader = .{ .slice = text }; - const scheme = reader.readWhile(isSchemeChar); - - // after the scheme, a ':' must appear - if (reader.get()) |c| { - if (c != ':') - return error.UnexpectedCharacter; - } else { - return error.InvalidFormat; - } - - return parseAfterScheme(scheme, reader.readUntilEof()); + const end = for (text, 0..) |byte, i| { + if (!isSchemeChar(byte)) break i; + } else text.len; + // After the scheme, a ':' must appear. + if (end >= text.len) return error.InvalidFormat; + if (text[end] != ':') return error.UnexpectedCharacter; + return parseAfterScheme(text[0..end], text[end + 1 ..]); } pub const ResolveInPlaceError = ParseError || error{NoSpaceLeft}; /// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. -/// Copies `new` to the beginning of `aux_buf.*`, allowing the slices to overlap, -/// then parses `new` as a URI, and then resolves the path in place. +/// +/// Assumes new location is already copied to the beginning of `aux_buf.*`. +/// Parses that new location as a URI, and then resolves the path in place. +/// /// If a merge needs to take place, the newly constructed path will be stored -/// in `aux_buf.*` just after the copied `new`, and `aux_buf.*` will be modified -/// to only contain the remaining unused space. -pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: *[]u8) ResolveInPlaceError!Uri { - std.mem.copyForwards(u8, aux_buf.*, new); - // At this point, new is an invalid pointer. - const new_mut = aux_buf.*[0..new.len]; - aux_buf.* = aux_buf.*[new.len..]; - - const new_parsed = parse(new_mut) catch |err| - (parseAfterScheme("", new_mut) catch return err); - // As you can see above, `new_mut` is not a const pointer. +/// in `aux_buf.*` just after the copied location, and `aux_buf.*` will be +/// modified to only contain the remaining unused space. +pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceError!Uri { + const new = aux_buf.*[0..new_len]; + const new_parsed = parse(new) catch |err| (parseAfterScheme("", new) catch return err); + aux_buf.* = aux_buf.*[new_len..]; + // As you can see above, `new` is not a const pointer. const new_path: []u8 = @constCast(new_parsed.path.percent_encoded); if (new_parsed.scheme.len > 0) return .{ @@ -461,7 +489,7 @@ test remove_dot_segments { /// 5.2.3. Merge Paths fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Component { - var aux: std.io.Writer = .fixed(aux_buf.*); + var aux: Writer = .fixed(aux_buf.*); if (!base.isEmpty()) { base.formatPath(&aux) catch return error.NoSpaceLeft; aux.end = std.mem.lastIndexOfScalar(u8, aux.buffered(), '/') orelse return remove_dot_segments(new); @@ -472,59 +500,6 @@ fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Co return merged_path; } -const SliceReader = struct { - const Self = @This(); - - slice: []const u8, - offset: usize = 0, - - fn get(self: *Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - const c = self.slice[self.offset]; - self.offset += 1; - return c; - } - - fn peek(self: Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - return self.slice[self.offset]; - } - - fn readWhile(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntil(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and !predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntilEof(self: *Self) []const u8 { - const start = self.offset; - self.offset = self.slice.len; - return self.slice[start..]; - } - - fn peekPrefix(self: Self, prefix: []const u8) bool { - if (self.offset + prefix.len > self.slice.len) - return false; - return std.mem.eql(u8, self.slice[self.offset..][0..prefix.len], prefix); - } -}; - /// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) fn isSchemeChar(c: u8) bool { return switch (c) { @@ -533,19 +508,6 @@ fn isSchemeChar(c: u8) bool { }; } -/// reserved = gen-delims / sub-delims -fn isReserved(c: u8) bool { - return isGenLimit(c) or isSubLimit(c); -} - -/// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" -fn isGenLimit(c: u8) bool { - return switch (c) { - ':', ',', '?', '#', '[', ']', '@' => true, - else => false, - }; -} - /// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" /// / "*" / "+" / "," / ";" / "=" fn isSubLimit(c: u8) bool { @@ -585,26 +547,8 @@ fn isQueryChar(c: u8) bool { const isFragmentChar = isQueryChar; -fn isAuthoritySeparator(c: u8) bool { - return switch (c) { - '/', '?', '#' => true, - else => false, - }; -} - -fn isPathSeparator(c: u8) bool { - return switch (c) { - '?', '#' => true, - else => false, - }; -} - -fn isQuerySeparator(c: u8) bool { - return switch (c) { - '#' => true, - else => false, - }; -} +const authority_sep: [3]u8 = .{ '/', '?', '#' }; +const path_sep: [2]u8 = .{ '?', '#' }; test "basic" { const parsed = try parse("https://ziglang.org/download"); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 082fc9da70..5a9e333b67 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -328,7 +328,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; fragment: while (true) { // Ensure the input buffer pointer is stable in this scope. - input.rebaseCapacity(tls.max_ciphertext_record_len); + input.rebase(tls.max_ciphertext_record_len) catch |err| switch (err) { + error.EndOfStream => {}, // We have assurance the remainder of stream can be buffered. + }; const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { error.EndOfStream => return error.TlsConnectionTruncated, error.ReadFailed => return error.ReadFailed, diff --git a/lib/std/http.zig b/lib/std/http.zig index c64a946a25..24bc016c30 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -343,9 +343,6 @@ pub const Reader = struct { /// read from `in`. trailers: []const u8 = &.{}, body_err: ?BodyError = null, - /// Determines at which point `error.HttpHeadersOversize` occurs, as well - /// as the minimum buffer capacity of `in`. - max_head_len: usize, pub const RemainingChunkLen = enum(u64) { head = 0, @@ -397,27 +394,34 @@ pub const Reader = struct { ReadFailed, }; - /// Buffers the entire head. - pub fn receiveHead(reader: *Reader) HeadError!void { + /// Buffers the entire head inside `in`. + /// + /// The resulting memory is invalidated by any subsequent consumption of + /// the input stream. + pub fn receiveHead(reader: *Reader) HeadError![]const u8 { reader.trailers = &.{}; const in = reader.in; - try in.rebase(reader.max_head_len); var hp: HeadParser = .{}; - var head_end: usize = 0; + var head_len: usize = 0; while (true) { - if (head_end >= in.buffer.len) return error.HttpHeadersOversize; - in.fillMore() catch |err| switch (err) { - error.EndOfStream => switch (head_end) { - 0 => return error.HttpConnectionClosing, - else => return error.HttpRequestTruncated, - }, - error.ReadFailed => return error.ReadFailed, - }; - head_end += hp.feed(in.buffered()[head_end..]); + if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[head_len..]; + if (remaining.len == 0) { + in.fillMore() catch |err| switch (err) { + error.EndOfStream => switch (head_len) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }, + error.ReadFailed => return error.ReadFailed, + }; + continue; + } + head_len += hp.feed(remaining); if (hp.state == .finished) { - reader.head_buffer = in.steal(head_end); reader.state = .received_head; - return; + const head_buffer = in.buffered()[0..head_len]; + in.toss(head_len); + return head_buffer; } } } @@ -786,7 +790,7 @@ pub const BodyWriter = struct { }; pub fn isEliding(w: *const BodyWriter) bool { - return w.writer.vtable.drain == Writer.discardingDrain; + return w.writer.vtable.drain == elidingDrain; } /// Sends all buffered data across `BodyWriter.http_protocol_output`. @@ -930,6 +934,46 @@ pub const BodyWriter = struct { return w.consume(n); } + pub fn elidingDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + const slice = data[0 .. data.len - 1]; + const pattern = data[slice.len]; + var written: usize = pattern.len * splat; + for (slice) |bytes| written += bytes.len; + switch (bw.state) { + .content_length => |*len| len.* -= written + w.end, + else => {}, + } + w.end = 0; + return written; + } + + pub fn elidingSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @fieldParentPtr("writer", w); + if (File.Handle == void) return error.Unimplemented; + if (builtin.zig_backend == .stage2_aarch64) return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= w.end, + else => {}, + } + w.end = 0; + if (limit == .nothing) return 0; + if (file_reader.getSize()) |size| { + const n = limit.minInt64(size - file_reader.pos); + if (n == 0) return error.EndOfStream; + file_reader.seekBy(@intCast(n)) catch return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= n, + else => {}, + } + return n; + } else |_| { + // Error is observable on `file_reader` instance, and it is better to + // treat the file as a pipe. + return error.Unimplemented; + } + } + /// Returns `null` if size cannot be computed without making any syscalls. pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { const bw: *BodyWriter = @fieldParentPtr("writer", w); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 61a9eeb5c3..6660761930 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -821,7 +821,6 @@ pub const Request = struct { /// Returns the request's `Connection` back to the pool of the `Client`. pub fn deinit(r: *Request) void { - r.reader.restituteHeadBuffer(); if (r.connection) |connection| { connection.closing = connection.closing or switch (r.reader.state) { .ready => false, @@ -908,13 +907,13 @@ pub const Request = struct { const connection = r.connection.?; const w = connection.writer(); - try r.method.write(w); + try r.method.format(w); try w.writeByte(' '); if (r.method == .CONNECT) { - try uri.writeToStream(.{ .authority = true }, w); + try uri.writeToStream(w, .{ .authority = true }); } else { - try uri.writeToStream(.{ + try uri.writeToStream(w, .{ .scheme = connection.proxied, .authentication = connection.proxied, .authority = connection.proxied, @@ -928,7 +927,7 @@ pub const Request = struct { if (try emitOverridableHeader("host: ", r.headers.host, w)) { try w.writeAll("host: "); - try uri.writeToStream(.{ .authority = true }, w); + try uri.writeToStream(w, .{ .authority = true }); try w.writeAll("\r\n"); } @@ -1046,10 +1045,10 @@ pub const Request = struct { pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { var aux_buf = redirect_buffer; while (true) { - try r.reader.receiveHead(); + const head_buffer = try r.reader.receiveHead(); const response: Response = .{ .request = r, - .head = Response.Head.parse(r.reader.head_buffer) catch return error.HttpHeadersInvalid, + .head = Response.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, }; const head = &response.head; @@ -1121,7 +1120,6 @@ pub const Request = struct { _ = reader.discardRemaining() catch |err| switch (err) { error.ReadFailed => return r.reader.body_err.?, }; - r.reader.restituteHeadBuffer(); } const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) { error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid, @@ -1302,12 +1300,13 @@ pub const basic_authorization = struct { } pub fn write(uri: Uri, out: *Writer) Writer.Error!void { - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; + var buf: [max_user_len + 1 + max_password_len]u8 = undefined; var w: Writer = .fixed(&buf); - w.print("{fuser}:{fpassword}", .{ - uri.user orelse Uri.Component.empty, - uri.password orelse Uri.Component.empty, - }) catch unreachable; + const user: Uri.Component = uri.user orelse .empty; + const password: Uri.Component = uri.user orelse .empty; + user.formatUser(&w) catch unreachable; + w.writeByte(':') catch unreachable; + password.formatPassword(&w) catch unreachable; try out.print("Basic {b64}", .{w.buffered()}); } }; @@ -1697,6 +1696,7 @@ pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadErro StreamTooLong, /// TODO provide optional diagnostics when this occurs or break into more error codes WriteFailed, + UnsupportedCompressionMethod, }; /// Perform a one-shot HTTP request with the provided options. @@ -1748,7 +1748,8 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const decompress_buffer: []u8 = switch (response.head.content_encoding) { .identity => &.{}, .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), - else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024), + .deflate, .gzip => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.flate.max_window_len), + .compress => return error.UnsupportedCompressionMethod, }; defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 004741d1ae..aa6c72a5b3 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -6,7 +6,7 @@ const mem = std.mem; const Uri = std.Uri; const assert = std.debug.assert; const testing = std.testing; -const Writer = std.io.Writer; +const Writer = std.Io.Writer; const Server = @This(); @@ -21,7 +21,7 @@ reader: http.Reader, /// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. /// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(in: *std.io.Reader, out: *Writer) Server { +pub fn init(in: *std.Io.Reader, out: *Writer) Server { return .{ .reader = .{ .in = in, @@ -33,25 +33,22 @@ pub fn init(in: *std.io.Reader, out: *Writer) Server { }; } -pub fn deinit(s: *Server) void { - s.reader.restituteHeadBuffer(); -} - pub const ReceiveHeadError = http.Reader.HeadError || error{ /// Client sent headers that did not conform to the HTTP protocol. /// - /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// To find out more detailed diagnostics, `Request.head_buffer` can be /// passed directly to `Request.Head.parse`. HttpHeadersInvalid, }; pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - try s.reader.receiveHead(); + const head_buffer = try s.reader.receiveHead(); return .{ .server = s, + .head_buffer = head_buffer, // No need to track the returned error here since users can repeat the // parse with the header buffer to get detailed diagnostics. - .head = Request.Head.parse(s.reader.head_buffer) catch return error.HttpHeadersInvalid, + .head = Request.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, }; } @@ -60,6 +57,7 @@ pub const Request = struct { /// Pointers in this struct are invalidated with the next call to /// `receiveHead`. head: Head, + head_buffer: []const u8, respond_err: ?RespondError = null, pub const RespondError = error{ @@ -229,7 +227,7 @@ pub const Request = struct { pub fn iterateHeaders(r: *Request) http.HeaderIterator { assert(r.server.reader.state == .received_head); - return http.HeaderIterator.init(r.server.reader.head_buffer); + return http.HeaderIterator.init(r.head_buffer); } test iterateHeaders { @@ -244,7 +242,6 @@ pub const Request = struct { .reader = .{ .in = undefined, .state = .received_head, - .head_buffer = @constCast(request_bytes), .interface = undefined, }, .out = undefined, @@ -253,6 +250,7 @@ pub const Request = struct { var request: Request = .{ .server = &server, .head = undefined, + .head_buffer = @constCast(request_bytes), }; var it = request.iterateHeaders(); @@ -435,10 +433,8 @@ pub const Request = struct { for (o.extra_headers) |header| { assert(header.name.len != 0); - try out.writeAll(header.name); - try out.writeAll(": "); - try out.writeAll(header.value); - try out.writeAll("\r\n"); + var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&bufs); } try out.writeAll("\r\n"); @@ -453,7 +449,13 @@ pub const Request = struct { return if (elide_body) .{ .http_protocol_output = request.server.out, .state = state, - .writer = .discarding(buffer), + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.elidingDrain, + .sendFile = http.BodyWriter.elidingSendFile, + }, + }, } else .{ .http_protocol_output = request.server.out, .state = state, @@ -564,7 +566,7 @@ pub const Request = struct { /// /// See `readerExpectNone` for an infallible alternative that cannot write /// to the server output stream. - pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.io.Reader { + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.Io.Reader { const flush = request.head.expect != null; try writeExpectContinue(request); if (flush) try request.server.out.flush(); @@ -576,7 +578,7 @@ pub const Request = struct { /// this function. /// /// Asserts that this function is only called once. - pub fn readerExpectNone(request: *Request, buffer: []u8) *std.io.Reader { + pub fn readerExpectNone(request: *Request, buffer: []u8) *std.Io.Reader { assert(request.server.reader.state == .received_head); assert(request.head.expect == null); if (!request.head.method.requestHasBody()) return .ending; @@ -640,7 +642,7 @@ pub const Request = struct { /// See https://tools.ietf.org/html/rfc6455 pub const WebSocket = struct { key: []const u8, - input: *std.io.Reader, + input: *std.Io.Reader, output: *Writer, pub const Header0 = packed struct(u8) { diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 4c3466d5c9..df80ca6339 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -65,7 +65,7 @@ test "trailers" { try req.sendBodiless(); var response = try req.receiveHead(&.{}); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -183,7 +183,11 @@ test "echo content server" { if (request.head.expect) |expect_header_value| { if (mem.eql(u8, expect_header_value, "garbage")) { try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{})); - try request.respond("", .{ .keep_alive = false }); + request.head.expect = null; + try request.respond("", .{ + .keep_alive = false, + .status = .expectation_failed, + }); continue; } } @@ -204,7 +208,7 @@ test "echo content server" { // request.head.target, //}); - const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .limited(8192)); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); try expect(mem.startsWith(u8, request.head.target, "/echo-content")); @@ -273,7 +277,6 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { for (0..500) |i| { try w.print("{d}, ah ha ha!\n", .{i}); } - try expectEqual(7390, w.count); try w.flush(); try response.end(); try expectEqual(.closing, server.reader.state); @@ -291,7 +294,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -362,7 +365,7 @@ test "receiving arbitrary http headers from the client" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -408,12 +411,10 @@ test "general client/server API coverage" { fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { const log = std.log.scoped(.server); - log.info("{f} {s} {s}", .{ - request.head.method, @tagName(request.head.version), request.head.target, - }); + log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); const gpa = std.testing.allocator; - const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .limited(8192)); + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .unlimited); defer gpa.free(body); if (mem.startsWith(u8, request.head.target, "/get")) { @@ -447,7 +448,8 @@ test "general client/server API coverage" { try w.writeAll("Hello, World!\n"); } - try w.writeAll("Hello, World!\n" ** 1024); + var vec: [1][]const u8 = .{"Hello, World!\n"}; + try w.writeSplatAll(&vec, 1024); i = 0; while (i < 5) : (i += 1) { @@ -556,7 +558,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -579,7 +581,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192 * 1024)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); @@ -601,7 +603,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -625,7 +627,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -648,7 +650,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -674,7 +676,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -703,7 +705,7 @@ test "general client/server API coverage" { try std.testing.expectEqual(.ok, response.head.status); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); @@ -740,7 +742,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -762,7 +764,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -784,7 +786,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -825,7 +827,7 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Encoded redirect successful!\n", body); @@ -915,7 +917,7 @@ test "Server streams both reading and writing" { try body_writer.writer.writeAll("fish"); try body_writer.end(); - const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); try expectEqualStrings("ONE FISH", body); @@ -947,7 +949,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -980,7 +982,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1034,7 +1036,7 @@ fn echoTests(client: *http.Client, port: u16) !void { var response = try req.receiveHead(&redirect_buffer); try expectEqual(.ok, response.head.status); - const body = try response.reader(&.{}).allocRemaining(gpa, .limited(8192)); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1175,7 +1177,7 @@ test "redirect to different connection" { var response = try req.receiveHead(&redirect_buffer); var reader = response.reader(&.{}); - const body = try reader.allocRemaining(gpa, .limited(8192)); + const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("good job, you pass", body); From fdf0e4612e4441d2f0386839078d8a7adf738a3e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 21:30:27 -0700 Subject: [PATCH 55/70] update build system to new http.Server API --- lib/std/Build/Fuzz.zig | 39 +++++++-------- lib/std/Build/WebServer.zig | 95 ++++++++++++++++--------------------- lib/std/http/Server.zig | 59 ++++++++++++----------- 3 files changed, 90 insertions(+), 103 deletions(-) diff --git a/lib/std/Build/Fuzz.zig b/lib/std/Build/Fuzz.zig index a25b501755..bc10f7907a 100644 --- a/lib/std/Build/Fuzz.zig +++ b/lib/std/Build/Fuzz.zig @@ -234,7 +234,7 @@ pub const Previous = struct { }; pub fn sendUpdate( fuzz: *Fuzz, - socket: *std.http.WebSocket, + socket: *std.http.Server.WebSocket, prev: *Previous, ) !void { fuzz.coverage_mutex.lock(); @@ -263,36 +263,36 @@ pub fn sendUpdate( .string_bytes_len = @intCast(coverage_map.coverage.string_bytes.items.len), .start_timestamp = coverage_map.start_timestamp, }; - const iovecs: [5]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(coverage_map.coverage.directories.keys())), - makeIov(@ptrCast(coverage_map.coverage.files.keys())), - makeIov(@ptrCast(coverage_map.source_locations)), - makeIov(coverage_map.coverage.string_bytes.items), + var iovecs: [5][]const u8 = .{ + @ptrCast(&header), + @ptrCast(coverage_map.coverage.directories.keys()), + @ptrCast(coverage_map.coverage.files.keys()), + @ptrCast(coverage_map.source_locations), + coverage_map.coverage.string_bytes.items, }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); } const header: abi.CoverageUpdateHeader = .{ .n_runs = n_runs, .unique_runs = unique_runs, }; - const iovecs: [2]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(seen_pcs)), + var iovecs: [2][]const u8 = .{ + @ptrCast(&header), + @ptrCast(seen_pcs), }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); prev.unique_runs = unique_runs; } if (prev.entry_points != coverage_map.entry_points.items.len) { const header: abi.EntryPointHeader = .init(@intCast(coverage_map.entry_points.items.len)); - const iovecs: [2]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(coverage_map.entry_points.items)), + var iovecs: [2][]const u8 = .{ + @ptrCast(&header), + @ptrCast(coverage_map.entry_points.items), }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); prev.entry_points = coverage_map.entry_points.items.len; } @@ -448,10 +448,3 @@ fn addEntryPoint(fuzz: *Fuzz, coverage_id: u64, addr: u64) error{ AlreadyReporte } try coverage_map.entry_points.append(fuzz.ws.gpa, @intCast(index)); } - -fn makeIov(s: []const u8) std.posix.iovec_const { - return .{ - .base = s.ptr, - .len = s.len, - }; -} diff --git a/lib/std/Build/WebServer.zig b/lib/std/Build/WebServer.zig index 9264d7473c..868aabe67e 100644 --- a/lib/std/Build/WebServer.zig +++ b/lib/std/Build/WebServer.zig @@ -251,48 +251,44 @@ pub fn now(s: *const WebServer) i64 { fn accept(ws: *WebServer, connection: std.net.Server.Connection) void { defer connection.stream.close(); - var read_buf: [0x4000]u8 = undefined; - var server: std.http.Server = .init(connection, &read_buf); + var send_buffer: [4096]u8 = undefined; + var recv_buffer: [4096]u8 = undefined; + var connection_reader = connection.stream.reader(&recv_buffer); + var connection_writer = connection.stream.writer(&send_buffer); + var server: http.Server = .init(connection_reader.interface(), &connection_writer.interface); while (true) { var request = server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => return, - else => { - log.err("failed to receive http request: {s}", .{@errorName(err)}); - return; - }, + else => return log.err("failed to receive http request: {t}", .{err}), }; - var ws_send_buf: [0x4000]u8 = undefined; - var ws_recv_buf: [0x4000]u8 align(4) = undefined; - if (std.http.WebSocket.init(&request, &ws_send_buf, &ws_recv_buf) catch |err| { - log.err("failed to initialize websocket connection: {s}", .{@errorName(err)}); - return; - }) |ws_init| { - var web_socket = ws_init; - ws.serveWebSocket(&web_socket) catch |err| { - log.err("failed to serve websocket: {s}", .{@errorName(err)}); - return; - }; - comptime unreachable; - } else { - ws.serveRequest(&request) catch |err| switch (err) { - error.AlreadyReported => return, - else => { - log.err("failed to serve '{s}': {s}", .{ request.head.target, @errorName(err) }); + switch (request.upgradeRequested()) { + .websocket => |opt_key| { + const key = opt_key orelse return log.err("missing websocket key", .{}); + var web_socket = request.respondWebSocket(.{ .key = key }) catch { + return log.err("failed to respond web socket: {t}", .{connection_writer.err.?}); + }; + ws.serveWebSocket(&web_socket) catch |err| { + log.err("failed to serve websocket: {t}", .{err}); return; - }, - }; + }; + comptime unreachable; + }, + .other => |name| return log.err("unknown upgrade request: {s}", .{name}), + .none => { + ws.serveRequest(&request) catch |err| switch (err) { + error.AlreadyReported => return, + else => { + log.err("failed to serve '{s}': {t}", .{ request.head.target, err }); + return; + }, + }; + }, } } } -fn makeIov(s: []const u8) std.posix.iovec_const { - return .{ - .base = s.ptr, - .len = s.len, - }; -} -fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { +fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { var prev_build_status = ws.build_status.load(.monotonic); const prev_step_status_bits = try ws.gpa.alloc(u8, ws.step_status_bits.len); @@ -312,11 +308,8 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { .timestamp = ws.now(), .steps_len = @intCast(ws.all_steps.len), }; - try sock.writeMessagev(&.{ - makeIov(@ptrCast(&hello_header)), - makeIov(ws.step_names_trailing), - makeIov(prev_step_status_bits), - }, .binary); + var bufs: [3][]const u8 = .{ @ptrCast(&hello_header), ws.step_names_trailing, prev_step_status_bits }; + try sock.writeMessageVec(&bufs, .binary); } var prev_fuzz: Fuzz.Previous = .init; @@ -380,7 +373,7 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { std.Thread.Futex.timedWait(&ws.update_id, start_update_id, std.time.ns_per_ms * default_update_interval_ms) catch {}; } } -fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void { +fn recvWebSocketMessages(ws: *WebServer, sock: *http.Server.WebSocket) void { while (true) { const msg = sock.readSmallMessage() catch return; if (msg.opcode != .binary) continue; @@ -402,7 +395,7 @@ fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void { } } -fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void { +fn serveRequest(ws: *WebServer, req: *http.Server.Request) !void { // Strip an optional leading '/debug' component from the request. const target: []const u8, const debug: bool = target: { if (mem.eql(u8, req.head.target, "/debug")) break :target .{ "/", true }; @@ -431,7 +424,7 @@ fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void { fn serveLibFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, sub_path: []const u8, content_type: []const u8, ) !void { @@ -442,7 +435,7 @@ fn serveLibFile( } fn serveClientWasm( ws: *WebServer, - req: *std.http.Server.Request, + req: *http.Server.Request, optimize_mode: std.builtin.OptimizeMode, ) !void { var arena_state: std.heap.ArenaAllocator = .init(ws.gpa); @@ -456,12 +449,12 @@ fn serveClientWasm( pub fn serveFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, path: Cache.Path, content_type: []const u8, ) !void { const gpa = ws.gpa; - // The desired API is actually sendfile, which will require enhancing std.http.Server. + // The desired API is actually sendfile, which will require enhancing http.Server. // We load the file with every request so that the user can make changes to the file // and refresh the HTML page without restarting this server. const file_contents = path.root_dir.handle.readFileAlloc(gpa, path.sub_path, 10 * 1024 * 1024) catch |err| { @@ -478,14 +471,13 @@ pub fn serveFile( } pub fn serveTarFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, paths: []const Cache.Path, ) !void { const gpa = ws.gpa; - var send_buf: [0x4000]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buf, + var send_buffer: [0x4000]u8 = undefined; + var response = try request.respondStreaming(&send_buffer, .{ .respond_options = .{ .extra_headers = &.{ .{ .name = "Content-Type", .value = "application/x-tar" }, @@ -497,10 +489,7 @@ pub fn serveTarFile( var cached_cwd_path: ?[]const u8 = null; defer if (cached_cwd_path) |p| gpa.free(p); - var response_buf: [1024]u8 = undefined; - var adapter = response.writer().adaptToNewApi(); - adapter.new_interface.buffer = &response_buf; - var archiver: std.tar.Writer = .{ .underlying_writer = &adapter.new_interface }; + var archiver: std.tar.Writer = .{ .underlying_writer = &response.writer }; for (paths) |path| { var file = path.root_dir.handle.openFile(path.sub_path, .{}) catch |err| { @@ -526,7 +515,6 @@ pub fn serveTarFile( } // intentionally not calling `archiver.finishPedantically` - try adapter.new_interface.flush(); try response.end(); } @@ -804,7 +792,7 @@ pub fn wait(ws: *WebServer) RunnerRequest { } } -const cache_control_header: std.http.Header = .{ +const cache_control_header: http.Header = .{ .name = "Cache-Control", .value = "max-age=0, must-revalidate", }; @@ -819,5 +807,6 @@ const Build = std.Build; const Cache = Build.Cache; const Fuzz = Build.Fuzz; const abi = Build.abi; +const http = std.http; const WebServer = @This(); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index aa6c72a5b3..188f2a45e4 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -7,6 +7,7 @@ const Uri = std.Uri; const assert = std.debug.assert; const testing = std.testing; const Writer = std.Io.Writer; +const Reader = std.Io.Reader; const Server = @This(); @@ -21,7 +22,7 @@ reader: http.Reader, /// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. /// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(in: *std.Io.Reader, out: *Writer) Server { +pub fn init(in: *Reader, out: *Writer) Server { return .{ .reader = .{ .in = in, @@ -225,7 +226,7 @@ pub const Request = struct { } }; - pub fn iterateHeaders(r: *Request) http.HeaderIterator { + pub fn iterateHeaders(r: *const Request) http.HeaderIterator { assert(r.server.reader.state == .received_head); return http.HeaderIterator.init(r.head_buffer); } @@ -486,10 +487,11 @@ pub const Request = struct { none, }; + /// Does not invalidate `request.head`. pub fn upgradeRequested(request: *const Request) UpgradeRequest { switch (request.head.version) { - .@"HTTP/1.0" => return null, - .@"HTTP/1.1" => if (request.head.method != .GET) return null, + .@"HTTP/1.0" => return .none, + .@"HTTP/1.1" => if (request.head.method != .GET) return .none, } var sec_websocket_key: ?[]const u8 = null; @@ -517,7 +519,7 @@ pub const Request = struct { /// The header is not guaranteed to be sent until `WebSocket.flush` is /// called on the returned struct. - pub fn respondWebSocket(request: *Request, options: WebSocketOptions) Writer.Error!WebSocket { + pub fn respondWebSocket(request: *Request, options: WebSocketOptions) ExpectContinueError!WebSocket { if (request.head.expect != null) return error.HttpExpectationFailed; const out = request.server.out; @@ -536,16 +538,14 @@ pub const Request = struct { try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase }); try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: "); const base64_digest = try out.writableArray(28); - assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); + assert(std.base64.standard.Encoder.encode(base64_digest, &digest).len == base64_digest.len); out.advance(base64_digest.len); try out.writeAll("\r\n"); for (options.extra_headers) |header| { assert(header.name.len != 0); - try out.writeAll(header.name); - try out.writeAll(": "); - try out.writeAll(header.value); - try out.writeAll("\r\n"); + var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&bufs); } try out.writeAll("\r\n"); @@ -566,7 +566,7 @@ pub const Request = struct { /// /// See `readerExpectNone` for an infallible alternative that cannot write /// to the server output stream. - pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*std.Io.Reader { + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*Reader { const flush = request.head.expect != null; try writeExpectContinue(request); if (flush) try request.server.out.flush(); @@ -578,7 +578,7 @@ pub const Request = struct { /// this function. /// /// Asserts that this function is only called once. - pub fn readerExpectNone(request: *Request, buffer: []u8) *std.Io.Reader { + pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader { assert(request.server.reader.state == .received_head); assert(request.head.expect == null); if (!request.head.method.requestHasBody()) return .ending; @@ -642,7 +642,7 @@ pub const Request = struct { /// See https://tools.ietf.org/html/rfc6455 pub const WebSocket = struct { key: []const u8, - input: *std.Io.Reader, + input: *Reader, output: *Writer, pub const Header0 = packed struct(u8) { @@ -679,6 +679,8 @@ pub const WebSocket = struct { UnexpectedOpCode, MessageTooBig, MissingMaskBit, + ReadFailed, + EndOfStream, }; pub const SmallMessage = struct { @@ -693,8 +695,9 @@ pub const WebSocket = struct { pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { const in = ws.input; while (true) { - const h0 = in.takeStruct(Header0); - const h1 = in.takeStruct(Header1); + const header = try in.takeArray(2); + const h0: Header0 = @bitCast(header[0]); + const h1: Header1 = @bitCast(header[1]); switch (h0.opcode) { .text, .binary, .pong, .ping => {}, @@ -734,47 +737,49 @@ pub const WebSocket = struct { } pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { - try writeMessageVecUnflushed(ws, &.{data}, op); + var bufs: [1][]const u8 = .{data}; + try writeMessageVecUnflushed(ws, &bufs, op); try ws.output.flush(); } pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { - try writeMessageVecUnflushed(ws, &.{data}, op); + var bufs: [1][]const u8 = .{data}; + try writeMessageVecUnflushed(ws, &bufs, op); } - pub fn writeMessageVec(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + pub fn writeMessageVec(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { try writeMessageVecUnflushed(ws, data, op); try ws.output.flush(); } - pub fn writeMessageVecUnflushed(ws: *WebSocket, data: []const []const u8, op: Opcode) Writer.Error!void { + pub fn writeMessageVecUnflushed(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { const total_len = l: { var total_len: u64 = 0; for (data) |iovec| total_len += iovec.len; break :l total_len; }; const out = ws.output; - try out.writeStruct(@as(Header0, .{ + try out.writeByte(@bitCast(@as(Header0, .{ .opcode = op, .fin = true, - })); + }))); switch (total_len) { - 0...125 => try out.writeStruct(@as(Header1, .{ + 0...125 => try out.writeByte(@bitCast(@as(Header1, .{ .payload_len = @enumFromInt(total_len), .mask = false, - })), + }))), 126...0xffff => { - try out.writeStruct(@as(Header1, .{ + try out.writeByte(@bitCast(@as(Header1, .{ .payload_len = .len16, .mask = false, - })); + }))); try out.writeInt(u16, @intCast(total_len), .big); }, else => { - try out.writeStruct(@as(Header1, .{ + try out.writeByte(@bitCast(@as(Header1, .{ .payload_len = .len64, .mask = false, - })); + }))); try out.writeInt(u64, total_len, .big); }, } From b757d7f941e243d4a40fbba90e51a9f514350320 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 19:31:03 -0700 Subject: [PATCH 56/70] fetch: update for new http API it's not quite finished because I need to make it not copy the Resource --- lib/std/http/Client.zig | 17 +- src/Package/Fetch.zig | 323 +++++++++++++++++--------------------- src/Package/Fetch/git.zig | 283 ++++++++++++++++++--------------- 3 files changed, 309 insertions(+), 314 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 6660761930..5b817af467 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -682,7 +682,7 @@ pub const Response = struct { /// /// See also: /// * `readerDecompressing` - pub fn reader(response: *Response, buffer: []u8) *Reader { + pub fn reader(response: *const Response, buffer: []u8) *Reader { const req = response.request; if (!req.method.responseHasBody()) return .ending; const head = &response.head; @@ -805,6 +805,11 @@ pub const Request = struct { unhandled = std.math.maxInt(u16), _, + pub fn init(n: u16) RedirectBehavior { + assert(n != std.math.maxInt(u16)); + return @enumFromInt(n); + } + pub fn subtractOne(rb: *RedirectBehavior) void { switch (rb.*) { .not_allowed => unreachable, @@ -855,6 +860,14 @@ pub const Request = struct { return result; } + /// Transfers the HTTP head and body over the connection and flushes. + pub fn sendBodyComplete(r: *Request, body: []u8) Writer.Error!void { + r.transfer_encoding = .{ .content_length = body.len }; + var bw = try sendBodyUnflushed(r, body); + bw.writer.end = body.len; + try bw.end(); + } + /// Transfers the HTTP head over the connection, which is not flushed until /// `BodyWriter.flush` or `BodyWriter.end` is called. /// @@ -1296,7 +1309,7 @@ pub const basic_authorization = struct { pub fn value(uri: Uri, out: []u8) []u8 { var bw: Writer = .fixed(out); write(uri, &bw) catch unreachable; - return bw.getWritten(); + return bw.buffered(); } pub fn write(uri: Uri, out: *Writer) Writer.Error!void { diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 836435a86b..b6d22a37ac 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -385,20 +385,21 @@ pub fn run(f: *Fetch) RunError!void { var resource: Resource = .{ .dir = dir }; return f.runResource(path_or_url, &resource, null); } else |dir_err| { + var server_header_buffer: [init_resource_buffer_size]u8 = undefined; + const file_err = if (dir_err == error.NotDir) e: { if (fs.cwd().openFile(path_or_url, .{})) |file| { - var resource: Resource = .{ .file = file }; + var resource: Resource = .{ .file = file.reader(&server_header_buffer) }; return f.runResource(path_or_url, &resource, null); } else |err| break :e err; } else dir_err; const uri = std.Uri.parse(path_or_url) catch |uri_err| { return f.fail(0, try eb.printString( - "'{s}' could not be recognized as a file path ({s}) or an URL ({s})", - .{ path_or_url, @errorName(file_err), @errorName(uri_err) }, + "'{s}' could not be recognized as a file path ({t}) or an URL ({t})", + .{ path_or_url, file_err, uri_err }, )); }; - var server_header_buffer: [header_buffer_size]u8 = undefined; var resource = try f.initResource(uri, &server_header_buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, null); } @@ -464,8 +465,8 @@ pub fn run(f: *Fetch) RunError!void { f.location_tok, try eb.printString("invalid URI: {s}", .{@errorName(err)}), ); - var server_header_buffer: [header_buffer_size]u8 = undefined; - var resource = try f.initResource(uri, &server_header_buffer); + var buffer: [init_resource_buffer_size]u8 = undefined; + var resource = try f.initResource(uri, &buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, remote.hash); } @@ -866,8 +867,8 @@ fn fail(f: *Fetch, msg_tok: std.zig.Ast.TokenIndex, msg_str: u32) RunError { } const Resource = union(enum) { - file: fs.File, - http_request: std.http.Client.Request, + file: fs.File.Reader, + http_request: HttpRequest, git: Git, dir: fs.Dir, @@ -877,10 +878,16 @@ const Resource = union(enum) { want_oid: git.Oid, }; + const HttpRequest = struct { + request: std.http.Client.Request, + head: std.http.Client.Response.Head, + buffer: []u8, + }; + fn deinit(resource: *Resource) void { switch (resource.*) { - .file => |*file| file.close(), - .http_request => |*req| req.deinit(), + .file => |*file_reader| file_reader.file.close(), + .http_request => |*http_request| http_request.request.deinit(), .git => |*git_resource| { git_resource.fetch_stream.deinit(); git_resource.session.deinit(); @@ -890,21 +897,19 @@ const Resource = union(enum) { resource.* = undefined; } - fn reader(resource: *Resource) std.io.AnyReader { - return .{ - .context = resource, - .readFn = read, - }; - } - - fn read(context: *const anyopaque, buffer: []u8) anyerror!usize { - const resource: *Resource = @constCast(@ptrCast(@alignCast(context))); - switch (resource.*) { - .file => |*f| return f.read(buffer), - .http_request => |*r| return r.read(buffer), - .git => |*g| return g.fetch_stream.read(buffer), + fn reader(resource: *Resource) *std.Io.Reader { + return switch (resource.*) { + .file => |*file_reader| return &file_reader.interface, + .http_request => |*http_request| { + const response: std.http.Client.Response = .{ + .request = &http_request.request, + .head = http_request.head, + }; + return response.reader(http_request.buffer); + }, + .git => |*g| return &g.fetch_stream.reader, .dir => unreachable, - } + }; } }; @@ -967,20 +972,21 @@ const FileType = enum { } }; -const header_buffer_size = 16 * 1024; +const init_resource_buffer_size = git.Packet.max_data_length; -fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Resource { +fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource { const gpa = f.arena.child_allocator; const arena = f.arena.allocator(); const eb = &f.error_bundle; if (ascii.eqlIgnoreCase(uri.scheme, "file")) { const path = try uri.path.toRawMaybeAlloc(arena); - return .{ .file = f.parent_package_root.openFile(path, .{}) catch |err| { - return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {s}", .{ - f.parent_package_root, path, @errorName(err), + const file = f.parent_package_root.openFile(path, .{}) catch |err| { + return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {t}", .{ + f.parent_package_root, path, err, })); - } }; + }; + return .{ .file = file.reader(reader_buffer) }; } const http_client = f.job_queue.http_client; @@ -988,37 +994,27 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re if (ascii.eqlIgnoreCase(uri.scheme, "http") or ascii.eqlIgnoreCase(uri.scheme, "https")) { - var req = http_client.open(.GET, uri, .{ - .server_header_buffer = server_header_buffer, - }) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to connect to server: {s}", - .{@errorName(err)}, - )); - }; - errdefer req.deinit(); // releases more than memory + var request = http_client.request(.GET, uri, .{}) catch |err| + return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})); + defer request.deinit(); - req.send() catch |err| { - return f.fail(f.location_tok, try eb.printString( - "HTTP request failed: {s}", - .{@errorName(err)}, - )); - }; - req.wait() catch |err| { - return f.fail(f.location_tok, try eb.printString( - "invalid HTTP response: {s}", - .{@errorName(err)}, - )); - }; + request.sendBodiless() catch |err| + return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err})); - if (req.response.status != .ok) { - return f.fail(f.location_tok, try eb.printString( - "bad HTTP response code: '{d} {s}'", - .{ @intFromEnum(req.response.status), req.response.status.phrase() orelse "" }, - )); - } + var redirect_buffer: [1024]u8 = undefined; + const response = request.receiveHead(&redirect_buffer) catch |err| + return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err})); - return .{ .http_request = req }; + if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString( + "bad HTTP response code: '{d} {s}'", + .{ response.head.status, response.head.status.phrase() orelse "" }, + )); + + return .{ .http_request = .{ + .request = request, + .head = response.head, + .buffer = reader_buffer, + } }; } if (ascii.eqlIgnoreCase(uri.scheme, "git+http") or @@ -1026,7 +1022,7 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re { var transport_uri = uri; transport_uri.scheme = uri.scheme["git+".len..]; - var session = git.Session.init(gpa, http_client, transport_uri, server_header_buffer) catch |err| { + var session = git.Session.init(gpa, http_client, transport_uri, reader_buffer) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to discover remote git server capabilities: {s}", .{@errorName(err)}, @@ -1042,16 +1038,12 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re const want_ref_head = try std.fmt.allocPrint(arena, "refs/heads/{s}", .{want_ref}); const want_ref_tag = try std.fmt.allocPrint(arena, "refs/tags/{s}", .{want_ref}); - var ref_iterator = session.listRefs(.{ + var ref_iterator: git.Session.RefIterator = undefined; + session.listRefs(&ref_iterator, .{ .ref_prefixes = &.{ want_ref, want_ref_head, want_ref_tag }, .include_peeled = true, - .server_header_buffer = server_header_buffer, - }) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to list refs: {s}", - .{@errorName(err)}, - )); - }; + .buffer = reader_buffer, + }) catch |err| return f.fail(f.location_tok, try eb.printString("unable to list refs: {t}", .{err})); defer ref_iterator.deinit(); while (ref_iterator.next() catch |err| { return f.fail(f.location_tok, try eb.printString( @@ -1089,14 +1081,14 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re var want_oid_buf: [git.Oid.max_formatted_length]u8 = undefined; _ = std.fmt.bufPrint(&want_oid_buf, "{f}", .{want_oid}) catch unreachable; - var fetch_stream = session.fetch(&.{&want_oid_buf}, server_header_buffer) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to create fetch stream: {s}", - .{@errorName(err)}, - )); + var fetch_stream: git.Session.FetchStream = undefined; + session.fetch(&fetch_stream, &.{&want_oid_buf}, reader_buffer) catch |err| { + return f.fail(f.location_tok, try eb.printString("unable to create fetch stream: {t}", .{err})); }; errdefer fetch_stream.deinit(); + if (true) @panic("TODO this moves fetch_stream, invalidating its reader"); + return .{ .git = .{ .session = session, .fetch_stream = fetch_stream, @@ -1104,10 +1096,7 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re } }; } - return f.fail(f.location_tok, try eb.printString( - "unsupported URL scheme: {s}", - .{uri.scheme}, - )); + return f.fail(f.location_tok, try eb.printString("unsupported URL scheme: {s}", .{uri.scheme})); } fn unpackResource( @@ -1121,9 +1110,11 @@ fn unpackResource( .file => FileType.fromPath(uri_path) orelse return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})), - .http_request => |req| ft: { + .http_request => |*http_request| ft: { + const head = &http_request.head; + // Content-Type takes first precedence. - const content_type = req.response.content_type orelse + const content_type = head.content_type orelse return f.fail(f.location_tok, try eb.addString("missing 'Content-Type' header")); // Extract the MIME type, ignoring charset and boundary directives @@ -1165,7 +1156,7 @@ fn unpackResource( } // Next, the filename from 'content-disposition: attachment' takes precedence. - if (req.response.content_disposition) |cd_header| { + if (head.content_disposition) |cd_header| { break :ft FileType.fromContentDisposition(cd_header) orelse { return f.fail(f.location_tok, try eb.printString( "unsupported Content-Disposition header value: '{s}' for Content-Type=application/octet-stream", @@ -1176,10 +1167,7 @@ fn unpackResource( // Finally, the path from the URI is used. break :ft FileType.fromPath(uri_path) orelse { - return f.fail(f.location_tok, try eb.printString( - "unknown file type: '{s}'", - .{uri_path}, - )); + return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})); }; }, @@ -1187,10 +1175,9 @@ fn unpackResource( .dir => |dir| { f.recursiveDirectoryCopy(dir, tmp_directory.handle) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to copy directory '{s}': {s}", - .{ uri_path, @errorName(err) }, - )); + return f.fail(f.location_tok, try eb.printString("unable to copy directory '{s}': {t}", .{ + uri_path, err, + })); }; return .{}; }, @@ -1198,15 +1185,11 @@ fn unpackResource( switch (file_type) { .tar => { - var adapter_buffer: [1024]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); - return unpackTarball(f, tmp_directory.handle, &adapter.new_interface); + return unpackTarball(f, tmp_directory.handle, resource.reader()); }, .@"tar.gz" => { - var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); var flate_buffer: [std.compress.flate.max_window_len]u8 = undefined; - var decompress: std.compress.flate.Decompress = .init(&adapter.new_interface, .gzip, &flate_buffer); + var decompress: std.compress.flate.Decompress = .init(resource.reader(), .gzip, &flate_buffer); return try unpackTarball(f, tmp_directory.handle, &decompress.reader); }, .@"tar.xz" => { @@ -1227,9 +1210,7 @@ fn unpackResource( .@"tar.zst" => { const window_size = std.compress.zstd.default_window_len; const window_buffer = try f.arena.allocator().create([window_size]u8); - var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); - var decompress: std.compress.zstd.Decompress = .init(&adapter.new_interface, window_buffer, .{ + var decompress: std.compress.zstd.Decompress = .init(resource.reader(), window_buffer, .{ .verify_checksum = false, }); return try unpackTarball(f, tmp_directory.handle, &decompress.reader); @@ -1237,12 +1218,15 @@ fn unpackResource( .git_pack => return unpackGitPack(f, tmp_directory.handle, &resource.git) catch |err| switch (err) { error.FetchFailed => return error.FetchFailed, error.OutOfMemory => return error.OutOfMemory, - else => |e| return f.fail(f.location_tok, try eb.printString( - "unable to unpack git files: {s}", - .{@errorName(e)}, - )), + else => |e| return f.fail(f.location_tok, try eb.printString("unable to unpack git files: {t}", .{e})), + }, + .zip => return unzip(f, tmp_directory.handle, resource.reader()) catch |err| switch (err) { + error.ReadFailed => return f.fail(f.location_tok, try eb.printString( + "failed reading resource: {t}", + .{err}, + )), + else => |e| return e, }, - .zip => return try unzip(f, tmp_directory.handle, resource.reader()), } } @@ -1277,99 +1261,69 @@ fn unpackTarball(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) RunError!Un return res; } -fn unzip(f: *Fetch, out_dir: fs.Dir, reader: anytype) RunError!UnpackResult { +fn unzip(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) error{ ReadFailed, OutOfMemory, FetchFailed }!UnpackResult { // We write the entire contents to a file first because zip files // must be processed back to front and they could be too large to // load into memory. const cache_root = f.job_queue.global_cache; - - // TODO: the downside of this solution is if we get a failure/crash/oom/power out - // during this process, we leave behind a zip file that would be - // difficult to know if/when it can be cleaned up. - // Might be worth it to use a mechanism that enables other processes - // to see if the owning process of a file is still alive (on linux this - // can be done with file locks). - // Coupled with this mechansism, we could also use slots (i.e. zig-cache/tmp/0, - // zig-cache/tmp/1, etc) which would mean that subsequent runs would - // automatically clean up old dead files. - // This could all be done with a simple TmpFile abstraction. const prefix = "tmp/"; const suffix = ".zip"; - - const random_bytes_count = 20; - const random_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); - var zip_path: [prefix.len + random_path_len + suffix.len]u8 = undefined; - @memcpy(zip_path[0..prefix.len], prefix); - @memcpy(zip_path[prefix.len + random_path_len ..], suffix); - { - var random_bytes: [random_bytes_count]u8 = undefined; - std.crypto.random.bytes(&random_bytes); - _ = std.fs.base64_encoder.encode( - zip_path[prefix.len..][0..random_path_len], - &random_bytes, - ); - } - - defer cache_root.handle.deleteFile(&zip_path) catch {}; - const eb = &f.error_bundle; + const random_len = @sizeOf(u64) * 2; - { - var zip_file = cache_root.handle.createFile( - &zip_path, - .{}, - ) catch |err| return f.fail(f.location_tok, try eb.printString( - "failed to create tmp zip file: {s}", - .{@errorName(err)}, - )); - defer zip_file.close(); - var buf: [4096]u8 = undefined; - while (true) { - const len = reader.readAll(&buf) catch |err| return f.fail(f.location_tok, try eb.printString( - "read zip stream failed: {s}", - .{@errorName(err)}, - )); - if (len == 0) break; - zip_file.deprecatedWriter().writeAll(buf[0..len]) catch |err| return f.fail(f.location_tok, try eb.printString( - "write temporary zip file failed: {s}", - .{@errorName(err)}, - )); - } - } + var zip_path: [prefix.len + random_len + suffix.len]u8 = undefined; + zip_path[0..prefix.len].* = prefix.*; + zip_path[prefix.len + random_len ..].* = suffix.*; + + var zip_file = while (true) { + const random_integer = std.crypto.random.int(u64); + zip_path[prefix.len..][0..random_len].* = std.fmt.hex(random_integer); + + break cache_root.handle.createFile(&zip_path, .{ + .exclusive = true, + .read = true, + }) catch |err| switch (err) { + error.PathAlreadyExists => continue, + else => |e| return f.fail( + f.location_tok, + try eb.printString("failed to create temporary zip file: {t}", .{e}), + ), + }; + }; + defer zip_file.close(); + var zip_file_buffer: [4096]u8 = undefined; + var zip_file_reader = b: { + var zip_file_writer = zip_file.writer(&zip_file_buffer); + + _ = reader.streamRemaining(&zip_file_writer.interface) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return f.fail( + f.location_tok, + try eb.printString("failed writing temporary zip file: {t}", .{err}), + ), + }; + zip_file_writer.interface.flush() catch |err| return f.fail( + f.location_tok, + try eb.printString("failed writing temporary zip file: {t}", .{err}), + ); + break :b zip_file_writer.moveToReader(); + }; var diagnostics: std.zip.Diagnostics = .{ .allocator = f.arena.allocator() }; // no need to deinit since we are using an arena allocator - { - var zip_file = cache_root.handle.openFile( - &zip_path, - .{}, - ) catch |err| return f.fail(f.location_tok, try eb.printString( - "failed to open temporary zip file: {s}", - .{@errorName(err)}, - )); - defer zip_file.close(); + zip_file_reader.seekTo(0) catch |err| + return f.fail(f.location_tok, try eb.printString("failed to seek temporary zip file: {t}", .{err})); + std.zip.extract(out_dir, &zip_file_reader, .{ + .allow_backslashes = true, + .diagnostics = &diagnostics, + }) catch |err| return f.fail(f.location_tok, try eb.printString("zip extract failed: {t}", .{err})); - var zip_file_buffer: [1024]u8 = undefined; - var zip_file_reader = zip_file.reader(&zip_file_buffer); + cache_root.handle.deleteFile(&zip_path) catch |err| + return f.fail(f.location_tok, try eb.printString("delete temporary zip failed: {t}", .{err})); - std.zip.extract(out_dir, &zip_file_reader, .{ - .allow_backslashes = true, - .diagnostics = &diagnostics, - }) catch |err| return f.fail(f.location_tok, try eb.printString( - "zip extract failed: {s}", - .{@errorName(err)}, - )); - } - - cache_root.handle.deleteFile(&zip_path) catch |err| return f.fail(f.location_tok, try eb.printString( - "delete temporary zip failed: {s}", - .{@errorName(err)}, - )); - - const res: UnpackResult = .{ .root_dir = diagnostics.root_dir }; - return res; + return .{ .root_dir = diagnostics.root_dir }; } fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!UnpackResult { @@ -1387,10 +1341,13 @@ fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!U var pack_file = try pack_dir.createFile("pkg.pack", .{ .read = true }); defer pack_file.close(); var pack_file_buffer: [4096]u8 = undefined; - var fifo = std.fifo.LinearFifo(u8, .{ .Slice = {} }).init(&pack_file_buffer); - try fifo.pump(resource.fetch_stream.reader(), pack_file.deprecatedWriter()); - - var pack_file_reader = pack_file.reader(&pack_file_buffer); + var pack_file_reader = b: { + var pack_file_writer = pack_file.writer(&pack_file_buffer); + const fetch_reader = &resource.fetch_stream.reader; + _ = try fetch_reader.streamRemaining(&pack_file_writer.interface); + try pack_file_writer.interface.flush(); + break :b pack_file_writer.moveToReader(); + }; var index_file = try pack_dir.createFile("pkg.idx", .{ .read = true }); defer index_file.close(); diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 6ff951014b..88652343f5 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -585,17 +585,17 @@ const ObjectCache = struct { /// [protocol-common](https://git-scm.com/docs/protocol-common). The special /// meanings of the delimiter and response-end packets are documented in /// [protocol-v2](https://git-scm.com/docs/protocol-v2). -const Packet = union(enum) { +pub const Packet = union(enum) { flush, delimiter, response_end, data: []const u8, - const max_data_length = 65516; + pub const max_data_length = 65516; /// Reads a packet in pkt-line format. - fn read(reader: anytype, buf: *[max_data_length]u8) !Packet { - const length = std.fmt.parseUnsigned(u16, &try reader.readBytesNoEof(4), 16) catch return error.InvalidPacket; + fn read(reader: *std.Io.Reader) !Packet { + const length = std.fmt.parseUnsigned(u16, try reader.take(4), 16) catch return error.InvalidPacket; switch (length) { 0 => return .flush, 1 => return .delimiter, @@ -603,13 +603,11 @@ const Packet = union(enum) { 3 => return error.InvalidPacket, else => if (length - 4 > max_data_length) return error.InvalidPacket, } - const data = buf[0 .. length - 4]; - try reader.readNoEof(data); - return .{ .data = data }; + return .{ .data = try reader.take(length - 4) }; } /// Writes a packet in pkt-line format. - fn write(packet: Packet, writer: anytype) !void { + fn write(packet: Packet, writer: *std.Io.Writer) !void { switch (packet) { .flush => try writer.writeAll("0000"), .delimiter => try writer.writeAll("0001"), @@ -657,8 +655,10 @@ pub const Session = struct { allocator: Allocator, transport: *std.http.Client, uri: std.Uri, - http_headers_buffer: []u8, + /// Asserted to be at least `Packet.max_data_length` + response_buffer: []u8, ) !Session { + assert(response_buffer.len >= Packet.max_data_length); var session: Session = .{ .transport = transport, .location = try .init(allocator, uri), @@ -668,7 +668,8 @@ pub const Session = struct { .allocator = allocator, }; errdefer session.deinit(); - var capability_iterator = try session.getCapabilities(http_headers_buffer); + var capability_iterator: CapabilityIterator = undefined; + try session.getCapabilities(&capability_iterator, response_buffer); defer capability_iterator.deinit(); while (try capability_iterator.next()) |capability| { if (mem.eql(u8, capability.key, "agent")) { @@ -743,7 +744,8 @@ pub const Session = struct { /// /// The `session.location` is updated if the server returns a redirect, so /// that subsequent session functions do not need to handle redirects. - fn getCapabilities(session: *Session, http_headers_buffer: []u8) !CapabilityIterator { + fn getCapabilities(session: *Session, it: *CapabilityIterator, response_buffer: []u8) !void { + assert(response_buffer.len >= Packet.max_data_length); var info_refs_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -757,19 +759,22 @@ pub const Session = struct { info_refs_uri.fragment = null; const max_redirects = 3; - var request = try session.transport.open(.GET, info_refs_uri, .{ - .redirect_behavior = @enumFromInt(max_redirects), - .server_header_buffer = http_headers_buffer, - .extra_headers = &.{ - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); - errdefer request.deinit(); - try request.send(); - try request.finish(); + it.* = .{ + .request = try session.transport.request(.GET, info_refs_uri, .{ + .redirect_behavior = .init(max_redirects), + .extra_headers = &.{ + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .reader = undefined, + }; + errdefer it.deinit(); + const request = &it.request; + try request.sendBodiless(); - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; + var redirect_buffer: [1024]u8 = undefined; + const response = try request.receiveHead(&redirect_buffer); + if (response.head.status != .ok) return error.ProtocolError; const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects; if (any_redirects_occurred) { const request_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -784,8 +789,7 @@ pub const Session = struct { session.location = new_location; } - const reader = request.reader(); - var buf: [Packet.max_data_length]u8 = undefined; + it.reader = response.reader(response_buffer); var state: enum { response_start, response_content } = .response_start; while (true) { // Some Git servers (at least GitHub) include an additional @@ -795,15 +799,15 @@ pub const Session = struct { // Thus, we need to skip any such useless additional responses // before we get the one we're actually looking for. The responses // will be delimited by flush packets. - const packet = Packet.read(reader, &buf) catch |e| switch (e) { + const packet = Packet.read(it.reader) catch |err| switch (err) { error.EndOfStream => return error.UnsupportedProtocol, // 'version 2' packet not found - else => |other| return other, + else => |e| return e, }; switch (packet) { .flush => state = .response_start, .data => |data| switch (state) { .response_start => if (mem.eql(u8, Packet.normalizeText(data), "version 2")) { - return .{ .request = request }; + return; } else { state = .response_content; }, @@ -816,7 +820,7 @@ pub const Session = struct { const CapabilityIterator = struct { request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, + reader: *std.Io.Reader, const Capability = struct { key: []const u8, @@ -830,13 +834,13 @@ pub const Session = struct { } }; - fn deinit(iterator: *CapabilityIterator) void { - iterator.request.deinit(); - iterator.* = undefined; + fn deinit(it: *CapabilityIterator) void { + it.request.deinit(); + it.* = undefined; } - fn next(iterator: *CapabilityIterator) !?Capability { - switch (try Packet.read(iterator.request.reader(), &iterator.buf)) { + fn next(it: *CapabilityIterator) !?Capability { + switch (try Packet.read(it.reader)) { .flush => return null, .data => |data| return Capability.parse(Packet.normalizeText(data)), else => return error.UnexpectedPacket, @@ -854,11 +858,13 @@ pub const Session = struct { include_symrefs: bool = false, /// Whether to include the peeled object ID for returned tag refs. include_peeled: bool = false, - server_header_buffer: []u8, + /// Asserted to be at least `Packet.max_data_length`. + buffer: []u8, }; /// Returns an iterator over refs known to the server. - pub fn listRefs(session: Session, options: ListRefsOptions) !RefIterator { + pub fn listRefs(session: Session, it: *RefIterator, options: ListRefsOptions) !void { + assert(options.buffer.len >= Packet.max_data_length); var upload_pack_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -871,59 +877,56 @@ pub const Session = struct { upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var body: std.ArrayListUnmanaged(u8) = .empty; - defer body.deinit(session.allocator); - const body_writer = body.writer(session.allocator); - try Packet.write(.{ .data = "command=ls-refs\n" }, body_writer); + var body: std.Io.Writer = .fixed(options.buffer); + try Packet.write(.{ .data = "command=ls-refs\n" }, &body); if (session.supports_agent) { - try Packet.write(.{ .data = agent_capability }, body_writer); + try Packet.write(.{ .data = agent_capability }, &body); } { - const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)}); + const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={t}\n", .{ + session.object_format, + }); defer session.allocator.free(object_format_packet); - try Packet.write(.{ .data = object_format_packet }, body_writer); + try Packet.write(.{ .data = object_format_packet }, &body); } - try Packet.write(.delimiter, body_writer); + try Packet.write(.delimiter, &body); for (options.ref_prefixes) |ref_prefix| { const ref_prefix_packet = try std.fmt.allocPrint(session.allocator, "ref-prefix {s}\n", .{ref_prefix}); defer session.allocator.free(ref_prefix_packet); - try Packet.write(.{ .data = ref_prefix_packet }, body_writer); + try Packet.write(.{ .data = ref_prefix_packet }, &body); } if (options.include_symrefs) { - try Packet.write(.{ .data = "symrefs\n" }, body_writer); + try Packet.write(.{ .data = "symrefs\n" }, &body); } if (options.include_peeled) { - try Packet.write(.{ .data = "peel\n" }, body_writer); + try Packet.write(.{ .data = "peel\n" }, &body); } - try Packet.write(.flush, body_writer); + try Packet.write(.flush, &body); - var request = try session.transport.open(.POST, upload_pack_uri, .{ - .redirect_behavior = .unhandled, - .server_header_buffer = options.server_header_buffer, - .extra_headers = &.{ - .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); - errdefer request.deinit(); - request.transfer_encoding = .{ .content_length = body.items.len }; - try request.send(); - try request.writeAll(body.items); - try request.finish(); - - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; - - return .{ + it.* = .{ + .request = try session.transport.request(.POST, upload_pack_uri, .{ + .redirect_behavior = .unhandled, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .reader = undefined, .format = session.object_format, - .request = request, }; + const request = &it.request; + errdefer request.deinit(); + try request.sendBodyComplete(body.buffered()); + + const response = try request.receiveHead(options.buffer); + if (response.head.status != .ok) return error.ProtocolError; + it.reader = response.reader(options.buffer); } pub const RefIterator = struct { format: Oid.Format, request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, + reader: *std.Io.Reader, pub const Ref = struct { oid: Oid, @@ -937,13 +940,13 @@ pub const Session = struct { iterator.* = undefined; } - pub fn next(iterator: *RefIterator) !?Ref { - switch (try Packet.read(iterator.request.reader(), &iterator.buf)) { + pub fn next(it: *RefIterator) !?Ref { + switch (try Packet.read(it.reader)) { .flush => return null, .data => |data| { const ref_data = Packet.normalizeText(data); const oid_sep_pos = mem.indexOfScalar(u8, ref_data, ' ') orelse return error.InvalidRefPacket; - const oid = Oid.parse(iterator.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket; + const oid = Oid.parse(it.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket; const name_sep_pos = mem.indexOfScalarPos(u8, ref_data, oid_sep_pos + 1, ' ') orelse ref_data.len; const name = ref_data[oid_sep_pos + 1 .. name_sep_pos]; @@ -957,7 +960,7 @@ pub const Session = struct { if (mem.startsWith(u8, attribute, "symref-target:")) { symref_target = attribute["symref-target:".len..]; } else if (mem.startsWith(u8, attribute, "peeled:")) { - peeled = Oid.parse(iterator.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket; + peeled = Oid.parse(it.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket; } last_sep_pos = next_sep_pos; } @@ -973,9 +976,12 @@ pub const Session = struct { /// performed if the server supports it. pub fn fetch( session: Session, + fs: *FetchStream, wants: []const []const u8, - http_headers_buffer: []u8, - ) !FetchStream { + /// Asserted to be at least `Packet.max_data_length`. + response_buffer: []u8, + ) !void { + assert(response_buffer.len >= Packet.max_data_length); var upload_pack_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -988,63 +994,71 @@ pub const Session = struct { upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var body: std.ArrayListUnmanaged(u8) = .empty; - defer body.deinit(session.allocator); - const body_writer = body.writer(session.allocator); - try Packet.write(.{ .data = "command=fetch\n" }, body_writer); + var body: std.Io.Writer = .fixed(response_buffer); + try Packet.write(.{ .data = "command=fetch\n" }, &body); if (session.supports_agent) { - try Packet.write(.{ .data = agent_capability }, body_writer); + try Packet.write(.{ .data = agent_capability }, &body); } { const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)}); defer session.allocator.free(object_format_packet); - try Packet.write(.{ .data = object_format_packet }, body_writer); + try Packet.write(.{ .data = object_format_packet }, &body); } - try Packet.write(.delimiter, body_writer); + try Packet.write(.delimiter, &body); // Our packfile parser supports the OFS_DELTA object type - try Packet.write(.{ .data = "ofs-delta\n" }, body_writer); + try Packet.write(.{ .data = "ofs-delta\n" }, &body); // We do not currently convey server progress information to the user - try Packet.write(.{ .data = "no-progress\n" }, body_writer); + try Packet.write(.{ .data = "no-progress\n" }, &body); if (session.supports_shallow) { - try Packet.write(.{ .data = "deepen 1\n" }, body_writer); + try Packet.write(.{ .data = "deepen 1\n" }, &body); } for (wants) |want| { var buf: [Packet.max_data_length]u8 = undefined; const arg = std.fmt.bufPrint(&buf, "want {s}\n", .{want}) catch unreachable; - try Packet.write(.{ .data = arg }, body_writer); + try Packet.write(.{ .data = arg }, &body); } - try Packet.write(.{ .data = "done\n" }, body_writer); - try Packet.write(.flush, body_writer); + try Packet.write(.{ .data = "done\n" }, &body); + try Packet.write(.flush, &body); - var request = try session.transport.open(.POST, upload_pack_uri, .{ - .redirect_behavior = .not_allowed, - .server_header_buffer = http_headers_buffer, - .extra_headers = &.{ - .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); + fs.* = .{ + .request = try session.transport.request(.POST, upload_pack_uri, .{ + .redirect_behavior = .not_allowed, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .input = undefined, + .reader = undefined, + .remaining_len = undefined, + }; + const request = &fs.request; errdefer request.deinit(); - request.transfer_encoding = .{ .content_length = body.items.len }; - try request.send(); - try request.writeAll(body.items); - try request.finish(); - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; + try request.sendBodyComplete(body.buffered()); - const reader = request.reader(); + const response = try request.receiveHead(&.{}); + if (response.head.status != .ok) return error.ProtocolError; + + const reader = response.reader(response_buffer); // We are not interested in any of the sections of the returned fetch // data other than the packfile section, since we aren't doing anything // complex like ref negotiation (this is a fresh clone). var state: enum { section_start, section_content } = .section_start; while (true) { - var buf: [Packet.max_data_length]u8 = undefined; - const packet = try Packet.read(reader, &buf); + const packet = try Packet.read(reader); switch (state) { .section_start => switch (packet) { .data => |data| if (mem.eql(u8, Packet.normalizeText(data), "packfile")) { - return .{ .request = request }; + fs.input = reader; + fs.reader = .{ + .buffer = &.{}, + .vtable = &.{ .stream = FetchStream.stream }, + .seek = 0, + .end = 0, + }; + fs.remaining_len = 0; + return; } else { state = .section_content; }, @@ -1061,20 +1075,23 @@ pub const Session = struct { pub const FetchStream = struct { request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, - pos: usize = 0, - len: usize = 0, + input: *std.Io.Reader, + reader: std.Io.Reader, + err: ?Error = null, + remaining_len: usize, - pub fn deinit(stream: *FetchStream) void { - stream.request.deinit(); + pub fn deinit(fs: *FetchStream) void { + fs.request.deinit(); } - pub const ReadError = std.http.Client.Request.ReadError || error{ + pub const Error = error{ InvalidPacket, ProtocolError, UnexpectedPacket, + WriteFailed, + ReadFailed, + EndOfStream, }; - pub const Reader = std.io.GenericReader(*FetchStream, ReadError, read); const StreamCode = enum(u8) { pack_data = 1, @@ -1083,33 +1100,41 @@ pub const Session = struct { _, }; - pub fn reader(stream: *FetchStream) Reader { - return .{ .context = stream }; - } - - pub fn read(stream: *FetchStream, buf: []u8) !usize { - if (stream.pos == stream.len) { + pub fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const fs: *FetchStream = @alignCast(@fieldParentPtr("reader", r)); + const input = fs.input; + if (fs.remaining_len == 0) { while (true) { - switch (try Packet.read(stream.request.reader(), &stream.buf)) { - .flush => return 0, + switch (Packet.read(input) catch |err| { + fs.err = err; + return error.ReadFailed; + }) { + .flush => return error.EndOfStream, .data => |data| if (data.len > 1) switch (@as(StreamCode, @enumFromInt(data[0]))) { .pack_data => { - stream.pos = 1; - stream.len = data.len; + input.toss(1); + fs.remaining_len = data.len; break; }, - .fatal_error => return error.ProtocolError, + .fatal_error => { + fs.err = error.ProtocolError; + return error.ReadFailed; + }, else => {}, }, - else => return error.UnexpectedPacket, + else => { + fs.err = error.UnexpectedPacket; + return error.ReadFailed; + }, } } } - - const size = @min(buf.len, stream.len - stream.pos); - @memcpy(buf[0..size], stream.buf[stream.pos .. stream.pos + size]); - stream.pos += size; - return size; + const buf = limit.slice(try w.writableSliceGreedy(1)); + const n = @min(buf.len, fs.remaining_len); + @memcpy(buf[0..n], input.buffered()[0..n]); + input.toss(n); + fs.remaining_len -= n; + return n; } }; }; From 18c4a500b6fd93e3e48ccaf413f2e9f90730f069 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 20:27:55 -0700 Subject: [PATCH 57/70] std.Io.Reader: fix appendRemainingUnlimited Now it avoids mutating `r` unnecessarily, allowing the `ending` Reader to work. --- lib/std/Io/Reader.zig | 7 +++++-- lib/std/http/test.zig | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index 88c2a20f97..ee8f875716 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -367,8 +367,11 @@ pub fn appendRemainingUnlimited( const buffer_contents = r.buffer[r.seek..r.end]; try list.ensureUnusedCapacity(gpa, buffer_contents.len + bump); list.appendSliceAssumeCapacity(buffer_contents); - r.seek = 0; - r.end = 0; + // If statement protects `ending`. + if (r.end != 0) { + r.seek = 0; + r.end = 0; + } // From here, we leave `buffer` empty, appending directly to `list`. var writer: Writer = .{ .buffer = undefined, diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index df80ca6339..2b8f64c606 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -414,7 +414,8 @@ test "general client/server API coverage" { log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); const gpa = std.testing.allocator; - const body = try (try request.readerExpectContinue(&.{})).allocRemaining(gpa, .unlimited); + const reader = (try request.readerExpectContinue(&.{})); + const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); if (mem.startsWith(u8, request.head.target, "/get")) { From fe0ff7f718182a023a39c6126f40e5ffe3364f7b Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 20:35:07 -0700 Subject: [PATCH 58/70] fix 32-bit builds --- lib/std/crypto/tls/Client.zig | 4 ++-- lib/std/http.zig | 38 +++++++++++++++++------------------ lib/std/http/Client.zig | 24 +++++++++++----------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 5a9e333b67..3fe51e7b3b 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -910,7 +910,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client } fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { - const c: *Client = @fieldParentPtr("writer", w); + const c: *Client = @alignCast(@fieldParentPtr("writer", w)); if (true) @panic("update to use the buffer and flush"); const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; const output = c.output; @@ -1046,7 +1046,7 @@ pub fn eof(c: Client) bool { } fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { - const c: *Client = @fieldParentPtr("reader", r); + const c: *Client = @alignCast(@fieldParentPtr("reader", r)); if (c.eof()) return error.EndOfStream; const input = c.input; // If at least one full encrypted record is not buffered, read once. diff --git a/lib/std/http.zig b/lib/std/http.zig index 24bc016c30..f8cbb84dc5 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -519,33 +519,33 @@ pub const Reader = struct { w: *Writer, limit: std.Io.Limit, ) std.Io.Reader.StreamError!usize { - const reader: *Reader = @fieldParentPtr("interface", io_r); + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); const remaining_content_length = &reader.state.body_remaining_content_length; const remaining = remaining_content_length.*; if (remaining == 0) { reader.state = .ready; return error.EndOfStream; } - const n = try reader.in.stream(w, limit.min(.limited(remaining))); + const n = try reader.in.stream(w, limit.min(.limited64(remaining))); remaining_content_length.* = remaining - n; return n; } fn contentLengthDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { - const reader: *Reader = @fieldParentPtr("interface", io_r); + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); const remaining_content_length = &reader.state.body_remaining_content_length; const remaining = remaining_content_length.*; if (remaining == 0) { reader.state = .ready; return error.EndOfStream; } - const n = try reader.in.discard(limit.min(.limited(remaining))); + const n = try reader.in.discard(limit.min(.limited64(remaining))); remaining_content_length.* = remaining - n; return n; } fn chunkedStream(io_r: *std.Io.Reader, w: *Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { - const reader: *Reader = @fieldParentPtr("interface", io_r); + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); const chunk_len_ptr = switch (reader.state) { .ready => return error.EndOfStream, .body_remaining_chunk_len => |*x| x, @@ -591,7 +591,7 @@ pub const Reader = struct { } } if (cp.chunk_len == 0) return parseTrailers(reader, 0); - const n = try in.stream(w, limit.min(.limited(cp.chunk_len))); + const n = try in.stream(w, limit.min(.limited64(cp.chunk_len))); chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); return n; }, @@ -607,7 +607,7 @@ pub const Reader = struct { continue :len .head; }, else => |remaining_chunk_len| { - const n = try in.stream(w, limit.min(.limited(@intFromEnum(remaining_chunk_len) - 2))); + const n = try in.stream(w, limit.min(.limited64(@intFromEnum(remaining_chunk_len) - 2))); chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); return n; }, @@ -615,7 +615,7 @@ pub const Reader = struct { } fn chunkedDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { - const reader: *Reader = @fieldParentPtr("interface", io_r); + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); const chunk_len_ptr = switch (reader.state) { .ready => return error.EndOfStream, .body_remaining_chunk_len => |*x| x, @@ -659,7 +659,7 @@ pub const Reader = struct { } } if (cp.chunk_len == 0) return parseTrailers(reader, 0); - const n = try in.discard(limit.min(.limited(cp.chunk_len))); + const n = try in.discard(limit.min(.limited64(cp.chunk_len))); chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); return n; }, @@ -675,7 +675,7 @@ pub const Reader = struct { continue :len .head; }, else => |remaining_chunk_len| { - const n = try in.discard(limit.min(.limited(remaining_chunk_len.int() - 2))); + const n = try in.discard(limit.min(.limited64(remaining_chunk_len.int() - 2))); chunk_len_ptr.* = .init(remaining_chunk_len.int() - n); return n; }, @@ -758,7 +758,7 @@ pub const BodyWriter = struct { /// How many zeroes to reserve for hex-encoded chunk length. const chunk_len_digits = 8; - const max_chunk_len: usize = std.math.pow(usize, 16, chunk_len_digits) - 1; + const max_chunk_len: usize = std.math.pow(u64, 16, chunk_len_digits) - 1; const chunk_header_template = ("0" ** chunk_len_digits) ++ "\r\n"; comptime { @@ -918,7 +918,7 @@ pub const BodyWriter = struct { } pub fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const out = bw.http_protocol_output; const n = try out.writeSplatHeader(w.buffered(), data, splat); @@ -927,7 +927,7 @@ pub const BodyWriter = struct { } pub fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const out = bw.http_protocol_output; const n = try out.writeSplatHeader(w.buffered(), data, splat); @@ -935,7 +935,7 @@ pub const BodyWriter = struct { } pub fn elidingDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); const slice = data[0 .. data.len - 1]; const pattern = data[slice.len]; var written: usize = pattern.len * splat; @@ -949,7 +949,7 @@ pub const BodyWriter = struct { } pub fn elidingSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); if (File.Handle == void) return error.Unimplemented; if (builtin.zig_backend == .stage2_aarch64) return error.Unimplemented; switch (bw.state) { @@ -976,7 +976,7 @@ pub const BodyWriter = struct { /// Returns `null` if size cannot be computed without making any syscalls. pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const out = bw.http_protocol_output; const n = try out.sendFileHeader(w.buffered(), file_reader, limit); @@ -984,7 +984,7 @@ pub const BodyWriter = struct { } pub fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const out = bw.http_protocol_output; const n = try out.sendFileHeader(w.buffered(), file_reader, limit); @@ -993,7 +993,7 @@ pub const BodyWriter = struct { } pub fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const data_len = Writer.countSendFileLowerBound(w.end, file_reader, limit) orelse { // If the file size is unknown, we cannot lower to a `sendFile` since we would @@ -1041,7 +1041,7 @@ pub const BodyWriter = struct { } pub fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { - const bw: *BodyWriter = @fieldParentPtr("writer", w); + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); assert(!bw.isEliding()); const out = bw.http_protocol_output; const data_len = w.end + Writer.countSplat(data, splat); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 5b817af467..e6e1d4434f 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -82,7 +82,7 @@ pub const ConnectionPool = struct { var next = pool.free.last; while (next) |node| : (next = node.prev) { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); if (connection.protocol != criteria.protocol) continue; if (connection.port != criteria.port) continue; @@ -127,7 +127,7 @@ pub const ConnectionPool = struct { if (connection.closing or pool.free_size == 0) return connection.destroy(); if (pool.free_len >= pool.free_size) { - const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); + const popped: *Connection = @alignCast(@fieldParentPtr("pool_node", pool.free.popFirst().?)); pool.free_len -= 1; popped.destroy(); @@ -183,14 +183,14 @@ pub const ConnectionPool = struct { var next = pool.free.first; while (next) |node| { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); next = node.next; connection.destroy(); } next = pool.used.first; while (next) |node| { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); next = node.next; connection.destroy(); } @@ -366,11 +366,11 @@ pub const Connection = struct { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return tls.host(); }, .plain => { - const plain: *Plain = @fieldParentPtr("connection", c); + const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); return plain.host(); }, }; @@ -383,11 +383,11 @@ pub const Connection = struct { switch (c.protocol) { .tls => { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); tls.destroy(); }, .plain => { - const plain: *Plain = @fieldParentPtr("connection", c); + const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); plain.destroy(); }, } @@ -399,7 +399,7 @@ pub const Connection = struct { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return &tls.client.writer; }, .plain => &c.stream_writer.interface, @@ -412,7 +412,7 @@ pub const Connection = struct { return switch (c.protocol) { .tls => { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); return &tls.client.reader; }, .plain => c.stream_reader.interface(), @@ -422,7 +422,7 @@ pub const Connection = struct { pub fn flush(c: *Connection) Writer.Error!void { if (c.protocol == .tls) { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.writer.flush(); } try c.stream_writer.interface.flush(); @@ -434,7 +434,7 @@ pub const Connection = struct { pub fn end(c: *Connection) Writer.Error!void { if (c.protocol == .tls) { if (disable_tls) unreachable; - const tls: *Tls = @fieldParentPtr("connection", c); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.end(); try tls.client.writer.flush(); } From 75ddb62a2e58c61b0ae046745a619132f63c3bcf Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 20:36:34 -0700 Subject: [PATCH 59/70] std.net: fix windows build --- lib/std/net.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/net.zig b/lib/std/net.zig index 38c0785194..f63aa6dd0a 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1944,7 +1944,7 @@ pub const Stream = struct { pub const Error = ReadError; pub fn getStream(r: *const Reader) Stream { - return r.stream; + return r.net_stream; } pub fn getError(r: *const Reader) ?Error { From 8b571723fa1487099cd0e676062dfd703ef6847f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 31 Jul 2025 22:38:38 -0700 Subject: [PATCH 60/70] remove std.fifo I never liked how this data structure took its API as a parameter. This use case is now served by std.Io buffering. --- lib/std/fifo.zig | 548 ----------------------------------------------- lib/std/std.zig | 1 - 2 files changed, 549 deletions(-) delete mode 100644 lib/std/fifo.zig diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig deleted file mode 100644 index e18b5edb01..0000000000 --- a/lib/std/fifo.zig +++ /dev/null @@ -1,548 +0,0 @@ -// FIFO of fixed size items -// Usually used for e.g. byte buffers - -const std = @import("std"); -const math = std.math; -const mem = std.mem; -const Allocator = mem.Allocator; -const assert = std.debug.assert; -const testing = std.testing; - -pub const LinearFifoBufferType = union(enum) { - /// The buffer is internal to the fifo; it is of the specified size. - Static: usize, - - /// The buffer is passed as a slice to the initialiser. - Slice, - - /// The buffer is managed dynamically using a `mem.Allocator`. - Dynamic, -}; - -pub fn LinearFifo( - comptime T: type, - comptime buffer_type: LinearFifoBufferType, -) type { - const autoalign = false; - - const powers_of_two = switch (buffer_type) { - .Static => std.math.isPowerOfTwo(buffer_type.Static), - .Slice => false, // Any size slice could be passed in - .Dynamic => true, // This could be configurable in future - }; - - return struct { - allocator: if (buffer_type == .Dynamic) Allocator else void, - buf: if (buffer_type == .Static) [buffer_type.Static]T else []T, - head: usize, - count: usize, - - const Self = @This(); - pub const Reader = std.io.GenericReader(*Self, error{}, readFn); - pub const Writer = std.io.GenericWriter(*Self, error{OutOfMemory}, appendWrite); - - // Type of Self argument for slice operations. - // If buffer is inline (Static) then we need to ensure we haven't - // returned a slice into a copy on the stack - const SliceSelfArg = if (buffer_type == .Static) *Self else Self; - - pub const init = switch (buffer_type) { - .Static => initStatic, - .Slice => initSlice, - .Dynamic => initDynamic, - }; - - fn initStatic() Self { - comptime assert(buffer_type == .Static); - return .{ - .allocator = {}, - .buf = undefined, - .head = 0, - .count = 0, - }; - } - - fn initSlice(buf: []T) Self { - comptime assert(buffer_type == .Slice); - return .{ - .allocator = {}, - .buf = buf, - .head = 0, - .count = 0, - }; - } - - fn initDynamic(allocator: Allocator) Self { - comptime assert(buffer_type == .Dynamic); - return .{ - .allocator = allocator, - .buf = &.{}, - .head = 0, - .count = 0, - }; - } - - pub fn deinit(self: Self) void { - if (buffer_type == .Dynamic) self.allocator.free(self.buf); - } - - pub fn realign(self: *Self) void { - if (self.buf.len - self.head >= self.count) { - mem.copyForwards(T, self.buf[0..self.count], self.buf[self.head..][0..self.count]); - self.head = 0; - } else { - var tmp: [4096 / 2 / @sizeOf(T)]T = undefined; - - while (self.head != 0) { - const n = @min(self.head, tmp.len); - const m = self.buf.len - n; - @memcpy(tmp[0..n], self.buf[0..n]); - mem.copyForwards(T, self.buf[0..m], self.buf[n..][0..m]); - @memcpy(self.buf[m..][0..n], tmp[0..n]); - self.head -= n; - } - } - { // set unused area to undefined - const unused = mem.sliceAsBytes(self.buf[self.count..]); - @memset(unused, undefined); - } - } - - /// Reduce allocated capacity to `size`. - pub fn shrink(self: *Self, size: usize) void { - assert(size >= self.count); - if (buffer_type == .Dynamic) { - self.realign(); - self.buf = self.allocator.realloc(self.buf, size) catch |e| switch (e) { - error.OutOfMemory => return, // no problem, capacity is still correct then. - }; - } - } - - /// Ensure that the buffer can fit at least `size` items - pub fn ensureTotalCapacity(self: *Self, size: usize) !void { - if (self.buf.len >= size) return; - if (buffer_type == .Dynamic) { - self.realign(); - const new_size = if (powers_of_two) math.ceilPowerOfTwo(usize, size) catch return error.OutOfMemory else size; - self.buf = try self.allocator.realloc(self.buf, new_size); - } else { - return error.OutOfMemory; - } - } - - /// Makes sure at least `size` items are unused - pub fn ensureUnusedCapacity(self: *Self, size: usize) error{OutOfMemory}!void { - if (self.writableLength() >= size) return; - - return try self.ensureTotalCapacity(math.add(usize, self.count, size) catch return error.OutOfMemory); - } - - /// Returns number of items currently in fifo - pub fn readableLength(self: Self) usize { - return self.count; - } - - /// Returns a writable slice from the 'read' end of the fifo - fn readableSliceMut(self: SliceSelfArg, offset: usize) []T { - if (offset > self.count) return &[_]T{}; - - var start = self.head + offset; - if (start >= self.buf.len) { - start -= self.buf.len; - return self.buf[start .. start + (self.count - offset)]; - } else { - const end = @min(self.head + self.count, self.buf.len); - return self.buf[start..end]; - } - } - - /// Returns a readable slice from `offset` - pub fn readableSlice(self: SliceSelfArg, offset: usize) []const T { - return self.readableSliceMut(offset); - } - - pub fn readableSliceOfLen(self: *Self, len: usize) []const T { - assert(len <= self.count); - const buf = self.readableSlice(0); - if (buf.len >= len) { - return buf[0..len]; - } else { - self.realign(); - return self.readableSlice(0)[0..len]; - } - } - - /// Discard first `count` items in the fifo - pub fn discard(self: *Self, count: usize) void { - assert(count <= self.count); - { // set old range to undefined. Note: may be wrapped around - const slice = self.readableSliceMut(0); - if (slice.len >= count) { - const unused = mem.sliceAsBytes(slice[0..count]); - @memset(unused, undefined); - } else { - const unused = mem.sliceAsBytes(slice[0..]); - @memset(unused, undefined); - const unused2 = mem.sliceAsBytes(self.readableSliceMut(slice.len)[0 .. count - slice.len]); - @memset(unused2, undefined); - } - } - if (autoalign and self.count == count) { - self.head = 0; - self.count = 0; - } else { - var head = self.head + count; - if (powers_of_two) { - // Note it is safe to do a wrapping subtract as - // bitwise & with all 1s is a noop - head &= self.buf.len -% 1; - } else { - head %= self.buf.len; - } - self.head = head; - self.count -= count; - } - } - - /// Read the next item from the fifo - pub fn readItem(self: *Self) ?T { - if (self.count == 0) return null; - - const c = self.buf[self.head]; - self.discard(1); - return c; - } - - /// Read data from the fifo into `dst`, returns number of items copied. - pub fn read(self: *Self, dst: []T) usize { - var dst_left = dst; - - while (dst_left.len > 0) { - const slice = self.readableSlice(0); - if (slice.len == 0) break; - const n = @min(slice.len, dst_left.len); - @memcpy(dst_left[0..n], slice[0..n]); - self.discard(n); - dst_left = dst_left[n..]; - } - - return dst.len - dst_left.len; - } - - /// Same as `read` except it returns an error union - /// The purpose of this function existing is to match `std.io.GenericReader` API. - fn readFn(self: *Self, dest: []u8) error{}!usize { - return self.read(dest); - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - /// Returns number of items available in fifo - pub fn writableLength(self: Self) usize { - return self.buf.len - self.count; - } - - /// Returns the first section of writable buffer. - /// Note that this may be of length 0 - pub fn writableSlice(self: SliceSelfArg, offset: usize) []T { - if (offset > self.buf.len) return &[_]T{}; - - const tail = self.head + offset + self.count; - if (tail < self.buf.len) { - return self.buf[tail..]; - } else { - return self.buf[tail - self.buf.len ..][0 .. self.writableLength() - offset]; - } - } - - /// Returns a writable buffer of at least `size` items, allocating memory as needed. - /// Use `fifo.update` once you've written data to it. - pub fn writableWithSize(self: *Self, size: usize) ![]T { - try self.ensureUnusedCapacity(size); - - // try to avoid realigning buffer - var slice = self.writableSlice(0); - if (slice.len < size) { - self.realign(); - slice = self.writableSlice(0); - } - return slice; - } - - /// Update the tail location of the buffer (usually follows use of writable/writableWithSize) - pub fn update(self: *Self, count: usize) void { - assert(self.count + count <= self.buf.len); - self.count += count; - } - - /// Appends the data in `src` to the fifo. - /// You must have ensured there is enough space. - pub fn writeAssumeCapacity(self: *Self, src: []const T) void { - assert(self.writableLength() >= src.len); - - var src_left = src; - while (src_left.len > 0) { - const writable_slice = self.writableSlice(0); - assert(writable_slice.len != 0); - const n = @min(writable_slice.len, src_left.len); - @memcpy(writable_slice[0..n], src_left[0..n]); - self.update(n); - src_left = src_left[n..]; - } - } - - /// Write a single item to the fifo - pub fn writeItem(self: *Self, item: T) !void { - try self.ensureUnusedCapacity(1); - return self.writeItemAssumeCapacity(item); - } - - pub fn writeItemAssumeCapacity(self: *Self, item: T) void { - var tail = self.head + self.count; - if (powers_of_two) { - tail &= self.buf.len - 1; - } else { - tail %= self.buf.len; - } - self.buf[tail] = item; - self.update(1); - } - - /// Appends the data in `src` to the fifo. - /// Allocates more memory as necessary - pub fn write(self: *Self, src: []const T) !void { - try self.ensureUnusedCapacity(src.len); - - return self.writeAssumeCapacity(src); - } - - /// Same as `write` except it returns the number of bytes written, which is always the same - /// as `bytes.len`. The purpose of this function existing is to match `std.io.GenericWriter` API. - fn appendWrite(self: *Self, bytes: []const u8) error{OutOfMemory}!usize { - try self.write(bytes); - return bytes.len; - } - - pub fn writer(self: *Self) Writer { - return .{ .context = self }; - } - - /// Make `count` items available before the current read location - fn rewind(self: *Self, count: usize) void { - assert(self.writableLength() >= count); - - var head = self.head + (self.buf.len - count); - if (powers_of_two) { - head &= self.buf.len - 1; - } else { - head %= self.buf.len; - } - self.head = head; - self.count += count; - } - - /// Place data back into the read stream - pub fn unget(self: *Self, src: []const T) !void { - try self.ensureUnusedCapacity(src.len); - - self.rewind(src.len); - - const slice = self.readableSliceMut(0); - if (src.len < slice.len) { - @memcpy(slice[0..src.len], src); - } else { - @memcpy(slice, src[0..slice.len]); - const slice2 = self.readableSliceMut(slice.len); - @memcpy(slice2[0 .. src.len - slice.len], src[slice.len..]); - } - } - - /// Returns the item at `offset`. - /// Asserts offset is within bounds. - pub fn peekItem(self: Self, offset: usize) T { - assert(offset < self.count); - - var index = self.head + offset; - if (powers_of_two) { - index &= self.buf.len - 1; - } else { - index %= self.buf.len; - } - return self.buf[index]; - } - - /// Pump data from a reader into a writer. - /// Stops when reader returns 0 bytes (EOF). - /// Buffer size must be set before calling; a buffer length of 0 is invalid. - pub fn pump(self: *Self, src_reader: anytype, dest_writer: anytype) !void { - assert(self.buf.len > 0); - while (true) { - if (self.writableLength() > 0) { - const n = try src_reader.read(self.writableSlice(0)); - if (n == 0) break; // EOF - self.update(n); - } - self.discard(try dest_writer.write(self.readableSlice(0))); - } - // flush remaining data - while (self.readableLength() > 0) { - self.discard(try dest_writer.write(self.readableSlice(0))); - } - } - - pub fn toOwnedSlice(self: *Self) Allocator.Error![]T { - if (self.head != 0) self.realign(); - assert(self.head == 0); - assert(self.count <= self.buf.len); - const allocator = self.allocator; - if (allocator.resize(self.buf, self.count)) { - const result = self.buf[0..self.count]; - self.* = Self.init(allocator); - return result; - } - const new_memory = try allocator.dupe(T, self.buf[0..self.count]); - allocator.free(self.buf); - self.* = Self.init(allocator); - return new_memory; - } - }; -} - -test "LinearFifo(u8, .Dynamic) discard(0) from empty buffer should not error on overflow" { - var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator); - defer fifo.deinit(); - - // If overflow is not explicitly allowed this will crash in debug / safe mode - fifo.discard(0); -} - -test "LinearFifo(u8, .Dynamic)" { - var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator); - defer fifo.deinit(); - - try fifo.write("HELLO"); - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - try testing.expectEqualSlices(u8, "HELLO", fifo.readableSlice(0)); - - { - var i: usize = 0; - while (i < 5) : (i += 1) { - try fifo.write(&[_]u8{fifo.peekItem(i)}); - } - try testing.expectEqual(@as(usize, 10), fifo.readableLength()); - try testing.expectEqualSlices(u8, "HELLOHELLO", fifo.readableSlice(0)); - } - - { - try testing.expectEqual(@as(u8, 'H'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'E'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'O'), fifo.readItem().?); - } - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - - { // Writes that wrap around - try testing.expectEqual(@as(usize, 11), fifo.writableLength()); - try testing.expectEqual(@as(usize, 6), fifo.writableSlice(0).len); - fifo.writeAssumeCapacity("6 FifoType.init(), - .Slice => FifoType.init(buf[0..]), - .Dynamic => FifoType.init(testing.allocator), - }; - defer fifo.deinit(); - - try fifo.write(&[_]T{ 0, 1, 1, 0, 1 }); - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - - { - try testing.expectEqual(@as(T, 0), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(T, 0), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(usize, 0), fifo.readableLength()); - } - - { - try fifo.writeItem(1); - try fifo.writeItem(1); - try fifo.writeItem(1); - try testing.expectEqual(@as(usize, 3), fifo.readableLength()); - } - - { - var readBuf: [3]T = undefined; - const n = fifo.read(&readBuf); - try testing.expectEqual(@as(usize, 3), n); // NOTE: It should be the number of items. - } - } - } -} diff --git a/lib/std/std.zig b/lib/std/std.zig index aaae4c2eba..5cca56262a 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -57,7 +57,6 @@ pub const debug = @import("debug.zig"); pub const dwarf = @import("dwarf.zig"); pub const elf = @import("elf.zig"); pub const enums = @import("enums.zig"); -pub const fifo = @import("fifo.zig"); pub const fmt = @import("fmt.zig"); pub const fs = @import("fs.zig"); pub const gpu = @import("gpu.zig"); From 499bf2a1c77f21fc8724c9999fa233f283f3dc77 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 00:22:03 -0700 Subject: [PATCH 61/70] std.Io: delete BufferedReader --- lib/compiler/resinator/errors.zig | 8 +- lib/docs/wasm/markdown.zig | 13 +- lib/std/Io.zig | 7 -- lib/std/Io/buffered_reader.zig | 201 ------------------------------ lib/std/Io/test.zig | 6 +- src/Package/Fetch.zig | 10 +- 6 files changed, 13 insertions(+), 232 deletions(-) delete mode 100644 lib/std/Io/buffered_reader.zig diff --git a/lib/compiler/resinator/errors.zig b/lib/compiler/resinator/errors.zig index 14a001894e..431f14692a 100644 --- a/lib/compiler/resinator/errors.zig +++ b/lib/compiler/resinator/errors.zig @@ -1078,11 +1078,9 @@ const CorrespondingLines = struct { at_eof: bool = false, span: SourceMappings.CorrespondingSpan, file: std.fs.File, - buffered_reader: BufferedReaderType, + buffered_reader: *std.Io.Reader, code_page: SupportedCodePage, - const BufferedReaderType = std.io.BufferedReader(512, std.fs.File.DeprecatedReader); - pub fn init(cwd: std.fs.Dir, err_details: ErrorDetails, line_for_comparison: []const u8, corresponding_span: SourceMappings.CorrespondingSpan, corresponding_file: []const u8) !CorrespondingLines { // We don't do line comparison for this error, so don't print the note if the line // number is different @@ -1101,9 +1099,7 @@ const CorrespondingLines = struct { .buffered_reader = undefined, .code_page = err_details.code_page, }; - corresponding_lines.buffered_reader = BufferedReaderType{ - .unbuffered_reader = corresponding_lines.file.deprecatedReader(), - }; + corresponding_lines.buffered_reader = corresponding_lines.file.reader(); errdefer corresponding_lines.deinit(); var fbs = std.io.fixedBufferStream(&corresponding_lines.line_buf); diff --git a/lib/docs/wasm/markdown.zig b/lib/docs/wasm/markdown.zig index 3293b680c9..32e32b4104 100644 --- a/lib/docs/wasm/markdown.zig +++ b/lib/docs/wasm/markdown.zig @@ -145,13 +145,12 @@ fn mainImpl() !void { var parser = try Parser.init(gpa); defer parser.deinit(); - var stdin_buf = std.io.bufferedReader(std.fs.File.stdin().deprecatedReader()); - var line_buf = std.ArrayList(u8).init(gpa); - defer line_buf.deinit(); - while (stdin_buf.reader().streamUntilDelimiter(line_buf.writer(), '\n', null)) { - if (line_buf.getLastOrNull() == '\r') _ = line_buf.pop(); - try parser.feedLine(line_buf.items); - line_buf.clearRetainingCapacity(); + var stdin_buffer: [1024]u8 = undefined; + var stdin_reader = std.fs.File.stdin().reader(&stdin_buffer); + + while (stdin_reader.takeDelimiterExclusive('\n')) |line| { + const trimmed = std.mem.trimRight(u8, line, '\r'); + try parser.feedLine(trimmed); } else |err| switch (err) { error.EndOfStream => {}, else => |e| return e, diff --git a/lib/std/Io.zig b/lib/std/Io.zig index 688120c08b..9018f4b714 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -428,12 +428,6 @@ pub const BufferedWriter = @import("Io/buffered_writer.zig").BufferedWriter; /// Deprecated in favor of `Writer`. pub const bufferedWriter = @import("Io/buffered_writer.zig").bufferedWriter; /// Deprecated in favor of `Reader`. -pub const BufferedReader = @import("Io/buffered_reader.zig").BufferedReader; -/// Deprecated in favor of `Reader`. -pub const bufferedReader = @import("Io/buffered_reader.zig").bufferedReader; -/// Deprecated in favor of `Reader`. -pub const bufferedReaderSize = @import("Io/buffered_reader.zig").bufferedReaderSize; -/// Deprecated in favor of `Reader`. pub const FixedBufferStream = @import("Io/fixed_buffer_stream.zig").FixedBufferStream; /// Deprecated in favor of `Reader`. pub const fixedBufferStream = @import("Io/fixed_buffer_stream.zig").fixedBufferStream; @@ -926,7 +920,6 @@ pub fn PollFiles(comptime StreamEnum: type) type { test { _ = Reader; _ = Writer; - _ = BufferedReader; _ = BufferedWriter; _ = CountingWriter; _ = CountingReader; diff --git a/lib/std/Io/buffered_reader.zig b/lib/std/Io/buffered_reader.zig deleted file mode 100644 index 548dd92f73..0000000000 --- a/lib/std/Io/buffered_reader.zig +++ /dev/null @@ -1,201 +0,0 @@ -const std = @import("../std.zig"); -const io = std.io; -const mem = std.mem; -const assert = std.debug.assert; -const testing = std.testing; - -pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) type { - return struct { - unbuffered_reader: ReaderType, - buf: [buffer_size]u8 = undefined, - start: usize = 0, - end: usize = 0, - - pub const Error = ReaderType.Error; - pub const Reader = io.GenericReader(*Self, Error, read); - - const Self = @This(); - - pub fn read(self: *Self, dest: []u8) Error!usize { - // First try reading from the already buffered data onto the destination. - const current = self.buf[self.start..self.end]; - if (current.len != 0) { - const to_transfer = @min(current.len, dest.len); - @memcpy(dest[0..to_transfer], current[0..to_transfer]); - self.start += to_transfer; - return to_transfer; - } - - // If dest is large, read from the unbuffered reader directly into the destination. - if (dest.len >= buffer_size) { - return self.unbuffered_reader.read(dest); - } - - // If dest is small, read from the unbuffered reader into our own internal buffer, - // and then transfer to destination. - self.end = try self.unbuffered_reader.read(&self.buf); - const to_transfer = @min(self.end, dest.len); - @memcpy(dest[0..to_transfer], self.buf[0..to_transfer]); - self.start = to_transfer; - return to_transfer; - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; -} - -pub fn bufferedReader(reader: anytype) BufferedReader(4096, @TypeOf(reader)) { - return .{ .unbuffered_reader = reader }; -} - -pub fn bufferedReaderSize(comptime size: usize, reader: anytype) BufferedReader(size, @TypeOf(reader)) { - return .{ .unbuffered_reader = reader }; -} - -test "OneByte" { - const OneByteReadReader = struct { - str: []const u8, - curr: usize, - - const Error = error{NoError}; - const Self = @This(); - const Reader = io.GenericReader(*Self, Error, read); - - fn init(str: []const u8) Self { - return Self{ - .str = str, - .curr = 0, - }; - } - - fn read(self: *Self, dest: []u8) Error!usize { - if (self.str.len <= self.curr or dest.len == 0) - return 0; - - dest[0] = self.str[self.curr]; - self.curr += 1; - return 1; - } - - fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; - - const str = "This is a test"; - var one_byte_stream = OneByteReadReader.init(str); - var buf_reader = bufferedReader(one_byte_stream.reader()); - const stream = buf_reader.reader(); - - const res = try stream.readAllAlloc(testing.allocator, str.len + 1); - defer testing.allocator.free(res); - try testing.expectEqualSlices(u8, str, res); -} - -fn smallBufferedReader(underlying_stream: anytype) BufferedReader(8, @TypeOf(underlying_stream)) { - return .{ .unbuffered_reader = underlying_stream }; -} -test "Block" { - const BlockReader = struct { - block: []const u8, - reads_allowed: usize, - curr_read: usize, - - const Error = error{NoError}; - const Self = @This(); - const Reader = io.GenericReader(*Self, Error, read); - - fn init(block: []const u8, reads_allowed: usize) Self { - return Self{ - .block = block, - .reads_allowed = reads_allowed, - .curr_read = 0, - }; - } - - fn read(self: *Self, dest: []u8) Error!usize { - if (self.curr_read >= self.reads_allowed) return 0; - @memcpy(dest[0..self.block.len], self.block); - - self.curr_read += 1; - return self.block.len; - } - - fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; - - const block = "0123"; - - // len out == block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [4]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out < block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [3]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out > block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [5]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out == 0 - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [0]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, ""); - } - - // len bufreader buf > block - { - var test_buf_reader: BufferedReader(5, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [4]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } -} diff --git a/lib/std/Io/test.zig b/lib/std/Io/test.zig index e0fcea7674..9733e6f044 100644 --- a/lib/std/Io/test.zig +++ b/lib/std/Io/test.zig @@ -45,9 +45,9 @@ test "write a file, read it, then delete it" { const expected_file_size: u64 = "begin".len + data.len + "end".len; try expectEqual(expected_file_size, file_size); - var buf_stream = io.bufferedReader(file.deprecatedReader()); - const st = buf_stream.reader(); - const contents = try st.readAllAlloc(std.testing.allocator, 2 * 1024); + var file_buffer: [1024]u8 = undefined; + var file_reader = file.reader(&file_buffer); + const contents = try file_reader.interface.allocRemaining(std.testing.allocator, .limited(2 * 1024)); defer std.testing.allocator.free(contents); try expect(mem.eql(u8, contents[0.."begin".len], "begin")); diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index b6d22a37ac..2a11e1ea6e 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -1194,14 +1194,8 @@ fn unpackResource( }, .@"tar.xz" => { const gpa = f.arena.child_allocator; - const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); - var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to decompress tarball: {s}", - .{@errorName(err)}, - )); - }; + var dcp = std.compress.xz.decompress(gpa, resource.reader().adaptToOldInterface()) catch |err| + return f.fail(f.location_tok, try eb.printString("unable to decompress tarball: {t}", .{err})); defer dcp.deinit(); var adapter_buffer: [1024]u8 = undefined; var adapter = dcp.reader().adaptToNewApi(&adapter_buffer); From 4d4251c9f95254fe94858a4c2319cbb59d2261d0 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 1 Aug 2025 00:24:35 -0700 Subject: [PATCH 62/70] std.Io: delete LimitedReader --- lib/std/Io.zig | 4 ---- lib/std/Io/limited_reader.zig | 45 ----------------------------------- 2 files changed, 49 deletions(-) delete mode 100644 lib/std/Io/limited_reader.zig diff --git a/lib/std/Io.zig b/lib/std/Io.zig index 9018f4b714..bd1a724e4b 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -431,10 +431,6 @@ pub const bufferedWriter = @import("Io/buffered_writer.zig").bufferedWriter; pub const FixedBufferStream = @import("Io/fixed_buffer_stream.zig").FixedBufferStream; /// Deprecated in favor of `Reader`. pub const fixedBufferStream = @import("Io/fixed_buffer_stream.zig").fixedBufferStream; -/// Deprecated in favor of `Reader.Limited`. -pub const LimitedReader = @import("Io/limited_reader.zig").LimitedReader; -/// Deprecated in favor of `Reader.Limited`. -pub const limitedReader = @import("Io/limited_reader.zig").limitedReader; /// Deprecated with no replacement; inefficient pattern pub const CountingWriter = @import("Io/counting_writer.zig").CountingWriter; /// Deprecated with no replacement; inefficient pattern diff --git a/lib/std/Io/limited_reader.zig b/lib/std/Io/limited_reader.zig deleted file mode 100644 index b6b555f76d..0000000000 --- a/lib/std/Io/limited_reader.zig +++ /dev/null @@ -1,45 +0,0 @@ -const std = @import("../std.zig"); -const io = std.io; -const assert = std.debug.assert; -const testing = std.testing; - -pub fn LimitedReader(comptime ReaderType: type) type { - return struct { - inner_reader: ReaderType, - bytes_left: u64, - - pub const Error = ReaderType.Error; - pub const Reader = io.GenericReader(*Self, Error, read); - - const Self = @This(); - - pub fn read(self: *Self, dest: []u8) Error!usize { - const max_read = @min(self.bytes_left, dest.len); - const n = try self.inner_reader.read(dest[0..max_read]); - self.bytes_left -= n; - return n; - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; -} - -/// Returns an initialised `LimitedReader`. -/// `bytes_left` is a `u64` to be able to take 64 bit file offsets -pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) { - return .{ .inner_reader = inner_reader, .bytes_left = bytes_left }; -} - -test "basic usage" { - const data = "hello world"; - var fbs = std.io.fixedBufferStream(data); - var early_stream = limitedReader(fbs.reader(), 3); - - var buf: [5]u8 = undefined; - try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf)); - try testing.expectEqualSlices(u8, data[0..3], buf[0..3]); - try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf)); - try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{})); -} From 4395f512c932826360ed9efa1564f4f22693394b Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Aug 2025 22:26:21 -0700 Subject: [PATCH 63/70] resinator: just enough fixes to make it compile --- lib/compiler/resinator/compile.zig | 98 +++++++++++++++--------------- lib/compiler/resinator/cvtres.zig | 36 +++++------ lib/compiler/resinator/errors.zig | 8 +-- lib/compiler/resinator/ico.zig | 3 +- 4 files changed, 75 insertions(+), 70 deletions(-) diff --git a/lib/compiler/resinator/compile.zig b/lib/compiler/resinator/compile.zig index 3515421ff0..18d142eb96 100644 --- a/lib/compiler/resinator/compile.zig +++ b/lib/compiler/resinator/compile.zig @@ -550,7 +550,7 @@ pub const Compiler = struct { // so get it here to simplify future usage. const filename_token = node.filename.getFirstToken(); - const file = self.searchForFile(filename_utf8) catch |err| switch (err) { + const file_handle = self.searchForFile(filename_utf8) catch |err| switch (err) { error.OutOfMemory => |e| return e, else => |e| { const filename_string_index = try self.diagnostics.putString(filename_utf8); @@ -564,13 +564,15 @@ pub const Compiler = struct { }); }, }; - defer file.close(); + defer file_handle.close(); + var file_buffer: [2048]u8 = undefined; + var file_reader = file_handle.reader(&file_buffer); if (maybe_predefined_type) |predefined_type| { switch (predefined_type) { .GROUP_ICON, .GROUP_CURSOR => { // Check for animated icon first - if (ani.isAnimatedIcon(file.deprecatedReader())) { + if (ani.isAnimatedIcon(file_reader.interface.adaptToOldInterface())) { // Animated icons are just put into the resource unmodified, // and the resource type changes to ANIICON/ANICURSOR @@ -582,18 +584,18 @@ pub const Compiler = struct { header.type_value.ordinal = @intFromEnum(new_predefined_type); header.memory_flags = MemoryFlags.defaults(new_predefined_type); header.applyMemoryFlags(node.common_resource_attributes, self.source); - header.data_size = @intCast(try file.getEndPos()); + header.data_size = @intCast(try file_reader.getSize()); try header.write(writer, self.errContext(node.id)); - try file.seekTo(0); - try writeResourceData(writer, file.deprecatedReader(), header.data_size); + try file_reader.seekTo(0); + try writeResourceData(writer, &file_reader.interface, header.data_size); return; } // isAnimatedIcon moved the file cursor so reset to the start - try file.seekTo(0); + try file_reader.seekTo(0); - const icon_dir = ico.read(self.allocator, file.deprecatedReader(), try file.getEndPos()) catch |err| switch (err) { + const icon_dir = ico.read(self.allocator, file_reader.interface.adaptToOldInterface(), try file_reader.getSize()) catch |err| switch (err) { error.OutOfMemory => |e| return e, else => |e| { return self.iconReadError( @@ -671,15 +673,15 @@ pub const Compiler = struct { try writer.writeInt(u16, entry.type_specific_data.cursor.hotspot_y, .little); } - try file.seekTo(entry.data_offset_from_start_of_file); - var header_bytes = file.deprecatedReader().readBytesNoEof(16) catch { + try file_reader.seekTo(entry.data_offset_from_start_of_file); + var header_bytes = (file_reader.interface.takeArray(16) catch { return self.iconReadError( error.UnexpectedEOF, filename_utf8, filename_token, predefined_type, ); - }; + }).*; const image_format = ico.ImageFormat.detect(&header_bytes); if (!image_format.validate(&header_bytes)) { @@ -802,8 +804,8 @@ pub const Compiler = struct { }, } - try file.seekTo(entry.data_offset_from_start_of_file); - try writeResourceDataNoPadding(writer, file.deprecatedReader(), entry.data_size_in_bytes); + try file_reader.seekTo(entry.data_offset_from_start_of_file); + try writeResourceDataNoPadding(writer, &file_reader.interface, entry.data_size_in_bytes); try writeDataPadding(writer, full_data_size); if (self.state.icon_id == std.math.maxInt(u16)) { @@ -857,9 +859,9 @@ pub const Compiler = struct { }, .BITMAP => { header.applyMemoryFlags(node.common_resource_attributes, self.source); - const file_size = try file.getEndPos(); + const file_size = try file_reader.getSize(); - const bitmap_info = bmp.read(file.deprecatedReader(), file_size) catch |err| { + const bitmap_info = bmp.read(file_reader.interface.adaptToOldInterface(), file_size) catch |err| { const filename_string_index = try self.diagnostics.putString(filename_utf8); return self.addErrorDetailsAndFail(.{ .err = .bmp_read_error, @@ -921,18 +923,17 @@ pub const Compiler = struct { header.data_size = bmp_bytes_to_write; try header.write(writer, self.errContext(node.id)); - try file.seekTo(bmp.file_header_len); - const file_reader = file.deprecatedReader(); - try writeResourceDataNoPadding(writer, file_reader, bitmap_info.dib_header_size); + try file_reader.seekTo(bmp.file_header_len); + try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.dib_header_size); if (bitmap_info.getBitmasksByteLen() > 0) { - try writeResourceDataNoPadding(writer, file_reader, bitmap_info.getBitmasksByteLen()); + try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.getBitmasksByteLen()); } if (bitmap_info.getExpectedPaletteByteLen() > 0) { - try writeResourceDataNoPadding(writer, file_reader, @intCast(bitmap_info.getActualPaletteByteLen())); + try writeResourceDataNoPadding(writer, &file_reader.interface, @intCast(bitmap_info.getActualPaletteByteLen())); } - try file.seekTo(bitmap_info.pixel_data_offset); + try file_reader.seekTo(bitmap_info.pixel_data_offset); const pixel_bytes: u32 = @intCast(file_size - bitmap_info.pixel_data_offset); - try writeResourceDataNoPadding(writer, file_reader, pixel_bytes); + try writeResourceDataNoPadding(writer, &file_reader.interface, pixel_bytes); try writeDataPadding(writer, bmp_bytes_to_write); return; }, @@ -956,7 +957,7 @@ pub const Compiler = struct { return; } header.applyMemoryFlags(node.common_resource_attributes, self.source); - const file_size = try file.getEndPos(); + const file_size = try file_reader.getSize(); if (file_size > std.math.maxInt(u32)) { return self.addErrorDetailsAndFail(.{ .err = .resource_data_size_exceeds_max, @@ -968,8 +969,9 @@ pub const Compiler = struct { header.data_size = @intCast(file_size); try header.write(writer, self.errContext(node.id)); - var header_slurping_reader = headerSlurpingReader(148, file.deprecatedReader()); - try writeResourceData(writer, header_slurping_reader.reader(), header.data_size); + var header_slurping_reader = headerSlurpingReader(148, file_reader.interface.adaptToOldInterface()); + var adapter = header_slurping_reader.reader().adaptToNewApi(&.{}); + try writeResourceData(writer, &adapter.new_interface, header.data_size); try self.state.font_dir.add(self.arena, FontDir.Font{ .id = header.name_value.ordinal, @@ -992,7 +994,7 @@ pub const Compiler = struct { } // Fallback to just writing out the entire contents of the file - const data_size = try file.getEndPos(); + const data_size = try file_reader.getSize(); if (data_size > std.math.maxInt(u32)) { return self.addErrorDetailsAndFail(.{ .err = .resource_data_size_exceeds_max, @@ -1002,7 +1004,7 @@ pub const Compiler = struct { // We now know that the data size will fit in a u32 header.data_size = @intCast(data_size); try header.write(writer, self.errContext(node.id)); - try writeResourceData(writer, file.deprecatedReader(), header.data_size); + try writeResourceData(writer, &file_reader.interface, header.data_size); } fn iconReadError( @@ -1250,8 +1252,8 @@ pub const Compiler = struct { const data_len: u32 = @intCast(data_buffer.items.len); try self.writeResourceHeader(writer, node.id, node.type, data_len, node.common_resource_attributes, self.state.language); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_len); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_len); } pub fn writeResourceHeader(self: *Compiler, writer: anytype, id_token: Token, type_token: Token, data_size: u32, common_resource_attributes: []Token, language: res.Language) !void { @@ -1266,15 +1268,15 @@ pub const Compiler = struct { try header.write(writer, self.errContext(id_token)); } - pub fn writeResourceDataNoPadding(writer: anytype, data_reader: anytype, data_size: u32) !void { - var limited_reader = std.io.limitedReader(data_reader, data_size); - - const FifoBuffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }); - var fifo = FifoBuffer.init(); - try fifo.pump(limited_reader.reader(), writer); + pub fn writeResourceDataNoPadding(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void { + var adapted = writer.adaptToNewApi(); + var buffer: [128]u8 = undefined; + adapted.new_interface.buffer = &buffer; + try data_reader.streamExact(&adapted.new_interface, data_size); + try adapted.new_interface.flush(); } - pub fn writeResourceData(writer: anytype, data_reader: anytype, data_size: u32) !void { + pub fn writeResourceData(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void { try writeResourceDataNoPadding(writer, data_reader, data_size); try writeDataPadding(writer, data_size); } @@ -1339,8 +1341,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to @@ -1732,8 +1734,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } fn writeDialogHeaderAndStrings( @@ -2046,8 +2048,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Weight and italic carry over from previous FONT statements within a single resource, @@ -2121,8 +2123,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to @@ -2386,8 +2388,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects writer to be a LimitedWriter limited to u16, meaning all writes to @@ -3321,8 +3323,8 @@ pub const StringTable = struct { // we fully control and know are numbers, so they have a fixed size. try header.writeAssertNoOverflow(writer); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try Compiler.writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try Compiler.writeResourceData(writer, &data_fbs, data_size); } }; diff --git a/lib/compiler/resinator/cvtres.zig b/lib/compiler/resinator/cvtres.zig index 21574f2704..e0c69ad4d9 100644 --- a/lib/compiler/resinator/cvtres.zig +++ b/lib/compiler/resinator/cvtres.zig @@ -105,31 +105,33 @@ pub const ResourceAndSize = struct { pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !ResourceAndSize { var header_counting_reader = std.io.countingReader(reader); - const header_reader = header_counting_reader.reader(); - const data_size = try header_reader.readInt(u32, .little); - const header_size = try header_reader.readInt(u32, .little); + var buffer: [1024]u8 = undefined; + var header_reader_adapter = header_counting_reader.reader().adaptToNewApi(&buffer); + const header_reader = &header_reader_adapter.new_interface; + const data_size = try header_reader.takeInt(u32, .little); + const header_size = try header_reader.takeInt(u32, .little); const total_size: u64 = @as(u64, header_size) + data_size; if (total_size > max_size) return error.ImpossibleSize; var header_bytes_available = header_size -| 8; - var type_reader = std.io.limitedReader(header_reader, header_bytes_available); - const type_value = try parseNameOrOrdinal(allocator, type_reader.reader()); + var type_reader: std.Io.Reader = .fixed(try header_reader.take(header_bytes_available)); + const type_value = try parseNameOrOrdinal(allocator, &type_reader); errdefer type_value.deinit(allocator); header_bytes_available -|= @intCast(type_value.byteLen()); - var name_reader = std.io.limitedReader(header_reader, header_bytes_available); - const name_value = try parseNameOrOrdinal(allocator, name_reader.reader()); + var name_reader: std.Io.Reader = .fixed(try header_reader.take(header_bytes_available)); + const name_value = try parseNameOrOrdinal(allocator, &name_reader); errdefer name_value.deinit(allocator); const padding_after_name = numPaddingBytesNeeded(@intCast(header_counting_reader.bytes_read)); - try header_reader.skipBytes(padding_after_name, .{ .buf_size = 3 }); + try header_reader.discardAll(padding_after_name); std.debug.assert(header_counting_reader.bytes_read % 4 == 0); - const data_version = try header_reader.readInt(u32, .little); - const memory_flags: MemoryFlags = @bitCast(try header_reader.readInt(u16, .little)); - const language: Language = @bitCast(try header_reader.readInt(u16, .little)); - const version = try header_reader.readInt(u32, .little); - const characteristics = try header_reader.readInt(u32, .little); + const data_version = try header_reader.takeInt(u32, .little); + const memory_flags: MemoryFlags = @bitCast(try header_reader.takeInt(u16, .little)); + const language: Language = @bitCast(try header_reader.takeInt(u16, .little)); + const version = try header_reader.takeInt(u32, .little); + const characteristics = try header_reader.takeInt(u32, .little); const header_bytes_read = header_counting_reader.bytes_read; if (header_size != header_bytes_read) return error.HeaderSizeMismatch; @@ -156,10 +158,10 @@ pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !Reso }; } -pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal { - const first_code_unit = try reader.readInt(u16, .little); +pub fn parseNameOrOrdinal(allocator: Allocator, reader: *std.Io.Reader) !NameOrOrdinal { + const first_code_unit = try reader.takeInt(u16, .little); if (first_code_unit == 0xFFFF) { - const ordinal_value = try reader.readInt(u16, .little); + const ordinal_value = try reader.takeInt(u16, .little); return .{ .ordinal = ordinal_value }; } var name_buf = try std.ArrayListUnmanaged(u16).initCapacity(allocator, 16); @@ -167,7 +169,7 @@ pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal var code_unit = first_code_unit; while (code_unit != 0) { try name_buf.append(allocator, std.mem.nativeToLittle(u16, code_unit)); - code_unit = try reader.readInt(u16, .little); + code_unit = try reader.takeInt(u16, .little); } return .{ .name = try name_buf.toOwnedSliceSentinel(allocator, 0) }; } diff --git a/lib/compiler/resinator/errors.zig b/lib/compiler/resinator/errors.zig index 431f14692a..4bc443c4e7 100644 --- a/lib/compiler/resinator/errors.zig +++ b/lib/compiler/resinator/errors.zig @@ -1078,7 +1078,7 @@ const CorrespondingLines = struct { at_eof: bool = false, span: SourceMappings.CorrespondingSpan, file: std.fs.File, - buffered_reader: *std.Io.Reader, + buffered_reader: std.fs.File.Reader, code_page: SupportedCodePage, pub fn init(cwd: std.fs.Dir, err_details: ErrorDetails, line_for_comparison: []const u8, corresponding_span: SourceMappings.CorrespondingSpan, corresponding_file: []const u8) !CorrespondingLines { @@ -1099,7 +1099,7 @@ const CorrespondingLines = struct { .buffered_reader = undefined, .code_page = err_details.code_page, }; - corresponding_lines.buffered_reader = corresponding_lines.file.reader(); + corresponding_lines.buffered_reader = corresponding_lines.file.reader(&.{}); errdefer corresponding_lines.deinit(); var fbs = std.io.fixedBufferStream(&corresponding_lines.line_buf); @@ -1107,7 +1107,7 @@ const CorrespondingLines = struct { try corresponding_lines.writeLineFromStreamVerbatim( writer, - corresponding_lines.buffered_reader.reader(), + corresponding_lines.buffered_reader.interface.adaptToOldInterface(), corresponding_span.start_line, ); @@ -1150,7 +1150,7 @@ const CorrespondingLines = struct { try self.writeLineFromStreamVerbatim( writer, - self.buffered_reader.reader(), + self.buffered_reader.interface.adaptToOldInterface(), self.line_num, ); diff --git a/lib/compiler/resinator/ico.zig b/lib/compiler/resinator/ico.zig index e6de1d469e..a73becd7b9 100644 --- a/lib/compiler/resinator/ico.zig +++ b/lib/compiler/resinator/ico.zig @@ -18,7 +18,8 @@ pub fn read(allocator: std.mem.Allocator, reader: anytype, max_size: u64) ReadEr if (empty_reader_errorset) { return readAnyError(allocator, reader, max_size) catch |err| switch (err) { error.EndOfStream => error.UnexpectedEOF, - else => |e| return e, + error.OutOfMemory, error.InvalidHeader, error.InvalidImageType, error.ImpossibleDataSize, error.UnexpectedEOF, error.ReadError => |e| return e, + else => return error.ReadError, }; } else { return readAnyError(allocator, reader, max_size) catch |err| switch (err) { From 063d87132a779d912a7db59872a586a6c331f154 Mon Sep 17 00:00:00 2001 From: Ryan Liptak Date: Wed, 6 Aug 2025 01:11:00 -0700 Subject: [PATCH 64/70] resinator: a few more updates/fixes Just enough to get things working correctly again --- lib/compiler/resinator/cli.zig | 12 ++++---- lib/compiler/resinator/cvtres.zig | 47 +++++++++++++------------------ lib/compiler/resinator/ico.zig | 6 ++-- lib/compiler/resinator/main.zig | 4 +-- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/lib/compiler/resinator/cli.zig b/lib/compiler/resinator/cli.zig index 6d4a368b4e..bfc67d8791 100644 --- a/lib/compiler/resinator/cli.zig +++ b/lib/compiler/resinator/cli.zig @@ -1141,6 +1141,8 @@ pub fn parse(allocator: Allocator, args: []const []const u8, diagnostics: *Diagn } output_format = .res; } + } else { + output_format_source = .output_format_arg; } options.output_source = .{ .filename = try filepathWithExtension(allocator, options.input_source.filename, output_format.?.extension()) }; } else { @@ -1529,21 +1531,21 @@ fn testParseOutput(args: []const []const u8, expected_output: []const u8) !?Opti var diagnostics = Diagnostics.init(std.testing.allocator); defer diagnostics.deinit(); - var output = std.ArrayList(u8).init(std.testing.allocator); + var output: std.io.Writer.Allocating = .init(std.testing.allocator); defer output.deinit(); var options = parse(std.testing.allocator, args, &diagnostics) catch |err| switch (err) { error.ParseError => { - try diagnostics.renderToWriter(args, output.writer(), .no_color); - try std.testing.expectEqualStrings(expected_output, output.items); + try diagnostics.renderToWriter(args, &output.writer, .no_color); + try std.testing.expectEqualStrings(expected_output, output.getWritten()); return null; }, else => |e| return e, }; errdefer options.deinit(); - try diagnostics.renderToWriter(args, output.writer(), .no_color); - try std.testing.expectEqualStrings(expected_output, output.items); + try diagnostics.renderToWriter(args, &output.writer, .no_color); + try std.testing.expectEqualStrings(expected_output, output.getWritten()); return options; } diff --git a/lib/compiler/resinator/cvtres.zig b/lib/compiler/resinator/cvtres.zig index e0c69ad4d9..27e14ae9a3 100644 --- a/lib/compiler/resinator/cvtres.zig +++ b/lib/compiler/resinator/cvtres.zig @@ -65,7 +65,7 @@ pub const ParseResOptions = struct { }; /// The returned ParsedResources should be freed by calling its `deinit` function. -pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions) !ParsedResources { +pub fn parseRes(allocator: Allocator, reader: *std.Io.Reader, options: ParseResOptions) !ParsedResources { var resources = ParsedResources.init(allocator); errdefer resources.deinit(); @@ -74,7 +74,7 @@ pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions) return resources; } -pub fn parseResInto(resources: *ParsedResources, reader: anytype, options: ParseResOptions) !void { +pub fn parseResInto(resources: *ParsedResources, reader: *std.Io.Reader, options: ParseResOptions) !void { const allocator = resources.allocator; var bytes_remaining: u64 = options.max_size; { @@ -103,45 +103,38 @@ pub const ResourceAndSize = struct { total_size: u64, }; -pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !ResourceAndSize { - var header_counting_reader = std.io.countingReader(reader); - var buffer: [1024]u8 = undefined; - var header_reader_adapter = header_counting_reader.reader().adaptToNewApi(&buffer); - const header_reader = &header_reader_adapter.new_interface; - const data_size = try header_reader.takeInt(u32, .little); - const header_size = try header_reader.takeInt(u32, .little); +pub fn parseResource(allocator: Allocator, reader: *std.Io.Reader, max_size: u64) !ResourceAndSize { + const data_size = try reader.takeInt(u32, .little); + const header_size = try reader.takeInt(u32, .little); const total_size: u64 = @as(u64, header_size) + data_size; if (total_size > max_size) return error.ImpossibleSize; - var header_bytes_available = header_size -| 8; - var type_reader: std.Io.Reader = .fixed(try header_reader.take(header_bytes_available)); - const type_value = try parseNameOrOrdinal(allocator, &type_reader); + const remaining_header_bytes = try reader.take(header_size -| 8); + var remaining_header_reader: std.Io.Reader = .fixed(remaining_header_bytes); + const type_value = try parseNameOrOrdinal(allocator, &remaining_header_reader); errdefer type_value.deinit(allocator); - header_bytes_available -|= @intCast(type_value.byteLen()); - var name_reader: std.Io.Reader = .fixed(try header_reader.take(header_bytes_available)); - const name_value = try parseNameOrOrdinal(allocator, &name_reader); + const name_value = try parseNameOrOrdinal(allocator, &remaining_header_reader); errdefer name_value.deinit(allocator); - const padding_after_name = numPaddingBytesNeeded(@intCast(header_counting_reader.bytes_read)); - try header_reader.discardAll(padding_after_name); + const padding_after_name = numPaddingBytesNeeded(@intCast(remaining_header_reader.seek)); + try remaining_header_reader.discardAll(padding_after_name); - std.debug.assert(header_counting_reader.bytes_read % 4 == 0); - const data_version = try header_reader.takeInt(u32, .little); - const memory_flags: MemoryFlags = @bitCast(try header_reader.takeInt(u16, .little)); - const language: Language = @bitCast(try header_reader.takeInt(u16, .little)); - const version = try header_reader.takeInt(u32, .little); - const characteristics = try header_reader.takeInt(u32, .little); + std.debug.assert(remaining_header_reader.seek % 4 == 0); + const data_version = try remaining_header_reader.takeInt(u32, .little); + const memory_flags: MemoryFlags = @bitCast(try remaining_header_reader.takeInt(u16, .little)); + const language: Language = @bitCast(try remaining_header_reader.takeInt(u16, .little)); + const version = try remaining_header_reader.takeInt(u32, .little); + const characteristics = try remaining_header_reader.takeInt(u32, .little); - const header_bytes_read = header_counting_reader.bytes_read; - if (header_size != header_bytes_read) return error.HeaderSizeMismatch; + if (remaining_header_reader.seek != remaining_header_reader.end) return error.HeaderSizeMismatch; const data = try allocator.alloc(u8, data_size); errdefer allocator.free(data); - try reader.readNoEof(data); + try reader.readSliceAll(data); const padding_after_data = numPaddingBytesNeeded(@intCast(data_size)); - try reader.skipBytes(padding_after_data, .{ .buf_size = 3 }); + try reader.discardAll(padding_after_data); return .{ .resource = .{ diff --git a/lib/compiler/resinator/ico.zig b/lib/compiler/resinator/ico.zig index a73becd7b9..dca74fc857 100644 --- a/lib/compiler/resinator/ico.zig +++ b/lib/compiler/resinator/ico.zig @@ -14,12 +14,12 @@ pub fn read(allocator: std.mem.Allocator, reader: anytype, max_size: u64) ReadEr // Some Reader implementations have an empty ReadError error set which would // cause 'unreachable else' if we tried to use an else in the switch, so we // need to detect this case and not try to translate to ReadError + const anyerror_reader_errorset = @TypeOf(reader).Error == anyerror; const empty_reader_errorset = @typeInfo(@TypeOf(reader).Error).error_set == null or @typeInfo(@TypeOf(reader).Error).error_set.?.len == 0; - if (empty_reader_errorset) { + if (empty_reader_errorset and !anyerror_reader_errorset) { return readAnyError(allocator, reader, max_size) catch |err| switch (err) { error.EndOfStream => error.UnexpectedEOF, - error.OutOfMemory, error.InvalidHeader, error.InvalidImageType, error.ImpossibleDataSize, error.UnexpectedEOF, error.ReadError => |e| return e, - else => return error.ReadError, + else => |e| return e, }; } else { return readAnyError(allocator, reader, max_size) catch |err| switch (err) { diff --git a/lib/compiler/resinator/main.zig b/lib/compiler/resinator/main.zig index 30e9c825bb..3187f038b9 100644 --- a/lib/compiler/resinator/main.zig +++ b/lib/compiler/resinator/main.zig @@ -325,8 +325,8 @@ pub fn main() !void { std.debug.assert(options.output_format == .coff); // TODO: Maybe use a buffered file reader instead of reading file into memory -> fbs - var fbs = std.io.fixedBufferStream(res_data.bytes); - break :resources cvtres.parseRes(allocator, fbs.reader(), .{ .max_size = res_data.bytes.len }) catch |err| { + var res_reader: std.Io.Reader = .fixed(res_data.bytes); + break :resources cvtres.parseRes(allocator, &res_reader, .{ .max_size = res_data.bytes.len }) catch |err| { // TODO: Better errors try error_handler.emitMessage(allocator, .err, "unable to parse res from '{s}': {s}", .{ res_stream.name, @errorName(err) }); std.process.exit(1); From 995bfdf0ffb28cf2031b0b3ada53b2cb9275912f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 18:37:48 -0700 Subject: [PATCH 65/70] std.http.Server: add safety for invalidated Head strings and fix bad unit test API usage that it finds --- lib/std/http/Client.zig | 18 ++++++++-- lib/std/http/Server.zig | 15 ++++++-- lib/std/http/test.zig | 79 +++++++++++++++++++++++------------------ 3 files changed, 72 insertions(+), 40 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index e6e1d4434f..a6b940018e 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -444,8 +444,8 @@ pub const Connection = struct { pub const Response = struct { request: *Request, - /// Pointers in this struct are invalidated with the next call to - /// `receiveHead`. + /// Pointers in this struct are invalidated when the response body stream + /// is initialized. head: Head, pub const Head = struct { @@ -671,6 +671,16 @@ pub const Response = struct { try expectEqual(@as(u10, 418), parseInt3("418")); try expectEqual(@as(u10, 999), parseInt3("999")); } + + /// Help the programmer avoid bugs by calling this when the string + /// memory of `Head` becomes invalidated. + fn invalidateStrings(h: *Head) void { + h.bytes = undefined; + h.reason = undefined; + if (h.location) |*s| s.* = undefined; + if (h.content_type) |*s| s.* = undefined; + if (h.content_disposition) |*s| s.* = undefined; + } }; /// If compressed body has been negotiated this will return compressed bytes. @@ -682,7 +692,8 @@ pub const Response = struct { /// /// See also: /// * `readerDecompressing` - pub fn reader(response: *const Response, buffer: []u8) *Reader { + pub fn reader(response: *Response, buffer: []u8) *Reader { + response.head.invalidateStrings(); const req = response.request; if (!req.method.responseHasBody()) return .ending; const head = &response.head; @@ -703,6 +714,7 @@ pub const Response = struct { decompressor: *http.Decompressor, decompression_buffer: []u8, ) *Reader { + response.head.invalidateStrings(); const head = &response.head; return response.request.reader.bodyReaderDecompressing( head.transfer_encoding, diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 188f2a45e4..be59cd05d0 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -55,8 +55,8 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { pub const Request = struct { server: *Server, - /// Pointers in this struct are invalidated with the next call to - /// `receiveHead`. + /// Pointers in this struct are invalidated when the request body stream is + /// initialized. head: Head, head_buffer: []const u8, respond_err: ?RespondError = null, @@ -224,6 +224,14 @@ pub const Request = struct { inline fn int64(array: *const [8]u8) u64 { return @bitCast(array.*); } + + /// Help the programmer avoid bugs by calling this when the string + /// memory of `Head` becomes invalidated. + fn invalidateStrings(h: *Head) void { + h.target = undefined; + if (h.expect) |*s| s.* = undefined; + if (h.content_type) |*s| s.* = undefined; + } }; pub fn iterateHeaders(r: *const Request) http.HeaderIterator { @@ -578,9 +586,12 @@ pub const Request = struct { /// this function. /// /// Asserts that this function is only called once. + /// + /// Invalidates the string memory inside `Head`. pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader { assert(request.server.reader.state == .received_head); assert(request.head.expect == null); + request.head.invalidateStrings(); if (!request.head.method.requestHasBody()) return .ending; return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length); } diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 2b8f64c606..69355667ed 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -65,23 +65,22 @@ test "trailers" { try req.sendBodiless(); var response = try req.receiveHead(&.{}); + { + var it = response.head.iterateHeaders(); + const header = it.next().?; + try expectEqualStrings("transfer-encoding", header.name); + try expectEqualStrings("chunked", header.value); + try expectEqual(null, it.next()); + } + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - { - var it = response.head.iterateHeaders(); - const header = it.next().?; - try expect(!it.is_trailer); - try expectEqualStrings("transfer-encoding", header.name); - try expectEqualStrings("chunked", header.value); - try expectEqual(null, it.next()); - } { var it = response.iterateTrailers(); const header = it.next().?; - try expect(it.is_trailer); try expectEqualStrings("X-Checksum", header.name); try expectEqualStrings("aaaa", header.value); try expectEqual(null, it.next()); @@ -208,12 +207,14 @@ test "echo content server" { // request.head.target, //}); + try expect(mem.startsWith(u8, request.head.target, "/echo-content")); + try expectEqualStrings("text/plain", request.head.content_type.?); + + // head strings expire here const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); - try expect(mem.startsWith(u8, request.head.target, "/echo-content")); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", request.head.content_type.?); var response = try request.respondStreaming(&.{}, .{ .content_length = switch (request.head.transfer_encoding) { @@ -410,17 +411,19 @@ test "general client/server API coverage" { fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { const log = std.log.scoped(.server); + const gpa = std.testing.allocator; log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); + const target = try gpa.dupe(u8, request.head.target); + defer gpa.free(target); - const gpa = std.testing.allocator; const reader = (try request.readerExpectContinue(&.{})); const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); - if (mem.startsWith(u8, request.head.target, "/get")) { + if (mem.startsWith(u8, target, "/get")) { var response = try request.respondStreaming(&.{}, .{ - .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) + .content_length = if (mem.indexOf(u8, target, "?chunked") == null) 14 else null, @@ -435,7 +438,7 @@ test "general client/server API coverage" { try w.writeAll("World!\n"); try response.end(); // Writing again would cause an assertion failure. - } else if (mem.startsWith(u8, request.head.target, "/large")) { + } else if (mem.startsWith(u8, target, "/large")) { var response = try request.respondStreaming(&.{}, .{ .content_length = 14 * 1024 + 14 * 10, }); @@ -458,7 +461,7 @@ test "general client/server API coverage" { } try response.end(); - } else if (mem.eql(u8, request.head.target, "/redirect/1")) { + } else if (mem.eql(u8, target, "/redirect/1")) { var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .status = .found, @@ -472,14 +475,14 @@ test "general client/server API coverage" { try w.writeAll("Hello, "); try w.writeAll("Redirected!\n"); try response.end(); - } else if (mem.eql(u8, request.head.target, "/redirect/2")) { + } else if (mem.eql(u8, target, "/redirect/2")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/redirect/1" }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/3")) { + } else if (mem.eql(u8, target, "/redirect/3")) { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{ listen_port, }); @@ -491,23 +494,23 @@ test "general client/server API coverage" { .{ .name = "location", .value = location }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/4")) { + } else if (mem.eql(u8, target, "/redirect/4")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/redirect/3" }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/5")) { + } else if (mem.eql(u8, target, "/redirect/5")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/%2525" }, }, }); - } else if (mem.eql(u8, request.head.target, "/%2525")) { + } else if (mem.eql(u8, target, "/%2525")) { try request.respond("Encoded redirect successful!\n", .{}); - } else if (mem.eql(u8, request.head.target, "/redirect/invalid")) { + } else if (mem.eql(u8, target, "/redirect/invalid")) { const invalid_port = try getUnusedTcpPort(); const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port}); defer gpa.free(location); @@ -518,7 +521,7 @@ test "general client/server API coverage" { .{ .name = "location", .value = location }, }, }); - } else if (mem.eql(u8, request.head.target, "/empty")) { + } else if (mem.eql(u8, target, "/empty")) { try request.respond("", .{ .extra_headers = &.{ .{ .name = "empty", .value = "" }, @@ -559,11 +562,12 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); + try expectEqualStrings("text/plain", response.head.content_type.?); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -604,12 +608,13 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expectEqual(14, response.head.content_length.?); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", response.head.content_type.?); - try expectEqual(14, response.head.content_length.?); } // connection has been kept alive @@ -628,11 +633,12 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); + try expectEqualStrings("text/plain", response.head.content_type.?); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been kept alive @@ -651,12 +657,13 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expect(response.head.transfer_encoding == .chunked); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", response.head.content_type.?); - try expect(response.head.transfer_encoding == .chunked); } // connection has been kept alive @@ -677,11 +684,12 @@ test "general client/server API coverage" { try req.sendBodiless(); var response = try req.receiveHead(&redirect_buffer); + try expectEqualStrings("text/plain", response.head.content_type.?); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", response.head.content_type.?); } // connection has been closed @@ -706,11 +714,6 @@ test "general client/server API coverage" { try std.testing.expectEqual(.ok, response.head.status); - const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); - defer gpa.free(body); - - try expectEqualStrings("", body); - var it = response.head.iterateHeaders(); { const header = it.next().?; @@ -718,6 +721,12 @@ test "general client/server API coverage" { try expectEqualStrings("content-length", header.name); try expectEqualStrings("0", header.value); } + + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); + defer gpa.free(body); + + try expectEqualStrings("", body); + { const header = it.next().?; try expect(!it.is_trailer); From 59ba1973089c01554a5cefee1339fea773a3200a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 19:01:22 -0700 Subject: [PATCH 66/70] std.http: address review comments thank you everybody --- lib/std/Uri.zig | 3 ++- lib/std/http.zig | 4 ++-- lib/std/http/Client.zig | 8 ++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 5390bee5b5..7244c9595b 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -377,7 +377,8 @@ pub fn parse(text: []const u8) ParseError!Uri { pub const ResolveInPlaceError = ParseError || error{NoSpaceLeft}; -/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. +/// Resolves a URI against a base URI, conforming to +/// [RFC 3986, Section 5](https://www.rfc-editor.org/rfc/rfc3986#section-5) /// /// Assumes new location is already copied to the beginning of `aux_buf.*`. /// Parses that new location as a URI, and then resolves the path in place. diff --git a/lib/std/http.zig b/lib/std/http.zig index f8cbb84dc5..bb6dd54f47 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -496,7 +496,7 @@ pub const Reader = struct { return reader.in; }, .deflate => { - decompressor.* = .{ .flate = .init(reader.in, .raw, decompression_buffer) }; + decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) }; return &decompressor.flate.reader; }, .gzip => { @@ -730,7 +730,7 @@ pub const Decompressor = union(enum) { return transfer_reader; }, .deflate => { - decompressor.* = .{ .flate = .init(transfer_reader, .raw, buffer) }; + decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) }; return &decompressor.flate.reader; }, .gzip => { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index a6b940018e..12af5cf2a0 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -115,8 +115,6 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. /// If the connection is marked as closing, it will be closed instead. /// - /// `allocator` must be the same one used to create `connection`. - /// /// Threadsafe. pub fn release(pool: *ConnectionPool, connection: *Connection) void { pool.mutex.lock(); @@ -484,10 +482,8 @@ pub const Response = struct { }; var it = mem.splitSequence(u8, bytes, "\r\n"); - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } + const first_line = it.first(); + if (first_line.len < 12) return error.HttpHeadersInvalid; const version: http.Version = switch (int64(first_line[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", From 1596aed1fd1306b99a1da417d10e79cf53aa05ab Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 19:39:39 -0700 Subject: [PATCH 67/70] fetch: avoid copying Resource --- src/Package/Fetch.zig | 46 +++++++++++++++++++-------------------- src/Package/Fetch/git.zig | 6 ++--- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 2a11e1ea6e..6fd5a5989c 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -400,7 +400,8 @@ pub fn run(f: *Fetch) RunError!void { .{ path_or_url, file_err, uri_err }, )); }; - var resource = try f.initResource(uri, &server_header_buffer); + var resource: Resource = undefined; + try f.initResource(uri, &resource, &server_header_buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, null); } }, @@ -466,7 +467,8 @@ pub fn run(f: *Fetch) RunError!void { try eb.printString("invalid URI: {s}", .{@errorName(err)}), ); var buffer: [init_resource_buffer_size]u8 = undefined; - var resource = try f.initResource(uri, &buffer); + var resource: Resource = undefined; + try f.initResource(uri, &resource, &buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, remote.hash); } @@ -880,7 +882,7 @@ const Resource = union(enum) { const HttpRequest = struct { request: std.http.Client.Request, - head: std.http.Client.Response.Head, + response: std.http.Client.Response, buffer: []u8, }; @@ -900,13 +902,7 @@ const Resource = union(enum) { fn reader(resource: *Resource) *std.Io.Reader { return switch (resource.*) { .file => |*file_reader| return &file_reader.interface, - .http_request => |*http_request| { - const response: std.http.Client.Response = .{ - .request = &http_request.request, - .head = http_request.head, - }; - return response.reader(http_request.buffer); - }, + .http_request => |*http_request| return http_request.response.reader(http_request.buffer), .git => |*g| return &g.fetch_stream.reader, .dir => unreachable, }; @@ -974,7 +970,7 @@ const FileType = enum { const init_resource_buffer_size = git.Packet.max_data_length; -fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource { +fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u8) RunError!void { const gpa = f.arena.child_allocator; const arena = f.arena.allocator(); const eb = &f.error_bundle; @@ -986,7 +982,8 @@ fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource f.parent_package_root, path, err, })); }; - return .{ .file = file.reader(reader_buffer) }; + resource.* = .{ .file = file.reader(reader_buffer) }; + return; } const http_client = f.job_queue.http_client; @@ -994,15 +991,21 @@ fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource if (ascii.eqlIgnoreCase(uri.scheme, "http") or ascii.eqlIgnoreCase(uri.scheme, "https")) { - var request = http_client.request(.GET, uri, .{}) catch |err| - return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})); + resource.* = .{ .http_request = .{ + .request = http_client.request(.GET, uri, .{}) catch |err| + return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})), + .response = undefined, + .buffer = reader_buffer, + } }; + const request = &resource.http_request.request; defer request.deinit(); request.sendBodiless() catch |err| return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err})); var redirect_buffer: [1024]u8 = undefined; - const response = request.receiveHead(&redirect_buffer) catch |err| + const response = &resource.http_request.response; + response.* = request.receiveHead(&redirect_buffer) catch |err| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err})); if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString( @@ -1010,11 +1013,7 @@ fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource .{ response.head.status, response.head.status.phrase() orelse "" }, )); - return .{ .http_request = .{ - .request = request, - .head = response.head, - .buffer = reader_buffer, - } }; + return; } if (ascii.eqlIgnoreCase(uri.scheme, "git+http") or @@ -1087,13 +1086,12 @@ fn initResource(f: *Fetch, uri: std.Uri, reader_buffer: []u8) RunError!Resource }; errdefer fetch_stream.deinit(); - if (true) @panic("TODO this moves fetch_stream, invalidating its reader"); - - return .{ .git = .{ + resource.* = .{ .git = .{ .session = session, .fetch_stream = fetch_stream, .want_oid = want_oid, } }; + return; } return f.fail(f.location_tok, try eb.printString("unsupported URL scheme: {s}", .{uri.scheme})); @@ -1111,7 +1109,7 @@ fn unpackResource( return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})), .http_request => |*http_request| ft: { - const head = &http_request.head; + const head = &http_request.response.head; // Content-Type takes first precedence. const content_type = head.content_type orelse diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 88652343f5..390b977c3a 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -773,7 +773,7 @@ pub const Session = struct { try request.sendBodiless(); var redirect_buffer: [1024]u8 = undefined; - const response = try request.receiveHead(&redirect_buffer); + var response = try request.receiveHead(&redirect_buffer); if (response.head.status != .ok) return error.ProtocolError; const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects; if (any_redirects_occurred) { @@ -918,7 +918,7 @@ pub const Session = struct { errdefer request.deinit(); try request.sendBodyComplete(body.buffered()); - const response = try request.receiveHead(options.buffer); + var response = try request.receiveHead(options.buffer); if (response.head.status != .ok) return error.ProtocolError; it.reader = response.reader(options.buffer); } @@ -1037,7 +1037,7 @@ pub const Session = struct { try request.sendBodyComplete(body.buffered()); - const response = try request.receiveHead(&.{}); + var response = try request.receiveHead(&.{}); if (response.head.status != .ok) return error.ProtocolError; const reader = response.reader(response_buffer); From 2c82d1c03aaa2c2abce3d620ae831ac31cf58e2c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 19:56:10 -0700 Subject: [PATCH 68/70] std.http: remove custom method support let's see if anybody notices it missing --- lib/std/http.zig | 62 +++++++++++++---------------------------- lib/std/http/Client.zig | 2 +- lib/std/http/Server.zig | 5 ++-- lib/std/http/test.zig | 2 +- 4 files changed, 24 insertions(+), 47 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index bb6dd54f47..6822af88c9 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -20,51 +20,32 @@ pub const Version = enum { /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition /// /// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum(u64) { - GET = parse("GET"), - HEAD = parse("HEAD"), - POST = parse("POST"), - PUT = parse("PUT"), - DELETE = parse("DELETE"), - CONNECT = parse("CONNECT"), - OPTIONS = parse("OPTIONS"), - TRACE = parse("TRACE"), - PATCH = parse("PATCH"), - - _, - - /// Converts `s` into a type that may be used as a `Method` field. - /// Asserts that `s` is 24 or fewer bytes. - pub fn parse(s: []const u8) u64 { - var x: u64 = 0; - const len = @min(s.len, @sizeOf(@TypeOf(x))); - @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); - return x; - } - - pub fn format(self: Method, w: *Writer) Writer.Error!void { - const bytes: []const u8 = @ptrCast(&@intFromEnum(self)); - const str = std.mem.sliceTo(bytes, 0); - try w.writeAll(str); - } +pub const Method = enum { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, /// Returns true if a request of this method is allowed to have a body /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { + pub fn requestHasBody(m: Method) bool { + return switch (m) { .POST, .PUT, .PATCH => true, .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - else => true, }; } /// Returns true if a response to this method is allowed to have a body /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { + pub fn responseHasBody(m: Method) bool { + return switch (m) { .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, .HEAD, .PUT, .TRACE => false, - else => true, }; } @@ -73,11 +54,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { + pub fn safe(m: Method) bool { + return switch (m) { .GET, .HEAD, .OPTIONS, .TRACE => true, .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - else => false, }; } @@ -88,11 +68,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { + pub fn idempotent(m: Method) bool { + return switch (m) { .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, .CONNECT, .POST, .PATCH => false, - else => false, }; } @@ -102,11 +81,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { + pub fn cacheable(m: Method) bool { + return switch (m) { .GET, .HEAD => true, .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - else => false, }; } }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 12af5cf2a0..95397b6f07 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -928,7 +928,7 @@ pub const Request = struct { const connection = r.connection.?; const w = connection.writer(); - try r.method.format(w); + try w.writeAll(@tagName(r.method)); try w.writeByte(' '); if (r.method == .CONNECT) { diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index be59cd05d0..9574a4dc6a 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -97,10 +97,9 @@ pub const Request = struct { const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + const method = std.meta.stringToEnum(http.Method, first_line[0..method_end]) orelse + return error.UnknownHttpMethod; const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 69355667ed..556afc092f 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -413,7 +413,7 @@ test "general client/server API coverage" { const log = std.log.scoped(.server); const gpa = std.testing.allocator; - log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); + log.info("{t} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); const target = try gpa.dupe(u8, request.head.target); defer gpa.free(target); From ac18b98aa398b9d310bee0420c6130e51b973332 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 20:26:53 -0700 Subject: [PATCH 69/70] std.fs.File.Reader: fix readVec fill respect the case when there is existing buffer --- lib/std/fs/File.zig | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 2791642ac7..7ad71ad274 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1351,8 +1351,7 @@ pub const Reader = struct { } r.pos += n; if (n > data_size) { - io_reader.seek = 0; - io_reader.end = n - data_size; + io_reader.end += n - data_size; return data_size; } return n; @@ -1386,8 +1385,7 @@ pub const Reader = struct { } r.pos += n; if (n > data_size) { - io_reader.seek = 0; - io_reader.end = n - data_size; + io_reader.end += n - data_size; return data_size; } return n; From aac26f3b31ddb43e863ac7186fd19a7e251c1b8a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 6 Aug 2025 22:39:26 -0700 Subject: [PATCH 70/70] TLS, HTTP, and package fetching fixes * TLS: add missing assert for output buffer length requirement * TLS: add missing flushes * TLS: add flush implementation * TLS: finish drain implementation * HTTP: correct buffer sizes for TLS * HTTP: expose a getReadError method on Connection * HTTP: add missing flush on sendBodyComplete * Fetch: remove unwanted deinit * Fetch: improve error reporting --- lib/std/crypto/tls/Client.zig | 62 ++++++++++++++++++++++++++--------- lib/std/http/Client.zig | 38 ++++++++++++++++----- src/Package/Fetch.zig | 12 +++++-- 3 files changed, 84 insertions(+), 28 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3fe51e7b3b..5e89c071c6 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -8,8 +8,8 @@ const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; -const Reader = std.io.Reader; -const Writer = std.io.Writer; +const Reader = std.Io.Reader; +const Writer = std.Io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -27,6 +27,8 @@ reader: Reader, /// The encrypted stream from the client to the server. Bytes are pushed here /// via `writer`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. output: *Writer, /// The plaintext stream from the client to the server. writer: Writer, @@ -122,7 +124,6 @@ pub const Options = struct { /// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool = false, write_buffer: []u8, - /// Asserted to have capacity at least `min_buffer_len`. read_buffer: []u8, /// Populated when `error.TlsAlert` is returned from `init`. alert: ?*tls.Alert = null, @@ -185,6 +186,7 @@ const InitError = error{ /// `input` is asserted to have buffer capacity at least `min_buffer_len`. pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { assert(input.buffer.len >= min_buffer_len); + assert(output.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -278,6 +280,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { var iovecs: [2][]const u8 = .{ cleartext_header, host }; try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); + try output.flush(); } var tls_version: tls.ProtocolVersion = undefined; @@ -763,6 +766,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client &client_verify_msg, }; try output.writeVecAll(&all_msgs_vec); + try output.flush(); }, } write_seq += 1; @@ -828,6 +832,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client &finished_msg, }; try output.writeVecAll(&all_msgs_vec); + try output.flush(); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); @@ -877,7 +882,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client .buffer = options.write_buffer, .vtable = &.{ .drain = drain, - .sendFile = Writer.unimplementedSendFile, + .flush = flush, }, }, .tls_version = tls_version, @@ -911,31 +916,56 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { const c: *Client = @alignCast(@fieldParentPtr("writer", w)); - if (true) @panic("update to use the buffer and flush"); - const sliced_data = if (splat == 0) data[0..data.len -| 1] else data; const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); - var total_clear: usize = 0; var ciphertext_end: usize = 0; - for (sliced_data) |buf| { - const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); - total_clear += prepared.cleartext_len; - ciphertext_end += prepared.ciphertext_end; - if (total_clear < buf.len) break; + var total_clear: usize = 0; + done: { + { + const buf = w.buffered(); + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + for (data[0 .. data.len - 1]) |buf| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + const buf = data[data.len - 1]; + for (0..splat) |_| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } } output.advance(ciphertext_end); - return total_clear; + return w.consume(total_clear); +} + +fn flush(w: *Writer) Writer.Error!void { + const c: *Client = @alignCast(@fieldParentPtr("writer", w)); + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data); + output.advance(prepared.ciphertext_end); + w.end = 0; } /// Sends a `close_notify` alert, which is necessary for the server to /// distinguish between a properly finished TLS session, or a truncation /// attack. pub fn end(c: *Client) Writer.Error!void { + try flush(&c.writer); const output = c.output; const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); - output.advance(prepared.cleartext_len); - return prepared.ciphertext_end; + output.advance(prepared.ciphertext_end); } fn prepareCiphertextRecord( @@ -1045,7 +1075,7 @@ pub fn eof(c: Client) bool { return c.received_close_notify; } -fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize { +fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { const c: *Client = @alignCast(@fieldParentPtr("reader", r)); if (c.eof()) return error.EndOfStream; const input = c.input; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 95397b6f07..37022b4d0b 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{}, /// /// If the entire HTTP header cannot fit in this amount of bytes, /// `error.HttpHeadersOversize` will be returned from `Request.wait`. -read_buffer_size: usize = 4096, +read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, /// Each `Connection` allocates this amount for the writer buffer. write_buffer_size: usize = 1024, @@ -304,15 +304,16 @@ pub const Connection = struct { const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; - const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; - assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size]; + assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len); @memcpy(host_buffer, remote_host); const tls: *Tls = @ptrCast(base); tls.* = .{ .connection = .{ .client = client, - .stream_writer = stream.writer(socket_write_buffer), - .stream_reader = stream.reader(&.{}), + .stream_writer = stream.writer(tls_write_buffer), + .stream_reader = stream.reader(tls_read_buffer), .pool_node = .{}, .port = port, .host_len = @intCast(remote_host.len), @@ -328,8 +329,8 @@ pub const Connection = struct { .host = .{ .explicit = remote_host }, .ca = .{ .bundle = client.ca_bundle }, .ssl_key_log = client.ssl_key_log, - .read_buffer = tls_read_buffer, - .write_buffer = tls_write_buffer, + .read_buffer = read_buffer, + .write_buffer = write_buffer, // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, @@ -347,7 +348,8 @@ pub const Connection = struct { } fn allocLen(client: *Client, host_len: usize) usize { - return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size; + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + + client.write_buffer_size + client.read_buffer_size; } fn host(tls: *Tls) []u8 { @@ -356,6 +358,21 @@ pub const Connection = struct { } }; + pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError; + + pub fn getReadError(c: *const Connection) ?ReadError { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c)); + return tls.client.read_err orelse c.stream_reader.getError(); + }, + .plain => { + return c.stream_reader.getError(); + }, + }; + } + fn getStream(c: *Connection) net.Stream { return c.stream_reader.getStream(); } @@ -434,7 +451,6 @@ pub const Connection = struct { if (disable_tls) unreachable; const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); try tls.client.end(); - try tls.client.writer.flush(); } try c.stream_writer.interface.flush(); } @@ -874,6 +890,7 @@ pub const Request = struct { var bw = try sendBodyUnflushed(r, body); bw.writer.end = body.len; try bw.end(); + try r.connection.?.flush(); } /// Transfers the HTTP head over the connection, which is not flushed until @@ -1063,6 +1080,9 @@ pub const Request = struct { /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize` /// is returned instead. This buffer may be empty if no redirects are to be /// handled. + /// + /// If this fails with `error.ReadFailed` then the `Connection.getReadError` + /// method of `r.connection` can be used to get more detailed information. pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { var aux_buf = redirect_buffer; while (true) { diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 6fd5a5989c..fd8c26f1e6 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -998,15 +998,21 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u .buffer = reader_buffer, } }; const request = &resource.http_request.request; - defer request.deinit(); + errdefer request.deinit(); request.sendBodiless() catch |err| return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err})); var redirect_buffer: [1024]u8 = undefined; const response = &resource.http_request.response; - response.* = request.receiveHead(&redirect_buffer) catch |err| - return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err})); + response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) { + error.ReadFailed => { + return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{ + request.connection.?.getReadError().?, + })); + }, + else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})), + }; if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString( "bad HTTP response code: '{d} {s}'",