From ebcfc86bb9c8cec1a66511858a6443b1927191f2 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 11 Dec 2022 13:58:11 -0700 Subject: [PATCH 01/59] Compilation: better error message for file not found --- src/Compilation.zig | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Compilation.zig b/src/Compilation.zig index a18b05a939..4c7489c0c8 100644 --- a/src/Compilation.zig +++ b/src/Compilation.zig @@ -584,7 +584,17 @@ pub const AllErrors = struct { Message.HashContext, std.hash_map.default_max_load_percentage, ).init(allocator); - const err_source = try module_err_msg.src_loc.file_scope.getSource(module.gpa); + const err_source = module_err_msg.src_loc.file_scope.getSource(module.gpa) catch |err| { + const file_path = try module_err_msg.src_loc.file_scope.fullPath(allocator); + try errors.append(.{ + .plain = .{ + .msg = try std.fmt.allocPrint(allocator, "unable to load '{s}': {s}", .{ + file_path, @errorName(err), + }), + }, + }); + return; + }; const err_span = try module_err_msg.src_loc.span(module.gpa); const err_loc = std.zig.findLineColumn(err_source.bytes, err_span.main); From cd0d514643404103a83881fc4d7c46674ed9f991 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 11 Dec 2022 15:35:15 -0700 Subject: [PATCH 02/59] remove the experimental std.x namespace Playtime is over. I'm working on networking now. --- lib/std/c.zig | 4 +- lib/std/c/darwin.zig | 11 +- lib/std/c/dragonfly.zig | 14 +- lib/std/c/freebsd.zig | 12 +- lib/std/c/haiku.zig | 12 +- lib/std/c/netbsd.zig | 12 +- lib/std/c/openbsd.zig | 12 +- lib/std/c/solaris.zig | 11 +- lib/std/os.zig | 4 +- lib/std/os/linux.zig | 53 +- lib/std/os/linux/seccomp.zig | 18 +- lib/std/os/windows/ws2_32.zig | 19 +- lib/std/std.zig | 1 - lib/std/x.zig | 19 - lib/std/x/net/bpf.zig | 1003 ------------------------------- lib/std/x/net/ip.zig | 57 -- lib/std/x/net/tcp.zig | 447 -------------- lib/std/x/os/io.zig | 224 ------- lib/std/x/os/net.zig | 605 ------------------- lib/std/x/os/socket.zig | 320 ---------- lib/std/x/os/socket_posix.zig | 275 --------- lib/std/x/os/socket_windows.zig | 458 -------------- 22 files changed, 143 insertions(+), 3448 deletions(-) delete mode 100644 lib/std/x.zig delete mode 100644 lib/std/x/net/bpf.zig delete mode 100644 lib/std/x/net/ip.zig delete mode 100644 lib/std/x/net/tcp.zig delete mode 100644 lib/std/x/os/io.zig delete mode 100644 lib/std/x/os/net.zig delete mode 100644 lib/std/x/os/socket.zig delete mode 100644 lib/std/x/os/socket_posix.zig delete mode 100644 lib/std/x/os/socket_windows.zig diff --git a/lib/std/c.zig b/lib/std/c.zig index 5f03f1c619..212b8e2d4d 100644 --- a/lib/std/c.zig +++ b/lib/std/c.zig @@ -206,7 +206,7 @@ pub extern "c" fn sendto( dest_addr: ?*const c.sockaddr, addrlen: c.socklen_t, ) isize; -pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const std.x.os.Socket.Message, flags: c_int) isize; +pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const c.msghdr_const, flags: u32) isize; pub extern "c" fn recv(sockfd: c.fd_t, arg1: ?*anyopaque, arg2: usize, arg3: c_int) isize; pub extern "c" fn recvfrom( @@ -217,7 +217,7 @@ pub extern "c" fn recvfrom( noalias src_addr: ?*c.sockaddr, noalias addrlen: ?*c.socklen_t, ) isize; -pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *std.x.os.Socket.Message, flags: c_int) isize; +pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *c.msghdr, flags: u32) isize; pub extern "c" fn kill(pid: c.pid_t, sig: c_int) c_int; pub extern "c" fn getdirentries(fd: c.fd_t, buf_ptr: [*]u8, nbytes: usize, basep: *i64) isize; diff --git a/lib/std/c/darwin.zig b/lib/std/c/darwin.zig index b68f04379f..9c5ac1e93a 100644 --- a/lib/std/c/darwin.zig +++ b/lib/std/c/darwin.zig @@ -1007,7 +1007,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), family: sa_family_t = AF.INET, diff --git a/lib/std/c/dragonfly.zig b/lib/std/c/dragonfly.zig index 2410310fc7..f5471c3145 100644 --- a/lib/std/c/dragonfly.zig +++ b/lib/std/c/dragonfly.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("../std.zig"); +const assert = std.debug.assert; const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -476,11 +477,20 @@ pub const CLOCK = struct { pub const sockaddr = extern struct { len: u8, - family: u8, + family: sa_family_t, data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/freebsd.zig b/lib/std/c/freebsd.zig index c4bd4a44a7..f3858317df 100644 --- a/lib/std/c/freebsd.zig +++ b/lib/std/c/freebsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -401,7 +402,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/haiku.zig b/lib/std/c/haiku.zig index 86b9f25902..9c4f8460de 100644 --- a/lib/std/c/haiku.zig +++ b/lib/std/c/haiku.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -339,7 +340,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/netbsd.zig b/lib/std/c/netbsd.zig index d9bf925c17..805bb1bd3e 100644 --- a/lib/std/c/netbsd.zig +++ b/lib/std/c/netbsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -481,7 +482,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [126]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/openbsd.zig b/lib/std/c/openbsd.zig index 83aed68483..74f824cae3 100644 --- a/lib/std/c/openbsd.zig +++ b/lib/std/c/openbsd.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const maxInt = std.math.maxInt; const builtin = @import("builtin"); const iovec = std.os.iovec; @@ -372,7 +373,16 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 256; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + len: u8 align(8), + family: sa_family_t, + padding: [254]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { len: u8 = @sizeOf(in), diff --git a/lib/std/c/solaris.zig b/lib/std/c/solaris.zig index cbeeb5fb42..fe60c426e5 100644 --- a/lib/std/c/solaris.zig +++ b/lib/std/c/solaris.zig @@ -1,4 +1,5 @@ const std = @import("../std.zig"); +const assert = std.debug.assert; const builtin = @import("builtin"); const maxInt = std.math.maxInt; const iovec = std.os.iovec; @@ -435,7 +436,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 256; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: sa_family_t align(8), + padding: [254]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; pub const in = extern struct { family: sa_family_t = AF.INET, diff --git a/lib/std/os.zig b/lib/std/os.zig index b0884cef05..a47e3d0068 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -5616,11 +5616,11 @@ pub fn sendmsg( /// The file descriptor of the sending socket. sockfd: socket_t, /// Message header and iovecs - msg: msghdr_const, + msg: *const msghdr_const, flags: u32, ) SendMsgError!usize { while (true) { - const rc = system.sendmsg(sockfd, @ptrCast(*const std.x.os.Socket.Message, &msg), @intCast(c_int, flags)); + const rc = system.sendmsg(sockfd, msg, flags); if (builtin.os.tag == .windows) { if (rc == windows.ws2_32.SOCKET_ERROR) { switch (windows.ws2_32.WSAGetLastError()) { diff --git a/lib/std/os/linux.zig b/lib/std/os/linux.zig index ecb8a21d7a..d9d5fb3204 100644 --- a/lib/std/os/linux.zig +++ b/lib/std/os/linux.zig @@ -1226,11 +1226,14 @@ pub fn getsockopt(fd: i32, level: u32, optname: u32, noalias optval: [*]u8, noal return syscall5(.getsockopt, @bitCast(usize, @as(isize, fd)), level, optname, @ptrToInt(optval), @ptrToInt(optlen)); } -pub fn sendmsg(fd: i32, msg: *const std.x.os.Socket.Message, flags: c_int) usize { +pub fn sendmsg(fd: i32, msg: *const msghdr_const, flags: u32) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const msg_usize = @ptrToInt(msg); if (native_arch == .x86) { - return socketcall(SC.sendmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); + return socketcall(SC.sendmsg, &[3]usize{ fd_usize, msg_usize, flags }); + } else { + return syscall3(.sendmsg, fd_usize, msg_usize, flags); } - return syscall3(.sendmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags))); } pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize { @@ -1274,24 +1277,42 @@ pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize } pub fn connect(fd: i32, addr: *const anyopaque, len: socklen_t) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const addr_usize = @ptrToInt(addr); if (native_arch == .x86) { - return socketcall(SC.connect, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len }); + return socketcall(SC.connect, &[3]usize{ fd_usize, addr_usize, len }); + } else { + return syscall3(.connect, fd_usize, addr_usize, len); } - return syscall3(.connect, @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len); } -pub fn recvmsg(fd: i32, msg: *std.x.os.Socket.Message, flags: c_int) usize { +pub fn recvmsg(fd: i32, msg: *msghdr, flags: u32) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const msg_usize = @ptrToInt(msg); if (native_arch == .x86) { - return socketcall(SC.recvmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) }); + return socketcall(SC.recvmsg, &[3]usize{ fd_usize, msg_usize, flags }); + } else { + return syscall3(.recvmsg, fd_usize, msg_usize, flags); } - return syscall3(.recvmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags))); } -pub fn recvfrom(fd: i32, noalias buf: [*]u8, len: usize, flags: u32, noalias addr: ?*sockaddr, noalias alen: ?*socklen_t) usize { +pub fn recvfrom( + fd: i32, + noalias buf: [*]u8, + len: usize, + flags: u32, + noalias addr: ?*sockaddr, + noalias alen: ?*socklen_t, +) usize { + const fd_usize = @bitCast(usize, @as(isize, fd)); + const buf_usize = @ptrToInt(buf); + const addr_usize = @ptrToInt(addr); + const alen_usize = @ptrToInt(alen); if (native_arch == .x86) { - return socketcall(SC.recvfrom, &[6]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen) }); + return socketcall(SC.recvfrom, &[6]usize{ fd_usize, buf_usize, len, flags, addr_usize, alen_usize }); + } else { + return syscall6(.recvfrom, fd_usize, buf_usize, len, flags, addr_usize, alen_usize); } - return syscall6(.recvfrom, @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen)); } pub fn shutdown(fd: i32, how: i32) usize { @@ -3219,7 +3240,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: sa_family_t align(8), + padding: [SS_MAXSIZE - @sizeOf(sa_family_t)]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; /// IPv4 socket address pub const in = extern struct { diff --git a/lib/std/os/linux/seccomp.zig b/lib/std/os/linux/seccomp.zig index fd002e7416..03a96633f8 100644 --- a/lib/std/os/linux/seccomp.zig +++ b/lib/std/os/linux/seccomp.zig @@ -6,16 +6,14 @@ //! isn't that useful for general-purpose applications, and so a mode that //! utilizes user-supplied filters mode was added. //! -//! Seccomp filters are classic BPF programs, which means that all the -//! information under `std.x.net.bpf` applies here as well. Conceptually, a -//! seccomp program is attached to the kernel and is executed on each syscall. -//! The "packet" being validated is the `data` structure, and the verdict is an -//! action that the kernel performs on the calling process. The actions are -//! variations on a "pass" or "fail" result, where a pass allows the syscall to -//! continue and a fail blocks the syscall and returns some sort of error value. -//! See the full list of actions under ::RET for more information. Finally, only -//! word-sized, absolute loads (`ld [k]`) are supported to read from the `data` -//! structure. +//! Seccomp filters are classic BPF programs. Conceptually, a seccomp program +//! is attached to the kernel and is executed on each syscall. The "packet" +//! being validated is the `data` structure, and the verdict is an action that +//! the kernel performs on the calling process. The actions are variations on a +//! "pass" or "fail" result, where a pass allows the syscall to continue and a +//! fail blocks the syscall and returns some sort of error value. See the full +//! list of actions under ::RET for more information. Finally, only word-sized, +//! absolute loads (`ld [k]`) are supported to read from the `data` structure. //! //! There are some issues with the filter API that have traditionally made //! writing them a pain: diff --git a/lib/std/os/windows/ws2_32.zig b/lib/std/os/windows/ws2_32.zig index 90e1422fd2..b4d18264f3 100644 --- a/lib/std/os/windows/ws2_32.zig +++ b/lib/std/os/windows/ws2_32.zig @@ -1,4 +1,5 @@ const std = @import("../../std.zig"); +const assert = std.debug.assert; const windows = std.os.windows; const WINAPI = windows.WINAPI; @@ -1106,7 +1107,15 @@ pub const sockaddr = extern struct { data: [14]u8, pub const SS_MAXSIZE = 128; - pub const storage = std.x.os.Socket.Address.Native.Storage; + pub const storage = extern struct { + family: ADDRESS_FAMILY align(8), + padding: [SS_MAXSIZE - @sizeOf(ADDRESS_FAMILY)]u8 = undefined, + + comptime { + assert(@sizeOf(storage) == SS_MAXSIZE); + assert(@alignOf(storage) == 8); + } + }; /// IPv4 socket address pub const in = extern struct { @@ -1207,7 +1216,7 @@ pub const LPFN_GETACCEPTEXSOCKADDRS = *const fn ( pub const LPFN_WSASENDMSG = *const fn ( s: SOCKET, - lpMsg: *const std.x.os.Socket.Message, + lpMsg: *const WSAMSG_const, dwFlags: u32, lpNumberOfBytesSent: ?*u32, lpOverlapped: ?*OVERLAPPED, @@ -1216,7 +1225,7 @@ pub const LPFN_WSASENDMSG = *const fn ( pub const LPFN_WSARECVMSG = *const fn ( s: SOCKET, - lpMsg: *std.x.os.Socket.Message, + lpMsg: *WSAMSG, lpdwNumberOfBytesRecv: ?*u32, lpOverlapped: ?*OVERLAPPED, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, @@ -2090,7 +2099,7 @@ pub extern "ws2_32" fn WSASend( pub extern "ws2_32" fn WSASendMsg( s: SOCKET, - lpMsg: *const std.x.os.Socket.Message, + lpMsg: *WSAMSG_const, dwFlags: u32, lpNumberOfBytesSent: ?*u32, lpOverlapped: ?*OVERLAPPED, @@ -2099,7 +2108,7 @@ pub extern "ws2_32" fn WSASendMsg( pub extern "ws2_32" fn WSARecvMsg( s: SOCKET, - lpMsg: *std.x.os.Socket.Message, + lpMsg: *WSAMSG, lpdwNumberOfBytesRecv: ?*u32, lpOverlapped: ?*OVERLAPPED, lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE, diff --git a/lib/std/std.zig b/lib/std/std.zig index 1b4217b506..4bfb44d12f 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -90,7 +90,6 @@ pub const tz = @import("tz.zig"); pub const unicode = @import("unicode.zig"); pub const valgrind = @import("valgrind.zig"); pub const wasm = @import("wasm.zig"); -pub const x = @import("x.zig"); pub const zig = @import("zig.zig"); pub const start = @import("start.zig"); diff --git a/lib/std/x.zig b/lib/std/x.zig deleted file mode 100644 index 64caf324ed..0000000000 --- a/lib/std/x.zig +++ /dev/null @@ -1,19 +0,0 @@ -const std = @import("std.zig"); - -pub const os = struct { - pub const Socket = @import("x/os/socket.zig").Socket; - pub usingnamespace @import("x/os/io.zig"); - pub usingnamespace @import("x/os/net.zig"); -}; - -pub const net = struct { - pub const ip = @import("x/net/ip.zig"); - pub const tcp = @import("x/net/tcp.zig"); - pub const bpf = @import("x/net/bpf.zig"); -}; - -test { - inline for (.{ os, net }) |module| { - std.testing.refAllDecls(module); - } -} diff --git a/lib/std/x/net/bpf.zig b/lib/std/x/net/bpf.zig deleted file mode 100644 index bee930c332..0000000000 --- a/lib/std/x/net/bpf.zig +++ /dev/null @@ -1,1003 +0,0 @@ -//! This package provides instrumentation for creating Berkeley Packet Filter[1] -//! (BPF) programs, along with a simulator for running them. -//! -//! BPF is a mechanism for cheap, in-kernel packet filtering. Programs are -//! attached to a network device and executed for every packet that flows -//! through it. The program must then return a verdict: the amount of packet -//! bytes that the kernel should copy into userspace. Execution speed is -//! achieved by having programs run in a limited virtual machine, which has the -//! added benefit of graceful failure in the face of buggy programs. -//! -//! The BPF virtual machine has a 32-bit word length and a small number of -//! word-sized registers: -//! -//! - The accumulator, `a`: The source/destination of arithmetic and logic -//! operations. -//! - The index register, `x`: Used as an offset for indirect memory access and -//! as a comparison value for conditional jumps. -//! - The scratch memory store, `M[0]..M[15]`: Used for saving the value of a/x -//! for later use. -//! -//! The packet being examined is an array of bytes, and is addressed using plain -//! array subscript notation, e.g. [10] for the byte at offset 10. An implicit -//! program counter, `pc`, is intialized to zero and incremented for each instruction. -//! -//! The machine has a fixed instruction set with the following form, where the -//! numbers represent bit length: -//! -//! ``` -//! ┌───────────┬──────┬──────┐ -//! │ opcode:16 │ jt:8 │ jt:8 │ -//! ├───────────┴──────┴──────┤ -//! │ k:32 │ -//! └─────────────────────────┘ -//! ``` -//! -//! The `opcode` indicates the instruction class and its addressing mode. -//! Opcodes are generated by performing binary addition on the 8-bit class and -//! mode constants. For example, the opcode for loading a byte from the packet -//! at X + 2, (`ldb [x + 2]`), is: -//! -//! ``` -//! LD | IND | B = 0x00 | 0x40 | 0x20 -//! = 0x60 -//! ``` -//! -//! `jt` is an offset used for conditional jumps, and increments the program -//! counter by its amount if the comparison was true. Conversely, `jf` -//! increments the counter if it was false. These fields are ignored in all -//! other cases. `k` is a generic variable used for various purposes, most -//! commonly as some sort of constant. -//! -//! This package contains opcode extensions used by different implementations, -//! where "extension" is anything outside of the original that was imported into -//! 4.4BSD[2]. These are marked with "EXTENSION", along with a list of -//! implementations that use them. -//! -//! Most of the doc-comments use the BPF assembly syntax as described in the -//! original paper[1]. For the sake of completeness, here is the complete -//! instruction set, along with the extensions: -//! -//!``` -//! opcode addressing modes -//! ld #k #len M[k] [k] [x + k] -//! ldh [k] [x + k] -//! ldb [k] [x + k] -//! ldx #k #len M[k] 4 * ([k] & 0xf) arc4random() -//! st M[k] -//! stx M[k] -//! jmp L -//! jeq #k, Lt, Lf -//! jgt #k, Lt, Lf -//! jge #k, Lt, Lf -//! jset #k, Lt, Lf -//! add #k x -//! sub #k x -//! mul #k x -//! div #k x -//! or #k x -//! and #k x -//! lsh #k x -//! rsh #k x -//! neg #k x -//! mod #k x -//! xor #k x -//! ret #k a -//! tax -//! txa -//! ``` -//! -//! Finally, a note on program design. The lack of backwards jumps leads to a -//! "return early, return often" control flow. Take for example the program -//! generated from the tcpdump filter `ip`: -//! -//! ``` -//! (000) ldh [12] ; Ethernet Packet Type -//! (001) jeq #0x86dd, 2, 7 ; ETHERTYPE_IPV6 -//! (002) ldb [20] ; IPv6 Next Header -//! (003) jeq #0x6, 10, 4 ; TCP -//! (004) jeq #0x2c, 5, 11 ; IPv6 Fragment Header -//! (005) ldb [54] ; TCP Source Port -//! (006) jeq #0x6, 10, 11 ; IPPROTO_TCP -//! (007) jeq #0x800, 8, 11 ; ETHERTYPE_IP -//! (008) ldb [23] ; IPv4 Protocol -//! (009) jeq #0x6, 10, 11 ; IPPROTO_TCP -//! (010) ret #262144 ; copy 0x40000 -//! (011) ret #0 ; skip packet -//! ``` -//! -//! Here we can make a few observations: -//! -//! - The problem "filter only tcp packets" has essentially been transformed -//! into a series of layer checks. -//! - There are two distinct branches in the code, one for validating IPv4 -//! headers and one for IPv6 headers. -//! - Most conditional jumps in these branches lead directly to the last two -//! instructions, a pass or fail. Thus the goal of a program is to find the -//! fastest route to a pass/fail comparison. -//! -//! [1]: S. McCanne and V. Jacobson, "The BSD Packet Filter: A New Architecture -//! for User-level Packet Capture", Proceedings of the 1993 Winter USENIX. -//! [2]: https://minnie.tuhs.org/cgi-bin/utree.pl?file=4.4BSD/usr/src/sys/net/bpf.h -const std = @import("std"); -const builtin = @import("builtin"); -const native_endian = builtin.target.cpu.arch.endian(); -const mem = std.mem; -const math = std.math; -const random = std.crypto.random; -const assert = std.debug.assert; -const expectEqual = std.testing.expectEqual; -const expectError = std.testing.expectError; -const expect = std.testing.expect; - -// instruction classes -/// ld, ldh, ldb: Load data into a. -pub const LD = 0x00; -/// ldx: Load data into x. -pub const LDX = 0x01; -/// st: Store into scratch memory the value of a. -pub const ST = 0x02; -/// st: Store into scratch memory the value of x. -pub const STX = 0x03; -/// alu: Wrapping arithmetic/bitwise operations on a using the value of k/x. -pub const ALU = 0x04; -/// jmp, jeq, jgt, je, jset: Increment the program counter based on a comparison -/// between k/x and the accumulator. -pub const JMP = 0x05; -/// ret: Return a verdict using the value of k/the accumulator. -pub const RET = 0x06; -/// tax, txa: Register value copying between X and a. -pub const MISC = 0x07; - -// Size of data to be loaded from the packet. -/// ld: 32-bit full word. -pub const W = 0x00; -/// ldh: 16-bit half word. -pub const H = 0x08; -/// ldb: Single byte. -pub const B = 0x10; - -// Addressing modes used for loads to a/x. -/// #k: The immediate value stored in k. -pub const IMM = 0x00; -/// [k]: The value at offset k in the packet. -pub const ABS = 0x20; -/// [x + k]: The value at offset x + k in the packet. -pub const IND = 0x40; -/// M[k]: The value of the k'th scratch memory register. -pub const MEM = 0x60; -/// #len: The size of the packet. -pub const LEN = 0x80; -/// 4 * ([k] & 0xf): Four times the low four bits of the byte at offset k in the -/// packet. This is used for efficiently loading the header length of an IP -/// packet. -pub const MSH = 0xa0; -/// arc4random: 32-bit integer generated from a CPRNG (see arc4random(3)) loaded into a. -/// EXTENSION. Defined for: -/// - OpenBSD. -pub const RND = 0xc0; - -// Modifiers for different instruction classes. -/// Use the value of k for alu operations (add #k). -/// Compare against the value of k for jumps (jeq #k, Lt, Lf). -/// Return the value of k for returns (ret #k). -pub const K = 0x00; -/// Use the value of x for alu operations (add x). -/// Compare against the value of X for jumps (jeq x, Lt, Lf). -pub const X = 0x08; -/// Return the value of a for returns (ret a). -pub const A = 0x10; - -// ALU Operations on a using the value of k/x. -// All arithmetic operations are defined to overflow the value of a. -/// add: a = a + k -/// a = a + x. -pub const ADD = 0x00; -/// sub: a = a - k -/// a = a - x. -pub const SUB = 0x10; -/// mul: a = a * k -/// a = a * x. -pub const MUL = 0x20; -/// div: a = a / k -/// a = a / x. -/// Truncated division. -pub const DIV = 0x30; -/// or: a = a | k -/// a = a | x. -pub const OR = 0x40; -/// and: a = a & k -/// a = a & x. -pub const AND = 0x50; -/// lsh: a = a << k -/// a = a << x. -/// a = a << k, a = a << x. -pub const LSH = 0x60; -/// rsh: a = a >> k -/// a = a >> x. -pub const RSH = 0x70; -/// neg: a = -a. -/// Note that this isn't a binary negation, rather the value of `~a + 1`. -pub const NEG = 0x80; -/// mod: a = a % k -/// a = a % x. -/// EXTENSION. Defined for: -/// - Linux. -/// - NetBSD + Minix 3. -/// - FreeBSD and derivitives. -pub const MOD = 0x90; -/// xor: a = a ^ k -/// a = a ^ x. -/// EXTENSION. Defined for: -/// - Linux. -/// - NetBSD + Minix 3. -/// - FreeBSD and derivitives. -pub const XOR = 0xa0; - -// Jump operations using a comparison between a and x/k. -/// jmp L: pc += k. -/// No comparison done here. -pub const JA = 0x00; -/// jeq #k, Lt, Lf: pc += (a == k) ? jt : jf. -/// jeq x, Lt, Lf: pc += (a == x) ? jt : jf. -pub const JEQ = 0x10; -/// jgt #k, Lt, Lf: pc += (a > k) ? jt : jf. -/// jgt x, Lt, Lf: pc += (a > x) ? jt : jf. -pub const JGT = 0x20; -/// jge #k, Lt, Lf: pc += (a >= k) ? jt : jf. -/// jge x, Lt, Lf: pc += (a >= x) ? jt : jf. -pub const JGE = 0x30; -/// jset #k, Lt, Lf: pc += (a & k > 0) ? jt : jf. -/// jset x, Lt, Lf: pc += (a & x > 0) ? jt : jf. -pub const JSET = 0x40; - -// Miscellaneous operations/register copy. -/// tax: x = a. -pub const TAX = 0x00; -/// txa: a = x. -pub const TXA = 0x80; - -/// The 16 registers in the scratch memory store as named enums. -pub const Scratch = enum(u4) { m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15 }; -pub const MEMWORDS = 16; -pub const MAXINSNS = switch (builtin.os.tag) { - .linux => 4096, - else => 512, -}; -pub const MINBUFSIZE = 32; -pub const MAXBUFSIZE = 1 << 21; - -pub const Insn = extern struct { - opcode: u16, - jt: u8, - jf: u8, - k: u32, - - /// Implements the `std.fmt.format` API. - /// The formatting is similar to the output of tcpdump -dd. - pub fn format( - self: Insn, - comptime layout: []const u8, - opts: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - - try std.fmt.format( - writer, - "Insn{{ 0x{X:0<2}, {d}, {d}, 0x{X:0<8} }}", - .{ self.opcode, self.jt, self.jf, self.k }, - ); - } - - const Size = enum(u8) { - word = W, - half_word = H, - byte = B, - }; - - fn stmt(opcode: u16, k: u32) Insn { - return .{ - .opcode = opcode, - .jt = 0, - .jf = 0, - .k = k, - }; - } - - pub fn ld_imm(value: u32) Insn { - return stmt(LD | IMM, value); - } - - pub fn ld_abs(size: Size, offset: u32) Insn { - return stmt(LD | ABS | @enumToInt(size), offset); - } - - pub fn ld_ind(size: Size, offset: u32) Insn { - return stmt(LD | IND | @enumToInt(size), offset); - } - - pub fn ld_mem(reg: Scratch) Insn { - return stmt(LD | MEM, @enumToInt(reg)); - } - - pub fn ld_len() Insn { - return stmt(LD | LEN | W, 0); - } - - pub fn ld_rnd() Insn { - return stmt(LD | RND | W, 0); - } - - pub fn ldx_imm(value: u32) Insn { - return stmt(LDX | IMM, value); - } - - pub fn ldx_mem(reg: Scratch) Insn { - return stmt(LDX | MEM, @enumToInt(reg)); - } - - pub fn ldx_len() Insn { - return stmt(LDX | LEN | W, 0); - } - - pub fn ldx_msh(offset: u32) Insn { - return stmt(LDX | MSH | B, offset); - } - - pub fn st(reg: Scratch) Insn { - return stmt(ST, @enumToInt(reg)); - } - pub fn stx(reg: Scratch) Insn { - return stmt(STX, @enumToInt(reg)); - } - - const AluOp = enum(u16) { - add = ADD, - sub = SUB, - mul = MUL, - div = DIV, - @"or" = OR, - @"and" = AND, - lsh = LSH, - rsh = RSH, - mod = MOD, - xor = XOR, - }; - - const Source = enum(u16) { - k = K, - x = X, - }; - const KOrX = union(Source) { - k: u32, - x: void, - }; - - pub fn alu_neg() Insn { - return stmt(ALU | NEG, 0); - } - - pub fn alu(op: AluOp, source: KOrX) Insn { - return stmt( - ALU | @enumToInt(op) | @enumToInt(source), - if (source == .k) source.k else 0, - ); - } - - const JmpOp = enum(u16) { - jeq = JEQ, - jgt = JGT, - jge = JGE, - jset = JSET, - }; - - pub fn jmp_ja(location: u32) Insn { - return stmt(JMP | JA, location); - } - - pub fn jmp(op: JmpOp, source: KOrX, jt: u8, jf: u8) Insn { - return Insn{ - .opcode = JMP | @enumToInt(op) | @enumToInt(source), - .jt = jt, - .jf = jf, - .k = if (source == .k) source.k else 0, - }; - } - - const Verdict = enum(u16) { - k = K, - a = A, - }; - const KOrA = union(Verdict) { - k: u32, - a: void, - }; - - pub fn ret(verdict: KOrA) Insn { - return stmt( - RET | @enumToInt(verdict), - if (verdict == .k) verdict.k else 0, - ); - } - - pub fn tax() Insn { - return stmt(MISC | TAX, 0); - } - - pub fn txa() Insn { - return stmt(MISC | TXA, 0); - } -}; - -fn opcodeEqual(opcode: u16, insn: Insn) !void { - try expectEqual(opcode, insn.opcode); -} - -test "opcodes" { - try opcodeEqual(0x00, Insn.ld_imm(0)); - try opcodeEqual(0x20, Insn.ld_abs(.word, 0)); - try opcodeEqual(0x28, Insn.ld_abs(.half_word, 0)); - try opcodeEqual(0x30, Insn.ld_abs(.byte, 0)); - try opcodeEqual(0x40, Insn.ld_ind(.word, 0)); - try opcodeEqual(0x48, Insn.ld_ind(.half_word, 0)); - try opcodeEqual(0x50, Insn.ld_ind(.byte, 0)); - try opcodeEqual(0x60, Insn.ld_mem(.m0)); - try opcodeEqual(0x80, Insn.ld_len()); - try opcodeEqual(0xc0, Insn.ld_rnd()); - - try opcodeEqual(0x01, Insn.ldx_imm(0)); - try opcodeEqual(0x61, Insn.ldx_mem(.m0)); - try opcodeEqual(0x81, Insn.ldx_len()); - try opcodeEqual(0xb1, Insn.ldx_msh(0)); - - try opcodeEqual(0x02, Insn.st(.m0)); - try opcodeEqual(0x03, Insn.stx(.m0)); - - try opcodeEqual(0x04, Insn.alu(.add, .{ .k = 0 })); - try opcodeEqual(0x14, Insn.alu(.sub, .{ .k = 0 })); - try opcodeEqual(0x24, Insn.alu(.mul, .{ .k = 0 })); - try opcodeEqual(0x34, Insn.alu(.div, .{ .k = 0 })); - try opcodeEqual(0x44, Insn.alu(.@"or", .{ .k = 0 })); - try opcodeEqual(0x54, Insn.alu(.@"and", .{ .k = 0 })); - try opcodeEqual(0x64, Insn.alu(.lsh, .{ .k = 0 })); - try opcodeEqual(0x74, Insn.alu(.rsh, .{ .k = 0 })); - try opcodeEqual(0x94, Insn.alu(.mod, .{ .k = 0 })); - try opcodeEqual(0xa4, Insn.alu(.xor, .{ .k = 0 })); - try opcodeEqual(0x84, Insn.alu_neg()); - try opcodeEqual(0x0c, Insn.alu(.add, .x)); - try opcodeEqual(0x1c, Insn.alu(.sub, .x)); - try opcodeEqual(0x2c, Insn.alu(.mul, .x)); - try opcodeEqual(0x3c, Insn.alu(.div, .x)); - try opcodeEqual(0x4c, Insn.alu(.@"or", .x)); - try opcodeEqual(0x5c, Insn.alu(.@"and", .x)); - try opcodeEqual(0x6c, Insn.alu(.lsh, .x)); - try opcodeEqual(0x7c, Insn.alu(.rsh, .x)); - try opcodeEqual(0x9c, Insn.alu(.mod, .x)); - try opcodeEqual(0xac, Insn.alu(.xor, .x)); - - try opcodeEqual(0x05, Insn.jmp_ja(0)); - try opcodeEqual(0x15, Insn.jmp(.jeq, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x25, Insn.jmp(.jgt, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x35, Insn.jmp(.jge, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x45, Insn.jmp(.jset, .{ .k = 0 }, 0, 0)); - try opcodeEqual(0x1d, Insn.jmp(.jeq, .x, 0, 0)); - try opcodeEqual(0x2d, Insn.jmp(.jgt, .x, 0, 0)); - try opcodeEqual(0x3d, Insn.jmp(.jge, .x, 0, 0)); - try opcodeEqual(0x4d, Insn.jmp(.jset, .x, 0, 0)); - - try opcodeEqual(0x06, Insn.ret(.{ .k = 0 })); - try opcodeEqual(0x16, Insn.ret(.a)); - - try opcodeEqual(0x07, Insn.tax()); - try opcodeEqual(0x87, Insn.txa()); -} - -pub const Error = error{ - InvalidOpcode, - InvalidOffset, - InvalidLocation, - DivisionByZero, - NoReturn, -}; - -/// A simple implementation of the BPF virtual-machine. -/// Use this to run/debug programs. -pub fn simulate( - packet: []const u8, - filter: []const Insn, - byte_order: std.builtin.Endian, -) Error!u32 { - assert(filter.len > 0 and filter.len < MAXINSNS); - assert(packet.len < MAXBUFSIZE); - const len = @intCast(u32, packet.len); - - var a: u32 = 0; - var x: u32 = 0; - var m = mem.zeroes([MEMWORDS]u32); - var pc: usize = 0; - - while (pc < filter.len) : (pc += 1) { - const i = filter[pc]; - // Cast to a wider type to protect against overflow. - const k = @as(u64, i.k); - const remaining = filter.len - (pc + 1); - - // Do validation/error checking here to compress the second switch. - switch (i.opcode) { - LD | ABS | W => if (k + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset, - LD | ABS | H => if (k + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset, - LD | ABS | B => if (k >= packet.len) return error.InvalidOffset, - LD | IND | W => if (k + x + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset, - LD | IND | H => if (k + x + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset, - LD | IND | B => if (k + x >= packet.len) return error.InvalidOffset, - - LDX | MSH | B => if (k >= packet.len) return error.InvalidOffset, - ST, STX, LD | MEM, LDX | MEM => if (i.k >= MEMWORDS) return error.InvalidOffset, - - JMP | JA => if (remaining <= i.k) return error.InvalidOffset, - JMP | JEQ | K, - JMP | JGT | K, - JMP | JGE | K, - JMP | JSET | K, - JMP | JEQ | X, - JMP | JGT | X, - JMP | JGE | X, - JMP | JSET | X, - => if (remaining <= i.jt or remaining <= i.jf) return error.InvalidLocation, - else => {}, - } - switch (i.opcode) { - LD | IMM => a = i.k, - LD | MEM => a = m[i.k], - LD | LEN | W => a = len, - LD | RND | W => a = random.int(u32), - LD | ABS | W => a = mem.readInt(u32, packet[i.k..][0..@sizeOf(u32)], byte_order), - LD | ABS | H => a = mem.readInt(u16, packet[i.k..][0..@sizeOf(u16)], byte_order), - LD | ABS | B => a = packet[i.k], - LD | IND | W => a = mem.readInt(u32, packet[i.k + x ..][0..@sizeOf(u32)], byte_order), - LD | IND | H => a = mem.readInt(u16, packet[i.k + x ..][0..@sizeOf(u16)], byte_order), - LD | IND | B => a = packet[i.k + x], - - LDX | IMM => x = i.k, - LDX | MEM => x = m[i.k], - LDX | LEN | W => x = len, - LDX | MSH | B => x = @as(u32, @truncate(u4, packet[i.k])) << 2, - - ST => m[i.k] = a, - STX => m[i.k] = x, - - ALU | ADD | K => a +%= i.k, - ALU | SUB | K => a -%= i.k, - ALU | MUL | K => a *%= i.k, - ALU | DIV | K => a = try math.divTrunc(u32, a, i.k), - ALU | OR | K => a |= i.k, - ALU | AND | K => a &= i.k, - ALU | LSH | K => a = math.shl(u32, a, i.k), - ALU | RSH | K => a = math.shr(u32, a, i.k), - ALU | MOD | K => a = try math.mod(u32, a, i.k), - ALU | XOR | K => a ^= i.k, - ALU | ADD | X => a +%= x, - ALU | SUB | X => a -%= x, - ALU | MUL | X => a *%= x, - ALU | DIV | X => a = try math.divTrunc(u32, a, x), - ALU | OR | X => a |= x, - ALU | AND | X => a &= x, - ALU | LSH | X => a = math.shl(u32, a, x), - ALU | RSH | X => a = math.shr(u32, a, x), - ALU | MOD | X => a = try math.mod(u32, a, x), - ALU | XOR | X => a ^= x, - ALU | NEG => a = @bitCast(u32, -%@bitCast(i32, a)), - - JMP | JA => pc += i.k, - JMP | JEQ | K => pc += if (a == i.k) i.jt else i.jf, - JMP | JGT | K => pc += if (a > i.k) i.jt else i.jf, - JMP | JGE | K => pc += if (a >= i.k) i.jt else i.jf, - JMP | JSET | K => pc += if (a & i.k > 0) i.jt else i.jf, - JMP | JEQ | X => pc += if (a == x) i.jt else i.jf, - JMP | JGT | X => pc += if (a > x) i.jt else i.jf, - JMP | JGE | X => pc += if (a >= x) i.jt else i.jf, - JMP | JSET | X => pc += if (a & x > 0) i.jt else i.jf, - - RET | K => return i.k, - RET | A => return a, - - MISC | TAX => x = a, - MISC | TXA => a = x, - else => return error.InvalidOpcode, - } - } - - return error.NoReturn; -} - -// This program is the BPF form of the tcpdump filter: -// -// tcpdump -dd 'ip host mirror.internode.on.net and tcp port ftp-data' -// -// As of January 2022, mirror.internode.on.net resolves to 150.101.135.3 -// -// For reference, here's what it looks like in BPF assembler. -// Note that the jumps are used for TCP/IP layer checks. -// -// ``` -// ldh [12] (#proto) -// jeq #0x0800 (ETHERTYPE_IP), L1, fail -// L1: ld [26] -// jeq #150.101.135.3, L2, dest -// dest: ld [30] -// jeq #150.101.135.3, L2, fail -// L2: ldb [23] -// jeq #0x6 (IPPROTO_TCP), L3, fail -// L3: ldh [20] -// jset #0x1fff, fail, plen -// plen: ldx 4 * ([14] & 0xf) -// ldh [x + 14] -// jeq #0x14 (FTP), pass, dstp -// dstp: ldh [x + 16] -// jeq #0x14 (FTP), pass, fail -// pass: ret #0x40000 -// fail: ret #0 -// ``` -const tcpdump_filter = [_]Insn{ - Insn.ld_abs(.half_word, 12), - Insn.jmp(.jeq, .{ .k = 0x800 }, 0, 14), - Insn.ld_abs(.word, 26), - Insn.jmp(.jeq, .{ .k = 0x96658703 }, 2, 0), - Insn.ld_abs(.word, 30), - Insn.jmp(.jeq, .{ .k = 0x96658703 }, 0, 10), - Insn.ld_abs(.byte, 23), - Insn.jmp(.jeq, .{ .k = 0x6 }, 0, 8), - Insn.ld_abs(.half_word, 20), - Insn.jmp(.jset, .{ .k = 0x1fff }, 6, 0), - Insn.ldx_msh(14), - Insn.ld_ind(.half_word, 14), - Insn.jmp(.jeq, .{ .k = 0x14 }, 2, 0), - Insn.ld_ind(.half_word, 16), - Insn.jmp(.jeq, .{ .k = 0x14 }, 0, 1), - Insn.ret(.{ .k = 0x40000 }), - Insn.ret(.{ .k = 0 }), -}; - -// This packet is the output of `ls` on mirror.internode.on.net:/, captured -// using the filter above. -// -// zig fmt: off -const ftp_data = [_]u8{ - // ethernet - 14 bytes: IPv4(0x0800) from a4:71:74:ad:4b:f0 -> de:ad:be:ef:f0:0f - 0xde, 0xad, 0xbe, 0xef, 0xf0, 0x0f, 0xa4, 0x71, 0x74, 0xad, 0x4b, 0xf0, 0x08, 0x00, - // IPv4 - 20 bytes: TCP data from 150.101.135.3 -> 192.168.1.3 - 0x45, 0x00, 0x01, 0xf2, 0x70, 0x3b, 0x40, 0x00, 0x37, 0x06, 0xf2, 0xb6, - 0x96, 0x65, 0x87, 0x03, 0xc0, 0xa8, 0x01, 0x03, - // TCP - 32 bytes: Source port: 20 (FTP). Payload = 446 bytes - 0x00, 0x14, 0x80, 0x6d, 0x35, 0x81, 0x2d, 0x40, 0x4f, 0x8a, 0x29, 0x9e, 0x80, 0x18, 0x00, 0x2e, - 0x88, 0x8d, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x0b, 0x59, 0x5d, 0x09, 0x32, 0x8b, 0x51, 0xa0 -} ++ - // Raw line-based FTP data - 446 bytes - "lrwxrwxrwx 1 root root 12 Feb 14 2012 debian -> .pub2/debian\r\n" ++ - "lrwxrwxrwx 1 root root 15 Feb 14 2012 debian-cd -> .pub2/debian-cd\r\n" ++ - "lrwxrwxrwx 1 root root 9 Mar 9 2018 linux -> pub/linux\r\n" ++ - "drwxr-xr-X 3 mirror mirror 4096 Sep 20 08:10 pub\r\n" ++ - "lrwxrwxrwx 1 root root 12 Feb 14 2012 ubuntu -> .pub2/ubuntu\r\n" ++ - "-rw-r--r-- 1 root root 1044 Jan 20 2015 welcome.msg\r\n"; -// zig fmt: on - -test "tcpdump filter" { - try expectEqual( - @as(u32, 0x40000), - try simulate(ftp_data, &tcpdump_filter, .Big), - ); -} - -fn expectPass(data: anytype, filter: []const Insn) !void { - try expectEqual( - @as(u32, 0), - try simulate(mem.asBytes(data), filter, .Big), - ); -} - -fn expectFail(expected_error: anyerror, data: anytype, filter: []const Insn) !void { - try expectError( - expected_error, - simulate(mem.asBytes(data), filter, native_endian), - ); -} - -test "simulator coverage" { - const some_data = [_]u8{ - 0xaa, 0xbb, 0xcc, 0xdd, 0x7f, - }; - - try expectPass(&some_data, &.{ - // ld #10 - // ldx #1 - // st M[0] - // stx M[1] - // fail if A != 10 - Insn.ld_imm(10), - Insn.ldx_imm(1), - Insn.st(.m0), - Insn.stx(.m1), - Insn.jmp(.jeq, .{ .k = 10 }, 1, 0), - Insn.ret(.{ .k = 1 }), - // ld [0] - // fail if A != 0xaabbccdd - Insn.ld_abs(.word, 0), - Insn.jmp(.jeq, .{ .k = 0xaabbccdd }, 1, 0), - Insn.ret(.{ .k = 2 }), - // ldh [0] - // fail if A != 0xaabb - Insn.ld_abs(.half_word, 0), - Insn.jmp(.jeq, .{ .k = 0xaabb }, 1, 0), - Insn.ret(.{ .k = 3 }), - // ldb [0] - // fail if A != 0xaa - Insn.ld_abs(.byte, 0), - Insn.jmp(.jeq, .{ .k = 0xaa }, 1, 0), - Insn.ret(.{ .k = 4 }), - // ld [x + 0] - // fail if A != 0xbbccdd7f - Insn.ld_ind(.word, 0), - Insn.jmp(.jeq, .{ .k = 0xbbccdd7f }, 1, 0), - Insn.ret(.{ .k = 5 }), - // ldh [x + 0] - // fail if A != 0xbbcc - Insn.ld_ind(.half_word, 0), - Insn.jmp(.jeq, .{ .k = 0xbbcc }, 1, 0), - Insn.ret(.{ .k = 6 }), - // ldb [x + 0] - // fail if A != 0xbb - Insn.ld_ind(.byte, 0), - Insn.jmp(.jeq, .{ .k = 0xbb }, 1, 0), - Insn.ret(.{ .k = 7 }), - // ld M[0] - // fail if A != 10 - Insn.ld_mem(.m0), - Insn.jmp(.jeq, .{ .k = 10 }, 1, 0), - Insn.ret(.{ .k = 8 }), - // ld #len - // fail if A != 5 - Insn.ld_len(), - Insn.jmp(.jeq, .{ .k = some_data.len }, 1, 0), - Insn.ret(.{ .k = 9 }), - // ld #0 - // ld arc4random() - // fail if A == 0 - Insn.ld_imm(0), - Insn.ld_rnd(), - Insn.jmp(.jgt, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 10 }), - // ld #3 - // ldx #10 - // st M[2] - // txa - // fail if a != x - Insn.ld_imm(3), - Insn.ldx_imm(10), - Insn.st(.m2), - Insn.txa(), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 11 }), - // ldx M[2] - // fail if A <= X - Insn.ldx_mem(.m2), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 12 }), - // ldx #len - // fail if a <= x - Insn.ldx_len(), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 13 }), - // a = 4 * (0x7f & 0xf) - // x = 4 * ([4] & 0xf) - // fail if a != x - Insn.ld_imm(4 * (0x7f & 0xf)), - Insn.ldx_msh(4), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 14 }), - // ld #(u32)-1 - // ldx #2 - // add #1 - // fail if a != 0 - Insn.ld_imm(0xffffffff), - Insn.ldx_imm(2), - Insn.alu(.add, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 15 }), - // sub #1 - // fail if a != (u32)-1 - Insn.alu(.sub, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0), - Insn.ret(.{ .k = 16 }), - // add x - // fail if a != 1 - Insn.alu(.add, .x), - Insn.jmp(.jeq, .{ .k = 1 }, 1, 0), - Insn.ret(.{ .k = 17 }), - // sub x - // fail if a != (u32)-1 - Insn.alu(.sub, .x), - Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0), - Insn.ret(.{ .k = 18 }), - // ld #16 - // mul #2 - // fail if a != 32 - Insn.ld_imm(16), - Insn.alu(.mul, .{ .k = 2 }), - Insn.jmp(.jeq, .{ .k = 32 }, 1, 0), - Insn.ret(.{ .k = 19 }), - // mul x - // fail if a != 64 - Insn.alu(.mul, .x), - Insn.jmp(.jeq, .{ .k = 64 }, 1, 0), - Insn.ret(.{ .k = 20 }), - // div #2 - // fail if a != 32 - Insn.alu(.div, .{ .k = 2 }), - Insn.jmp(.jeq, .{ .k = 32 }, 1, 0), - Insn.ret(.{ .k = 21 }), - // div x - // fail if a != 16 - Insn.alu(.div, .x), - Insn.jmp(.jeq, .{ .k = 16 }, 1, 0), - Insn.ret(.{ .k = 22 }), - // or #4 - // fail if a != 20 - Insn.alu(.@"or", .{ .k = 4 }), - Insn.jmp(.jeq, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 23 }), - // or x - // fail if a != 22 - Insn.alu(.@"or", .x), - Insn.jmp(.jeq, .{ .k = 22 }, 1, 0), - Insn.ret(.{ .k = 24 }), - // and #6 - // fail if a != 6 - Insn.alu(.@"and", .{ .k = 0b110 }), - Insn.jmp(.jeq, .{ .k = 6 }, 1, 0), - Insn.ret(.{ .k = 25 }), - // and x - // fail if a != 2 - Insn.alu(.@"and", .x), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 26 }), - // xor #15 - // fail if a != 13 - Insn.alu(.xor, .{ .k = 0b1111 }), - Insn.jmp(.jeq, .{ .k = 0b1101 }, 1, 0), - Insn.ret(.{ .k = 27 }), - // xor x - // fail if a != 15 - Insn.alu(.xor, .x), - Insn.jmp(.jeq, .{ .k = 0b1111 }, 1, 0), - Insn.ret(.{ .k = 28 }), - // rsh #1 - // fail if a != 7 - Insn.alu(.rsh, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0b0111 }, 1, 0), - Insn.ret(.{ .k = 29 }), - // rsh x - // fail if a != 1 - Insn.alu(.rsh, .x), - Insn.jmp(.jeq, .{ .k = 0b0001 }, 1, 0), - Insn.ret(.{ .k = 30 }), - // lsh #1 - // fail if a != 2 - Insn.alu(.lsh, .{ .k = 1 }), - Insn.jmp(.jeq, .{ .k = 0b0010 }, 1, 0), - Insn.ret(.{ .k = 31 }), - // lsh x - // fail if a != 8 - Insn.alu(.lsh, .x), - Insn.jmp(.jeq, .{ .k = 0b1000 }, 1, 0), - Insn.ret(.{ .k = 32 }), - // mod 6 - // fail if a != 2 - Insn.alu(.mod, .{ .k = 6 }), - Insn.jmp(.jeq, .{ .k = 2 }, 1, 0), - Insn.ret(.{ .k = 33 }), - // mod x - // fail if a != 0 - Insn.alu(.mod, .x), - Insn.jmp(.jeq, .{ .k = 0 }, 1, 0), - Insn.ret(.{ .k = 34 }), - // tax - // neg - // fail if a != (u32)-2 - Insn.txa(), - Insn.alu_neg(), - Insn.jmp(.jeq, .{ .k = ~@as(u32, 2) + 1 }, 1, 0), - Insn.ret(.{ .k = 35 }), - // ja #1 (skip the next instruction) - Insn.jmp_ja(1), - Insn.ret(.{ .k = 36 }), - // ld #20 - // tax - // fail if a != 20 - // fail if a != x - Insn.ld_imm(20), - Insn.tax(), - Insn.jmp(.jeq, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 37 }), - Insn.jmp(.jeq, .x, 1, 0), - Insn.ret(.{ .k = 38 }), - // ld #19 - // fail if a == 20 - // fail if a == x - // fail if a >= 20 - // fail if a >= X - Insn.ld_imm(19), - Insn.jmp(.jeq, .{ .k = 20 }, 0, 1), - Insn.ret(.{ .k = 39 }), - Insn.jmp(.jeq, .x, 0, 1), - Insn.ret(.{ .k = 40 }), - Insn.jmp(.jgt, .{ .k = 20 }, 0, 1), - Insn.ret(.{ .k = 41 }), - Insn.jmp(.jgt, .x, 0, 1), - Insn.ret(.{ .k = 42 }), - // ld #21 - // fail if a < 20 - // fail if a < x - Insn.ld_imm(21), - Insn.jmp(.jgt, .{ .k = 20 }, 1, 0), - Insn.ret(.{ .k = 43 }), - Insn.jmp(.jgt, .x, 1, 0), - Insn.ret(.{ .k = 44 }), - // ldx #22 - // fail if a < 22 - // fail if a < x - Insn.ldx_imm(22), - Insn.jmp(.jge, .{ .k = 22 }, 0, 1), - Insn.ret(.{ .k = 45 }), - Insn.jmp(.jge, .x, 0, 1), - Insn.ret(.{ .k = 46 }), - // ld #23 - // fail if a >= 22 - // fail if a >= x - Insn.ld_imm(23), - Insn.jmp(.jge, .{ .k = 22 }, 1, 0), - Insn.ret(.{ .k = 47 }), - Insn.jmp(.jge, .x, 1, 0), - Insn.ret(.{ .k = 48 }), - // ldx #0b10100 - // fail if a & 0b10100 == 0 - // fail if a & x == 0 - Insn.ldx_imm(0b10100), - Insn.jmp(.jset, .{ .k = 0b10100 }, 1, 0), - Insn.ret(.{ .k = 47 }), - Insn.jmp(.jset, .x, 1, 0), - Insn.ret(.{ .k = 48 }), - // ldx #0 - // fail if a & 0 > 0 - // fail if a & x > 0 - Insn.ldx_imm(0), - Insn.jmp(.jset, .{ .k = 0 }, 0, 1), - Insn.ret(.{ .k = 49 }), - Insn.jmp(.jset, .x, 0, 1), - Insn.ret(.{ .k = 50 }), - Insn.ret(.{ .k = 0 }), - }); - try expectPass(&some_data, &.{ - Insn.ld_imm(35), - Insn.ld_imm(0), - Insn.ret(.a), - }); - - // Errors - try expectFail(error.NoReturn, &some_data, &.{ - Insn.ld_imm(10), - }); - try expectFail(error.InvalidOpcode, &some_data, &.{ - Insn.stmt(0x7f, 0xdeadbeef), - }); - try expectFail(error.InvalidOffset, &some_data, &.{ - Insn.stmt(LD | ABS | W, 10), - }); - try expectFail(error.InvalidLocation, &some_data, &.{ - Insn.jmp(.jeq, .{ .k = 0 }, 10, 0), - }); - try expectFail(error.InvalidLocation, &some_data, &.{ - Insn.jmp(.jeq, .{ .k = 0 }, 0, 10), - }); -} diff --git a/lib/std/x/net/ip.zig b/lib/std/x/net/ip.zig deleted file mode 100644 index b3da9725d8..0000000000 --- a/lib/std/x/net/ip.zig +++ /dev/null @@ -1,57 +0,0 @@ -const std = @import("../../std.zig"); - -const fmt = std.fmt; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const Socket = std.x.os.Socket; - -/// A generic IP abstraction. -const ip = @This(); - -/// A union of all eligible types of IP addresses. -pub const Address = union(enum) { - ipv4: IPv4.Address, - ipv6: IPv6.Address, - - /// Instantiate a new address with a IPv4 host and port. - pub fn initIPv4(host: IPv4, port: u16) Address { - return .{ .ipv4 = .{ .host = host, .port = port } }; - } - - /// Instantiate a new address with a IPv6 host and port. - pub fn initIPv6(host: IPv6, port: u16) Address { - return .{ .ipv6 = .{ .host = host, .port = port } }; - } - - /// Re-interpret a generic socket address into an IP address. - pub fn from(address: Socket.Address) ip.Address { - return switch (address) { - .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, - .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, - }; - } - - /// Re-interpret an IP address into a generic socket address. - pub fn into(self: ip.Address) Socket.Address { - return switch (self) { - .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, - .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, - }; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: ip.Address, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - _ = opts; - switch (self) { - .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - } - } -}; diff --git a/lib/std/x/net/tcp.zig b/lib/std/x/net/tcp.zig deleted file mode 100644 index 0293deb9db..0000000000 --- a/lib/std/x/net/tcp.zig +++ /dev/null @@ -1,447 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const io = std.io; -const os = std.os; -const ip = std.x.net.ip; - -const fmt = std.fmt; -const mem = std.mem; -const testing = std.testing; -const native_os = builtin.os; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const Socket = std.x.os.Socket; -const Buffer = std.x.os.Buffer; - -/// A generic TCP socket abstraction. -const tcp = @This(); - -/// A TCP client-address pair. -pub const Connection = struct { - client: tcp.Client, - address: ip.Address, - - /// Enclose a TCP client and address into a client-address pair. - pub fn from(conn: Socket.Connection) tcp.Connection { - return .{ - .client = tcp.Client.from(conn.socket), - .address = ip.Address.from(conn.address), - }; - } - - /// Unravel a TCP client-address pair into a socket-address pair. - pub fn into(self: tcp.Connection) Socket.Connection { - return .{ - .socket = self.client.socket, - .address = self.address.into(), - }; - } - - /// Closes the underlying client of the connection. - pub fn deinit(self: tcp.Connection) void { - self.client.deinit(); - } -}; - -/// Possible domains that a TCP client/listener may operate over. -pub const Domain = enum(u16) { - ip = os.AF.INET, - ipv6 = os.AF.INET6, -}; - -/// A TCP client. -pub const Client = struct { - socket: Socket, - - /// Implements `std.io.Reader`. - pub const Reader = struct { - client: Client, - flags: u32, - - /// Implements `readFn` for `std.io.Reader`. - pub fn read(self: Client.Reader, buffer: []u8) !usize { - return self.client.read(buffer, self.flags); - } - }; - - /// Implements `std.io.Writer`. - pub const Writer = struct { - client: Client, - flags: u32, - - /// Implements `writeFn` for `std.io.Writer`. - pub fn write(self: Client.Writer, buffer: []const u8) !usize { - return self.client.write(buffer, self.flags); - } - }; - - /// Opens a new client. - pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client { - return Client{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK.STREAM, - os.IPPROTO.TCP, - flags, - ), - }; - } - - /// Enclose a TCP client over an existing socket. - pub fn from(socket: Socket) Client { - return Client{ .socket = socket }; - } - - /// Closes the client. - pub fn deinit(self: Client) void { - self.socket.deinit(); - } - - /// Shutdown either the read side, write side, or all sides of the client's underlying socket. - pub fn shutdown(self: Client, how: os.ShutdownHow) !void { - return self.socket.shutdown(how); - } - - /// Have the client attempt to the connect to an address. - pub fn connect(self: Client, address: ip.Address) !void { - return self.socket.connect(address.into()); - } - - /// Extracts the error set of a function. - /// TODO: remove after Socket.{read, write} error unions are well-defined across different platforms - fn ErrorSetOf(comptime Function: anytype) type { - return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set; - } - - /// Wrap `tcp.Client` into `std.io.Reader`. - pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) { - return .{ .context = .{ .client = self, .flags = flags } }; - } - - /// Wrap `tcp.Client` into `std.io.Writer`. - pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) { - return .{ .context = .{ .client = self, .flags = flags } }; - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Client, buf: []u8, flags: u32) !usize { - return self.socket.read(buf, flags); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Client, buf: []const u8, flags: u32) !usize { - return self.socket.write(buf, flags); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize { - return self.socket.writeMessage(msg, flags); - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize { - return self.socket.readMessage(msg, flags); - } - - /// Query and return the latest cached error on the client's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the read buffer size of the client's underlying socket. - pub fn getReadBufferSize(self: Client) !u32 { - return self.socket.getReadBufferSize(); - } - - /// Query the write buffer size of the client's underlying socket. - pub fn getWriteBufferSize(self: Client) !u32 { - return self.socket.getWriteBufferSize(); - } - - /// Query the address that the client's socket is locally bounded to. - pub fn getLocalAddress(self: Client) !ip.Address { - return ip.Address.from(try self.socket.getLocalAddress()); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Client) !ip.Address { - return ip.Address.from(try self.socket.getRemoteAddress()); - } - - /// Have close() or shutdown() syscalls block until all queued messages in the client have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Client, timeout_seconds: ?u16) !void { - return self.socket.setLinger(timeout_seconds); - } - - /// Have keep-alive messages be sent periodically. The timing in which keep-alive messages are sent are - /// dependant on operating system settings. It returns `error.UnsupportedSocketOption` if the host does - /// not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Client, enabled: bool) !void { - return self.socket.setKeepAlive(enabled); - } - - /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets disabling Nagle's algorithm. - pub fn setNoDelay(self: Client, enabled: bool) !void { - if (@hasDecl(os.TCP, "NODELAY")) { - const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled))); - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes); - } - return error.UnsupportedSocketOption; - } - - /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns - /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK. - pub fn setQuickACK(self: Client, enabled: bool) !void { - if (@hasDecl(os.TCP, "QUICKACK")) { - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Client, size: u32) !void { - return self.socket.setWriteBufferSize(size); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Client, size: u32) !void { - return self.socket.setReadBufferSize(size); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Client, milliseconds: u32) !void { - return self.socket.setWriteTimeout(milliseconds); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Client, milliseconds: u32) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -/// A TCP listener. -pub const Listener = struct { - socket: Socket, - - /// Opens a new listener. - pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener { - return Listener{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK.STREAM, - os.IPPROTO.TCP, - flags, - ), - }; - } - - /// Closes the listener. - pub fn deinit(self: Listener) void { - self.socket.deinit(); - } - - /// Shuts down the underlying listener's socket. The next subsequent call, or - /// a current pending call to accept() after shutdown is called will return - /// an error. - pub fn shutdown(self: Listener) !void { - return self.socket.shutdown(.recv); - } - - /// Binds the listener's socket to an address. - pub fn bind(self: Listener, address: ip.Address) !void { - return self.socket.bind(address.into()); - } - - /// Start listening for incoming connections. - pub fn listen(self: Listener, max_backlog_size: u31) !void { - return self.socket.listen(max_backlog_size); - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the listener's socket. - pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection { - return tcp.Connection.from(try self.socket.accept(flags)); - } - - /// Query and return the latest cached error on the listener's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the address that the listener's socket is locally bounded to. - pub fn getLocalAddress(self: Listener) !ip.Address { - return ip.Address.from(try self.socket.getLocalAddress()); - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Listener, enabled: bool) !void { - return self.socket.setReuseAddress(enabled); - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - pub fn setReusePort(self: Listener, enabled: bool) !void { - return self.socket.setReusePort(enabled); - } - - /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not - /// support TCP Fast Open. - pub fn setFastOpen(self: Listener, enabled: bool) !void { - if (@hasDecl(os.TCP, "FASTOPEN")) { - return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set a timeout on the listener that is to occur if no new incoming connections come in - /// after a specified number of milliseconds. A subsequent accept call to the listener - /// will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -test "tcp: create client/listener pair" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); -} - -test "tcp/client: 1ms read timeout" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - try client.setReadTimeout(1); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); - - var buf: [1]u8 = undefined; - try testing.expectError(error.WouldBlock, client.reader(0).read(&buf)); -} - -test "tcp/client: read and write multiple vectors" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - // https://github.com/ziglang/zig/issues/13893 - return error.SkipZigTest; - } - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ .close_on_exec = true }); - defer client.deinit(); - - try client.connect(binded_address); - - const conn = try listener.accept(.{ .close_on_exec = true }); - defer conn.deinit(); - - const message = "hello world"; - _ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{ - Buffer.from(message[0 .. message.len / 2]), - Buffer.from(message[message.len / 2 ..]), - }), 0); - - var buf: [message.len + 1]u8 = undefined; - var msg = Socket.Message.fromBuffers(&[_]Buffer{ - Buffer.from(buf[0 .. message.len / 2]), - Buffer.from(buf[message.len / 2 ..]), - }); - _ = try client.readMessage(&msg, 0); - - try testing.expectEqualStrings(message, buf[0..message.len]); -} - -test "tcp/listener: bind to unspecified ipv4 address" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - const address = try listener.getLocalAddress(); - try testing.expect(address == .ipv4); -} - -test "tcp/listener: bind to unspecified ipv6 address" { - if (native_os.tag == .wasi) return error.SkipZigTest; - - if (builtin.os.tag == .windows) { - // https://github.com/ziglang/zig/issues/13893 - return error.SkipZigTest; - } - - const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true }); - defer listener.deinit(); - - try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0)); - try listener.listen(128); - - const address = try listener.getLocalAddress(); - try testing.expect(address == .ipv6); -} diff --git a/lib/std/x/os/io.zig b/lib/std/x/os/io.zig deleted file mode 100644 index 6c4763df65..0000000000 --- a/lib/std/x/os/io.zig +++ /dev/null @@ -1,224 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const os = std.os; -const mem = std.mem; -const testing = std.testing; -const native_os = builtin.os; -const linux = std.os.linux; - -/// POSIX `iovec`, or Windows `WSABUF`. The difference between the two are the ordering -/// of fields, alongside the length being represented as either a ULONG or a size_t. -pub const Buffer = if (native_os.tag == .windows) - extern struct { - len: c_ulong, - ptr: usize, - - pub fn from(slice: []const u8) Buffer { - return .{ .len = @intCast(c_ulong, slice.len), .ptr = @ptrToInt(slice.ptr) }; - } - - pub fn into(self: Buffer) []const u8 { - return @intToPtr([*]const u8, self.ptr)[0..self.len]; - } - - pub fn intoMutable(self: Buffer) []u8 { - return @intToPtr([*]u8, self.ptr)[0..self.len]; - } - } -else - extern struct { - ptr: usize, - len: usize, - - pub fn from(slice: []const u8) Buffer { - return .{ .ptr = @ptrToInt(slice.ptr), .len = slice.len }; - } - - pub fn into(self: Buffer) []const u8 { - return @intToPtr([*]const u8, self.ptr)[0..self.len]; - } - - pub fn intoMutable(self: Buffer) []u8 { - return @intToPtr([*]u8, self.ptr)[0..self.len]; - } - }; - -pub const Reactor = struct { - pub const InitFlags = enum { - close_on_exec, - }; - - pub const Event = struct { - data: usize, - is_error: bool, - is_hup: bool, - is_readable: bool, - is_writable: bool, - }; - - pub const Interest = struct { - hup: bool = false, - oneshot: bool = false, - readable: bool = false, - writable: bool = false, - }; - - fd: os.fd_t, - - pub fn init(flags: std.enums.EnumFieldStruct(Reactor.InitFlags, bool, false)) !Reactor { - var raw_flags: u32 = 0; - const set = std.EnumSet(Reactor.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= linux.EPOLL.CLOEXEC; - return Reactor{ .fd = try os.epoll_create1(raw_flags) }; - } - - pub fn deinit(self: Reactor) void { - os.close(self.fd); - } - - pub fn update(self: Reactor, fd: os.fd_t, identifier: usize, interest: Reactor.Interest) !void { - var flags: u32 = 0; - flags |= if (interest.oneshot) linux.EPOLL.ONESHOT else linux.EPOLL.ET; - if (interest.hup) flags |= linux.EPOLL.RDHUP; - if (interest.readable) flags |= linux.EPOLL.IN; - if (interest.writable) flags |= linux.EPOLL.OUT; - - const event = &linux.epoll_event{ - .events = flags, - .data = .{ .ptr = identifier }, - }; - - os.epoll_ctl(self.fd, linux.EPOLL.CTL_MOD, fd, event) catch |err| switch (err) { - error.FileDescriptorNotRegistered => try os.epoll_ctl(self.fd, linux.EPOLL.CTL_ADD, fd, event), - else => return err, - }; - } - - pub fn remove(self: Reactor, fd: os.fd_t) !void { - // directly from man epoll_ctl BUGS section - // In kernel versions before 2.6.9, the EPOLL_CTL_DEL operation re‐ - // quired a non-null pointer in event, even though this argument is - // ignored. Since Linux 2.6.9, event can be specified as NULL when - // using EPOLL_CTL_DEL. Applications that need to be portable to - // kernels before 2.6.9 should specify a non-null pointer in event. - var event = linux.epoll_event{ - .events = 0, - .data = .{ .ptr = 0 }, - }; - - return os.epoll_ctl(self.fd, linux.EPOLL.CTL_DEL, fd, &event); - } - - pub fn poll(self: Reactor, comptime max_num_events: comptime_int, closure: anytype, timeout_milliseconds: ?u64) !void { - var events: [max_num_events]linux.epoll_event = undefined; - - const num_events = os.epoll_wait(self.fd, &events, if (timeout_milliseconds) |ms| @intCast(i32, ms) else -1); - for (events[0..num_events]) |ev| { - const is_error = ev.events & linux.EPOLL.ERR != 0; - const is_hup = ev.events & (linux.EPOLL.HUP | linux.EPOLL.RDHUP) != 0; - const is_readable = ev.events & linux.EPOLL.IN != 0; - const is_writable = ev.events & linux.EPOLL.OUT != 0; - - try closure.call(Reactor.Event{ - .data = ev.data.ptr, - .is_error = is_error, - .is_hup = is_hup, - .is_readable = is_readable, - .is_writable = is_writable, - }); - } - } -}; - -test "reactor/linux: drive async tcp client/listener pair" { - if (native_os.tag != .linux) return error.SkipZigTest; - - const ip = std.x.net.ip; - const tcp = std.x.net.tcp; - - const IPv4 = std.x.os.IPv4; - const IPv6 = std.x.os.IPv6; - - const reactor = try Reactor.init(.{ .close_on_exec = true }); - defer reactor.deinit(); - - const listener = try tcp.Listener.init(.ip, .{ - .close_on_exec = true, - .nonblocking = true, - }); - defer listener.deinit(); - - try reactor.update(listener.socket.fd, 0, .{ .readable = true }); - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 0, - .is_error = false, - .is_hup = true, - .is_readable = false, - .is_writable = false, - }, event); - } - }, null); - - try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0)); - try listener.listen(128); - - var binded_address = try listener.getLocalAddress(); - switch (binded_address) { - .ipv4 => |*ipv4| ipv4.host = IPv4.localhost, - .ipv6 => |*ipv6| ipv6.host = IPv6.localhost, - } - - const client = try tcp.Client.init(.ip, .{ - .close_on_exec = true, - .nonblocking = true, - }); - defer client.deinit(); - - try reactor.update(client.socket.fd, 1, .{ .readable = true, .writable = true }); - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 1, - .is_error = false, - .is_hup = true, - .is_readable = false, - .is_writable = true, - }, event); - } - }, null); - - client.connect(binded_address) catch |err| switch (err) { - error.WouldBlock => {}, - else => return err, - }; - - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 1, - .is_error = false, - .is_hup = false, - .is_readable = false, - .is_writable = true, - }, event); - } - }, null); - - try reactor.poll(1, struct { - fn call(event: Reactor.Event) !void { - try testing.expectEqual(Reactor.Event{ - .data = 0, - .is_error = false, - .is_hup = false, - .is_readable = true, - .is_writable = false, - }, event); - } - }, null); - - try reactor.remove(client.socket.fd); - try reactor.remove(listener.socket.fd); -} diff --git a/lib/std/x/os/net.zig b/lib/std/x/os/net.zig deleted file mode 100644 index e00299e243..0000000000 --- a/lib/std/x/os/net.zig +++ /dev/null @@ -1,605 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); - -const os = std.os; -const fmt = std.fmt; -const mem = std.mem; -const math = std.math; -const testing = std.testing; -const native_os = builtin.os; -const have_ifnamesize = @hasDecl(os.system, "IFNAMESIZE"); - -pub const ResolveScopeIdError = error{ - NameTooLong, - PermissionDenied, - AddressFamilyNotSupported, - ProtocolFamilyNotAvailable, - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, - ProtocolNotSupported, - SocketTypeNotSupported, - InterfaceNotFound, - FileSystem, - Unexpected, -}; - -/// Resolves a network interface name into a scope/zone ID. It returns -/// an error if either resolution fails, or if the interface name is -/// too long. -pub fn resolveScopeId(name: []const u8) ResolveScopeIdError!u32 { - if (have_ifnamesize) { - if (name.len >= os.IFNAMESIZE) return error.NameTooLong; - - if (native_os.tag == .windows or comptime native_os.tag.isDarwin()) { - var interface_name: [os.IFNAMESIZE:0]u8 = undefined; - mem.copy(u8, &interface_name, name); - interface_name[name.len] = 0; - - const rc = blk: { - if (native_os.tag == .windows) { - break :blk os.windows.ws2_32.if_nametoindex(@ptrCast([*:0]const u8, &interface_name)); - } else { - const index = os.system.if_nametoindex(@ptrCast([*:0]const u8, &interface_name)); - break :blk @bitCast(u32, index); - } - }; - if (rc == 0) { - return error.InterfaceNotFound; - } - return rc; - } - - if (native_os.tag == .linux) { - const fd = try os.socket(os.AF.INET, os.SOCK.DGRAM, 0); - defer os.closeSocket(fd); - - var f: os.ifreq = undefined; - mem.copy(u8, &f.ifrn.name, name); - f.ifrn.name[name.len] = 0; - - try os.ioctl_SIOCGIFINDEX(fd, &f); - - return @bitCast(u32, f.ifru.ivalue); - } - } - - return error.InterfaceNotFound; -} - -/// An IPv4 address comprised of 4 bytes. -pub const IPv4 = extern struct { - /// A IPv4 host-port pair. - pub const Address = extern struct { - host: IPv4, - port: u16, - }; - - /// Octets of a IPv4 address designating the local host. - pub const localhost_octets = [_]u8{ 127, 0, 0, 1 }; - - /// The IPv4 address of the local host. - pub const localhost: IPv4 = .{ .octets = localhost_octets }; - - /// Octets of an unspecified IPv4 address. - pub const unspecified_octets = [_]u8{0} ** 4; - - /// An unspecified IPv4 address. - pub const unspecified: IPv4 = .{ .octets = unspecified_octets }; - - /// Octets of a broadcast IPv4 address. - pub const broadcast_octets = [_]u8{255} ** 4; - - /// An IPv4 broadcast address. - pub const broadcast: IPv4 = .{ .octets = broadcast_octets }; - - /// The prefix octet pattern of a link-local IPv4 address. - pub const link_local_prefix = [_]u8{ 169, 254 }; - - /// The prefix octet patterns of IPv4 addresses intended for - /// documentation. - pub const documentation_prefixes = [_][]const u8{ - &[_]u8{ 192, 0, 2 }, - &[_]u8{ 198, 51, 100 }, - &[_]u8{ 203, 0, 113 }, - }; - - octets: [4]u8, - - /// Returns whether or not the two addresses are equal to, less than, or - /// greater than each other. - pub fn cmp(self: IPv4, other: IPv4) math.Order { - return mem.order(u8, &self.octets, &other.octets); - } - - /// Returns true if both addresses are semantically equivalent. - pub fn eql(self: IPv4, other: IPv4) bool { - return mem.eql(u8, &self.octets, &other.octets); - } - - /// Returns true if the address is a loopback address. - pub fn isLoopback(self: IPv4) bool { - return self.octets[0] == 127; - } - - /// Returns true if the address is an unspecified IPv4 address. - pub fn isUnspecified(self: IPv4) bool { - return mem.eql(u8, &self.octets, &unspecified_octets); - } - - /// Returns true if the address is a private IPv4 address. - pub fn isPrivate(self: IPv4) bool { - return self.octets[0] == 10 or - (self.octets[0] == 172 and self.octets[1] >= 16 and self.octets[1] <= 31) or - (self.octets[0] == 192 and self.octets[1] == 168); - } - - /// Returns true if the address is a link-local IPv4 address. - pub fn isLinkLocal(self: IPv4) bool { - return mem.startsWith(u8, &self.octets, &link_local_prefix); - } - - /// Returns true if the address is a multicast IPv4 address. - pub fn isMulticast(self: IPv4) bool { - return self.octets[0] >= 224 and self.octets[0] <= 239; - } - - /// Returns true if the address is a IPv4 broadcast address. - pub fn isBroadcast(self: IPv4) bool { - return mem.eql(u8, &self.octets, &broadcast_octets); - } - - /// Returns true if the address is in a range designated for documentation. Refer - /// to IETF RFC 5737 for more details. - pub fn isDocumentation(self: IPv4) bool { - inline for (documentation_prefixes) |prefix| { - if (mem.startsWith(u8, &self.octets, prefix)) { - return true; - } - } - return false; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: IPv4, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - - try fmt.format(writer, "{}.{}.{}.{}", .{ - self.octets[0], - self.octets[1], - self.octets[2], - self.octets[3], - }); - } - - /// Set of possible errors that may encountered when parsing an IPv4 - /// address. - pub const ParseError = error{ - UnexpectedEndOfOctet, - TooManyOctets, - OctetOverflow, - UnexpectedToken, - IncompleteAddress, - }; - - /// Parses an arbitrary IPv4 address. - pub fn parse(buf: []const u8) ParseError!IPv4 { - var octets: [4]u8 = undefined; - var octet: u8 = 0; - - var index: u8 = 0; - var saw_any_digits: bool = false; - - for (buf) |c| { - switch (c) { - '.' => { - if (!saw_any_digits) return error.UnexpectedEndOfOctet; - if (index == 3) return error.TooManyOctets; - octets[index] = octet; - index += 1; - octet = 0; - saw_any_digits = false; - }, - '0'...'9' => { - saw_any_digits = true; - octet = math.mul(u8, octet, 10) catch return error.OctetOverflow; - octet = math.add(u8, octet, c - '0') catch return error.OctetOverflow; - }, - else => return error.UnexpectedToken, - } - } - - if (index == 3 and saw_any_digits) { - octets[index] = octet; - return IPv4{ .octets = octets }; - } - - return error.IncompleteAddress; - } - - /// Maps the address to its IPv6 equivalent. In most cases, you would - /// want to map the address to its IPv6 equivalent rather than directly - /// re-interpreting the address. - pub fn mapToIPv6(self: IPv4) IPv6 { - var octets: [16]u8 = undefined; - mem.copy(u8, octets[0..12], &IPv6.v4_mapped_prefix); - mem.copy(u8, octets[12..], &self.octets); - return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id }; - } - - /// Directly re-interprets the address to its IPv6 equivalent. In most - /// cases, you would want to map the address to its IPv6 equivalent rather - /// than directly re-interpreting the address. - pub fn toIPv6(self: IPv4) IPv6 { - var octets: [16]u8 = undefined; - mem.set(u8, octets[0..12], 0); - mem.copy(u8, octets[12..], &self.octets); - return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id }; - } -}; - -/// An IPv6 address comprised of 16 bytes for an address, and 4 bytes -/// for a scope ID; cumulatively summing to 20 bytes in total. -pub const IPv6 = extern struct { - /// A IPv6 host-port pair. - pub const Address = extern struct { - host: IPv6, - port: u16, - }; - - /// Octets of a IPv6 address designating the local host. - pub const localhost_octets = [_]u8{0} ** 15 ++ [_]u8{0x01}; - - /// The IPv6 address of the local host. - pub const localhost: IPv6 = .{ - .octets = localhost_octets, - .scope_id = no_scope_id, - }; - - /// Octets of an unspecified IPv6 address. - pub const unspecified_octets = [_]u8{0} ** 16; - - /// An unspecified IPv6 address. - pub const unspecified: IPv6 = .{ - .octets = unspecified_octets, - .scope_id = no_scope_id, - }; - - /// The prefix of a IPv6 address that is mapped to a IPv4 address. - pub const v4_mapped_prefix = [_]u8{0} ** 10 ++ [_]u8{0xFF} ** 2; - - /// A marker value used to designate an IPv6 address with no - /// associated scope ID. - pub const no_scope_id = math.maxInt(u32); - - octets: [16]u8, - scope_id: u32, - - /// Returns whether or not the two addresses are equal to, less than, or - /// greater than each other. - pub fn cmp(self: IPv6, other: IPv6) math.Order { - return switch (mem.order(u8, self.octets, other.octets)) { - .eq => math.order(self.scope_id, other.scope_id), - else => |order| order, - }; - } - - /// Returns true if both addresses are semantically equivalent. - pub fn eql(self: IPv6, other: IPv6) bool { - return self.scope_id == other.scope_id and mem.eql(u8, &self.octets, &other.octets); - } - - /// Returns true if the address is an unspecified IPv6 address. - pub fn isUnspecified(self: IPv6) bool { - return mem.eql(u8, &self.octets, &unspecified_octets); - } - - /// Returns true if the address is a loopback address. - pub fn isLoopback(self: IPv6) bool { - return mem.eql(u8, self.octets[0..3], &[_]u8{ 0, 0, 0 }) and - mem.eql(u8, self.octets[12..], &[_]u8{ 0, 0, 0, 1 }); - } - - /// Returns true if the address maps to an IPv4 address. - pub fn mapsToIPv4(self: IPv6) bool { - return mem.startsWith(u8, &self.octets, &v4_mapped_prefix); - } - - /// Returns an IPv4 address representative of the address should - /// it the address be mapped to an IPv4 address. It returns null - /// otherwise. - pub fn toIPv4(self: IPv6) ?IPv4 { - if (!self.mapsToIPv4()) return null; - return IPv4{ .octets = self.octets[12..][0..4].* }; - } - - /// Returns true if the address is a multicast IPv6 address. - pub fn isMulticast(self: IPv6) bool { - return self.octets[0] == 0xFF; - } - - /// Returns true if the address is a unicast link local IPv6 address. - pub fn isLinkLocal(self: IPv6) bool { - return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0x80; - } - - /// Returns true if the address is a deprecated unicast site local - /// IPv6 address. Refer to IETF RFC 3879 for more details as to - /// why they are deprecated. - pub fn isSiteLocal(self: IPv6) bool { - return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0xC0; - } - - /// IPv6 multicast address scopes. - pub const Scope = enum(u8) { - interface = 1, - link = 2, - realm = 3, - admin = 4, - site = 5, - organization = 8, - global = 14, - unknown = 0xFF, - }; - - /// Returns the multicast scope of the address. - pub fn scope(self: IPv6) Scope { - if (!self.isMulticast()) return .unknown; - - return switch (self.octets[0] & 0x0F) { - 1 => .interface, - 2 => .link, - 3 => .realm, - 4 => .admin, - 5 => .site, - 8 => .organization, - 14 => .global, - else => .unknown, - }; - } - - /// Implements the `std.fmt.format` API. Specifying 'x' or 's' formats the - /// address lower-cased octets, while specifying 'X' or 'S' formats the - /// address using upper-cased ASCII octets. - /// - /// The default specifier is 'x'. - pub fn format( - self: IPv6, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - _ = opts; - const specifier = comptime &[_]u8{if (layout.len == 0) 'x' else switch (layout[0]) { - 'x', 'X' => |specifier| specifier, - 's' => 'x', - 'S' => 'X', - else => std.fmt.invalidFmtError(layout, self), - }}; - - if (mem.startsWith(u8, &self.octets, &v4_mapped_prefix)) { - return fmt.format(writer, "::{" ++ specifier ++ "}{" ++ specifier ++ "}:{}.{}.{}.{}", .{ - 0xFF, - 0xFF, - self.octets[12], - self.octets[13], - self.octets[14], - self.octets[15], - }); - } - - const zero_span: struct { from: usize, to: usize } = span: { - var i: usize = 0; - while (i < self.octets.len) : (i += 2) { - if (self.octets[i] == 0 and self.octets[i + 1] == 0) break; - } else break :span .{ .from = 0, .to = 0 }; - - const from = i; - - while (i < self.octets.len) : (i += 2) { - if (self.octets[i] != 0 or self.octets[i + 1] != 0) break; - } - - break :span .{ .from = from, .to = i }; - }; - - var i: usize = 0; - while (i != 16) : (i += 2) { - if (zero_span.from != zero_span.to and i == zero_span.from) { - try writer.writeAll("::"); - } else if (i >= zero_span.from and i < zero_span.to) {} else { - if (i != 0 and i != zero_span.to) try writer.writeAll(":"); - - const val = @as(u16, self.octets[i]) << 8 | self.octets[i + 1]; - try fmt.formatIntValue(val, specifier, .{}, writer); - } - } - - if (self.scope_id != no_scope_id and self.scope_id != 0) { - try fmt.format(writer, "%{d}", .{self.scope_id}); - } - } - - /// Set of possible errors that may encountered when parsing an IPv6 - /// address. - pub const ParseError = error{ - MalformedV4Mapping, - InterfaceNotFound, - UnknownScopeId, - } || IPv4.ParseError; - - /// Parses an arbitrary IPv6 address, including link-local addresses. - pub fn parse(buf: []const u8) ParseError!IPv6 { - if (mem.lastIndexOfScalar(u8, buf, '%')) |index| { - const ip_slice = buf[0..index]; - const scope_id_slice = buf[index + 1 ..]; - - if (scope_id_slice.len == 0) return error.UnknownScopeId; - - const scope_id: u32 = switch (scope_id_slice[0]) { - '0'...'9' => fmt.parseInt(u32, scope_id_slice, 10), - else => resolveScopeId(scope_id_slice) catch |err| switch (err) { - error.InterfaceNotFound => return error.InterfaceNotFound, - else => err, - }, - } catch return error.UnknownScopeId; - - return parseWithScopeID(ip_slice, scope_id); - } - - return parseWithScopeID(buf, no_scope_id); - } - - /// Parses an IPv6 address with a pre-specified scope ID. Presumes - /// that the address is not a link-local address. - pub fn parseWithScopeID(buf: []const u8, scope_id: u32) ParseError!IPv6 { - var octets: [16]u8 = undefined; - var octet: u16 = 0; - var tail: [16]u8 = undefined; - - var out: []u8 = &octets; - var index: u8 = 0; - - var saw_any_digits: bool = false; - var abbrv: bool = false; - - for (buf) |c, i| { - switch (c) { - ':' => { - if (!saw_any_digits) { - if (abbrv) return error.UnexpectedToken; - if (i != 0) abbrv = true; - mem.set(u8, out[index..], 0); - out = &tail; - index = 0; - continue; - } - if (index == 14) return error.TooManyOctets; - - out[index] = @truncate(u8, octet >> 8); - index += 1; - out[index] = @truncate(u8, octet); - index += 1; - - octet = 0; - saw_any_digits = false; - }, - '.' => { - if (!abbrv or out[0] != 0xFF and out[1] != 0xFF) { - return error.MalformedV4Mapping; - } - const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1; - const v4 = try IPv4.parse(buf[start_index..]); - octets[10] = 0xFF; - octets[11] = 0xFF; - mem.copy(u8, octets[12..], &v4.octets); - - return IPv6{ .octets = octets, .scope_id = scope_id }; - }, - else => { - saw_any_digits = true; - const digit = fmt.charToDigit(c, 16) catch return error.UnexpectedToken; - octet = math.mul(u16, octet, 16) catch return error.OctetOverflow; - octet = math.add(u16, octet, digit) catch return error.OctetOverflow; - }, - } - } - - if (!saw_any_digits and !abbrv) { - return error.IncompleteAddress; - } - - if (index == 14) { - out[14] = @truncate(u8, octet >> 8); - out[15] = @truncate(u8, octet); - } else { - out[index] = @truncate(u8, octet >> 8); - index += 1; - out[index] = @truncate(u8, octet); - index += 1; - mem.copy(u8, octets[16 - index ..], out[0..index]); - } - - return IPv6{ .octets = octets, .scope_id = scope_id }; - } -}; - -test { - testing.refAllDecls(@This()); -} - -test "ip: convert to and from ipv6" { - try testing.expectFmt("::7f00:1", "{}", .{IPv4.localhost.toIPv6()}); - try testing.expect(!IPv4.localhost.toIPv6().mapsToIPv4()); - - try testing.expectFmt("::ffff:127.0.0.1", "{}", .{IPv4.localhost.mapToIPv6()}); - try testing.expect(IPv4.localhost.mapToIPv6().mapsToIPv4()); - - try testing.expect(IPv4.localhost.toIPv6().toIPv4() == null); - try testing.expectFmt("127.0.0.1", "{?}", .{IPv4.localhost.mapToIPv6().toIPv4()}); -} - -test "ipv4: parse & format" { - const cases = [_][]const u8{ - "0.0.0.0", - "255.255.255.255", - "1.2.3.4", - "123.255.0.91", - "127.0.0.1", - }; - - for (cases) |case| { - try testing.expectFmt(case, "{}", .{try IPv4.parse(case)}); - } -} - -test "ipv6: parse & format" { - const inputs = [_][]const u8{ - "FF01:0:0:0:0:0:0:FB", - "FF01::Fb", - "::1", - "::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "::ffff:123.5.123.5", - }; - - const outputs = [_][]const u8{ - "ff01::fb", - "ff01::fb", - "::1", - "::", - "2001:db8::", - "::1234:5678", - "2001:db8::1234:5678", - "::ffff:123.5.123.5", - }; - - for (inputs) |input, i| { - try testing.expectFmt(outputs[i], "{}", .{try IPv6.parse(input)}); - } -} - -test "ipv6: parse & format addresses with scope ids" { - if (!have_ifnamesize) return error.SkipZigTest; - const iface = if (native_os.tag == .linux) - "lo" - else - "lo0"; - const input = "FF01::FB%" ++ iface; - const output = "ff01::fb%1"; - - const parsed = IPv6.parse(input) catch |err| switch (err) { - error.InterfaceNotFound => return, - else => return err, - }; - - try testing.expectFmt(output, "{}", .{parsed}); -} diff --git a/lib/std/x/os/socket.zig b/lib/std/x/os/socket.zig deleted file mode 100644 index 99782710cb..0000000000 --- a/lib/std/x/os/socket.zig +++ /dev/null @@ -1,320 +0,0 @@ -const std = @import("../../std.zig"); -const builtin = @import("builtin"); -const net = @import("net.zig"); - -const os = std.os; -const fmt = std.fmt; -const mem = std.mem; -const time = std.time; -const meta = std.meta; -const native_os = builtin.os; -const native_endian = builtin.cpu.arch.endian(); - -const Buffer = std.x.os.Buffer; - -const assert = std.debug.assert; - -/// A generic, cross-platform socket abstraction. -pub const Socket = struct { - /// A socket-address pair. - pub const Connection = struct { - socket: Socket, - address: Socket.Address, - - /// Enclose a socket and address into a socket-address pair. - pub fn from(socket: Socket, address: Socket.Address) Socket.Connection { - return .{ .socket = socket, .address = address }; - } - }; - - /// A generic socket address abstraction. It is safe to directly access and modify - /// the fields of a `Socket.Address`. - pub const Address = union(enum) { - pub const Native = struct { - pub const requires_prepended_length = native_os.getVersionRange() == .semver; - pub const Length = if (requires_prepended_length) u8 else [0]u8; - - pub const Family = if (requires_prepended_length) u8 else c_ushort; - - /// POSIX `sockaddr.storage`. The expected size and alignment is specified in IETF RFC 2553. - pub const Storage = extern struct { - pub const expected_size = os.sockaddr.SS_MAXSIZE; - pub const expected_alignment = 8; - - pub const padding_size = expected_size - - mem.alignForward(@sizeOf(Address.Native.Length), expected_alignment) - - mem.alignForward(@sizeOf(Address.Native.Family), expected_alignment); - - len: Address.Native.Length align(expected_alignment) = undefined, - family: Address.Native.Family align(expected_alignment) = undefined, - padding: [padding_size]u8 align(expected_alignment) = undefined, - - comptime { - assert(@sizeOf(Storage) == Storage.expected_size); - assert(@alignOf(Storage) == Storage.expected_alignment); - } - }; - }; - - ipv4: net.IPv4.Address, - ipv6: net.IPv6.Address, - - /// Instantiate a new address with a IPv4 host and port. - pub fn initIPv4(host: net.IPv4, port: u16) Socket.Address { - return .{ .ipv4 = .{ .host = host, .port = port } }; - } - - /// Instantiate a new address with a IPv6 host and port. - pub fn initIPv6(host: net.IPv6, port: u16) Socket.Address { - return .{ .ipv6 = .{ .host = host, .port = port } }; - } - - /// Parses a `sockaddr` into a generic socket address. - pub fn fromNative(address: *align(4) const os.sockaddr) Socket.Address { - switch (address.family) { - os.AF.INET => { - const info = @ptrCast(*const os.sockaddr.in, address); - const host = net.IPv4{ .octets = @bitCast([4]u8, info.addr) }; - const port = mem.bigToNative(u16, info.port); - return Socket.Address.initIPv4(host, port); - }, - os.AF.INET6 => { - const info = @ptrCast(*const os.sockaddr.in6, address); - const host = net.IPv6{ .octets = info.addr, .scope_id = info.scope_id }; - const port = mem.bigToNative(u16, info.port); - return Socket.Address.initIPv6(host, port); - }, - else => unreachable, - } - } - - /// Encodes a generic socket address into an extern union that may be reliably - /// casted into a `sockaddr` which may be passed into socket syscalls. - pub fn toNative(self: Socket.Address) extern union { - ipv4: os.sockaddr.in, - ipv6: os.sockaddr.in6, - } { - return switch (self) { - .ipv4 => |address| .{ - .ipv4 = .{ - .addr = @bitCast(u32, address.host.octets), - .port = mem.nativeToBig(u16, address.port), - }, - }, - .ipv6 => |address| .{ - .ipv6 = .{ - .addr = address.host.octets, - .port = mem.nativeToBig(u16, address.port), - .scope_id = address.host.scope_id, - .flowinfo = 0, - }, - }, - }; - } - - /// Returns the number of bytes that make up the `sockaddr` equivalent to the address. - pub fn getNativeSize(self: Socket.Address) u32 { - return switch (self) { - .ipv4 => @sizeOf(os.sockaddr.in), - .ipv6 => @sizeOf(os.sockaddr.in6), - }; - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: Socket.Address, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - if (layout.len != 0) std.fmt.invalidFmtError(layout, self); - _ = opts; - switch (self) { - .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), - } - } - }; - - /// POSIX `msghdr`. Denotes a destination address, set of buffers, control data, and flags. Ported - /// directly from musl. - pub const Message = if (native_os.isAtLeast(.windows, .vista) != null and native_os.isAtLeast(.windows, .vista).?) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_int = 0, - - buffers: usize = undefined, - buffers_len: c_ulong = undefined, - - control: Buffer = .{ - .ptr = @ptrToInt(@as(?[*]u8, null)), - .len = 0, - }, - flags: c_ulong = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (native_os.tag == .windows) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_int = 0, - - buffers: usize = undefined, - buffers_len: u32 = undefined, - - control: Buffer = .{ - .ptr = @ptrToInt(@as(?[*]u8, null)), - .len = 0, - }, - flags: u32 = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (@sizeOf(usize) > 4 and native_endian == .Big) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - _pad_1: c_int = 0, - buffers_len: c_int = undefined, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - _pad_2: c_int = 0, - control_len: c_uint = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - } - else if (@sizeOf(usize) > 4 and native_endian == .Little) - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - buffers_len: c_int = undefined, - _pad_1: c_int = 0, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - control_len: c_uint = 0, - _pad_2: c_int = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - } - else - extern struct { - name: usize = @ptrToInt(@as(?[*]u8, null)), - name_len: c_uint = 0, - - buffers: usize = undefined, - buffers_len: c_int = undefined, - - control: usize = @ptrToInt(@as(?[*]u8, null)), - control_len: c_uint = 0, - - flags: c_int = 0, - - pub usingnamespace MessageMixin(Message); - }; - - fn MessageMixin(comptime Self: type) type { - return struct { - pub fn fromBuffers(buffers: []const Buffer) Self { - var self: Self = .{}; - self.setBuffers(buffers); - return self; - } - - pub fn setName(self: *Self, name: []const u8) void { - self.name = @ptrToInt(name.ptr); - self.name_len = @intCast(meta.fieldInfo(Self, .name_len).type, name.len); - } - - pub fn setBuffers(self: *Self, buffers: []const Buffer) void { - self.buffers = @ptrToInt(buffers.ptr); - self.buffers_len = @intCast(meta.fieldInfo(Self, .buffers_len).type, buffers.len); - } - - pub fn setControl(self: *Self, control: []const u8) void { - if (native_os.tag == .windows) { - self.control = Buffer.from(control); - } else { - self.control = @ptrToInt(control.ptr); - self.control_len = @intCast(meta.fieldInfo(Self, .control_len).type, control.len); - } - } - - pub fn setFlags(self: *Self, flags: u32) void { - self.flags = @intCast(meta.fieldInfo(Self, .flags).type, flags); - } - - pub fn getName(self: Self) []const u8 { - return @intToPtr([*]const u8, self.name)[0..@intCast(usize, self.name_len)]; - } - - pub fn getBuffers(self: Self) []const Buffer { - return @intToPtr([*]const Buffer, self.buffers)[0..@intCast(usize, self.buffers_len)]; - } - - pub fn getControl(self: Self) []const u8 { - if (native_os.tag == .windows) { - return self.control.into(); - } else { - return @intToPtr([*]const u8, self.control)[0..@intCast(usize, self.control_len)]; - } - } - - pub fn getFlags(self: Self) u32 { - return @intCast(u32, self.flags); - } - }; - } - - /// POSIX `linger`, denoting the linger settings of a socket. - /// - /// Microsoft's documentation and glibc denote the fields to be unsigned - /// short's on Windows, whereas glibc and musl denote the fields to be - /// int's on every other platform. - pub const Linger = extern struct { - pub const Field = switch (native_os.tag) { - .windows => c_ushort, - else => c_int, - }; - - enabled: Field, - timeout_seconds: Field, - - pub fn init(timeout_seconds: ?u16) Socket.Linger { - return .{ - .enabled = @intCast(Socket.Linger.Field, @boolToInt(timeout_seconds != null)), - .timeout_seconds = if (timeout_seconds) |seconds| @intCast(Socket.Linger.Field, seconds) else 0, - }; - } - }; - - /// Possible set of flags to initialize a socket with. - pub const InitFlags = enum { - // Initialize a socket to be non-blocking. - nonblocking, - - // Have a socket close itself on exec syscalls. - close_on_exec, - }; - - /// The underlying handle of a socket. - fd: os.socket_t, - - /// Enclose a socket abstraction over an existing socket file descriptor. - pub fn from(fd: os.socket_t) Socket { - return Socket{ .fd = fd }; - } - - /// Mix in socket syscalls depending on the platform we are compiling against. - pub usingnamespace switch (native_os.tag) { - .windows => @import("socket_windows.zig"), - else => @import("socket_posix.zig"), - }.Mixin(Socket); -}; diff --git a/lib/std/x/os/socket_posix.zig b/lib/std/x/os/socket_posix.zig deleted file mode 100644 index 859075aa20..0000000000 --- a/lib/std/x/os/socket_posix.zig +++ /dev/null @@ -1,275 +0,0 @@ -const std = @import("../../std.zig"); - -const os = std.os; -const mem = std.mem; -const time = std.time; - -pub fn Mixin(comptime Socket: type) type { - return struct { - /// Open a new socket. - pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket { - var raw_flags: u32 = socket_type; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC; - if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK; - return Socket{ .fd = try os.socket(domain, raw_flags, protocol) }; - } - - /// Closes the socket. - pub fn deinit(self: Socket) void { - os.closeSocket(self.fd); - } - - /// Shutdown either the read side, write side, or all side of the socket. - pub fn shutdown(self: Socket, how: os.ShutdownHow) !void { - return os.shutdown(self.fd, how); - } - - /// Binds the socket to an address. - pub fn bind(self: Socket, address: Socket.Address) !void { - return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); - } - - /// Start listening for incoming connections on the socket. - pub fn listen(self: Socket, max_backlog_size: u31) !void { - return os.listen(self.fd, max_backlog_size); - } - - /// Have the socket attempt to the connect to an address. - pub fn connect(self: Socket, address: Socket.Address) !void { - return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the socket. - pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - - var raw_flags: u32 = 0; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC; - if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK; - - const socket = Socket{ .fd = try os.accept(self.fd, @ptrCast(*os.sockaddr, &address), &address_len, raw_flags) }; - const socket_address = Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - - return Socket.Connection.from(socket, socket_address); - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Socket, buf: []u8, flags: u32) !usize { - return os.recv(self.fd, buf, flags); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Socket, buf: []const u8, flags: u32) !usize { - return os.send(self.fd, buf, flags); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize { - while (true) { - const rc = os.system.sendmsg(self.fd, &msg, @intCast(c_int, flags)); - return switch (os.errno(rc)) { - .SUCCESS => return @intCast(usize, rc), - .ACCES => error.AccessDenied, - .AGAIN => error.WouldBlock, - .ALREADY => error.FastOpenAlreadyInProgress, - .BADF => unreachable, // always a race condition - .CONNRESET => error.ConnectionResetByPeer, - .DESTADDRREQ => unreachable, // The socket is not connection-mode, and no peer address is set. - .FAULT => unreachable, // An invalid user space address was specified for an argument. - .INTR => continue, - .INVAL => unreachable, // Invalid argument passed. - .ISCONN => unreachable, // connection-mode socket was connected already but a recipient was specified - .MSGSIZE => error.MessageTooBig, - .NOBUFS => error.SystemResources, - .NOMEM => error.SystemResources, - .NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket. - .OPNOTSUPP => unreachable, // Some bit in the flags argument is inappropriate for the socket type. - .PIPE => error.BrokenPipe, - .AFNOSUPPORT => error.AddressFamilyNotSupported, - .LOOP => error.SymLinkLoop, - .NAMETOOLONG => error.NameTooLong, - .NOENT => error.FileNotFound, - .NOTDIR => error.NotDir, - .HOSTUNREACH => error.NetworkUnreachable, - .NETUNREACH => error.NetworkUnreachable, - .NOTCONN => error.SocketNotConnected, - .NETDOWN => error.NetworkSubsystemFailed, - else => |err| os.unexpectedErrno(err), - }; - } - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize { - while (true) { - const rc = os.system.recvmsg(self.fd, msg, @intCast(c_int, flags)); - return switch (os.errno(rc)) { - .SUCCESS => @intCast(usize, rc), - .BADF => unreachable, // always a race condition - .FAULT => unreachable, - .INVAL => unreachable, - .NOTCONN => unreachable, - .NOTSOCK => unreachable, - .INTR => continue, - .AGAIN => error.WouldBlock, - .NOMEM => error.SystemResources, - .CONNREFUSED => error.ConnectionRefused, - .CONNRESET => error.ConnectionResetByPeer, - else => |err| os.unexpectedErrno(err), - }; - } - } - - /// Query the address that the socket is locally bounded to. - pub fn getLocalAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - try os.getsockname(self.fd, @ptrCast(*os.sockaddr, &address), &address_len); - return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: u32 = @sizeOf(Socket.Address.Native.Storage); - try os.getpeername(self.fd, @ptrCast(*os.sockaddr, &address), &address_len); - return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address)); - } - - /// Query and return the latest cached error on the socket. - pub fn getError(self: Socket) !void { - return os.getsockoptError(self.fd); - } - - /// Query the read buffer size of the socket. - pub fn getReadBufferSize(self: Socket) !u32 { - var value: u32 = undefined; - var value_len: u32 = @sizeOf(u32); - - const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&value), &value_len); - return switch (os.errno(rc)) { - .SUCCESS => value, - .BADF => error.BadFileDescriptor, - .FAULT => error.InvalidAddressSpace, - .INVAL => error.InvalidSocketOption, - .NOPROTOOPT => error.UnknownSocketOption, - .NOTSOCK => error.NotASocket, - else => |err| os.unexpectedErrno(err), - }; - } - - /// Query the write buffer size of the socket. - pub fn getWriteBufferSize(self: Socket) !u32 { - var value: u32 = undefined; - var value_len: u32 = @sizeOf(u32); - - const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&value), &value_len); - return switch (os.errno(rc)) { - .SUCCESS => value, - .BADF => error.BadFileDescriptor, - .FAULT => error.InvalidAddressSpace, - .INVAL => error.InvalidSocketOption, - .NOPROTOOPT => error.UnknownSocketOption, - .NOTSOCK => error.NotASocket, - else => |err| os.unexpectedErrno(err), - }; - } - - /// Set a socket option. - pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void { - return os.setsockopt(self.fd, level, code, value); - } - - /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void { - if (@hasDecl(os.SO, "LINGER")) { - const settings = Socket.Linger.init(timeout_seconds); - return self.setOption(os.SOL.SOCKET, os.SO.LINGER, mem.asBytes(&settings)); - } - - return error.UnsupportedSocketOption; - } - - /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive - /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if - /// the host does not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "KEEPALIVE")) { - return self.setOption(os.SOL.SOCKET, os.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "REUSEADDR")) { - return self.setOption(os.SOL.SOCKET, os.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - pub fn setReusePort(self: Socket, enabled: bool) !void { - if (@hasDecl(os.SO, "REUSEPORT")) { - return self.setOption(os.SOL.SOCKET, os.SO.REUSEPORT, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Socket, size: u32) !void { - return self.setOption(os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&size)); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Socket, size: u32) !void { - return self.setOption(os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&size)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Socket, milliseconds: usize) !void { - const timeout = os.timeval{ - .tv_sec = @intCast(i32, milliseconds / time.ms_per_s), - .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms), - }; - - return self.setOption(os.SOL.SOCKET, os.SO.SNDTIMEO, mem.asBytes(&timeout)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Socket, milliseconds: usize) !void { - const timeout = os.timeval{ - .tv_sec = @intCast(i32, milliseconds / time.ms_per_s), - .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms), - }; - - return self.setOption(os.SOL.SOCKET, os.SO.RCVTIMEO, mem.asBytes(&timeout)); - } - }; -} diff --git a/lib/std/x/os/socket_windows.zig b/lib/std/x/os/socket_windows.zig deleted file mode 100644 index 43b047dd10..0000000000 --- a/lib/std/x/os/socket_windows.zig +++ /dev/null @@ -1,458 +0,0 @@ -const std = @import("../../std.zig"); -const net = @import("net.zig"); - -const os = std.os; -const mem = std.mem; - -const windows = std.os.windows; -const ws2_32 = windows.ws2_32; - -pub fn Mixin(comptime Socket: type) type { - return struct { - /// Open a new socket. - pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket { - var raw_flags: u32 = ws2_32.WSA_FLAG_OVERLAPPED; - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.close_on_exec)) raw_flags |= ws2_32.WSA_FLAG_NO_HANDLE_INHERIT; - - const fd = ws2_32.WSASocketW( - @intCast(i32, domain), - @intCast(i32, socket_type), - @intCast(i32, protocol), - null, - 0, - raw_flags, - ); - if (fd == ws2_32.INVALID_SOCKET) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => { - _ = try windows.WSAStartup(2, 2); - return init(domain, socket_type, protocol, flags); - }, - .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported, - .WSAEMFILE => error.ProcessFdQuotaExceeded, - .WSAENOBUFS => error.SystemResources, - .WSAEPROTONOSUPPORT => error.ProtocolNotSupported, - else => |err| windows.unexpectedWSAError(err), - }; - } - - if (set.contains(.nonblocking)) { - var enabled: c_ulong = 1; - const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled); - if (rc == ws2_32.SOCKET_ERROR) { - return windows.unexpectedWSAError(ws2_32.WSAGetLastError()); - } - } - - return Socket{ .fd = fd }; - } - - /// Closes the socket. - pub fn deinit(self: Socket) void { - _ = ws2_32.closesocket(self.fd); - } - - /// Shutdown either the read side, write side, or all side of the socket. - pub fn shutdown(self: Socket, how: os.ShutdownHow) !void { - const rc = ws2_32.shutdown(self.fd, switch (how) { - .recv => ws2_32.SD_RECEIVE, - .send => ws2_32.SD_SEND, - .both => ws2_32.SD_BOTH, - }); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => return error.ConnectionAborted, - .WSAECONNRESET => return error.ConnectionResetByPeer, - .WSAEINPROGRESS => return error.BlockingOperationInProgress, - .WSAEINVAL => unreachable, - .WSAENETDOWN => return error.NetworkSubsystemFailed, - .WSAENOTCONN => return error.SocketNotConnected, - .WSAENOTSOCK => unreachable, - .WSANOTINITIALISED => unreachable, - else => |err| return windows.unexpectedWSAError(err), - }; - } - } - - /// Binds the socket to an address. - pub fn bind(self: Socket, address: Socket.Address) !void { - const rc = ws2_32.bind(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize())); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAEACCES => error.AccessDenied, - .WSAEADDRINUSE => error.AddressInUse, - .WSAEADDRNOTAVAIL => error.AddressNotAvailable, - .WSAEFAULT => error.BadAddress, - .WSAEINPROGRESS => error.WouldBlock, - .WSAEINVAL => error.AlreadyBound, - .WSAENOBUFS => error.NoEphemeralPortsAvailable, - .WSAENOTSOCK => error.NotASocket, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Start listening for incoming connections on the socket. - pub fn listen(self: Socket, max_backlog_size: u31) !void { - const rc = ws2_32.listen(self.fd, max_backlog_size); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAEADDRINUSE => error.AddressInUse, - .WSAEISCONN => error.AlreadyConnected, - .WSAEINVAL => error.SocketNotBound, - .WSAEMFILE, .WSAENOBUFS => error.SystemResources, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAEINPROGRESS => error.WouldBlock, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Have the socket attempt to the connect to an address. - pub fn connect(self: Socket, address: Socket.Address) !void { - const rc = ws2_32.connect(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize())); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAEADDRINUSE => error.AddressInUse, - .WSAEADDRNOTAVAIL => error.AddressNotAvailable, - .WSAECONNREFUSED => error.ConnectionRefused, - .WSAETIMEDOUT => error.ConnectionTimedOut, - .WSAEFAULT => error.BadAddress, - .WSAEINVAL => error.ListeningSocket, - .WSAEISCONN => error.AlreadyConnected, - .WSAENOTSOCK => error.NotASocket, - .WSAEACCES => error.BroadcastNotEnabled, - .WSAENOBUFS => error.SystemResources, - .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported, - .WSAEINPROGRESS, .WSAEWOULDBLOCK => error.WouldBlock, - .WSAEHOSTUNREACH, .WSAENETUNREACH => error.NetworkUnreachable, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the socket. - pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const fd = ws2_32.accept(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (fd == ws2_32.INVALID_SOCKET) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => unreachable, - .WSAEINVAL => error.SocketNotListening, - .WSAEMFILE => error.ProcessFdQuotaExceeded, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOBUFS => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAEWOULDBLOCK => error.WouldBlock, - else => |err| windows.unexpectedWSAError(err), - }; - } - - const socket = Socket.from(fd); - errdefer socket.deinit(); - - const socket_address = Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - - const set = std.EnumSet(Socket.InitFlags).init(flags); - if (set.contains(.nonblocking)) { - var enabled: c_ulong = 1; - const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled); - if (rc == ws2_32.SOCKET_ERROR) { - return windows.unexpectedWSAError(ws2_32.WSAGetLastError()); - } - } - - return Socket.Connection.from(socket, socket_address); - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn read(self: Socket, buf: []u8, flags: u32) !usize { - var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = buf.ptr }}; - var num_bytes: u32 = undefined; - var flags_ = flags; - - const rc = ws2_32.WSARecv(self.fd, bufs, 1, &num_bytes, &flags_, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEDISCON => error.ConnectionClosedByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn write(self: Socket, buf: []const u8, flags: u32) !usize { - var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = @intToPtr([*]u8, @ptrToInt(buf.ptr)) }}; - var num_bytes: u32 = undefined; - - const rc = ws2_32.WSASend(self.fd, bufs, 1, &num_bytes, flags, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOBUFS => error.BufferDeadlock, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize { - const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSASENDMSG, self.fd, ws2_32.WSAID_WSASENDMSG); - - var num_bytes: u32 = undefined; - - const rc = call(self.fd, &msg, flags, &num_bytes, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOBUFS => error.BufferDeadlock, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Read multiple I/O vectors with a prepended message header from the socket - /// with a set of flags specified. It returns the number of bytes that were - /// read into the buffer provided. - pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize { - _ = flags; - const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSARECVMSG, self.fd, ws2_32.WSAID_WSARECVMSG); - - var num_bytes: u32 = undefined; - - const rc = call(self.fd, msg, &num_bytes, null, null); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSAECONNABORTED => error.ConnectionAborted, - .WSAECONNRESET => error.ConnectionResetByPeer, - .WSAEDISCON => error.ConnectionClosedByPeer, - .WSAEFAULT => error.BadBuffer, - .WSAEINPROGRESS, - .WSAEWOULDBLOCK, - .WSA_IO_PENDING, - .WSAETIMEDOUT, - => error.WouldBlock, - .WSAEINTR => error.Cancelled, - .WSAEINVAL => error.SocketNotBound, - .WSAEMSGSIZE => error.MessageTooLarge, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENETRESET => error.NetworkReset, - .WSAENOTCONN => error.SocketNotConnected, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEOPNOTSUPP => error.OperationNotSupported, - .WSAESHUTDOWN => error.AlreadyShutdown, - .WSA_OPERATION_ABORTED => error.OperationAborted, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return @intCast(usize, num_bytes); - } - - /// Query the address that the socket is locally bounded to. - pub fn getLocalAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const rc = ws2_32.getsockname(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAEFAULT => unreachable, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEINVAL => error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - } - - /// Query the address that the socket is connected to. - pub fn getRemoteAddress(self: Socket) !Socket.Address { - var address: Socket.Address.Native.Storage = undefined; - var address_len: c_int = @sizeOf(Socket.Address.Native.Storage); - - const rc = ws2_32.getpeername(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAEFAULT => unreachable, - .WSAENETDOWN => error.NetworkSubsystemFailed, - .WSAENOTSOCK => error.FileDescriptorNotASocket, - .WSAEINVAL => error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - - return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address)); - } - - /// Query and return the latest cached error on the socket. - pub fn getError(self: Socket) !void { - _ = self; - return {}; - } - - /// Query the read buffer size of the socket. - pub fn getReadBufferSize(self: Socket) !u32 { - _ = self; - return 0; - } - - /// Query the write buffer size of the socket. - pub fn getWriteBufferSize(self: Socket) !u32 { - _ = self; - return 0; - } - - /// Set a socket option. - pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void { - const rc = ws2_32.setsockopt(self.fd, @intCast(i32, level), @intCast(i32, code), value.ptr, @intCast(i32, value.len)); - if (rc == ws2_32.SOCKET_ERROR) { - return switch (ws2_32.WSAGetLastError()) { - .WSANOTINITIALISED => unreachable, - .WSAENETDOWN => return error.NetworkSubsystemFailed, - .WSAEFAULT => unreachable, - .WSAENOTSOCK => return error.FileDescriptorNotASocket, - .WSAEINVAL => return error.SocketNotBound, - else => |err| windows.unexpectedWSAError(err), - }; - } - } - - /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully - /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption` - /// if the host does not support the option for a socket to linger around up until a timeout specified in - /// seconds. - pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void { - const settings = Socket.Linger.init(timeout_seconds); - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.LINGER, mem.asBytes(&settings)); - } - - /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive - /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if - /// the host does not support periodically sending keep-alive messages on connection-oriented sockets. - pub fn setKeepAlive(self: Socket, enabled: bool) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Socket, enabled: bool) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - /// - /// TODO: verify if this truly mimicks SO.REUSEPORT behavior, or if SO.REUSE_UNICASTPORT provides the correct behavior - pub fn setReusePort(self: Socket, enabled: bool) !void { - try self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.BROADCAST, mem.asBytes(&@as(u32, @boolToInt(enabled)))); - try self.setReuseAddress(enabled); - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Socket, size: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDBUF, mem.asBytes(&size)); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Socket, size: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVBUF, mem.asBytes(&size)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Socket, milliseconds: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDTIMEO, mem.asBytes(&milliseconds)); - } - - /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is - /// set on a non-blocking socket. - /// - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Socket, milliseconds: u32) !void { - return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVTIMEO, mem.asBytes(&milliseconds)); - } - }; -} From ba44513c2fe363b55b2c534be98179286b832b7e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 12 Dec 2022 21:18:56 -0700 Subject: [PATCH 03/59] std.http reorg; introduce std.crypto.Tls TLS is capable of sending a Client Hello --- lib/std/crypto.zig | 2 + lib/std/crypto/Tls.zig | 342 ++++++++++++++++++++++++++++++++++++++++ lib/std/http.zig | 251 ++++++++++++++++++++++++++++- lib/std/http/Client.zig | 114 ++++++++++++++ lib/std/http/method.zig | 65 -------- lib/std/http/status.zig | 182 --------------------- lib/std/net.zig | 7 + 7 files changed, 712 insertions(+), 251 deletions(-) create mode 100644 lib/std/crypto/Tls.zig create mode 100644 lib/std/http/Client.zig delete mode 100644 lib/std/http/method.zig delete mode 100644 lib/std/http/status.zig diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 8aaf305143..7b4a116d35 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -176,6 +176,8 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); +pub const Tls = @import("crypto/Tls.zig"); + test { _ = aead.aegis.Aegis128L; _ = aead.aegis.Aegis256; diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig new file mode 100644 index 0000000000..ab54b42b70 --- /dev/null +++ b/lib/std/crypto/Tls.zig @@ -0,0 +1,342 @@ +const std = @import("../std.zig"); +const Tls = @This(); +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const assert = std.debug.assert; + +state: State = .start, +x25519_priv_key: [32]u8 = undefined, +x25519_pub_key: [32]u8 = undefined, + +const State = enum { + /// In this state, all fields are undefined except state. + start, + sent_hello, +}; + +const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +const HandshakeType = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + certificate_request = 13, + certificate_verify = 15, + finished = 20, + key_update = 24, + message_hash = 254, +}; + +const ExtensionType = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, +}; + +const AlertLevel = enum(u8) { + warning = 1, + fatal = 2, + _, +}; + +const AlertDescription = enum(u8) { + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, +}; + +const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + _, +}; + +// Plaintext: +// * type: ContentType +// * legacy_record_version: u16 = 0x0303, +// * length: u16, +// - The length (in bytes) of the following TLSPlaintext.fragment. The +// length MUST NOT exceed 2^14 bytes. +// * fragment: opaque +// - the data being transmitted + +// Handshake: +// * type: HandshakeType +// * length: u24 +// * data: opaque + +const CipherSuite = enum(u16) { + TLS_AES_128_GCM_SHA256 = 0x1301, + TLS_AES_256_GCM_SHA384 = 0x1302, + TLS_CHACHA20_POLY1305_SHA256 = 0x1303, + TLS_AES_128_CCM_SHA256 = 0x1304, + TLS_AES_128_CCM_8_SHA256 = 0x1305, +}; + +const cipher_suites = blk: { + const fields = @typeInfo(CipherSuite).Enum.fields; + var result: [(fields.len + 1) * 2]u8 = undefined; + mem.writeIntBig(u16, result[0..2], result.len - 2); + for (fields) |field, i| { + const int = @enumToInt(@field(CipherSuite, field.name)); + result[(i + 1) * 2] = @truncate(u8, int >> 8); + result[(i + 1) * 2 + 1] = @truncate(u8, int); + } + break :blk result; +}; + +pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { + assert(tls.state == .start); + crypto.random.bytes(&tls.x25519_priv_key); + tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key); + + // random (u32) + var rand_buf: [32]u8 = undefined; + crypto.random.bytes(&rand_buf); + + const extensions_header = [_]u8{ + // Extensions byte length + undefined, undefined, + + // Extension: supported_versions (only TLS 1.3) + 0, 43, // ExtensionType.supported_versions + 0x00, 0x05, // byte length of this extension payload + 0x04, // byte length of supported versions + 0x03, 0x04, // TLS 1.3 + 0x03, 0x03, // TLS 1.2 + + // Extension: signature_algorithms + 0, 13, // ExtensionType.signature_algorithms + 0x00, 0x22, // byte length of this extension payload + 0x00, 0x20, // byte length of signature algorithms list + 0x04, 0x01, // rsa_pkcs1_sha256 + 0x05, 0x01, // rsa_pkcs1_sha384 + 0x06, 0x01, // rsa_pkcs1_sha512 + 0x04, 0x03, // ecdsa_secp256r1_sha256 + 0x05, 0x03, // ecdsa_secp384r1_sha384 + 0x06, 0x03, // ecdsa_secp521r1_sha512 + 0x08, 0x04, // rsa_pss_rsae_sha256 + 0x08, 0x05, // rsa_pss_rsae_sha384 + 0x08, 0x06, // rsa_pss_rsae_sha512 + 0x08, 0x07, // ed25519 + 0x08, 0x08, // ed448 + 0x08, 0x09, // rsa_pss_pss_sha256 + 0x08, 0x0a, // rsa_pss_pss_sha384 + 0x08, 0x0b, // rsa_pss_pss_sha512 + 0x02, 0x01, // rsa_pkcs1_sha1 + 0x02, 0x03, // ecdsa_sha1 + + // Extension: supported_groups + 0, 10, // ExtensionType.supported_groups + 0x00, 0x0c, // byte length of this extension payload + 0x00, 0x0a, // byte length of supported groups list + 0x00, 0x17, // secp256r1 + 0x00, 0x18, // secp384r1 + 0x00, 0x19, // secp521r1 + 0x00, 0x1D, // x25519 + 0x00, 0x1E, // x448 + + // Extension: key_share + 0, 51, // ExtensionType.key_share + 0x00, 38, // byte length of this extension payload + 0x00, 36, // byte length of client_shares + 0x00, 0x1D, // NamedGroup.x25519 + 0x00, 32, // byte length of key_exchange + } ++ tls.x25519_pub_key ++ [_]u8{ + + // Extension: server_name + 0, 0, // ExtensionType.server_name + undefined, undefined, // byte length of this extension payload + undefined, undefined, // server_name_list byte count + 0x00, // name_type + undefined, undefined, // host name len + }; + + var hello_header = [_]u8{ + // Plaintext header + @enumToInt(ContentType.handshake), + 0x03, 0x01, // legacy_record_version + undefined, undefined, // Plaintext fragment length (u16) + + // Handshake header + @enumToInt(HandshakeType.client_hello), + undefined, undefined, undefined, // handshake length (u24) + + // ClientHello + 0x03, 0x03, // legacy_version + } ++ rand_buf ++ [1]u8{0} ++ cipher_suites ++ [_]u8{ + 0x01, 0x00, // legacy_compression_methods + } ++ extensions_header; + + mem.writeIntBig(u16, hello_header[3..][0..2], @intCast(u16, hello_header.len - 5 + host.len)); + mem.writeIntBig(u24, hello_header[6..][0..3], @intCast(u24, hello_header.len - 9 + host.len)); + mem.writeIntBig( + u16, + hello_header[hello_header.len - extensions_header.len ..][0..2], + @intCast(u16, extensions_header.len - 2 + host.len), + ); + mem.writeIntBig(u16, hello_header[hello_header.len - 7 ..][0..2], @intCast(u16, 5 + host.len)); + mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len)); + mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len)); + + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &hello_header, + .iov_len = hello_header.len, + }, + .{ + .iov_base = host.ptr, + .iov_len = host.len, + }, + }; + try stream.writevAll(&iovecs); + + { + var buf: [1000]u8 = undefined; + const amt = try stream.read(&buf); + const resp = buf[0..amt]; + const ct = @intToEnum(ContentType, resp[0]); + if (ct == .alert) { + //const prot_ver = @bitCast(u16, resp[1..][0..2].*); + const len = std.mem.readIntBig(u16, resp[3..][0..2]); + const alert = resp[5..][0..len]; + const level = @intToEnum(AlertLevel, alert[0]); + const desc = @intToEnum(AlertDescription, alert[1]); + std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + std.process.exit(1); + } else { + std.debug.print("content_type: {s}\n", .{@tagName(ct)}); + std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) }); + } + } + + tls.state = .sent_hello; +} + +pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void { + _ = tls; + _ = stream; + _ = buffer; + @panic("hold on a minute, we didn't finish implementing the handshake yet"); +} diff --git a/lib/std/http.zig b/lib/std/http.zig index 8da6968403..cf92b462b8 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,8 +1,251 @@ +pub const Client = @import("http/Client.zig"); + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods +/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton +/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH +pub const Method = enum { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, + + /// Returns true if a request of this method is allowed to have a body + /// Actual behavior from servers may vary and should still be checked + pub fn requestHasBody(self: Method) bool { + return switch (self) { + .POST, .PUT, .PATCH => true, + .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, + }; + } + + /// Returns true if a response to this method is allowed to have a body + /// Actual behavior from clients may vary and should still be checked + pub fn responseHasBody(self: Method) bool { + return switch (self) { + .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, + .HEAD, .PUT, .TRACE => false, + }; + } + + /// An HTTP method is safe if it doesn't alter the state of the server. + /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + pub fn safe(self: Method) bool { + return switch (self) { + .GET, .HEAD, .OPTIONS, .TRACE => true, + .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, + }; + } + + /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. + /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 + pub fn idempotent(self: Method) bool { + return switch (self) { + .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, + .CONNECT, .POST, .PATCH => false, + }; + } + + /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. + /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable + /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 + pub fn cacheable(self: Method) bool { + return switch (self) { + .GET, .HEAD => true, + .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, + }; + } +}; + +/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status +pub const Status = enum(u10) { + @"continue" = 100, // RFC7231, Section 6.2.1 + switching_protocols = 101, // RFC7231, Section 6.2.2 + processing = 102, // RFC2518 + early_hints = 103, // RFC8297 + + ok = 200, // RFC7231, Section 6.3.1 + created = 201, // RFC7231, Section 6.3.2 + accepted = 202, // RFC7231, Section 6.3.3 + non_authoritative_info = 203, // RFC7231, Section 6.3.4 + no_content = 204, // RFC7231, Section 6.3.5 + reset_content = 205, // RFC7231, Section 6.3.6 + partial_content = 206, // RFC7233, Section 4.1 + multi_status = 207, // RFC4918 + already_reported = 208, // RFC5842 + im_used = 226, // RFC3229 + + multiple_choice = 300, // RFC7231, Section 6.4.1 + moved_permanently = 301, // RFC7231, Section 6.4.2 + found = 302, // RFC7231, Section 6.4.3 + see_other = 303, // RFC7231, Section 6.4.4 + not_modified = 304, // RFC7232, Section 4.1 + use_proxy = 305, // RFC7231, Section 6.4.5 + temporary_redirect = 307, // RFC7231, Section 6.4.7 + permanent_redirect = 308, // RFC7538 + + bad_request = 400, // RFC7231, Section 6.5.1 + unauthorized = 401, // RFC7235, Section 3.1 + payment_required = 402, // RFC7231, Section 6.5.2 + forbidden = 403, // RFC7231, Section 6.5.3 + not_found = 404, // RFC7231, Section 6.5.4 + method_not_allowed = 405, // RFC7231, Section 6.5.5 + not_acceptable = 406, // RFC7231, Section 6.5.6 + proxy_auth_required = 407, // RFC7235, Section 3.2 + request_timeout = 408, // RFC7231, Section 6.5.7 + conflict = 409, // RFC7231, Section 6.5.8 + gone = 410, // RFC7231, Section 6.5.9 + length_required = 411, // RFC7231, Section 6.5.10 + precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 + payload_too_large = 413, // RFC7231, Section 6.5.11 + uri_too_long = 414, // RFC7231, Section 6.5.12 + unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 + range_not_satisfiable = 416, // RFC7233, Section 4.4 + expectation_failed = 417, // RFC7231, Section 6.5.14 + teapot = 418, // RFC 7168, 2.3.3 + misdirected_request = 421, // RFC7540, Section 9.1.2 + unprocessable_entity = 422, // RFC4918 + locked = 423, // RFC4918 + failed_dependency = 424, // RFC4918 + too_early = 425, // RFC8470 + upgrade_required = 426, // RFC7231, Section 6.5.15 + precondition_required = 428, // RFC6585 + too_many_requests = 429, // RFC6585 + header_fields_too_large = 431, // RFC6585 + unavailable_for_legal_reasons = 451, // RFC7725 + + internal_server_error = 500, // RFC7231, Section 6.6.1 + not_implemented = 501, // RFC7231, Section 6.6.2 + bad_gateway = 502, // RFC7231, Section 6.6.3 + service_unavailable = 503, // RFC7231, Section 6.6.4 + gateway_timeout = 504, // RFC7231, Section 6.6.5 + http_version_not_supported = 505, // RFC7231, Section 6.6.6 + variant_also_negotiates = 506, // RFC2295 + insufficient_storage = 507, // RFC4918 + loop_detected = 508, // RFC5842 + not_extended = 510, // RFC2774 + network_authentication_required = 511, // RFC6585 + + _, + + pub fn phrase(self: Status) ?[]const u8 { + return switch (self) { + // 1xx statuses + .@"continue" => "Continue", + .switching_protocols => "Switching Protocols", + .processing => "Processing", + .early_hints => "Early Hints", + + // 2xx statuses + .ok => "OK", + .created => "Created", + .accepted => "Accepted", + .non_authoritative_info => "Non-Authoritative Information", + .no_content => "No Content", + .reset_content => "Reset Content", + .partial_content => "Partial Content", + .multi_status => "Multi-Status", + .already_reported => "Already Reported", + .im_used => "IM Used", + + // 3xx statuses + .multiple_choice => "Multiple Choice", + .moved_permanently => "Moved Permanently", + .found => "Found", + .see_other => "See Other", + .not_modified => "Not Modified", + .use_proxy => "Use Proxy", + .temporary_redirect => "Temporary Redirect", + .permanent_redirect => "Permanent Redirect", + + // 4xx statuses + .bad_request => "Bad Request", + .unauthorized => "Unauthorized", + .payment_required => "Payment Required", + .forbidden => "Forbidden", + .not_found => "Not Found", + .method_not_allowed => "Method Not Allowed", + .not_acceptable => "Not Acceptable", + .proxy_auth_required => "Proxy Authentication Required", + .request_timeout => "Request Timeout", + .conflict => "Conflict", + .gone => "Gone", + .length_required => "Length Required", + .precondition_failed => "Precondition Failed", + .payload_too_large => "Payload Too Large", + .uri_too_long => "URI Too Long", + .unsupported_media_type => "Unsupported Media Type", + .range_not_satisfiable => "Range Not Satisfiable", + .expectation_failed => "Expectation Failed", + .teapot => "I'm a teapot", + .misdirected_request => "Misdirected Request", + .unprocessable_entity => "Unprocessable Entity", + .locked => "Locked", + .failed_dependency => "Failed Dependency", + .too_early => "Too Early", + .upgrade_required => "Upgrade Required", + .precondition_required => "Precondition Required", + .too_many_requests => "Too Many Requests", + .header_fields_too_large => "Request Header Fields Too Large", + .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", + + // 5xx statuses + .internal_server_error => "Internal Server Error", + .not_implemented => "Not Implemented", + .bad_gateway => "Bad Gateway", + .service_unavailable => "Service Unavailable", + .gateway_timeout => "Gateway Timeout", + .http_version_not_supported => "HTTP Version Not Supported", + .variant_also_negotiates => "Variant Also Negotiates", + .insufficient_storage => "Insufficient Storage", + .loop_detected => "Loop Detected", + .not_extended => "Not Extended", + .network_authentication_required => "Network Authentication Required", + + else => return null, + }; + } + + pub const Class = enum { + informational, + success, + redirect, + client_error, + server_error, + }; + + pub fn class(self: Status) ?Class { + return switch (@enumToInt(self)) { + 100...199 => .informational, + 200...299 => .success, + 300...399 => .redirect, + 400...499 => .client_error, + 500...599 => .server_error, + else => null, + }; + } + + test { + try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); + try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); + } + + test { + try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class()); + try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class()); + } +}; + const std = @import("std.zig"); -pub const Method = @import("http/method.zig").Method; -pub const Status = @import("http/status.zig").Status; - test { - std.testing.refAllDecls(@This()); + _ = Client; + _ = Method; + _ = Status; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig new file mode 100644 index 0000000000..80904d765c --- /dev/null +++ b/lib/std/http/Client.zig @@ -0,0 +1,114 @@ +const std = @import("../std.zig"); +const assert = std.debug.assert; +const http = std.http; +const net = std.net; +const Client = @This(); + +allocator: std.mem.Allocator, +headers: std.ArrayListUnmanaged(u8) = .{}, +active_requests: usize = 0, + +pub const Request = struct { + client: *Client, + stream: net.Stream, + headers: std.ArrayListUnmanaged(u8) = .{}, + tls: std.crypto.Tls = .{}, + protocol: Protocol, + + pub const Protocol = enum { http, https }; + + pub const Options = struct { + family: Family = .any, + protocol: Protocol = .https, + method: http.Method = .GET, + host: []const u8 = "localhost", + path: []const u8 = "/", + port: u16 = 0, + + pub const Family = enum { any, ip4, ip6 }; + }; + + pub fn deinit(req: *Request) void { + req.client.active_requests -= 1; + req.headers.deinit(req.client.allocator); + req.* = undefined; + } + + pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void { + const gpa = req.client.allocator; + // Ensure an extra +2 for the \r\n in end() + try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6); + req.headers.appendSliceAssumeCapacity(name); + req.headers.appendSliceAssumeCapacity(": "); + req.headers.appendSliceAssumeCapacity(value); + req.headers.appendSliceAssumeCapacity("\r\n"); + } + + pub fn end(req: *Request) !void { + req.headers.appendSliceAssumeCapacity("\r\n"); + switch (req.protocol) { + .http => { + try req.stream.writeAll(req.headers.items); + }, + .https => { + try req.tls.writeAll(req.stream, req.headers.items); + }, + } + } +}; + +pub fn deinit(client: *Client) void { + assert(client.active_requests == 0); + client.headers.denit(client.allocator); + client.* = undefined; +} + +pub fn request(client: *Client, options: Request.Options) !Request { + var req: Request = .{ + .client = client, + .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), + .protocol = options.protocol, + }; + errdefer req.deinit(); + + switch (options.protocol) { + .http => {}, + .https => { + try req.tls.init(req.stream, options.host); + }, + } + + try req.headers.ensureUnusedCapacity( + client.allocator, + @tagName(options.method).len + + 1 + + options.path.len + + " HTTP/2\r\nHost: ".len + + options.host.len + + "\r\nUpgrade-Insecure-Requests: 1\r\n".len + + client.headers.items.len + + 2, // for the \r\n at the end of headers + ); + req.headers.appendSliceAssumeCapacity(@tagName(options.method)); + req.headers.appendSliceAssumeCapacity(" "); + req.headers.appendSliceAssumeCapacity(options.path); + req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: "); + req.headers.appendSliceAssumeCapacity(options.host); + switch (options.protocol) { + .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), + .http => req.headers.appendSliceAssumeCapacity("\r\n"), + } + req.headers.appendSliceAssumeCapacity(client.headers.items); + + client.active_requests += 1; + return req; +} + +pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void { + const gpa = client.allocator; + try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4); + client.headers.appendSliceAssumeCapacity(name); + client.headers.appendSliceAssumeCapacity(": "); + client.headers.appendSliceAssumeCapacity(value); + client.headers.appendSliceAssumeCapacity("\r\n"); +} diff --git a/lib/std/http/method.zig b/lib/std/http/method.zig deleted file mode 100644 index c118ca9a47..0000000000 --- a/lib/std/http/method.zig +++ /dev/null @@ -1,65 +0,0 @@ -//! HTTP Methods -//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods - -// Style guide is violated here so that @tagName can be used effectively -/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton -/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum { - GET, - HEAD, - POST, - PUT, - DELETE, - CONNECT, - OPTIONS, - TRACE, - PATCH, - - /// Returns true if a request of this method is allowed to have a body - /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { - .POST, .PUT, .PATCH => true, - .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - }; - } - - /// Returns true if a response to this method is allowed to have a body - /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { - .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, - .HEAD, .PUT, .TRACE => false, - }; - } - - /// An HTTP method is safe if it doesn't alter the state of the server. - /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { - .GET, .HEAD, .OPTIONS, .TRACE => true, - .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - }; - } - - /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state. - /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { - .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, - .CONNECT, .POST, .PATCH => false, - }; - } - - /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server. - /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable - /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { - .GET, .HEAD => true, - .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - }; - } -}; diff --git a/lib/std/http/status.zig b/lib/std/http/status.zig deleted file mode 100644 index 91738e0533..0000000000 --- a/lib/std/http/status.zig +++ /dev/null @@ -1,182 +0,0 @@ -//! HTTP Status -//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Status - -const std = @import("../std.zig"); - -pub const Status = enum(u10) { - @"continue" = 100, // RFC7231, Section 6.2.1 - switching_protocols = 101, // RFC7231, Section 6.2.2 - processing = 102, // RFC2518 - early_hints = 103, // RFC8297 - - ok = 200, // RFC7231, Section 6.3.1 - created = 201, // RFC7231, Section 6.3.2 - accepted = 202, // RFC7231, Section 6.3.3 - non_authoritative_info = 203, // RFC7231, Section 6.3.4 - no_content = 204, // RFC7231, Section 6.3.5 - reset_content = 205, // RFC7231, Section 6.3.6 - partial_content = 206, // RFC7233, Section 4.1 - multi_status = 207, // RFC4918 - already_reported = 208, // RFC5842 - im_used = 226, // RFC3229 - - multiple_choice = 300, // RFC7231, Section 6.4.1 - moved_permanently = 301, // RFC7231, Section 6.4.2 - found = 302, // RFC7231, Section 6.4.3 - see_other = 303, // RFC7231, Section 6.4.4 - not_modified = 304, // RFC7232, Section 4.1 - use_proxy = 305, // RFC7231, Section 6.4.5 - temporary_redirect = 307, // RFC7231, Section 6.4.7 - permanent_redirect = 308, // RFC7538 - - bad_request = 400, // RFC7231, Section 6.5.1 - unauthorized = 401, // RFC7235, Section 3.1 - payment_required = 402, // RFC7231, Section 6.5.2 - forbidden = 403, // RFC7231, Section 6.5.3 - not_found = 404, // RFC7231, Section 6.5.4 - method_not_allowed = 405, // RFC7231, Section 6.5.5 - not_acceptable = 406, // RFC7231, Section 6.5.6 - proxy_auth_required = 407, // RFC7235, Section 3.2 - request_timeout = 408, // RFC7231, Section 6.5.7 - conflict = 409, // RFC7231, Section 6.5.8 - gone = 410, // RFC7231, Section 6.5.9 - length_required = 411, // RFC7231, Section 6.5.10 - precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2 - payload_too_large = 413, // RFC7231, Section 6.5.11 - uri_too_long = 414, // RFC7231, Section 6.5.12 - unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3 - range_not_satisfiable = 416, // RFC7233, Section 4.4 - expectation_failed = 417, // RFC7231, Section 6.5.14 - teapot = 418, // RFC 7168, 2.3.3 - misdirected_request = 421, // RFC7540, Section 9.1.2 - unprocessable_entity = 422, // RFC4918 - locked = 423, // RFC4918 - failed_dependency = 424, // RFC4918 - too_early = 425, // RFC8470 - upgrade_required = 426, // RFC7231, Section 6.5.15 - precondition_required = 428, // RFC6585 - too_many_requests = 429, // RFC6585 - header_fields_too_large = 431, // RFC6585 - unavailable_for_legal_reasons = 451, // RFC7725 - - internal_server_error = 500, // RFC7231, Section 6.6.1 - not_implemented = 501, // RFC7231, Section 6.6.2 - bad_gateway = 502, // RFC7231, Section 6.6.3 - service_unavailable = 503, // RFC7231, Section 6.6.4 - gateway_timeout = 504, // RFC7231, Section 6.6.5 - http_version_not_supported = 505, // RFC7231, Section 6.6.6 - variant_also_negotiates = 506, // RFC2295 - insufficient_storage = 507, // RFC4918 - loop_detected = 508, // RFC5842 - not_extended = 510, // RFC2774 - network_authentication_required = 511, // RFC6585 - - _, - - pub fn phrase(self: Status) ?[]const u8 { - return switch (self) { - // 1xx statuses - .@"continue" => "Continue", - .switching_protocols => "Switching Protocols", - .processing => "Processing", - .early_hints => "Early Hints", - - // 2xx statuses - .ok => "OK", - .created => "Created", - .accepted => "Accepted", - .non_authoritative_info => "Non-Authoritative Information", - .no_content => "No Content", - .reset_content => "Reset Content", - .partial_content => "Partial Content", - .multi_status => "Multi-Status", - .already_reported => "Already Reported", - .im_used => "IM Used", - - // 3xx statuses - .multiple_choice => "Multiple Choice", - .moved_permanently => "Moved Permanently", - .found => "Found", - .see_other => "See Other", - .not_modified => "Not Modified", - .use_proxy => "Use Proxy", - .temporary_redirect => "Temporary Redirect", - .permanent_redirect => "Permanent Redirect", - - // 4xx statuses - .bad_request => "Bad Request", - .unauthorized => "Unauthorized", - .payment_required => "Payment Required", - .forbidden => "Forbidden", - .not_found => "Not Found", - .method_not_allowed => "Method Not Allowed", - .not_acceptable => "Not Acceptable", - .proxy_auth_required => "Proxy Authentication Required", - .request_timeout => "Request Timeout", - .conflict => "Conflict", - .gone => "Gone", - .length_required => "Length Required", - .precondition_failed => "Precondition Failed", - .payload_too_large => "Payload Too Large", - .uri_too_long => "URI Too Long", - .unsupported_media_type => "Unsupported Media Type", - .range_not_satisfiable => "Range Not Satisfiable", - .expectation_failed => "Expectation Failed", - .teapot => "I'm a teapot", - .misdirected_request => "Misdirected Request", - .unprocessable_entity => "Unprocessable Entity", - .locked => "Locked", - .failed_dependency => "Failed Dependency", - .too_early => "Too Early", - .upgrade_required => "Upgrade Required", - .precondition_required => "Precondition Required", - .too_many_requests => "Too Many Requests", - .header_fields_too_large => "Request Header Fields Too Large", - .unavailable_for_legal_reasons => "Unavailable For Legal Reasons", - - // 5xx statuses - .internal_server_error => "Internal Server Error", - .not_implemented => "Not Implemented", - .bad_gateway => "Bad Gateway", - .service_unavailable => "Service Unavailable", - .gateway_timeout => "Gateway Timeout", - .http_version_not_supported => "HTTP Version Not Supported", - .variant_also_negotiates => "Variant Also Negotiates", - .insufficient_storage => "Insufficient Storage", - .loop_detected => "Loop Detected", - .not_extended => "Not Extended", - .network_authentication_required => "Network Authentication Required", - - else => return null, - }; - } - - pub const Class = enum { - informational, - success, - redirect, - client_error, - server_error, - }; - - pub fn class(self: Status) ?Class { - return switch (@enumToInt(self)) { - 100...199 => .informational, - 200...299 => .success, - 300...399 => .redirect, - 400...499 => .client_error, - 500...599 => .server_error, - else => null, - }; - } -}; - -test { - try std.testing.expectEqualStrings("OK", Status.ok.phrase().?); - try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?); -} - -test { - try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class()); - try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class()); -} diff --git a/lib/std/net.zig b/lib/std/net.zig index 4a0582e7f5..ebc4de08b9 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1687,6 +1687,13 @@ pub const Stream = struct { } } + pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { From d2f5d0b1990a160aa1d648531ea5b1df7b2acdce Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 13 Dec 2022 20:15:41 -0700 Subject: [PATCH 04/59] std.crypto.Tls: parse the ServerHello handshake --- lib/std/crypto/Tls.zig | 127 ++++++++++++++++++++++++++++++++++++---- lib/std/http/Client.zig | 4 +- lib/std/net.zig | 22 +++++++ 3 files changed, 138 insertions(+), 15 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index ab54b42b70..0dc6946003 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -8,6 +8,13 @@ const assert = std.debug.assert; state: State = .start, x25519_priv_key: [32]u8 = undefined, x25519_pub_key: [32]u8 = undefined, +x25519_server_pub_key: [32]u8 = undefined, + +const ProtocolVersion = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; const State = enum { /// In this state, all fields are undefined except state. @@ -186,6 +193,18 @@ const NamedGroup = enum(u16) { // * length: u24 // * data: opaque +// ServerHello: +// * ProtocolVersion legacy_version = 0x0303; +// * Random random; +// * opaque legacy_session_id_echo<0..32>; +// * CipherSuite cipher_suite; +// * uint8 legacy_compression_method = 0; +// * Extension extensions<6..2^16-1>; + +// Extension: +// * ExtensionType extension_type; +// * opaque extension_data<0..2^16-1>; + const CipherSuite = enum(u16) { TLS_AES_128_GCM_SHA256 = 0x1301, TLS_AES_256_GCM_SHA384 = 0x1302, @@ -259,10 +278,10 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { // Extension: key_share 0, 51, // ExtensionType.key_share - 0x00, 38, // byte length of this extension payload - 0x00, 36, // byte length of client_shares + 0, 38, // byte length of this extension payload + 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 - 0x00, 32, // byte length of key_exchange + 0, 32, // byte length of key_exchange } ++ tls.x25519_pub_key ++ [_]u8{ // Extension: server_name @@ -313,21 +332,103 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { try stream.writevAll(&iovecs); { - var buf: [1000]u8 = undefined; - const amt = try stream.read(&buf); - const resp = buf[0..amt]; - const ct = @intToEnum(ContentType, resp[0]); + var handshake_buf: [4000]u8 = undefined; + const plaintext = handshake_buf[0..5]; + const amt = try stream.readAtLeast(&handshake_buf, plaintext.len); + if (amt < plaintext.len) return error.EndOfStream; + const ct = @intToEnum(ContentType, plaintext[0]); + const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); + const end = plaintext.len + frag_len; + if (end > handshake_buf.len) return error.TlsServerHelloTooBig; + if (amt < end) { + const amt2 = try stream.readAll(handshake_buf[amt..end]); + if (amt2 < plaintext.len) return error.EndOfStream; + } + const frag = handshake_buf[plaintext.len..end]; + if (ct == .alert) { - //const prot_ver = @bitCast(u16, resp[1..][0..2].*); - const len = std.mem.readIntBig(u16, resp[3..][0..2]); - const alert = resp[5..][0..len]; - const level = @intToEnum(AlertLevel, alert[0]); - const desc = @intToEnum(AlertDescription, alert[1]); + const level = @intToEnum(AlertLevel, frag[0]); + const desc = @intToEnum(AlertDescription, frag[1]); std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); std.process.exit(1); + } else if (ct == .handshake) { + if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + return error.TlsUnexpectedMessage; + } + const length = mem.readIntBig(u24, frag[1..4]); + if (4 + length != frag.len) return error.TlsBadLength; + const hello = frag[4..]; + const legacy_version = mem.readIntBig(u16, hello[0..2]); + const random = hello[2..34].*; + _ = random; + const legacy_session_id_echo_len = hello[34]; + if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; + const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); + const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + return error.TlsIllegalParameter; + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + const legacy_compression_method = hello[37]; + _ = legacy_compression_method; + const extensions_size = mem.readIntBig(u16, hello[38..40]); + if (40 + extensions_size != hello.len) return error.TlsBadLength; + var i: usize = 40; + var supported_version: u16 = 0; + var have_server_pub_key = false; + while (i < hello.len) { + const et = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const ext_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const next_i = i + ext_size; + if (next_i > hello.len) return error.TlsBadLength; + switch (et) { + @enumToInt(ExtensionType.supported_versions) => { + if (supported_version != 0) return error.TlsIllegalParameter; + supported_version = mem.readIntBig(u16, hello[i..][0..2]); + }, + @enumToInt(ExtensionType.key_share) => { + if (have_server_pub_key) return error.TlsIllegalParameter; + const named_group = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + switch (named_group) { + @enumToInt(NamedGroup.x25519) => { + const key_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + if (key_size != 32) return error.TlsBadLength; + const encrypted_key = hello[i..][0..32].*; + const server_pub_key = try crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + encrypted_key, + ); + tls.x25519_server_pub_key = server_pub_key; + have_server_pub_key = true; + }, + else => { + std.debug.print("named group: {x}\n", .{named_group}); + return error.TlsIllegalParameter; + }, + } + }, + else => { + std.debug.print("unexpected extension: {x}\n", .{et}); + }, + } + i = next_i; + } + if (!have_server_pub_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; + switch (tls_version) { + @enumToInt(ProtocolVersion.tls_1_2) => { + std.debug.print("server wants TLS v1.2\n", .{}); + }, + @enumToInt(ProtocolVersion.tls_1_3) => { + std.debug.print("server wants TLS v1.3\n", .{}); + }, + else => return error.TlsIllegalParameter, + } } else { std.debug.print("content_type: {s}\n", .{@tagName(ct)}); - std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) }); + std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) }); } } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 80904d765c..b10011a6b1 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -59,7 +59,7 @@ pub const Request = struct { pub fn deinit(client: *Client) void { assert(client.active_requests == 0); - client.headers.denit(client.allocator); + client.headers.deinit(client.allocator); client.* = undefined; } @@ -69,6 +69,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), .protocol = options.protocol, }; + client.active_requests += 1; errdefer req.deinit(); switch (options.protocol) { @@ -100,7 +101,6 @@ pub fn request(client: *Client, options: Request.Options) !Request { } req.headers.appendSliceAssumeCapacity(client.headers.items); - client.active_requests += 1; return req; } diff --git a/lib/std/net.zig b/lib/std/net.zig index ebc4de08b9..8c8ab51a4a 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1672,6 +1672,28 @@ pub const Stream = struct { } } + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { + return readAtLeast(s, buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// multiple times until at least the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error + /// condition. + pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + var index: usize = 0; + while (index < len) { + const amt = try s.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's /// file system thread instead of non-blocking. It needs to be reworked to properly /// use non-blocking I/O. From 920e5bc4ff4bdfee173768809e712f8004f7132d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 13 Dec 2022 21:59:01 -0700 Subject: [PATCH 05/59] std.crypto.Tls: discard ChangeCipherSpec messages The next step here is to decrypt encrypted records --- lib/std/crypto/Tls.zig | 225 ++++++++++++++++++++++++----------------- lib/std/net.zig | 6 +- 2 files changed, 138 insertions(+), 93 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 0dc6946003..65f54ffa68 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -188,6 +188,12 @@ const NamedGroup = enum(u16) { // * fragment: opaque // - the data being transmitted +// Ciphertext +// * ContentType opaque_type = application_data; /* 23 */ +// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ +// * uint16 length; +// * opaque encrypted_record[TLSCiphertext.length]; + // Handshake: // * type: HandshakeType // * length: u24 @@ -331,105 +337,144 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }; try stream.writevAll(&iovecs); - { - var handshake_buf: [4000]u8 = undefined; + var handshake_buf: [4000]u8 = undefined; + var len: usize = 0; + var i: usize = i: { const plaintext = handshake_buf[0..5]; - const amt = try stream.readAtLeast(&handshake_buf, plaintext.len); - if (amt < plaintext.len) return error.EndOfStream; + len = try stream.readAtLeast(&handshake_buf, plaintext.len); + if (len < plaintext.len) return error.EndOfStream; const ct = @intToEnum(ContentType, plaintext[0]); const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); const end = plaintext.len + frag_len; - if (end > handshake_buf.len) return error.TlsServerHelloTooBig; - if (amt < end) { - const amt2 = try stream.readAll(handshake_buf[amt..end]); - if (amt2 < plaintext.len) return error.EndOfStream; + if (end > handshake_buf.len) return error.TlsRecordOverflow; + if (end > len) { + len += try stream.readAtLeast(handshake_buf[len..], end - len); + if (end > len) return error.EndOfStream; } const frag = handshake_buf[plaintext.len..end]; - if (ct == .alert) { - const level = @intToEnum(AlertLevel, frag[0]); - const desc = @intToEnum(AlertDescription, frag[1]); - std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); - std.process.exit(1); - } else if (ct == .handshake) { - if (frag[0] != @enumToInt(HandshakeType.server_hello)) { - return error.TlsUnexpectedMessage; - } - const length = mem.readIntBig(u24, frag[1..4]); - if (4 + length != frag.len) return error.TlsBadLength; - const hello = frag[4..]; - const legacy_version = mem.readIntBig(u16, hello[0..2]); - const random = hello[2..34].*; - _ = random; - const legacy_session_id_echo_len = hello[34]; - if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; - const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch - return error.TlsIllegalParameter; - std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); - const legacy_compression_method = hello[37]; - _ = legacy_compression_method; - const extensions_size = mem.readIntBig(u16, hello[38..40]); - if (40 + extensions_size != hello.len) return error.TlsBadLength; - var i: usize = 40; - var supported_version: u16 = 0; - var have_server_pub_key = false; - while (i < hello.len) { - const et = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - const ext_size = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - const next_i = i + ext_size; - if (next_i > hello.len) return error.TlsBadLength; - switch (et) { - @enumToInt(ExtensionType.supported_versions) => { - if (supported_version != 0) return error.TlsIllegalParameter; - supported_version = mem.readIntBig(u16, hello[i..][0..2]); - }, - @enumToInt(ExtensionType.key_share) => { - if (have_server_pub_key) return error.TlsIllegalParameter; - const named_group = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - switch (named_group) { - @enumToInt(NamedGroup.x25519) => { - const key_size = mem.readIntBig(u16, hello[i..][0..2]); - i += 2; - if (key_size != 32) return error.TlsBadLength; - const encrypted_key = hello[i..][0..32].*; - const server_pub_key = try crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - encrypted_key, - ); - tls.x25519_server_pub_key = server_pub_key; - have_server_pub_key = true; - }, - else => { - std.debug.print("named group: {x}\n", .{named_group}); - return error.TlsIllegalParameter; - }, - } - }, - else => { - std.debug.print("unexpected extension: {x}\n", .{et}); - }, + switch (ct) { + .alert => { + const level = @intToEnum(AlertLevel, frag[0]); + const desc = @intToEnum(AlertDescription, frag[1]); + std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + return error.TlsAlert; + }, + .handshake => { + if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + return error.TlsUnexpectedMessage; } - i = next_i; - } - if (!have_server_pub_key) return error.TlsIllegalParameter; - const tls_version = if (supported_version == 0) legacy_version else supported_version; - switch (tls_version) { - @enumToInt(ProtocolVersion.tls_1_2) => { - std.debug.print("server wants TLS v1.2\n", .{}); - }, - @enumToInt(ProtocolVersion.tls_1_3) => { - std.debug.print("server wants TLS v1.3\n", .{}); - }, - else => return error.TlsIllegalParameter, - } - } else { - std.debug.print("content_type: {s}\n", .{@tagName(ct)}); - std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) }); + const length = mem.readIntBig(u24, frag[1..4]); + if (4 + length != frag.len) return error.TlsBadLength; + const hello = frag[4..]; + const legacy_version = mem.readIntBig(u16, hello[0..2]); + const random = hello[2..34].*; + _ = random; + const legacy_session_id_echo_len = hello[34]; + if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; + const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); + const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + return error.TlsIllegalParameter; + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + const legacy_compression_method = hello[37]; + _ = legacy_compression_method; + const extensions_size = mem.readIntBig(u16, hello[38..40]); + if (40 + extensions_size != hello.len) return error.TlsBadLength; + var i: usize = 40; + var supported_version: u16 = 0; + var have_server_pub_key = false; + while (i < hello.len) { + const et = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const ext_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const next_i = i + ext_size; + if (next_i > hello.len) return error.TlsBadLength; + switch (et) { + @enumToInt(ExtensionType.supported_versions) => { + if (supported_version != 0) return error.TlsIllegalParameter; + supported_version = mem.readIntBig(u16, hello[i..][0..2]); + }, + @enumToInt(ExtensionType.key_share) => { + if (have_server_pub_key) return error.TlsIllegalParameter; + const named_group = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + switch (named_group) { + @enumToInt(NamedGroup.x25519) => { + const key_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + if (key_size != 32) return error.TlsBadLength; + const encrypted_key = hello[i..][0..32].*; + const server_pub_key = try crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + encrypted_key, + ); + tls.x25519_server_pub_key = server_pub_key; + have_server_pub_key = true; + }, + else => { + std.debug.print("named group: {x}\n", .{named_group}); + return error.TlsIllegalParameter; + }, + } + }, + else => { + std.debug.print("unexpected extension: {x}\n", .{et}); + }, + } + i = next_i; + } + if (!have_server_pub_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; + switch (tls_version) { + @enumToInt(ProtocolVersion.tls_1_2) => { + std.debug.print("server wants TLS v1.2\n", .{}); + }, + @enumToInt(ProtocolVersion.tls_1_3) => { + std.debug.print("server wants TLS v1.3\n", .{}); + }, + else => return error.TlsIllegalParameter, + } + }, + else => return error.TlsUnexpectedMessage, } + break :i end; + }; + + while (true) { + const end_hdr = i + 5; + if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; + if (end_hdr > len) { + len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); + if (end_hdr > len) return error.EndOfStream; + } + const ct = @intToEnum(ContentType, handshake_buf[i]); + i += 1; + const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]); + i += 2; + _ = legacy_version; + const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); + i += 2; + const end = i + record_size; + if (end > handshake_buf.len) return error.TlsRecordOverflow; + if (end > len) { + len += try stream.readAtLeast(handshake_buf[len..], end - len); + if (end > len) return error.EndOfStream; + } + switch (ct) { + .change_cipher_spec => { + if (record_size != 1) return error.TlsUnexpectedMessage; + if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; + }, + .application_data => { + std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size}); + }, + else => { + std.debug.print("content type: {s}\n", .{@tagName(ct)}); + return error.TlsUnexpectedMessage; + }, + } + i = end; } tls.state = .sent_hello; diff --git a/lib/std/net.zig b/lib/std/net.zig index 8c8ab51a4a..a265fa69a9 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1680,9 +1680,9 @@ pub const Stream = struct { } /// Returns the number of bytes read, calling the underlying read function - /// multiple times until at least the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error + /// the minimal number of times until at least the buffer has at least + /// `len` bytes filled. If the number read is less than `len` it means the + /// stream reached the end. Reaching the end of the stream is not an error /// condition. pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { var index: usize = 0; From 595fff7cb664b5dc517a682b3daec5ee2767fe0d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 15 Dec 2022 00:55:33 -0700 Subject: [PATCH 06/59] std.crypto.Tls: decrypting handshake messages --- lib/std/crypto/Tls.zig | 225 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 216 insertions(+), 9 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 65f54ffa68..ea1bff9a08 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -234,7 +234,12 @@ const cipher_suites = blk: { pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { assert(tls.state == .start); crypto.random.bytes(&tls.x25519_priv_key); - tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key); + tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| { + switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + } + }; // random (u32) var rand_buf: [32]u8 = undefined; @@ -337,6 +342,14 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }; try stream.writevAll(&iovecs); + const client_hello_bytes1 = hello_header[5..]; + + var client_handshake_key: [32]u8 = undefined; + var server_handshake_key: [32]u8 = undefined; + var client_handshake_iv: [12]u8 = undefined; + var server_handshake_iv: [12]u8 = undefined; + var cipher_suite: CipherSuite = undefined; + var handshake_buf: [4000]u8 = undefined; var len: usize = 0; var i: usize = i: { @@ -373,7 +386,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch return error.TlsIllegalParameter; std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); const legacy_compression_method = hello[37]; @@ -404,12 +417,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const key_size = mem.readIntBig(u16, hello[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; - const encrypted_key = hello[i..][0..32].*; - const server_pub_key = try crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - encrypted_key, - ); - tls.x25519_server_pub_key = server_pub_key; + tls.x25519_server_pub_key = hello[i..][0..32].*; have_server_pub_key = true; }, else => { @@ -435,12 +443,77 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }, else => return error.TlsIllegalParameter, } + + const shared_key = crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + tls.x25519_server_pub_key, + ) catch return error.TlsDecryptFailure; + + switch (cipher_suite) { + .TLS_AES_128_GCM_SHA256 => { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); + const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); + const empty_hash = emptyHash(Hash); + const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); + const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); + server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); + client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); + server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ + // std.fmt.fmtSliceHexLower(&shared_key), + // std.fmt.fmtSliceHexLower(&hello_hash), + // std.fmt.fmtSliceHexLower(&early_secret), + // std.fmt.fmtSliceHexLower(&empty_hash), + // std.fmt.fmtSliceHexLower(&derived_secret), + // std.fmt.fmtSliceHexLower(&handshake_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); + }, + .TLS_AES_256_GCM_SHA384 => { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); + const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); + const empty_hash = emptyHash(Hash); + const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); + const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); + server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); + client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); + server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + } }, else => return error.TlsUnexpectedMessage, } break :i end; }; + var read_seq: u64 = 0; + while (true) { const end_hdr = i + 5; if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; @@ -467,7 +540,88 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size}); + var cleartext_buf: [1000]u8 = undefined; + const cleartext = switch (cipher_suite) { + .TLS_AES_128_GCM_SHA256 => c: { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const ciphertext_len = record_size - AEAD.tag_length; + const ciphertext = handshake_buf[i..][0..ciphertext_len]; + i += ciphertext.len; + if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..ciphertext.len]; + const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; + const V = @Vector(AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); + read_seq += 1; + const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; + //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{ + // read_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&@as([12]u8, operand)), + //}); + const ad = handshake_buf[end_hdr - 5 ..][0..5]; + const key = server_handshake_key[0..AEAD.key_length].*; + AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + return error.TlsBadRecordMac; + + break :c cleartext; + }, + .TLS_AES_256_GCM_SHA384 => c: { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const ciphertext_len = record_size - AEAD.tag_length; + const ciphertext = handshake_buf[i..][0..ciphertext_len]; + i += ciphertext.len; + if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..ciphertext.len]; + const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; + const V = @Vector(AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); + read_seq += 1; + const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; + const ad = handshake_buf[end_hdr - 5 ..][0..5]; + const key = server_handshake_key[0..AEAD.key_length].*; + AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + return error.TlsBadRecordMac; + + break :c cleartext; + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + + const inner_ct = cleartext[cleartext.len - 1]; + switch (inner_ct) { + @enumToInt(ContentType.handshake) => { + const handshake_len = mem.readIntBig(u24, cleartext[1..4]); + if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength; + switch (cleartext[0]) { + @enumToInt(HandshakeType.encrypted_extensions) => { + const ext_size = mem.readIntBig(u16, cleartext[4..6]); + if (ext_size != 0) { + @panic("TODO handle encrypted extensions"); + } + std.debug.print("empty encrypted extensions\n", .{}); + }, + else => { + std.debug.print("handshake type: {d}\n", .{cleartext[0]}); + return error.TlsUnexpectedMessage; + }, + } + }, + else => { + std.debug.print("inner content type: {d}\n", .{inner_ct}); + return error.TlsUnexpectedMessage; + }, + } }, else => { std.debug.print("content type: {s}\n", .{@tagName(ct)}); @@ -486,3 +640,56 @@ pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void { _ = buffer; @panic("hold on a minute, we didn't finish implementing the handshake yet"); } + +fn hkdfExpandLabel( + comptime Hkdf: type, + key: [Hkdf.prk_length]u8, + label: []const u8, + context: []const u8, + comptime len: usize, +) [len]u8 { + const max_label_len = 255; + const max_context_len = 255; + const tls13 = "tls13 "; + var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined; + mem.writeIntBig(u16, buf[0..2], len); + buf[2] = @intCast(u8, tls13.len + label.len); + buf[3..][0..tls13.len].* = tls13.*; + var i: usize = 3 + tls13.len; + mem.copy(u8, buf[i..], label); + i += label.len; + buf[i] = @intCast(u8, context.len); + i += 1; + mem.copy(u8, buf[i..], context); + i += context.len; + + var result: [len]u8 = undefined; + Hkdf.expand(&result, buf[0..i], key); + return result; +} + +fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { + var result: [Hash.digest_length]u8 = undefined; + Hash.hash(&.{}, &result, .{}); + return result; +} + +fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 { + var h = Hash.init(.{}); + h.update(s0); + h.update(s1); + h.update(s2); + var result: [Hash.digest_length]u8 = undefined; + h.final(&result); + return result; +} + +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + +inline fn big(x: anytype) @TypeOf(x) { + return switch (native_endian) { + .Big => x, + .Little => @byteSwap(x), + }; +} From 40a85506b2e6a97af9c06bdcd001b6fd84cc549a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 15 Dec 2022 20:35:41 -0700 Subject: [PATCH 07/59] std.crypto.Tls: add read/write methods --- lib/std/crypto/Tls.zig | 560 ++++++++++++++++++++++++++++++---------- lib/std/crypto/sha2.zig | 22 ++ lib/std/http/Client.zig | 12 +- 3 files changed, 458 insertions(+), 136 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index ea1bff9a08..6b5374512b 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -5,24 +5,24 @@ const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; -state: State = .start, -x25519_priv_key: [32]u8 = undefined, -x25519_pub_key: [32]u8 = undefined, -x25519_server_pub_key: [32]u8 = undefined, +application_cipher: ApplicationCipher, +read_seq: u64, +write_seq: u64, +/// The size is enough to contain exactly one TLSCiphertext record. +partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8, +/// The number of partially read bytes inside `partiall_read_buffer`. +partially_read_len: u15, -const ProtocolVersion = enum(u16) { +pub const ciphertext_record_header_len = 5; +pub const max_ciphertext_len = (1 << 14) + 256; + +pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, }; -const State = enum { - /// In this state, all fields are undefined except state. - start, - sent_hello, -}; - -const ContentType = enum(u8) { +pub const ContentType = enum(u8) { invalid = 0, change_cipher_spec = 20, alert = 21, @@ -31,7 +31,7 @@ const ContentType = enum(u8) { _, }; -const HandshakeType = enum(u8) { +pub const HandshakeType = enum(u8) { client_hello = 1, server_hello = 2, new_session_ticket = 4, @@ -45,7 +45,7 @@ const HandshakeType = enum(u8) { message_hash = 254, }; -const ExtensionType = enum(u16) { +pub const ExtensionType = enum(u16) { /// RFC 6066 server_name = 0, /// RFC 6066 @@ -92,13 +92,13 @@ const ExtensionType = enum(u16) { key_share = 51, }; -const AlertLevel = enum(u8) { +pub const AlertLevel = enum(u8) { warning = 1, fatal = 2, _, }; -const AlertDescription = enum(u8) { +pub const AlertDescription = enum(u8) { close_notify = 0, unexpected_message = 10, bad_record_mac = 20, @@ -129,7 +129,7 @@ const AlertDescription = enum(u8) { _, }; -const SignatureScheme = enum(u16) { +pub const SignatureScheme = enum(u16) { // RSASSA-PKCS1-v1_5 algorithms rsa_pkcs1_sha256 = 0x0401, rsa_pkcs1_sha384 = 0x0501, @@ -161,7 +161,7 @@ const SignatureScheme = enum(u16) { _, }; -const NamedGroup = enum(u16) { +pub const NamedGroup = enum(u16) { // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, secp384r1 = 0x0018, @@ -211,7 +211,7 @@ const NamedGroup = enum(u16) { // * ExtensionType extension_type; // * opaque extension_data<0..2^16-1>; -const CipherSuite = enum(u16) { +pub const CipherSuite = enum(u16) { TLS_AES_128_GCM_SHA256 = 0x1301, TLS_AES_256_GCM_SHA384 = 0x1302, TLS_CHACHA20_POLY1305_SHA256 = 0x1303, @@ -219,6 +219,73 @@ const CipherSuite = enum(u16) { TLS_AES_128_CCM_8_SHA256 = 0x1305, }; +pub const CipherParams = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.key_len]u8, + master_secret: [Hkdf.key_len]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_AES_256_GCM_SHA384: struct { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.key_len]u8, + master_secret: [Hkdf.key_len]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + +/// Encryption parameters for application traffic. +pub const ApplicationCipher = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + const Hash = crypto.hash.sha2.Sha256; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_AES_256_GCM_SHA384: struct { + const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + const Hash = crypto.hash.sha2.Sha384; + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + const cipher_suites = blk: { const fields = @typeInfo(CipherSuite).Enum.fields; var result: [(fields.len + 1) * 2]u8 = undefined; @@ -231,10 +298,11 @@ const cipher_suites = blk: { break :blk result; }; -pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { - assert(tls.state == .start); - crypto.random.bytes(&tls.x25519_priv_key); - tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| { +/// `host` is only borrowed during this function call. +pub fn init(stream: net.Stream, host: []const u8) !Tls { + var x25519_priv_key: [32]u8 = undefined; + crypto.random.bytes(&x25519_priv_key); + const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| { switch (err) { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, @@ -293,7 +361,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 0, 32, // byte length of key_exchange - } ++ tls.x25519_pub_key ++ [_]u8{ + } ++ x25519_pub_key ++ [_]u8{ // Extension: server_name 0, 0, // ExtensionType.server_name @@ -330,25 +398,23 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len)); mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len)); - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &hello_header, - .iov_len = hello_header.len, - }, - .{ - .iov_base = host.ptr, - .iov_len = host.len, - }, - }; - try stream.writevAll(&iovecs); + { + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &hello_header, + .iov_len = hello_header.len, + }, + .{ + .iov_base = host.ptr, + .iov_len = host.len, + }, + }; + try stream.writevAll(&iovecs); + } const client_hello_bytes1 = hello_header[5..]; - var client_handshake_key: [32]u8 = undefined; - var server_handshake_key: [32]u8 = undefined; - var client_handshake_iv: [12]u8 = undefined; - var server_handshake_iv: [12]u8 = undefined; - var cipher_suite: CipherSuite = undefined; + var cipher_params: CipherParams = undefined; var handshake_buf: [4000]u8 = undefined; var len: usize = 0; @@ -386,16 +452,16 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch return error.TlsIllegalParameter; - std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)}); const legacy_compression_method = hello[37]; _ = legacy_compression_method; const extensions_size = mem.readIntBig(u16, hello[38..40]); if (40 + extensions_size != hello.len) return error.TlsBadLength; var i: usize = 40; var supported_version: u16 = 0; - var have_server_pub_key = false; + var opt_x25519_server_pub_key: ?*[32]u8 = null; while (i < hello.len) { const et = mem.readIntBig(u16, hello[i..][0..2]); i += 2; @@ -409,7 +475,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { supported_version = mem.readIntBig(u16, hello[i..][0..2]); }, @enumToInt(ExtensionType.key_share) => { - if (have_server_pub_key) return error.TlsIllegalParameter; + if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; const named_group = mem.readIntBig(u16, hello[i..][0..2]); i += 2; switch (named_group) { @@ -417,8 +483,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { const key_size = mem.readIntBig(u16, hello[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; - tls.x25519_server_pub_key = hello[i..][0..32].*; - have_server_pub_key = true; + opt_x25519_server_pub_key = hello[i..][0..32]; }, else => { std.debug.print("named group: {x}\n", .{named_group}); @@ -432,7 +497,8 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } i = next_i; } - if (!have_server_pub_key) return error.TlsIllegalParameter; + const x25519_server_pub_key = opt_x25519_server_pub_key orelse + return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { @enumToInt(ProtocolVersion.tls_1_2) => { @@ -445,28 +511,44 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } const shared_key = crypto.dh.X25519.scalarmult( - tls.x25519_priv_key, - tls.x25519_server_pub_key, + x25519_priv_key, + x25519_server_pub_key.*, ) catch return error.TlsDecryptFailure; - switch (cipher_suite) { - .TLS_AES_128_GCM_SHA256 => { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const Hash = crypto.hash.sha2.Sha256; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); - const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); - const empty_hash = emptyHash(Hash); - const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); - const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); - server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); - client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); - server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); + switch (cipher_suite_tag) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| { + const P = std.meta.TagPayload(CipherParams, tag); + cipher_params = @unionInit(CipherParams, @tagName(tag), .{ + .handshake_secret = undefined, + .master_secret = undefined, + .client_handshake_key = undefined, + .server_handshake_key = undefined, + .client_finished_key = undefined, + .server_finished_key = undefined, + .client_handshake_iv = undefined, + .server_handshake_iv = undefined, + .transcript_hash = P.Hash.init(.{}), + }); + const p = &@field(cipher_params, @tagName(tag)); + p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(frag); // Server Hello + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = emptyHash(P.Hash); + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ // std.fmt.fmtSliceHexLower(&shared_key), // std.fmt.fmtSliceHexLower(&hello_hash), @@ -478,24 +560,6 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { // std.fmt.fmtSliceHexLower(&server_secret), //}); }, - .TLS_AES_256_GCM_SHA384 => { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const Hash = crypto.hash.sha2.Sha384; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash); - const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length)); - const empty_hash = emptyHash(Hash); - const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); - const handshake_secret = Hkdf.extract(&derived_secret, &shared_key); - const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); - const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); - client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length); - server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length); - client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length); - server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length); - }, .TLS_CHACHA20_POLY1305_SHA256 => { @panic("TODO"); }, @@ -541,50 +605,24 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { }, .application_data => { var cleartext_buf: [1000]u8 = undefined; - const cleartext = switch (cipher_suite) { - .TLS_AES_128_GCM_SHA256 => c: { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const ciphertext_len = record_size - AEAD.tag_length; + const cleartext = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + const P = @TypeOf(p.*); + const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext = handshake_buf[i..][0..ciphertext_len]; i += ciphertext.len; if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; - const V = @Vector(AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (AEAD.nonce_length - 8); + const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*; + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); read_seq += 1; - const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; - //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{ - // read_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&@as([12]u8, operand)), - //}); + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_handshake_iv) ^ operand; const ad = handshake_buf[end_hdr - 5 ..][0..5]; - const key = server_handshake_key[0..AEAD.key_length].*; - AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; - - break :c cleartext; - }, - .TLS_AES_256_GCM_SHA384 => c: { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const ciphertext_len = record_size - AEAD.tag_length; - const ciphertext = handshake_buf[i..][0..ciphertext_len]; - i += ciphertext.len; - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*; - const V = @Vector(AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (AEAD.nonce_length - 8); - const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); - read_seq += 1; - const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand; - const ad = handshake_buf[end_hdr - 5 ..][0..5]; - const key = server_handshake_key[0..AEAD.key_length].*; - AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch - return error.TlsBadRecordMac; - + p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]); break :c cleartext; }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -611,6 +649,86 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { } std.debug.print("empty encrypted extensions\n", .{}); }, + @enumToInt(HandshakeType.certificate) => { + std.debug.print("cool certificate bro\n", .{}); + }, + @enumToInt(HandshakeType.certificate_verify) => { + std.debug.print("the certificate came with a fancy signature\n", .{}); + }, + @enumToInt(HandshakeType.finished) => { + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { + const P = @TypeOf(p.*); + // TODO verify the server's data + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(HandshakeType.finished), + 0, 0, verify_data.len + 1 + P.AEAD.tag_length, // length + } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; + + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + var finished_msg = [_]u8{ + @enumToInt(ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ ([1]u8{undefined} ** wrapped_len); + + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + { + var iovecs = [_]std.os.iovec_const{ + .{ + .iov_base = &client_change_cipher_spec_msg, + .iov_len = client_change_cipher_spec_msg.len, + }, + .{ + .iov_base = &finished_msg, + .iov_len = finished_msg.len, + }, + }; + try stream.writevAll(&iovecs); + } + + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + return .{ + .application_cipher = app_cipher, + .read_seq = read_seq, + .write_seq = 1, + .partially_read_buffer = undefined, + .partially_read_len = 0, + }; + }, else => { std.debug.print("handshake type: {d}\n", .{cleartext[0]}); return error.TlsUnexpectedMessage; @@ -631,14 +749,185 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { i = end; } - tls.state = .sent_hello; + return error.TlsHandshakeFailure; } -pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void { - _ = tls; - _ = stream; - _ = buffer; - @panic("hold on a minute, we didn't finish implementing the handshake yet"); +pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { + var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined; + var iovecs_buf: [5]std.os.iovec_const = undefined; + var ciphertext_end: usize = 0; + var iovec_end: usize = 0; + var bytes_i: usize = 0; + switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + while (true) { + const ciphertext_len = @intCast(u16, @min( + @min(bytes.len - bytes_i, max_ciphertext_len), + ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end, + )); + if (ciphertext_len == 0) return bytes_i; + + const wrapped_len = ciphertext_len + P.AEAD.tag_length; + const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len]; + + const ad = record[0..5]; + ciphertext_end += 5; + const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; + ciphertext_end += ciphertext_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; + ciphertext_end += P.AEAD.tag_length; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq)); + tls.write_seq += 1; + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; + ad.* = + [_]u8{@enumToInt(ContentType.application_data)} ++ + int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ + int2(wrapped_len); + const cleartext = bytes[bytes_i..ciphertext.len]; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + + iovecs_buf[iovec_end] = .{ + .iov_base = record.ptr, + .iov_len = record.len, + }; + iovec_end += 1; + + bytes_i += ciphertext_len; + } + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + } + + // Ideally we would call writev exactly once here, however, we must ensure + // that we don't return with a record partially written. + var i: usize = 0; + var total_amt: usize = 0; + while (true) { + var amt = try stream.writev(iovecs_buf[i..iovec_end]); + total_amt += amt; + while (amt >= iovecs_buf[i].iov_len) { + amt -= iovecs_buf[i].iov_len; + i += 1; + // Rely on the property that iovecs delineate records, meaning that + // if amt equals zero here, we have fortunately found ourselves + // with a short read that aligns at the record boundary. + if (i >= iovec_end or amt == 0) return total_amt; + } + iovecs_buf[i].iov_base += amt; + iovecs_buf[i].iov_len -= amt; + } +} + +pub fn writeAll(tls: *Tls, stream: net.Stream, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try tls.write(stream, bytes[index..]); + } +} + +/// Returns number of bytes that have been read, which are now populated inside +/// `buffer`. A return value of zero bytes does not necessarily mean end of +/// stream. +pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { + const prev_len = tls.partially_read_len; + var in_buf: [max_ciphertext_len * 4]u8 = undefined; + mem.copy(u8, &in_buf, tls.partially_read_buffer[0..prev_len]); + + // Capacity of output buffer, in records, rounded up. + const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); + const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]); + const frag = in_buf[0 .. prev_len + actual_read_len]; + var in: usize = 0; + var out: usize = 0; + + while (true) { + if (in + ciphertext_record_header_len > frag.len) { + return finishRead(tls, frag, in, out); + } + const ct = @intToEnum(ContentType, frag[in]); + in += 1; + const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); + in += 2; + _ = legacy_version; + const record_size = mem.readIntBig(u16, frag[in..][0..2]); + in += 2; + const end = in + record_size; + if (end > frag.len) { + if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; + return finishRead(tls, frag, in, out); + } + switch (ct) { + .alert => { + @panic("TODO handle an alert here"); + }, + .application_data => { + const cleartext_len = switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + const P = @TypeOf(p.*); + const V = @Vector(P.AEAD.nonce_length, u8); + const ciphertext_len = record_size - P.AEAD.tag_length; + const ciphertext = frag[in..][0..ciphertext_len]; + in += ciphertext_len; + const auth_tag = frag[in..][0..P.AEAD.tag_length].*; + const cleartext = buffer[out..][0..ciphertext_len]; + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq)); + tls.read_seq += 1; + const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; + const ad = frag[0..ciphertext_record_header_len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch + return error.TlsBadRecordMac; + break :c cleartext.len; + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + + const inner_ct = buffer[out + cleartext_len - 1]; + switch (inner_ct) { + @enumToInt(ContentType.handshake) => { + std.debug.print("the server wants to keep shaking hands\n", .{}); + }, + @enumToInt(ContentType.application_data) => { + out += cleartext_len - 1; + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + in = end; + } +} + +fn finishRead(tls: *Tls, frag: []const u8, in: usize, out: usize) usize { + const saved_buf = frag[in..]; + mem.copy(u8, &tls.partially_read_buffer, saved_buf); + tls.partially_read_len = @intCast(u15, saved_buf.len); + return out; } fn hkdfExpandLabel( @@ -674,13 +963,9 @@ fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { return result; } -fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 { - var h = Hash.init(.{}); - h.update(s0); - h.update(s1); - h.update(s2); - var result: [Hash.digest_length]u8 = undefined; - h.final(&result); +fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 { + var result: [Hmac.mac_length]u8 = undefined; + Hmac.create(&result, message, &key); return result; } @@ -693,3 +978,10 @@ inline fn big(x: anytype) @TypeOf(x) { .Little => @byteSwap(x), }; } + +inline fn int2(x: u16) [2]u8 { + return .{ + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 9cdf8edcf1..217dea3723 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -142,6 +142,11 @@ fn Sha2x32(comptime params: Sha2Params32) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -175,6 +180,12 @@ fn Sha2x32(comptime params: Sha2Params32) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + const W = [64]u32{ 0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5, 0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174, @@ -621,6 +632,11 @@ fn Sha2x64(comptime params: Sha2Params64) type { d.total_len += b.len; } + pub fn peek(d: Self) [digest_length]u8 { + var copy = d; + return copy.finalResult(); + } + pub fn final(d: *Self, out: *[digest_length]u8) void { // The buffer here will never be completely full. mem.set(u8, d.buf[d.buf_len..], 0); @@ -654,6 +670,12 @@ fn Sha2x64(comptime params: Sha2Params64) type { } } + pub fn finalResult(d: *Self) [digest_length]u8 { + var result: [digest_length]u8 = undefined; + d.final(&result); + return result; + } + fn round(d: *Self, b: *const [128]u8) void { var s: [80]u64 = undefined; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index b10011a6b1..e7b056830a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -12,7 +12,7 @@ pub const Request = struct { client: *Client, stream: net.Stream, headers: std.ArrayListUnmanaged(u8) = .{}, - tls: std.crypto.Tls = .{}, + tls: std.crypto.Tls, protocol: Protocol, pub const Protocol = enum { http, https }; @@ -55,6 +55,13 @@ pub const Request = struct { }, } } + + pub fn read(req: *Request, buffer: []u8) !usize { + switch (req.protocol) { + .http => return req.stream.read(buffer), + .https => return req.tls.read(req.stream, buffer), + } + } }; pub fn deinit(client: *Client) void { @@ -68,6 +75,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { .client = client, .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), .protocol = options.protocol, + .tls = undefined, }; client.active_requests += 1; errdefer req.deinit(); @@ -75,7 +83,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { switch (options.protocol) { .http => {}, .https => { - try req.tls.init(req.stream, options.host); + req.tls = try std.crypto.Tls.init(req.stream, options.host); }, } From b97fc43baac9498799f3520ee860e5026e8dcb53 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 02:14:35 -0700 Subject: [PATCH 08/59] std.crypto.Tls: client is working against some servers --- lib/std/crypto/Tls.zig | 151 ++++++++++++++++++++++++++-------------- lib/std/http/Client.zig | 23 +++++- 2 files changed, 121 insertions(+), 53 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 6b5374512b..19bd3442cf 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -9,12 +9,14 @@ application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, /// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8, +partially_read_buffer: [max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partiall_read_buffer`. partially_read_len: u15, +eof: bool, pub const ciphertext_record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; +pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, @@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { var cipher_params: CipherParams = undefined; - var handshake_buf: [4000]u8 = undefined; + var handshake_buf: [8000]u8 = undefined; var len: usize = 0; var i: usize = i: { const plaintext = handshake_buf[0..5]; @@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { // std.fmt.fmtSliceHexLower(&hello_hash), // std.fmt.fmtSliceHexLower(&early_secret), // std.fmt.fmtSliceHexLower(&empty_hash), - // std.fmt.fmtSliceHexLower(&derived_secret), - // std.fmt.fmtSliceHexLower(&handshake_secret), + // std.fmt.fmtSliceHexLower(&hs_derived_secret), + // std.fmt.fmtSliceHexLower(&p.handshake_secret), // std.fmt.fmtSliceHexLower(&client_secret), // std.fmt.fmtSliceHexLower(&server_secret), //}); @@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const end_hdr = i + 5; if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; if (end_hdr > len) { + std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len }); len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); + std.debug.print("new len: {d} bytes\n", .{len}); if (end_hdr > len) return error.EndOfStream; } const ct = @intToEnum(ContentType, handshake_buf[i]); @@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); i += 2; const end = i + record_size; + std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end }); if (end > handshake_buf.len) return error.TlsRecordOverflow; if (end > len) { + std.debug.print("read len={d} atleast={d}\n", .{ len, end - len }); len += try stream.readAtLeast(handshake_buf[len..], end - len); + std.debug.print("new len: {d} bytes\n", .{len}); if (end > len) return error.EndOfStream; } switch (ct) { @@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - var cleartext_buf: [1000]u8 = undefined; + var cleartext_buf: [8000]u8 = undefined; const cleartext = switch (cipher_params) { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { const P = @TypeOf(p.*); @@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { }; const inner_ct = cleartext[cleartext.len - 1]; + std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)}); switch (inner_ct) { @enumToInt(ContentType.handshake) => { const handshake_len = mem.readIntBig(u24, cleartext[1..4]); - if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength; + if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength; + std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len }); switch (cleartext[0]) { @enumToInt(HandshakeType.encrypted_extensions) => { const ext_size = mem.readIntBig(u16, cleartext[4..6]); - if (ext_size != 0) { - @panic("TODO handle encrypted extensions"); - } - std.debug.print("empty encrypted extensions\n", .{}); + std.debug.print("{d} bytes of encrypted extensions\n", .{ + ext_size, + }); }, @enumToInt(HandshakeType.certificate) => { std.debug.print("cool certificate bro\n", .{}); @@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const nonce = p.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - { - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &client_change_cipher_spec_msg, - .iov_len = client_change_cipher_spec_msg.len, - }, - .{ - .iov_base = &finished_msg, - .iov_len = finished_msg.len, - }, - }; - try stream.writevAll(&iovecs); - } + //const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + _ = client_change_cipher_spec_msg; + const both_msgs = finished_msg; + try stream.writeAll(&both_msgs); const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ + // std.fmt.fmtSliceHexLower(&p.master_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); break :c @unionInit(ApplicationCipher, @tagName(tag), .{ .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), @@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { @panic("TODO"); }, }; + std.debug.print("remaining bytes: {d}\n", .{len - end}); return .{ .application_cipher = app_cipher, - .read_seq = read_seq, - .write_seq = 1, + .read_seq = 0, + .write_seq = 0, .partially_read_buffer = undefined, .partially_read_len = 0, + .eof = false, }; }, else => { @@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { } pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { - var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined; + var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined; + // Due to the trailing inner content type byte in the ciphertext, we need + // an additional buffer for storing the cleartext into before encrypting. + var cleartext_buf: [max_ciphertext_len]u8 = undefined; var iovecs_buf: [5]std.os.iovec_const = undefined; var ciphertext_end: usize = 0; var iovec_end: usize = 0; var bytes_i: usize = 0; - switch (tls.application_cipher) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| { + // How many bytes are taken up by overhead per record. + const overhead_len: usize = switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); + const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1; while (true) { - const ciphertext_len = @intCast(u16, @min( - @min(bytes.len - bytes_i, max_ciphertext_len), - ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end, + const encrypted_content_len = @intCast(u16, @min( + @min(bytes.len - bytes_i, max_ciphertext_len - 1), + ciphertext_buf.len - + ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, )); - if (ciphertext_len == 0) return bytes_i; + if (encrypted_content_len == 0) break :l overhead_len; - const wrapped_len = ciphertext_len + P.AEAD.tag_length; - const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len]; + mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); + cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data); + bytes_i += encrypted_content_len; + const ciphertext_len = encrypted_content_len + 1; + const cleartext = cleartext_buf[0..ciphertext_len]; - const ad = record[0..5]; - ciphertext_end += 5; + const record_start = ciphertext_end; + const ad = ciphertext_buf[ciphertext_end..][0..5]; + ad.* = + [_]u8{@enumToInt(ContentType.application_data)} ++ + int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ + int2(ciphertext_len + P.AEAD.tag_length); + ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; ciphertext_end += ciphertext_len; const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += P.AEAD.tag_length; + ciphertext_end += auth_tag.len; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq)); tls.write_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; - ad.* = - [_]u8{@enumToInt(ContentType.application_data)} ++ - int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ - int2(wrapped_len); - const cleartext = bytes[bytes_i..ciphertext.len]; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{ + // tls.write_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&p.client_key), + // std.fmt.fmtSliceHexLower(&p.client_iv), + // std.fmt.fmtSliceHexLower(ad), + // std.fmt.fmtSliceHexLower(auth_tag), + // std.fmt.fmtSliceHexLower(&p.server_key), + // std.fmt.fmtSliceHexLower(&p.server_iv), + //}); + const record = ciphertext_buf[record_start..ciphertext_end]; iovecs_buf[iovec_end] = .{ .iov_base = record.ptr, .iov_len = record.len, }; iovec_end += 1; - - bytes_i += ciphertext_len; } }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { .TLS_AES_128_CCM_8_SHA256 => { @panic("TODO"); }, - } + }; // Ideally we would call writev exactly once here, however, we must ensure // that we don't return with a record partially written. @@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { var total_amt: usize = 0; while (true) { var amt = try stream.writev(iovecs_buf[i..iovec_end]); - total_amt += amt; while (amt >= iovecs_buf[i].iov_len) { - amt -= iovecs_buf[i].iov_len; + const encrypted_amt = iovecs_buf[i].iov_len; + total_amt += encrypted_amt - overhead_len; + amt -= encrypted_amt; i += 1; // Rely on the property that iovecs delineate records, meaning that // if amt equals zero here, we have fortunately found ourselves @@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]); const frag = in_buf[0 .. prev_len + actual_read_len]; + if (frag.len == 0) { + tls.eof = true; + return 0; + } + std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len }); var in: usize = 0; var out: usize = 0; while (true) { if (in + ciphertext_record_header_len > frag.len) { + std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len }); return finishRead(tls, frag, in, out); } const ct = @intToEnum(ContentType, frag[in]); @@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const end = in + record_size; if (end > frag.len) { if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; + std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len }); return finishRead(tls, frag, in, out); } switch (ct) { @@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); + const ad = frag[in - 5 ..][0..5]; const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; @@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq)); tls.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - const ad = frag[0..ciphertext_record_header_len]; + //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ + // tls.read_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&p.server_key), + // std.fmt.fmtSliceHexLower(&p.server_iv), + //}); P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; break :c cleartext.len; @@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { }, }; - const inner_ct = buffer[out + cleartext_len - 1]; + const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]); switch (inner_ct) { - @enumToInt(ContentType.handshake) => { + .alert => { + const level = @intToEnum(AlertLevel, buffer[out]); + const desc = @intToEnum(AlertDescription, buffer[out + 1]); + if (desc == .close_notify) { + tls.eof = true; + return out; + } + std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + return error.TlsAlert; + }, + .handshake => { std.debug.print("the server wants to keep shaking hands\n", .{}); }, - @enumToInt(ContentType.application_data) => { + .application_data => { out += cleartext_len - 1; }, else => { + std.debug.print("inner content type: {d}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index e7b056830a..2c92163435 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -62,6 +62,25 @@ pub const Request = struct { .https => return req.tls.read(req.stream, buffer), } } + + pub fn readAll(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, buffer.len); + } + + pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { + var index: usize = 0; + while (index < len) { + const amt = try req.read(buffer[index..]); + if (amt == 0) { + switch (req.protocol) { + .http => break, + .https => if (req.tls.eof) break, + } + } + index += amt; + } + return index; + } }; pub fn deinit(client: *Client) void { @@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { @tagName(options.method).len + 1 + options.path.len + - " HTTP/2\r\nHost: ".len + + " HTTP/1.1\r\nHost: ".len + options.host.len + "\r\nUpgrade-Insecure-Requests: 1\r\n".len + client.headers.items.len + @@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { req.headers.appendSliceAssumeCapacity(@tagName(options.method)); req.headers.appendSliceAssumeCapacity(" "); req.headers.appendSliceAssumeCapacity(options.path); - req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: "); + req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); req.headers.appendSliceAssumeCapacity(options.host); switch (options.protocol) { .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), From 462b3ed69c20ea5dcae1660761012b3d5fa91367 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 13:16:53 -0700 Subject: [PATCH 09/59] std.crypto.Tls: handshake fixes * Handle multiple handshakes in one encrypted record * Fix incorrect handshake length sent to server --- lib/std/crypto/Tls.zig | 203 ++++++++++++++++++++++------------------- 1 file changed, 109 insertions(+), 94 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 19bd3442cf..4ea64f1be9 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -17,6 +17,10 @@ eof: bool, pub const ciphertext_record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; +pub const hello_retry_request_sequence = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +}; pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, @@ -450,7 +454,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const hello = frag[4..]; const legacy_version = mem.readIntBig(u16, hello[0..2]); const random = hello[2..34].*; - _ = random; + if (mem.eql(u8, &random, &hello_retry_request_sequence)) { + @panic("TODO handle HelloRetryRequest"); + } const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); @@ -551,7 +557,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{ + //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\nclient_handshake_iv: {}\nserver_handshake_iv: {}\n", .{ // std.fmt.fmtSliceHexLower(&shared_key), // std.fmt.fmtSliceHexLower(&hello_hash), // std.fmt.fmtSliceHexLower(&early_secret), @@ -560,6 +566,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { // std.fmt.fmtSliceHexLower(&p.handshake_secret), // std.fmt.fmtSliceHexLower(&client_secret), // std.fmt.fmtSliceHexLower(&server_secret), + // std.fmt.fmtSliceHexLower(&p.client_handshake_iv), + // std.fmt.fmtSliceHexLower(&p.server_handshake_iv), //}); }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -643,106 +651,113 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { }, }; - const inner_ct = cleartext[cleartext.len - 1]; - std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)}); + const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { - @enumToInt(ContentType.handshake) => { - const handshake_len = mem.readIntBig(u24, cleartext[1..4]); - if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength; - std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len }); - switch (cleartext[0]) { - @enumToInt(HandshakeType.encrypted_extensions) => { - const ext_size = mem.readIntBig(u16, cleartext[4..6]); - std.debug.print("{d} bytes of encrypted extensions\n", .{ - ext_size, - }); - }, - @enumToInt(HandshakeType.certificate) => { - std.debug.print("cool certificate bro\n", .{}); - }, - @enumToInt(HandshakeType.certificate_verify) => { - std.debug.print("the certificate came with a fancy signature\n", .{}); - }, - @enumToInt(HandshakeType.finished) => { - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @enumToInt(ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (cipher_params) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { - const P = @TypeOf(p.*); - // TODO verify the server's data - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @enumToInt(HandshakeType.finished), - 0, 0, verify_data.len + 1 + P.AEAD.tag_length, // length - } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type = cleartext[ct_i]; + ct_i += 1; + const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len - 1) + return error.TlsBadLength; + switch (handshake_type) { + @enumToInt(HandshakeType.encrypted_extensions) => { + const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); + ct_i += 2; + std.debug.print("{d} bytes of encrypted extensions\n", .{ + ext_size, + }); + }, + @enumToInt(HandshakeType.certificate) => { + std.debug.print("cool certificate bro\n", .{}); + }, + @enumToInt(HandshakeType.certificate_verify) => { + std.debug.print("the certificate came with a fancy signature\n", .{}); + }, + @enumToInt(HandshakeType.finished) => { + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (cipher_params) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { + const P = @TypeOf(p.*); + // TODO verify the server's data + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(HandshakeType.finished), + 0, 0, verify_data.len, // length + } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{ - @enumToInt(ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ ([1]u8{undefined} ** wrapped_len); + var finished_msg = [_]u8{ + @enumToInt(ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ ([1]u8{undefined} ** wrapped_len); - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - //const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - _ = client_change_cipher_spec_msg; - const both_msgs = finished_msg; - try stream.writeAll(&both_msgs); + const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + try stream.writeAll(&both_msgs); - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ - // std.fmt.fmtSliceHexLower(&p.master_secret), - // std.fmt.fmtSliceHexLower(&client_secret), - // std.fmt.fmtSliceHexLower(&server_secret), - //}); - break :c @unionInit(ApplicationCipher, @tagName(tag), .{ - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, - }; - std.debug.print("remaining bytes: {d}\n", .{len - end}); - return .{ - .application_cipher = app_cipher, - .read_seq = 0, - .write_seq = 0, - .partially_read_buffer = undefined, - .partially_read_len = 0, - .eof = false, - }; - }, - else => { - std.debug.print("handshake type: {d}\n", .{cleartext[0]}); - return error.TlsUnexpectedMessage; - }, + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ + // std.fmt.fmtSliceHexLower(&p.master_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); + break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + .TLS_CHACHA20_POLY1305_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_SHA256 => { + @panic("TODO"); + }, + .TLS_AES_128_CCM_8_SHA256 => { + @panic("TODO"); + }, + }; + std.debug.print("remaining bytes: {d}\n", .{len - end}); + return .{ + .application_cipher = app_cipher, + .read_seq = 0, + .write_seq = 0, + .partially_read_buffer = undefined, + .partially_read_len = 0, + .eof = false, + }; + }, + else => { + std.debug.print("handshake type: {d}\n", .{cleartext[0]}); + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len - 1) break; } }, else => { - std.debug.print("inner content type: {d}\n", .{inner_ct}); + std.debug.print("inner content type: {any}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } @@ -803,7 +818,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { tls.write_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{ + //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ // tls.write_seq - 1, // std.fmt.fmtSliceHexLower(&nonce), // std.fmt.fmtSliceHexLower(&p.client_key), From 02c33d02e05f3dd067bc5492d2617b7805ef897d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 13:57:56 -0700 Subject: [PATCH 10/59] std.crypto.Tls: parse encrypted extensions --- lib/std/crypto/Tls.zig | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 4ea64f1be9..105a3834e4 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -592,9 +592,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const end_hdr = i + 5; if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; if (end_hdr > len) { - std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len }); len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); - std.debug.print("new len: {d} bytes\n", .{len}); if (end_hdr > len) return error.EndOfStream; } const ct = @intToEnum(ContentType, handshake_buf[i]); @@ -605,12 +603,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); i += 2; const end = i + record_size; - std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end }); if (end > handshake_buf.len) return error.TlsRecordOverflow; if (end > len) { - std.debug.print("read len={d} atleast={d}\n", .{ len, end - len }); len += try stream.readAtLeast(handshake_buf[len..], end - len); - std.debug.print("new len: {d} bytes\n", .{len}); if (end > len) return error.EndOfStream; } switch (ct) { @@ -665,11 +660,25 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { return error.TlsBadLength; switch (handshake_type) { @enumToInt(HandshakeType.encrypted_extensions) => { - const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); + const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); ct_i += 2; - std.debug.print("{d} bytes of encrypted extensions\n", .{ - ext_size, - }); + const end_ext_i = ct_i + total_ext_size; + while (ct_i < end_ext_i) { + const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]); + ct_i += 2; + const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); + ct_i += 2; + const next_ext_i = ct_i + ext_size; + switch (et) { + @enumToInt(ExtensionType.server_name) => {}, + else => { + std.debug.print("encrypted extension: {any}\n", .{ + et, + }); + }, + } + ct_i = next_ext_i; + } }, @enumToInt(HandshakeType.certificate) => { std.debug.print("cool certificate bro\n", .{}); @@ -887,19 +896,18 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { // Capacity of output buffer, in records, rounded up. const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); - const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]); + const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)]; + const actual_read_len = try stream.read(ask_slice); const frag = in_buf[0 .. prev_len + actual_read_len]; if (frag.len == 0) { tls.eof = true; return 0; } - std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len }); var in: usize = 0; var out: usize = 0; while (true) { if (in + ciphertext_record_header_len > frag.len) { - std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len }); return finishRead(tls, frag, in, out); } const ct = @intToEnum(ContentType, frag[in]); @@ -912,7 +920,6 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const end = in + record_size; if (end > frag.len) { if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; - std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len }); return finishRead(tls, frag, in, out); } switch (ct) { @@ -980,6 +987,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { } }, else => { + std.debug.print("unexpected ct: {any}\n", .{ct}); return error.TlsUnexpectedMessage; }, } From 93ab8be8d8464452af6d2e686e91be2c1da98979 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 14:08:45 -0700 Subject: [PATCH 11/59] extract std.crypto.tls.Client into separate namespace --- lib/std/crypto.zig | 2 +- lib/std/crypto/tls.zig | 325 ++++++++++++++++ lib/std/crypto/{Tls.zig => tls/Client.zig} | 414 +++------------------ lib/std/http/Client.zig | 12 +- 4 files changed, 385 insertions(+), 368 deletions(-) create mode 100644 lib/std/crypto/tls.zig rename lib/std/crypto/{Tls.zig => tls/Client.zig} (75%) diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 7b4a116d35..2f7f729302 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -176,7 +176,7 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); -pub const Tls = @import("crypto/Tls.zig"); +pub const tls = @import("crypto/tls.zig"); test { _ = aead.aegis.Aegis128L; diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig new file mode 100644 index 0000000000..cc47a3f688 --- /dev/null +++ b/lib/std/crypto/tls.zig @@ -0,0 +1,325 @@ +//! Plaintext: +//! * type: ContentType +//! * legacy_record_version: u16 = 0x0303, +//! * length: u16, +//! - The length (in bytes) of the following TLSPlaintext.fragment. The +//! length MUST NOT exceed 2^14 bytes. +//! * fragment: opaque +//! - the data being transmitted +//! +//! Ciphertext +//! * ContentType opaque_type = application_data; /* 23 */ +//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ +//! * uint16 length; +//! * opaque encrypted_record[TLSCiphertext.length]; +//! +//! Handshake: +//! * type: HandshakeType +//! * length: u24 +//! * data: opaque +//! +//! ServerHello: +//! * ProtocolVersion legacy_version = 0x0303; +//! * Random random; +//! * opaque legacy_session_id_echo<0..32>; +//! * CipherSuite cipher_suite; +//! * uint8 legacy_compression_method = 0; +//! * Extension extensions<6..2^16-1>; +//! +//! Extension: +//! * ExtensionType extension_type; +//! * opaque extension_data<0..2^16-1>; + +const std = @import("../std.zig"); +const Tls = @This(); +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const assert = std.debug.assert; + +pub const Client = @import("tls/Client.zig"); + +pub const ciphertext_record_header_len = 5; +pub const max_ciphertext_len = (1 << 14) + 256; +pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; +pub const hello_retry_request_sequence = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +}; + +pub const ProtocolVersion = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; + +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +pub const HandshakeType = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + certificate_request = 13, + certificate_verify = 15, + finished = 20, + key_update = 24, + message_hash = 254, +}; + +pub const ExtensionType = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, +}; + +pub const AlertLevel = enum(u8) { + warning = 1, + fatal = 2, + _, +}; + +pub const AlertDescription = enum(u8) { + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, +}; + +pub const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +pub const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + _, +}; + +pub const CipherSuite = enum(u16) { + TLS_AES_128_GCM_SHA256 = 0x1301, + TLS_AES_256_GCM_SHA384 = 0x1302, + TLS_CHACHA20_POLY1305_SHA256 = 0x1303, + TLS_AES_128_CCM_SHA256 = 0x1304, + TLS_AES_128_CCM_8_SHA256 = 0x1305, +}; + +pub const CipherParams = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + pub const Hash = crypto.hash.sha2.Sha256; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.prk_length]u8, + master_secret: [Hkdf.prk_length]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_AES_256_GCM_SHA384: struct { + pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + pub const Hash = crypto.hash.sha2.Sha384; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + handshake_secret: [Hkdf.prk_length]u8, + master_secret: [Hkdf.prk_length]u8, + client_handshake_key: [AEAD.key_length]u8, + server_handshake_key: [AEAD.key_length]u8, + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + client_handshake_iv: [AEAD.nonce_length]u8, + server_handshake_iv: [AEAD.nonce_length]u8, + transcript_hash: Hash, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + +/// Encryption parameters for application traffic. +pub const ApplicationCipher = union(CipherSuite) { + TLS_AES_128_GCM_SHA256: struct { + pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm; + pub const Hash = crypto.hash.sha2.Sha256; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_AES_256_GCM_SHA384: struct { + pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm; + pub const Hash = crypto.hash.sha2.Sha384; + pub const Hmac = crypto.auth.hmac.Hmac(Hash); + pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }, + TLS_CHACHA20_POLY1305_SHA256: void, + TLS_AES_128_CCM_SHA256: void, + TLS_AES_128_CCM_8_SHA256: void, +}; + +pub fn hkdfExpandLabel( + comptime Hkdf: type, + key: [Hkdf.prk_length]u8, + label: []const u8, + context: []const u8, + comptime len: usize, +) [len]u8 { + const max_label_len = 255; + const max_context_len = 255; + const tls13 = "tls13 "; + var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined; + mem.writeIntBig(u16, buf[0..2], len); + buf[2] = @intCast(u8, tls13.len + label.len); + buf[3..][0..tls13.len].* = tls13.*; + var i: usize = 3 + tls13.len; + mem.copy(u8, buf[i..], label); + i += label.len; + buf[i] = @intCast(u8, context.len); + i += 1; + mem.copy(u8, buf[i..], context); + i += context.len; + + var result: [len]u8 = undefined; + Hkdf.expand(&result, buf[0..i], key); + return result; +} + +pub fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { + var result: [Hash.digest_length]u8 = undefined; + Hash.hash(&.{}, &result, .{}); + return result; +} + +pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 { + var result: [Hmac.mac_length]u8 = undefined; + Hmac.create(&result, message, &key); + return result; +} diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/tls/Client.zig similarity index 75% rename from lib/std/crypto/Tls.zig rename to lib/std/crypto/tls/Client.zig index 105a3834e4..93693e40e9 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,297 +1,28 @@ -const std = @import("../std.zig"); -const Tls = @This(); +const std = @import("../../std.zig"); +const tls = std.crypto.tls; +const Client = @This(); const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; +const ApplicationCipher = tls.ApplicationCipher; +const CipherSuite = tls.CipherSuite; +const ContentType = tls.ContentType; +const HandshakeType = tls.HandshakeType; +const CipherParams = tls.CipherParams; +const max_ciphertext_len = tls.max_ciphertext_len; +const hkdfExpandLabel = tls.hkdfExpandLabel; + application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, /// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [max_ciphertext_record_len]u8, +partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partiall_read_buffer`. partially_read_len: u15, eof: bool, -pub const ciphertext_record_header_len = 5; -pub const max_ciphertext_len = (1 << 14) + 256; -pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; -pub const hello_retry_request_sequence = [32]u8{ - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -}; - -pub const ProtocolVersion = enum(u16) { - tls_1_2 = 0x0303, - tls_1_3 = 0x0304, - _, -}; - -pub const ContentType = enum(u8) { - invalid = 0, - change_cipher_spec = 20, - alert = 21, - handshake = 22, - application_data = 23, - _, -}; - -pub const HandshakeType = enum(u8) { - client_hello = 1, - server_hello = 2, - new_session_ticket = 4, - end_of_early_data = 5, - encrypted_extensions = 8, - certificate = 11, - certificate_request = 13, - certificate_verify = 15, - finished = 20, - key_update = 24, - message_hash = 254, -}; - -pub const ExtensionType = enum(u16) { - /// RFC 6066 - server_name = 0, - /// RFC 6066 - max_fragment_length = 1, - /// RFC 6066 - status_request = 5, - /// RFC 8422, 7919 - supported_groups = 10, - /// RFC 8446 - signature_algorithms = 13, - /// RFC 5764 - use_srtp = 14, - /// RFC 6520 - heartbeat = 15, - /// RFC 7301 - application_layer_protocol_negotiation = 16, - /// RFC 6962 - signed_certificate_timestamp = 18, - /// RFC 7250 - client_certificate_type = 19, - /// RFC 7250 - server_certificate_type = 20, - /// RFC 7685 - padding = 21, - /// RFC 8446 - pre_shared_key = 41, - /// RFC 8446 - early_data = 42, - /// RFC 8446 - supported_versions = 43, - /// RFC 8446 - cookie = 44, - /// RFC 8446 - psk_key_exchange_modes = 45, - /// RFC 8446 - certificate_authorities = 47, - /// RFC 8446 - oid_filters = 48, - /// RFC 8446 - post_handshake_auth = 49, - /// RFC 8446 - signature_algorithms_cert = 50, - /// RFC 8446 - key_share = 51, -}; - -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; - -pub const AlertDescription = enum(u8) { - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, -}; - -pub const SignatureScheme = enum(u16) { - // RSASSA-PKCS1-v1_5 algorithms - rsa_pkcs1_sha256 = 0x0401, - rsa_pkcs1_sha384 = 0x0501, - rsa_pkcs1_sha512 = 0x0601, - - // ECDSA algorithms - ecdsa_secp256r1_sha256 = 0x0403, - ecdsa_secp384r1_sha384 = 0x0503, - ecdsa_secp521r1_sha512 = 0x0603, - - // RSASSA-PSS algorithms with public key OID rsaEncryption - rsa_pss_rsae_sha256 = 0x0804, - rsa_pss_rsae_sha384 = 0x0805, - rsa_pss_rsae_sha512 = 0x0806, - - // EdDSA algorithms - ed25519 = 0x0807, - ed448 = 0x0808, - - // RSASSA-PSS algorithms with public key OID RSASSA-PSS - rsa_pss_pss_sha256 = 0x0809, - rsa_pss_pss_sha384 = 0x080a, - rsa_pss_pss_sha512 = 0x080b, - - // Legacy algorithms - rsa_pkcs1_sha1 = 0x0201, - ecdsa_sha1 = 0x0203, - - _, -}; - -pub const NamedGroup = enum(u16) { - // Elliptic Curve Groups (ECDHE) - secp256r1 = 0x0017, - secp384r1 = 0x0018, - secp521r1 = 0x0019, - x25519 = 0x001D, - x448 = 0x001E, - - // Finite Field Groups (DHE) - ffdhe2048 = 0x0100, - ffdhe3072 = 0x0101, - ffdhe4096 = 0x0102, - ffdhe6144 = 0x0103, - ffdhe8192 = 0x0104, - - _, -}; - -// Plaintext: -// * type: ContentType -// * legacy_record_version: u16 = 0x0303, -// * length: u16, -// - The length (in bytes) of the following TLSPlaintext.fragment. The -// length MUST NOT exceed 2^14 bytes. -// * fragment: opaque -// - the data being transmitted - -// Ciphertext -// * ContentType opaque_type = application_data; /* 23 */ -// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ -// * uint16 length; -// * opaque encrypted_record[TLSCiphertext.length]; - -// Handshake: -// * type: HandshakeType -// * length: u24 -// * data: opaque - -// ServerHello: -// * ProtocolVersion legacy_version = 0x0303; -// * Random random; -// * opaque legacy_session_id_echo<0..32>; -// * CipherSuite cipher_suite; -// * uint8 legacy_compression_method = 0; -// * Extension extensions<6..2^16-1>; - -// Extension: -// * ExtensionType extension_type; -// * opaque extension_data<0..2^16-1>; - -pub const CipherSuite = enum(u16) { - TLS_AES_128_GCM_SHA256 = 0x1301, - TLS_AES_256_GCM_SHA384 = 0x1302, - TLS_CHACHA20_POLY1305_SHA256 = 0x1303, - TLS_AES_128_CCM_SHA256 = 0x1304, - TLS_AES_128_CCM_8_SHA256 = 0x1305, -}; - -pub const CipherParams = union(CipherSuite) { - TLS_AES_128_GCM_SHA256: struct { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const Hash = crypto.hash.sha2.Sha256; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - handshake_secret: [Hkdf.key_len]u8, - master_secret: [Hkdf.key_len]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, - }, - TLS_AES_256_GCM_SHA384: struct { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const Hash = crypto.hash.sha2.Sha384; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - handshake_secret: [Hkdf.key_len]u8, - master_secret: [Hkdf.key_len]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, - }, - TLS_CHACHA20_POLY1305_SHA256: void, - TLS_AES_128_CCM_SHA256: void, - TLS_AES_128_CCM_8_SHA256: void, -}; - -/// Encryption parameters for application traffic. -pub const ApplicationCipher = union(CipherSuite) { - TLS_AES_128_GCM_SHA256: struct { - const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - const Hash = crypto.hash.sha2.Sha256; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, - }, - TLS_AES_256_GCM_SHA384: struct { - const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - const Hash = crypto.hash.sha2.Sha384; - const Hmac = crypto.auth.hmac.Hmac(Hash); - const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, - }, - TLS_CHACHA20_POLY1305_SHA256: void, - TLS_AES_128_CCM_SHA256: void, - TLS_AES_128_CCM_8_SHA256: void, -}; - const cipher_suites = blk: { const fields = @typeInfo(CipherSuite).Enum.fields; var result: [(fields.len + 1) * 2]u8 = undefined; @@ -305,7 +36,7 @@ const cipher_suites = blk: { }; /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, host: []const u8) !Tls { +pub fn init(stream: net.Stream, host: []const u8) !Client { var x25519_priv_key: [32]u8 = undefined; crypto.random.bytes(&x25519_priv_key); const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| { @@ -440,8 +171,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { switch (ct) { .alert => { - const level = @intToEnum(AlertLevel, frag[0]); - const desc = @intToEnum(AlertDescription, frag[1]); + const level = @intToEnum(tls.AlertLevel, frag[0]); + const desc = @intToEnum(tls.AlertDescription, frag[1]); std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); return error.TlsAlert; }, @@ -454,7 +185,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const hello = frag[4..]; const legacy_version = mem.readIntBig(u16, hello[0..2]); const random = hello[2..34].*; - if (mem.eql(u8, &random, &hello_retry_request_sequence)) { + if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) { @panic("TODO handle HelloRetryRequest"); } const legacy_session_id_echo_len = hello[34]; @@ -478,16 +209,16 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const next_i = i + ext_size; if (next_i > hello.len) return error.TlsBadLength; switch (et) { - @enumToInt(ExtensionType.supported_versions) => { + @enumToInt(tls.ExtensionType.supported_versions) => { if (supported_version != 0) return error.TlsIllegalParameter; supported_version = mem.readIntBig(u16, hello[i..][0..2]); }, - @enumToInt(ExtensionType.key_share) => { + @enumToInt(tls.ExtensionType.key_share) => { if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; const named_group = mem.readIntBig(u16, hello[i..][0..2]); i += 2; switch (named_group) { - @enumToInt(NamedGroup.x25519) => { + @enumToInt(tls.NamedGroup.x25519) => { const key_size = mem.readIntBig(u16, hello[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; @@ -509,10 +240,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { - @enumToInt(ProtocolVersion.tls_1_2) => { + @enumToInt(tls.ProtocolVersion.tls_1_2) => { std.debug.print("server wants TLS v1.2\n", .{}); }, - @enumToInt(ProtocolVersion.tls_1_3) => { + @enumToInt(tls.ProtocolVersion.tls_1_3) => { std.debug.print("server wants TLS v1.3\n", .{}); }, else => return error.TlsIllegalParameter, @@ -544,7 +275,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const hello_hash = p.transcript_hash.peek(); const zeroes = [1]u8{0} ** P.Hash.digest_length; const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = emptyHash(P.Hash); + const empty_hash = tls.emptyHash(P.Hash); const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key); const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); @@ -670,7 +401,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { ct_i += 2; const next_ext_i = ct_i + ext_size; switch (et) { - @enumToInt(ExtensionType.server_name) => {}, + @enumToInt(tls.ExtensionType.server_name) => {}, else => { std.debug.print("encrypted extension: {any}\n", .{ et, @@ -699,7 +430,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const P = @TypeOf(p.*); // TODO verify the server's data const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ @enumToInt(HandshakeType.finished), 0, 0, verify_data.len, // length @@ -782,8 +513,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { return error.TlsHandshakeFailure; } -pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { - var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined; +pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { + var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; @@ -792,16 +523,16 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { var iovec_end: usize = 0; var bytes_i: usize = 0; // How many bytes are taken up by overhead per record. - const overhead_len: usize = switch (tls.application_cipher) { + const overhead_len: usize = switch (c.application_cipher) { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); - const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1; + const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1; while (true) { const encrypted_content_len = @intCast(u16, @min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), ciphertext_buf.len - - ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, + tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, )); if (encrypted_content_len == 0) break :l overhead_len; @@ -815,7 +546,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { const ad = ciphertext_buf[ciphertext_end..][0..5]; ad.* = [_]u8{@enumToInt(ContentType.application_data)} ++ - int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ + int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ int2(ciphertext_len + P.AEAD.tag_length); ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; @@ -823,12 +554,12 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; ciphertext_end += auth_tag.len; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq)); - tls.write_seq += 1; + const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq)); + c.write_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ - // tls.write_seq - 1, + // c.write_seq - 1, // std.fmt.fmtSliceHexLower(&nonce), // std.fmt.fmtSliceHexLower(&p.client_key), // std.fmt.fmtSliceHexLower(&p.client_iv), @@ -878,37 +609,37 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { } } -pub fn writeAll(tls: *Tls, stream: net.Stream, bytes: []const u8) !void { +pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { var index: usize = 0; while (index < bytes.len) { - index += try tls.write(stream, bytes[index..]); + index += try c.write(stream, bytes[index..]); } } /// Returns number of bytes that have been read, which are now populated inside /// `buffer`. A return value of zero bytes does not necessarily mean end of /// stream. -pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { - const prev_len = tls.partially_read_len; +pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { + const prev_len = c.partially_read_len; var in_buf: [max_ciphertext_len * 4]u8 = undefined; - mem.copy(u8, &in_buf, tls.partially_read_buffer[0..prev_len]); + mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]); // Capacity of output buffer, in records, rounded up. const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; - const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); + const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)]; const actual_read_len = try stream.read(ask_slice); const frag = in_buf[0 .. prev_len + actual_read_len]; if (frag.len == 0) { - tls.eof = true; + c.eof = true; return 0; } var in: usize = 0; var out: usize = 0; while (true) { - if (in + ciphertext_record_header_len > frag.len) { - return finishRead(tls, frag, in, out); + if (in + tls.ciphertext_record_header_len > frag.len) { + return finishRead(c, frag, in, out); } const ct = @intToEnum(ContentType, frag[in]); in += 1; @@ -920,14 +651,14 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const end = in + record_size; if (end > frag.len) { if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; - return finishRead(tls, frag, in, out); + return finishRead(c, frag, in, out); } switch (ct) { .alert => { @panic("TODO handle an alert here"); }, .application_data => { - const cleartext_len = switch (tls.application_cipher) { + const cleartext_len = switch (c.application_cipher) { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); @@ -938,11 +669,11 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const auth_tag = frag[in..][0..P.AEAD.tag_length].*; const cleartext = buffer[out..][0..ciphertext_len]; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq)); - tls.read_seq += 1; + const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); + c.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ - // tls.read_seq - 1, + // c.read_seq - 1, // std.fmt.fmtSliceHexLower(&nonce), // std.fmt.fmtSliceHexLower(&p.server_key), // std.fmt.fmtSliceHexLower(&p.server_iv), @@ -965,10 +696,10 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]); switch (inner_ct) { .alert => { - const level = @intToEnum(AlertLevel, buffer[out]); - const desc = @intToEnum(AlertDescription, buffer[out + 1]); + const level = @intToEnum(tls.AlertLevel, buffer[out]); + const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]); if (desc == .close_notify) { - tls.eof = true; + c.eof = true; return out; } std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); @@ -995,52 +726,13 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { } } -fn finishRead(tls: *Tls, frag: []const u8, in: usize, out: usize) usize { +fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; - mem.copy(u8, &tls.partially_read_buffer, saved_buf); - tls.partially_read_len = @intCast(u15, saved_buf.len); + mem.copy(u8, &c.partially_read_buffer, saved_buf); + c.partially_read_len = @intCast(u15, saved_buf.len); return out; } -fn hkdfExpandLabel( - comptime Hkdf: type, - key: [Hkdf.prk_length]u8, - label: []const u8, - context: []const u8, - comptime len: usize, -) [len]u8 { - const max_label_len = 255; - const max_context_len = 255; - const tls13 = "tls13 "; - var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined; - mem.writeIntBig(u16, buf[0..2], len); - buf[2] = @intCast(u8, tls13.len + label.len); - buf[3..][0..tls13.len].* = tls13.*; - var i: usize = 3 + tls13.len; - mem.copy(u8, buf[i..], label); - i += label.len; - buf[i] = @intCast(u8, context.len); - i += 1; - mem.copy(u8, buf[i..], context); - i += context.len; - - var result: [len]u8 = undefined; - Hkdf.expand(&result, buf[0..i], key); - return result; -} - -fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 { - var result: [Hash.digest_length]u8 = undefined; - Hash.hash(&.{}, &result, .{}); - return result; -} - -fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 { - var result: [Hmac.mac_length]u8 = undefined; - Hmac.create(&result, message, &key); - return result; -} - const builtin = @import("builtin"); const native_endian = builtin.cpu.arch.endian(); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 2c92163435..fcadf3669b 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -12,7 +12,7 @@ pub const Request = struct { client: *Client, stream: net.Stream, headers: std.ArrayListUnmanaged(u8) = .{}, - tls: std.crypto.Tls, + tls_client: std.crypto.tls.Client, protocol: Protocol, pub const Protocol = enum { http, https }; @@ -51,7 +51,7 @@ pub const Request = struct { try req.stream.writeAll(req.headers.items); }, .https => { - try req.tls.writeAll(req.stream, req.headers.items); + try req.tls_client.writeAll(req.stream, req.headers.items); }, } } @@ -59,7 +59,7 @@ pub const Request = struct { pub fn read(req: *Request, buffer: []u8) !usize { switch (req.protocol) { .http => return req.stream.read(buffer), - .https => return req.tls.read(req.stream, buffer), + .https => return req.tls_client.read(req.stream, buffer), } } @@ -74,7 +74,7 @@ pub const Request = struct { if (amt == 0) { switch (req.protocol) { .http => break, - .https => if (req.tls.eof) break, + .https => if (req.tls_client.eof) break, } } index += amt; @@ -94,7 +94,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { .client = client, .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), .protocol = options.protocol, - .tls = undefined, + .tls_client = undefined, }; client.active_requests += 1; errdefer req.deinit(); @@ -102,7 +102,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { switch (options.protocol) { .http => {}, .https => { - req.tls = try std.crypto.Tls.init(req.stream, options.host); + req.tls_client = try std.crypto.tls.Client.init(req.stream, options.host); }, } From 942b5b468fe0a517618b62f0260d3a32c7cc642e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 18:06:00 -0700 Subject: [PATCH 12/59] std.crypto.tls: implement the rest of the cipher suites Also: * Use KeyPair.create() function * Don't bother with CCM --- lib/std/crypto/tls.zig | 92 ++++++++++++---------------- lib/std/crypto/tls/Client.zig | 112 ++++++++++++---------------------- lib/std/meta.zig | 12 ++-- 3 files changed, 85 insertions(+), 131 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index cc47a3f688..babd8b465d 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -211,17 +211,20 @@ pub const NamedGroup = enum(u16) { }; pub const CipherSuite = enum(u16) { - TLS_AES_128_GCM_SHA256 = 0x1301, - TLS_AES_256_GCM_SHA384 = 0x1302, - TLS_CHACHA20_POLY1305_SHA256 = 0x1303, - TLS_AES_128_CCM_SHA256 = 0x1304, - TLS_AES_128_CCM_8_SHA256 = 0x1305, + AES_128_GCM_SHA256 = 0x1301, + AES_256_GCM_SHA384 = 0x1302, + CHACHA20_POLY1305_SHA256 = 0x1303, + AES_128_CCM_SHA256 = 0x1304, + AES_128_CCM_8_SHA256 = 0x1305, + AEGIS_256_SHA384 = 0x1306, + AEGIS_128L_SHA256 = 0x1307, + _, }; -pub const CipherParams = union(CipherSuite) { - TLS_AES_128_GCM_SHA256: struct { - pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - pub const Hash = crypto.hash.sha2.Sha256; +pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { + return struct { + pub const AEAD = AeadType; + pub const Hash = HashType; pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); @@ -234,55 +237,38 @@ pub const CipherParams = union(CipherSuite) { client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, transcript_hash: Hash, - }, - TLS_AES_256_GCM_SHA384: struct { - pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - pub const Hash = crypto.hash.sha2.Sha384; + }; +} + +pub const CipherParams = union(enum) { + AES_128_GCM_SHA256: CipherParamsT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), + AES_256_GCM_SHA384: CipherParamsT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), + CHACHA20_POLY1305_SHA256: CipherParamsT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), + AEGIS_256_SHA384: CipherParamsT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), + AEGIS_128L_SHA256: CipherParamsT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), +}; + +pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { + return struct { + pub const AEAD = AeadType; + pub const Hash = HashType; pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - handshake_secret: [Hkdf.prk_length]u8, - master_secret: [Hkdf.prk_length]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, - }, - TLS_CHACHA20_POLY1305_SHA256: void, - TLS_AES_128_CCM_SHA256: void, - TLS_AES_128_CCM_8_SHA256: void, -}; + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }; +} /// Encryption parameters for application traffic. -pub const ApplicationCipher = union(CipherSuite) { - TLS_AES_128_GCM_SHA256: struct { - pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm; - pub const Hash = crypto.hash.sha2.Sha256; - pub const Hmac = crypto.auth.hmac.Hmac(Hash); - pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, - }, - TLS_AES_256_GCM_SHA384: struct { - pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm; - pub const Hash = crypto.hash.sha2.Sha384; - pub const Hmac = crypto.auth.hmac.Hmac(Hash); - pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, - }, - TLS_CHACHA20_POLY1305_SHA256: void, - TLS_AES_128_CCM_SHA256: void, - TLS_AES_128_CCM_8_SHA256: void, +pub const ApplicationCipher = union(enum) { + AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), + AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), + CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), + AEGIS_256_SHA384: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), + AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), }; pub fn hkdfExpandLabel( diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 93693e40e9..a2f8ff5733 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -23,27 +23,27 @@ partially_read_buffer: [tls.max_ciphertext_record_len]u8, partially_read_len: u15, eof: bool, -const cipher_suites = blk: { - const fields = @typeInfo(CipherSuite).Enum.fields; - var result: [(fields.len + 1) * 2]u8 = undefined; - mem.writeIntBig(u16, result[0..2], result.len - 2); - for (fields) |field, i| { - const int = @enumToInt(@field(CipherSuite, field.name)); - result[(i + 1) * 2] = @truncate(u8, int >> 8); - result[(i + 1) * 2 + 1] = @truncate(u8, int); - } - break :blk result; -}; +// Measurement taken with 0.11.0-dev.810+c2f5848fe +// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +// aegis-128l: 15382 MiB/s +// aegis-256: 9553 MiB/s +// aes128-gcm: 3721 MiB/s +// aes256-gcm: 3010 MiB/s +// chacha20Poly1305: 597 MiB/s + +const cipher_suites = + int2(@enumToInt(tls.CipherSuite.AEGIS_128L_SHA256)) ++ + int2(@enumToInt(tls.CipherSuite.AEGIS_256_SHA384)) ++ + int2(@enumToInt(tls.CipherSuite.AES_128_GCM_SHA256)) ++ + int2(@enumToInt(tls.CipherSuite.AES_256_GCM_SHA384)) ++ + int2(@enumToInt(tls.CipherSuite.CHACHA20_POLY1305_SHA256)); /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, host: []const u8) !Client { - var x25519_priv_key: [32]u8 = undefined; - crypto.random.bytes(&x25519_priv_key); - const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| { - switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - } + const kp = crypto.dh.X25519.KeyPair.create(null) catch |err| switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, }; // random (u32) @@ -98,7 +98,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 0, 32, // byte length of key_exchange - } ++ x25519_pub_key ++ [_]u8{ + } ++ kp.public_key ++ [_]u8{ // Extension: server_name 0, 0, // ExtensionType.server_name @@ -120,7 +120,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { // ClientHello 0x03, 0x03, // legacy_version - } ++ rand_buf ++ [1]u8{0} ++ cipher_suites ++ [_]u8{ + } ++ rand_buf ++ [1]u8{0} ++ + int2(cipher_suites.len) ++ cipher_suites ++ + [_]u8{ 0x01, 0x00, // legacy_compression_methods } ++ extensions_header; @@ -191,9 +193,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const legacy_session_id_echo_len = hello[34]; if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); - const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch - return error.TlsIllegalParameter; - std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)}); + const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int); + std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag}); const legacy_compression_method = hello[37]; _ = legacy_compression_method; const extensions_size = mem.readIntBig(u16, hello[38..40]); @@ -250,13 +251,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { } const shared_key = crypto.dh.X25519.scalarmult( - x25519_priv_key, + kp.secret_key, x25519_server_pub_key.*, ) catch return error.TlsDecryptFailure; switch (cipher_suite_tag) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| { - const P = std.meta.TagPayload(CipherParams, tag); + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA384, + .AEGIS_128L_SHA256, + => |tag| { + const P = std.meta.TagPayloadByName(CipherParams, @tagName(tag)); cipher_params = @unionInit(CipherParams, @tagName(tag), .{ .handshake_secret = undefined, .master_secret = undefined, @@ -301,14 +307,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { // std.fmt.fmtSliceHexLower(&p.server_handshake_iv), //}); }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); + else => { + return error.TlsIllegalParameter; }, } }, @@ -347,7 +347,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .application_data => { var cleartext_buf: [8000]u8 = undefined; const cleartext = switch (cipher_params) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + inline else => |*p| c: { const P = @TypeOf(p.*); const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext = handshake_buf[i..][0..ciphertext_len]; @@ -366,15 +366,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]); break :c cleartext; }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, }; const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); @@ -426,7 +417,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { 0x01, }; const app_cipher = switch (cipher_params) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: { + inline else => |*p, tag| c: { const P = @TypeOf(p.*); // TODO verify the server's data const handshake_hash = p.transcript_hash.finalResult(); @@ -467,15 +458,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), }); }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, }; std.debug.print("remaining bytes: {d}\n", .{len - end}); return .{ @@ -524,7 +506,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { var bytes_i: usize = 0; // How many bytes are taken up by overhead per record. const overhead_len: usize = switch (c.application_cipher) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: { + inline else => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1; @@ -577,15 +559,6 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { iovec_end += 1; } }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, }; // Ideally we would call writev exactly once here, however, we must ensure @@ -659,7 +632,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { }, .application_data => { const cleartext_len = switch (c.application_cipher) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { + inline else => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const ad = frag[in - 5 ..][0..5]; @@ -682,15 +655,6 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsBadRecordMac; break :c cleartext.len; }, - .TLS_CHACHA20_POLY1305_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_SHA256 => { - @panic("TODO"); - }, - .TLS_AES_128_CCM_8_SHA256 => { - @panic("TODO"); - }, }; const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]); diff --git a/lib/std/meta.zig b/lib/std/meta.zig index 39d561469f..db284f8b61 100644 --- a/lib/std/meta.zig +++ b/lib/std/meta.zig @@ -810,21 +810,25 @@ test "std.meta.activeTag" { const TagPayloadType = TagPayload; -///Given a tagged union type, and an enum, return the type of the union -/// field corresponding to the enum tag. -pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type { +pub fn TagPayloadByName(comptime U: type, comptime tag_name: []const u8) type { comptime debug.assert(trait.is(.Union)(U)); const info = @typeInfo(U).Union; inline for (info.fields) |field_info| { - if (comptime mem.eql(u8, field_info.name, @tagName(tag))) + if (comptime mem.eql(u8, field_info.name, tag_name)) return field_info.type; } unreachable; } +/// Given a tagged union type, and an enum, return the type of the union field +/// corresponding to the enum tag. +pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type { + return TagPayloadByName(U, @tagName(tag)); +} + test "std.meta.TagPayload" { const Event = union(enum) { Moved: struct { From 8ef4dcd39f08a543f067a9820d82840dbf9d2ce5 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 16 Dec 2022 18:21:40 -0700 Subject: [PATCH 13/59] std.crypto.tls: add some benchmark data points Looks like aegis-128l is the winner on baseline too. --- lib/std/crypto/tls/Client.zig | 44 ++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index a2f8ff5733..40bdf4d61a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -23,22 +23,6 @@ partially_read_buffer: [tls.max_ciphertext_record_len]u8, partially_read_len: u15, eof: bool, -// Measurement taken with 0.11.0-dev.810+c2f5848fe -// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -// aegis-128l: 15382 MiB/s -// aegis-256: 9553 MiB/s -// aes128-gcm: 3721 MiB/s -// aes256-gcm: 3010 MiB/s -// chacha20Poly1305: 597 MiB/s - -const cipher_suites = - int2(@enumToInt(tls.CipherSuite.AEGIS_128L_SHA256)) ++ - int2(@enumToInt(tls.CipherSuite.AEGIS_256_SHA384)) ++ - int2(@enumToInt(tls.CipherSuite.AES_128_GCM_SHA256)) ++ - int2(@enumToInt(tls.CipherSuite.AES_256_GCM_SHA384)) ++ - int2(@enumToInt(tls.CipherSuite.CHACHA20_POLY1305_SHA256)); - /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, host: []const u8) !Client { const kp = crypto.dh.X25519.KeyPair.create(null) catch |err| switch (err) { @@ -713,3 +697,31 @@ inline fn int2(x: u16) [2]u8 { @truncate(u8, x), }; } + +/// The priority order here is chosen based on what crypto algorithms Zig has +/// available in the standard library as well as what is faster. Following are +/// a few data points on the relative performance of these algorithms. +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +/// aegis-128l: 15382 MiB/s +/// aegis-256: 9553 MiB/s +/// aes128-gcm: 3721 MiB/s +/// aes256-gcm: 3010 MiB/s +/// chacha20Poly1305: 597 MiB/s +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline +/// aegis-128l: 629 MiB/s +/// chacha20Poly1305: 529 MiB/s +/// aegis-256: 461 MiB/s +/// aes128-gcm: 138 MiB/s +/// aes256-gcm: 120 MiB/s +const cipher_suites = + int2(@enumToInt(tls.CipherSuite.AEGIS_128L_SHA256)) ++ + int2(@enumToInt(tls.CipherSuite.AEGIS_256_SHA384)) ++ + int2(@enumToInt(tls.CipherSuite.AES_128_GCM_SHA256)) ++ + int2(@enumToInt(tls.CipherSuite.AES_256_GCM_SHA384)) ++ + int2(@enumToInt(tls.CipherSuite.CHACHA20_POLY1305_SHA256)); From f6c3a86f0f570a2feb721e443efea3319d19d098 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 16:03:22 -0700 Subject: [PATCH 14/59] std.crypto.tls.Client: remove unnecessary coercion --- lib/std/crypto/tls/Client.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 40bdf4d61a..4afc1b7e17 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -343,7 +343,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); read_seq += 1; - const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_handshake_iv) ^ operand; + const nonce = @as(V, p.server_handshake_iv) ^ operand; const ad = handshake_buf[end_hdr - 5 ..][0..5]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; @@ -522,7 +522,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq)); c.write_seq += 1; - const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; + const nonce = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ // c.write_seq - 1, From 41f4461cdabb50e45c0f956d4a1380d2008cd127 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 16:17:10 -0700 Subject: [PATCH 15/59] std.crypto.tls.Client: verify the server's Finished message --- lib/std/crypto/tls.zig | 1 + lib/std/crypto/tls/Client.zig | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index babd8b465d..3944b7c974 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -237,6 +237,7 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, transcript_hash: Hash, + finished_digest: [Hash.digest_length]u8, }; } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 4afc1b7e17..c4ac6e508a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -257,6 +257,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .client_handshake_iv = undefined, .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), + .finished_digest = undefined, }); const p = &@field(cipher_params, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 @@ -391,6 +392,11 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }, @enumToInt(HandshakeType.certificate_verify) => { std.debug.print("the certificate came with a fancy signature\n", .{}); + switch (cipher_params) { + inline else => |*p| { + p.finished_digest = p.transcript_hash.peek(); + }, + } }, @enumToInt(HandshakeType.finished) => { // This message is to trick buggy proxies into behaving correctly. @@ -403,7 +409,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const app_cipher = switch (cipher_params) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); - // TODO verify the server's data + const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key); + const actual_server_verify_data = cleartext[ct_i..][0..handshake_len]; + if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data)) + return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ @@ -454,7 +463,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }; }, else => { - std.debug.print("handshake type: {d}\n", .{cleartext[0]}); return error.TlsUnexpectedMessage; }, } From e2efba76aa0e1566da65721db64537d94fea69df Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 18:01:13 -0700 Subject: [PATCH 16/59] std.crypto.tls: refactor to remove mutations build up the hello message with array concatenation and helper functions rather than hard-coded offsets and lengths. --- lib/std/crypto/tls.zig | 34 ++++++++ lib/std/crypto/tls/Client.zig | 152 +++++++++++++++------------------- 2 files changed, 99 insertions(+), 87 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 3944b7c974..d2a347e87e 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -310,3 +310,37 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) Hmac.create(&result, message, &key); return result; } + +pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { + return int2(@enumToInt(et)) ++ array(1, bytes); +} + +pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { + comptime assert(bytes.len % elem_size == 0); + return int2(bytes.len) ++ bytes; +} + +pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { + assert(@sizeOf(E) == 2); + var result: [tags.len * 2]u8 = undefined; + for (tags) |elem, i| { + result[i * 2] = @truncate(u8, @enumToInt(elem) >> 8); + result[i * 2 + 1] = @truncate(u8, @enumToInt(elem)); + } + return array(2, result); +} + +pub inline fn int2(x: u16) [2]u8 { + return .{ + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} + +pub inline fn int3(x: u24) [3]u8 { + return .{ + @truncate(u8, x >> 16), + @truncate(u8, x >> 8), + @truncate(u8, x), + }; +} diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c4ac6e508a..33146f79c9 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -13,6 +13,10 @@ const HandshakeType = tls.HandshakeType; const CipherParams = tls.CipherParams; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; +const int2 = tls.int2; +const int3 = tls.int3; +const array = tls.array; +const enum_array = tls.enum_array; application_cipher: ApplicationCipher, read_seq: u64, @@ -25,6 +29,8 @@ eof: bool, /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, host: []const u8) !Client { + const host_len = @intCast(u16, host.len); + const kp = crypto.dh.X25519.KeyPair.create(null) catch |err| switch (err) { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, @@ -34,92 +40,70 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { var rand_buf: [32]u8 = undefined; crypto.random.bytes(&rand_buf); - const extensions_header = [_]u8{ - // Extensions byte length - undefined, undefined, - - // Extension: supported_versions (only TLS 1.3) - 0, 43, // ExtensionType.supported_versions - 0x00, 0x05, // byte length of this extension payload - 0x04, // byte length of supported versions + const extensions_payload = + tls.extension(.supported_versions, [_]u8{ + 0x02, // byte length of supported versions 0x03, 0x04, // TLS 1.3 - 0x03, 0x03, // TLS 1.2 - - // Extension: signature_algorithms - 0, 13, // ExtensionType.signature_algorithms - 0x00, 0x22, // byte length of this extension payload - 0x00, 0x20, // byte length of signature algorithms list - 0x04, 0x01, // rsa_pkcs1_sha256 - 0x05, 0x01, // rsa_pkcs1_sha384 - 0x06, 0x01, // rsa_pkcs1_sha512 - 0x04, 0x03, // ecdsa_secp256r1_sha256 - 0x05, 0x03, // ecdsa_secp384r1_sha384 - 0x06, 0x03, // ecdsa_secp521r1_sha512 - 0x08, 0x04, // rsa_pss_rsae_sha256 - 0x08, 0x05, // rsa_pss_rsae_sha384 - 0x08, 0x06, // rsa_pss_rsae_sha512 - 0x08, 0x07, // ed25519 - 0x08, 0x08, // ed448 - 0x08, 0x09, // rsa_pss_pss_sha256 - 0x08, 0x0a, // rsa_pss_pss_sha384 - 0x08, 0x0b, // rsa_pss_pss_sha512 - 0x02, 0x01, // rsa_pkcs1_sha1 - 0x02, 0x03, // ecdsa_sha1 - - // Extension: supported_groups - 0, 10, // ExtensionType.supported_groups - 0x00, 0x0c, // byte length of this extension payload - 0x00, 0x0a, // byte length of supported groups list - 0x00, 0x17, // secp256r1 - 0x00, 0x18, // secp384r1 - 0x00, 0x19, // secp521r1 - 0x00, 0x1D, // x25519 - 0x00, 0x1E, // x448 - + }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .ecdsa_secp521r1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, + .ed448, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pkcs1_sha1, + .ecdsa_sha1, + })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ + //.secp256r1, + .x25519, + })) ++ [_]u8{ // Extension: key_share 0, 51, // ExtensionType.key_share 0, 38, // byte length of this extension payload 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 0, 32, // byte length of key_exchange - } ++ kp.public_key ++ [_]u8{ + } ++ kp.public_key ++ + int2(@enumToInt(tls.ExtensionType.server_name)) ++ + int2(host_len + 5) ++ // byte length of this extension payload + int2(host_len + 3) ++ // server_name_list byte count + [1]u8{0x00} ++ // name_type + int2(host_len); - // Extension: server_name - 0, 0, // ExtensionType.server_name - undefined, undefined, // byte length of this extension payload - undefined, undefined, // server_name_list byte count - 0x00, // name_type - undefined, undefined, // host name len - }; + const extensions_header = + int2(@intCast(u16, extensions_payload.len + host_len)) ++ + extensions_payload; - var hello_header = [_]u8{ + const legacy_compression_methods = 0x0100; + + const client_hello = + int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ + rand_buf ++ + [1]u8{0} ++ + cipher_suites ++ + int2(legacy_compression_methods) ++ + extensions_header; + + const handshake = + [_]u8{@enumToInt(HandshakeType.client_hello)} ++ + int3(@intCast(u24, client_hello.len + host_len)) ++ + client_hello; + + const hello_header = [_]u8{ // Plaintext header @enumToInt(ContentType.handshake), 0x03, 0x01, // legacy_record_version - undefined, undefined, // Plaintext fragment length (u16) - - // Handshake header - @enumToInt(HandshakeType.client_hello), - undefined, undefined, undefined, // handshake length (u24) - - // ClientHello - 0x03, 0x03, // legacy_version - } ++ rand_buf ++ [1]u8{0} ++ - int2(cipher_suites.len) ++ cipher_suites ++ - [_]u8{ - 0x01, 0x00, // legacy_compression_methods - } ++ extensions_header; - - mem.writeIntBig(u16, hello_header[3..][0..2], @intCast(u16, hello_header.len - 5 + host.len)); - mem.writeIntBig(u24, hello_header[6..][0..3], @intCast(u24, hello_header.len - 9 + host.len)); - mem.writeIntBig( - u16, - hello_header[hello_header.len - extensions_header.len ..][0..2], - @intCast(u16, extensions_header.len - 2 + host.len), - ); - mem.writeIntBig(u16, hello_header[hello_header.len - 7 ..][0..2], @intCast(u16, 5 + host.len)); - mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len)); - mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len)); + } ++ + int2(@intCast(u16, handshake.len + host_len)) ++ + handshake; { var iovecs = [_]std.os.iovec_const{ @@ -699,13 +683,6 @@ inline fn big(x: anytype) @TypeOf(x) { }; } -inline fn int2(x: u16) [2]u8 { - return .{ - @truncate(u8, x >> 8), - @truncate(u8, x), - }; -} - /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. @@ -727,9 +704,10 @@ inline fn int2(x: u16) [2]u8 { /// aegis-256: 461 MiB/s /// aes128-gcm: 138 MiB/s /// aes256-gcm: 120 MiB/s -const cipher_suites = - int2(@enumToInt(tls.CipherSuite.AEGIS_128L_SHA256)) ++ - int2(@enumToInt(tls.CipherSuite.AEGIS_256_SHA384)) ++ - int2(@enumToInt(tls.CipherSuite.AES_128_GCM_SHA256)) ++ - int2(@enumToInt(tls.CipherSuite.AES_256_GCM_SHA384)) ++ - int2(@enumToInt(tls.CipherSuite.CHACHA20_POLY1305_SHA256)); +const cipher_suites = enum_array(tls.CipherSuite, &.{ + .AEGIS_128L_SHA256, + .AEGIS_256_SHA384, + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, +}); From 7a2377838414157fb65850aa045c2112a0bbd006 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 18:14:37 -0700 Subject: [PATCH 17/59] std.crypto.tls: send a legacy session id To support middlebox compatibility mode. --- lib/std/crypto/tls/Client.zig | 57 +++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 33146f79c9..0eabfddb76 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -36,9 +36,11 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { error.IdentityElement => return error.InsufficientEntropy, }; - // random (u32) - var rand_buf: [32]u8 = undefined; - crypto.random.bytes(&rand_buf); + // This is used both for the random bytes and for the legacy session id. + var random_buffer: [64]u8 = undefined; + crypto.random.bytes(&random_buffer); + const hello_rand = random_buffer[0..32].*; + const legacy_session_id = random_buffer[32..64].*; const extensions_payload = tls.extension(.supported_versions, [_]u8{ @@ -86,8 +88,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const client_hello = int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ - rand_buf ++ - [1]u8{0} ++ + hello_rand ++ + [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ int2(legacy_compression_methods) ++ extensions_header; @@ -152,46 +154,55 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { } const length = mem.readIntBig(u24, frag[1..4]); if (4 + length != frag.len) return error.TlsBadLength; - const hello = frag[4..]; - const legacy_version = mem.readIntBig(u16, hello[0..2]); - const random = hello[2..34].*; + var i: usize = 4; + const legacy_version = mem.readIntBig(u16, frag[i..][0..2]); + i += 2; + const random = frag[i..][0..32].*; + i += 32; if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) { @panic("TODO handle HelloRetryRequest"); } - const legacy_session_id_echo_len = hello[34]; - if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; - const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); + const legacy_session_id_echo_len = frag[i]; + i += 1; + if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; + const legacy_session_id_echo = frag[i..][0..32]; + if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) + return error.TlsIllegalParameter; + i += 32; + const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); + i += 2; const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int); std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag}); - const legacy_compression_method = hello[37]; + const legacy_compression_method = frag[i]; + i += 1; _ = legacy_compression_method; - const extensions_size = mem.readIntBig(u16, hello[38..40]); - if (40 + extensions_size != hello.len) return error.TlsBadLength; - var i: usize = 40; + const extensions_size = mem.readIntBig(u16, frag[i..][0..2]); + i += 2; + if (i + extensions_size != frag.len) return error.TlsBadLength; var supported_version: u16 = 0; var opt_x25519_server_pub_key: ?*[32]u8 = null; - while (i < hello.len) { - const et = mem.readIntBig(u16, hello[i..][0..2]); + while (i < frag.len) { + const et = mem.readIntBig(u16, frag[i..][0..2]); i += 2; - const ext_size = mem.readIntBig(u16, hello[i..][0..2]); + const ext_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; const next_i = i + ext_size; - if (next_i > hello.len) return error.TlsBadLength; + if (next_i > frag.len) return error.TlsBadLength; switch (et) { @enumToInt(tls.ExtensionType.supported_versions) => { if (supported_version != 0) return error.TlsIllegalParameter; - supported_version = mem.readIntBig(u16, hello[i..][0..2]); + supported_version = mem.readIntBig(u16, frag[i..][0..2]); }, @enumToInt(tls.ExtensionType.key_share) => { if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; - const named_group = mem.readIntBig(u16, hello[i..][0..2]); + const named_group = mem.readIntBig(u16, frag[i..][0..2]); i += 2; switch (named_group) { @enumToInt(tls.NamedGroup.x25519) => { - const key_size = mem.readIntBig(u16, hello[i..][0..2]); + const key_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; if (key_size != 32) return error.TlsBadLength; - opt_x25519_server_pub_key = hello[i..][0..32]; + opt_x25519_server_pub_key = frag[i..][0..32]; }, else => { std.debug.print("named group: {x}\n", .{named_group}); From f460c2170504ce94471b16631c43638f50735241 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 18:27:53 -0700 Subject: [PATCH 18/59] std.crypto.tls.Client: avoid hard-coded bytes in key_share --- lib/std/crypto/tls/Client.zig | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0eabfddb76..11846ca526 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -31,16 +31,22 @@ eof: bool, pub fn init(stream: net.Stream, host: []const u8) !Client { const host_len = @intCast(u16, host.len); - const kp = crypto.dh.X25519.KeyPair.create(null) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - - // This is used both for the random bytes and for the legacy session id. - var random_buffer: [64]u8 = undefined; + var random_buffer: [128]u8 = undefined; crypto.random.bytes(&random_buffer); const hello_rand = random_buffer[0..32].*; const legacy_session_id = random_buffer[32..64].*; + const x25519_kp_seed = random_buffer[64..96].*; + const secp256r1_kp_seed = random_buffer[96..128].*; + + const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + }; + const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { + // Only possible to happen if the private key is all zeroes. + error.IdentityElement => return error.InsufficientEntropy, + }; + _ = secp256r1_kp; const extensions_payload = tls.extension(.supported_versions, [_]u8{ @@ -66,14 +72,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ //.secp256r1, .x25519, - })) ++ [_]u8{ - // Extension: key_share - 0, 51, // ExtensionType.key_share - 0, 38, // byte length of this extension payload - 0, 36, // byte length of client_shares - 0x00, 0x1D, // NamedGroup.x25519 - 0, 32, // byte length of key_exchange - } ++ kp.public_key ++ + })) ++ tls.extension( + .key_share, + array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ array(1, x25519_kp.public_key)), + ) ++ int2(@enumToInt(tls.ExtensionType.server_name)) ++ int2(host_len + 5) ++ // byte length of this extension payload int2(host_len + 3) ++ // server_name_list byte count @@ -230,7 +232,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { } const shared_key = crypto.dh.X25519.scalarmult( - kp.secret_key, + x25519_kp.secret_key, x25519_server_pub_key.*, ) catch return error.TlsDecryptFailure; From e2c16d03abae01538ec8ca65e6dc1ecc6f6ec420 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 18:51:28 -0700 Subject: [PATCH 19/59] std.crypto.tls.Client: support secp256r1 for handshake --- lib/std/crypto/tls/Client.zig | 46 +++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 11846ca526..6d6e0754da 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -46,7 +46,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; - _ = secp256r1_kp; const extensions_payload = tls.extension(.supported_versions, [_]u8{ @@ -70,11 +69,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .rsa_pkcs1_sha1, .ecdsa_sha1, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ - //.secp256r1, + .secp256r1, .x25519, })) ++ tls.extension( .key_share, - array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ array(1, x25519_kp.public_key)), + array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ + array(1, x25519_kp.public_key) ++ + int2(@enumToInt(tls.NamedGroup.secp256r1)) ++ + array(1, secp256r1_kp.public_key.toUncompressedSec1())), ) ++ int2(@enumToInt(tls.ExtensionType.server_name)) ++ int2(host_len + 5) ++ // byte length of this extension payload @@ -182,7 +184,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { i += 2; if (i + extensions_size != frag.len) return error.TlsBadLength; var supported_version: u16 = 0; - var opt_x25519_server_pub_key: ?*[32]u8 = null; + var shared_key: [32]u8 = undefined; + var have_shared_key = false; while (i < frag.len) { const et = mem.readIntBig(u16, frag[i..][0..2]); i += 2; @@ -196,15 +199,34 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { supported_version = mem.readIntBig(u16, frag[i..][0..2]); }, @enumToInt(tls.ExtensionType.key_share) => { - if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter; + if (have_shared_key) return error.TlsIllegalParameter; + have_shared_key = true; const named_group = mem.readIntBig(u16, frag[i..][0..2]); i += 2; + const key_size = mem.readIntBig(u16, frag[i..][0..2]); + i += 2; + switch (named_group) { @enumToInt(tls.NamedGroup.x25519) => { - const key_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; if (key_size != 32) return error.TlsBadLength; - opt_x25519_server_pub_key = frag[i..][0..32]; + const server_pub_key = frag[i..][0..32]; + + shared_key = crypto.dh.X25519.scalarmult( + x25519_kp.secret_key, + server_pub_key.*, + ) catch return error.TlsDecryptFailure; + }, + @enumToInt(tls.NamedGroup.secp256r1) => { + const server_pub_key = frag[i..][0..key_size]; + + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch { + return error.TlsDecryptFailure; + }; + const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch { + return error.TlsDecryptFailure; + }; + shared_key = mul.affineCoordinates().x.toBytes(.Big); }, else => { std.debug.print("named group: {x}\n", .{named_group}); @@ -218,8 +240,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { } i = next_i; } - const x25519_server_pub_key = opt_x25519_server_pub_key orelse - return error.TlsIllegalParameter; + if (!have_shared_key) return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { @enumToInt(tls.ProtocolVersion.tls_1_2) => { @@ -231,11 +252,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { else => return error.TlsIllegalParameter, } - const shared_key = crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - x25519_server_pub_key.*, - ) catch return error.TlsDecryptFailure; - switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, .AES_256_GCM_SHA384, From 5d7eca6669228cec762fc9063a7ea3cb52af357c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 18 Dec 2022 21:32:15 -0700 Subject: [PATCH 20/59] std.crypto.tls.Client: fix verify_data for batched handshakes --- lib/std/crypto/tls.zig | 7 ++- lib/std/crypto/tls/Client.zig | 92 +++++++++++++++++++---------------- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index d2a347e87e..09a22f9a23 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -221,6 +221,12 @@ pub const CipherSuite = enum(u16) { _, }; +pub const CertificateType = enum(u8) { + X509 = 0, + RawPublicKey = 2, + _, +}; + pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { return struct { pub const AEAD = AeadType; @@ -237,7 +243,6 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, transcript_hash: Hash, - finished_digest: [Hash.digest_length]u8, }; } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6d6e0754da..7fb96ff00c 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -62,12 +62,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, .ed25519, - .ed448, - .rsa_pss_pss_sha256, - .rsa_pss_pss_sha384, - .rsa_pss_pss_sha512, - .rsa_pkcs1_sha1, - .ecdsa_sha1, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ .secp256r1, .x25519, @@ -98,24 +92,21 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { int2(legacy_compression_methods) ++ extensions_header; - const handshake = + const out_handshake = [_]u8{@enumToInt(HandshakeType.client_hello)} ++ int3(@intCast(u24, client_hello.len + host_len)) ++ client_hello; - const hello_header = [_]u8{ - // Plaintext header + const plaintext_header = [_]u8{ @enumToInt(ContentType.handshake), 0x03, 0x01, // legacy_record_version - } ++ - int2(@intCast(u16, handshake.len + host_len)) ++ - handshake; + } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake; { var iovecs = [_]std.os.iovec_const{ .{ - .iov_base = &hello_header, - .iov_len = hello_header.len, + .iov_base = &plaintext_header, + .iov_len = plaintext_header.len, }, .{ .iov_base = host.ptr, @@ -125,7 +116,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { try stream.writevAll(&iovecs); } - const client_hello_bytes1 = hello_header[5..]; + const client_hello_bytes1 = plaintext_header[5..]; var cipher_params: CipherParams = undefined; @@ -176,7 +167,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); i += 2; const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int); - std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag}); const legacy_compression_method = frag[i]; i += 1; _ = legacy_compression_method; @@ -243,12 +233,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { if (!have_shared_key) return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; switch (tls_version) { - @enumToInt(tls.ProtocolVersion.tls_1_2) => { - std.debug.print("server wants TLS v1.2\n", .{}); - }, - @enumToInt(tls.ProtocolVersion.tls_1_3) => { - std.debug.print("server wants TLS v1.3\n", .{}); - }, + @enumToInt(tls.ProtocolVersion.tls_1_3) => {}, else => return error.TlsIllegalParameter, } @@ -270,7 +255,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .client_handshake_iv = undefined, .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), - .finished_digest = undefined, }); const p = &@field(cipher_params, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 @@ -361,7 +345,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const ad = handshake_buf[end_hdr - 5 ..][0..5]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; - p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]); break :c cleartext; }, }; @@ -378,17 +361,22 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const next_handshake_i = ct_i + handshake_len; if (next_handshake_i > cleartext.len - 1) return error.TlsBadLength; + const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; + const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { @enumToInt(HandshakeType.encrypted_extensions) => { - const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const end_ext_i = ct_i + total_ext_size; - while (ct_i < end_ext_i) { - const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]); - ct_i += 2; - const next_ext_i = ct_i + ext_size; + switch (cipher_params) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + const total_ext_size = mem.readIntBig(u16, handshake[0..2]); + var hs_i: usize = 2; + const end_ext_i = 2 + total_ext_size; + while (hs_i < end_ext_i) { + const et = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + const next_ext_i = hs_i + ext_size; switch (et) { @enumToInt(tls.ExtensionType.server_name) => {}, else => { @@ -397,19 +385,38 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }); }, } - ct_i = next_ext_i; + hs_i = next_ext_i; } }, @enumToInt(HandshakeType.certificate) => { - std.debug.print("cool certificate bro\n", .{}); + switch (cipher_params) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + var hs_i: usize = 0; + const cert_req_ctx_len = handshake[hs_i]; + hs_i += 1; + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); + hs_i += 3; + const end_certs = hs_i + certs_size; + while (hs_i < end_certs) { + const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); + hs_i += 3; + hs_i += cert_size; + const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); + hs_i += 2; + hs_i += total_ext_size; + + std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions\n", .{ + cert_size, total_ext_size, + }); + } }, @enumToInt(HandshakeType.certificate_verify) => { - std.debug.print("the certificate came with a fancy signature\n", .{}); switch (cipher_params) { - inline else => |*p| { - p.finished_digest = p.transcript_hash.peek(); - }, + inline else => |*p| p.transcript_hash.update(wrapped_handshake), } + std.debug.print("ignoring certificate_verify\n", .{}); }, @enumToInt(HandshakeType.finished) => { // This message is to trick buggy proxies into behaving correctly. @@ -422,9 +429,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const app_cipher = switch (cipher_params) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); - const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key); - const actual_server_verify_data = cleartext[ct_i..][0..handshake_len]; - if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data)) + const finished_digest = p.transcript_hash.peek(); + p.transcript_hash.update(wrapped_handshake); + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, handshake)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); From 3237000d957617120e32e54498cac9afa23cbcd4 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 19 Dec 2022 01:07:52 -0700 Subject: [PATCH 21/59] std.crypto.tls: rudimentary certificate parsing --- lib/std/crypto/tls.zig | 109 ++++++++++++++++++++++++++++++++++ lib/std/crypto/tls/Client.zig | 66 +++++++++++++++++++- 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 09a22f9a23..33d6bbe906 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -349,3 +349,112 @@ pub inline fn int3(x: u24) [3]u8 { @truncate(u8, x), }; } + +pub const Der = struct { + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + pub const PC = enum(u1) { + primitive, + constructed, + }; + + pub const Identifier = packed struct(u8) { + tag: Tag, + pc: PC, + class: Class, + }; + + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + null = 5, + object_identifier = 6, + sequence = 16, + _, + }; + + pub const Oid = enum { + commonName, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + sha224WithRSAEncryption, + + pub const map = std.ComptimeStringMap(Oid, .{ + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + }); + }; + + pub const Element = struct { + identifier: Identifier, + contents: []const u8, + }; + + pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; + + pub fn parseElement(bytes: []const u8, index: *usize) ParseElementError!Der.Element { + var i = index.*; + const identifier = @bitCast(Identifier, bytes[i]); + i += 1; + const size_byte = bytes[i]; + i += 1; + if ((size_byte >> 7) == 0) { + const contents = bytes[i..][0..size_byte]; + index.* = i + contents.len; + return .{ + .identifier = identifier, + .contents = contents, + }; + } + + const len_size = @truncate(u7, size_byte); + if (len_size > @sizeOf(usize)) { + return error.CertificateHasFieldWithInvalidLength; + } + + const end = i + len_size; + var long_form_size: usize = 0; + while (i < end) : (i += 1) { + long_form_size = (long_form_size << 8) | bytes[i]; + } + + const contents = bytes[i..][0..long_form_size]; + index.* = i + contents.len; + + return .{ + .identifier = identifier, + .contents = contents, + }; + } + + pub const ParseObjectIdError = error{ + CertificateHasUnrecognizedObjectId, + CertificateFieldHasWrongDataType, + } || ParseElementError; + + pub fn parseObjectId(bytes: []const u8, index: *usize) ParseObjectIdError!Oid { + const oid_element = try parseElement(bytes, index); + if (oid_element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; + return Oid.map.get(oid_element.contents) orelse return error.CertificateHasUnrecognizedObjectId; + } +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 7fb96ff00c..c167d85134 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -402,7 +402,71 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { while (hs_i < end_certs) { const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); hs_i += 3; - hs_i += cert_size; + const end_cert = hs_i + cert_size; + + const certificate = try tls.Der.parseElement(handshake, &hs_i); + { + var cert_i: usize = 0; + const tbs_certificate = try tls.Der.parseElement(certificate.contents, &cert_i); + { + var tbs_i: usize = 0; + const version = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const serial_number = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const signature = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const issuer = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const validity = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const subject = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const subject_pub_key = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + const extensions = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); + + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + _ = signature; + + _ = issuer; + _ = validity; + + std.debug.print("version: {any} '{}'\n", .{ + version.identifier, std.fmt.fmtSliceHexLower(version.contents), + }); + + std.debug.print("serial_number: {any} {}\n", .{ + serial_number.identifier, + std.fmt.fmtSliceHexLower(serial_number.contents), + }); + + std.debug.print("subject: {any} {}\n", .{ + subject.identifier, + std.fmt.fmtSliceHexLower(subject.contents), + }); + + std.debug.print("subject pub key: {any} {}\n", .{ + subject_pub_key.identifier, + std.fmt.fmtSliceHexLower(subject_pub_key.contents), + }); + + std.debug.print("extensions: {any} {}\n", .{ + extensions.identifier, + std.fmt.fmtSliceHexLower(extensions.contents), + }); + } + const signature_algorithm = try tls.Der.parseElement(certificate.contents, &cert_i); + const signature_value = try tls.Der.parseElement(certificate.contents, &cert_i); + + { + var sa_i: usize = 0; + const algorithm = try tls.Der.parseObjectId(signature_algorithm.contents, &sa_i); + std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm}); + //const parameters = try tls.Der.parseElement(signature_algorithm.contents, &sa_i); + } + + std.debug.print("signature_value: {any} {d} bytes\n", .{ + signature_value.identifier, signature_value.contents.len, + }); + } + + hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; hs_i += total_ext_size; From bbc074252cde0f45576b3910bec5a0f9e867c7f2 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 19 Dec 2022 21:22:51 -0700 Subject: [PATCH 22/59] introduce std.crypto.CertificateBundle for reading root certificate authority bundles from standard installation locations on the file system. So far only Linux logic is added. --- lib/std/crypto.zig | 5 + lib/std/crypto/CertificateBundle.zig | 173 +++++++++++++++++++++++++++ lib/std/crypto/Der.zig | 153 +++++++++++++++++++++++ lib/std/crypto/tls.zig | 109 ----------------- lib/std/crypto/tls/Client.zig | 99 ++++++--------- lib/std/http/Client.zig | 3 +- 6 files changed, 369 insertions(+), 173 deletions(-) create mode 100644 lib/std/crypto/CertificateBundle.zig create mode 100644 lib/std/crypto/Der.zig diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 2f7f729302..17c2e4afe9 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -177,6 +177,8 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); +pub const Der = @import("crypto/Der.zig"); +pub const CertificateBundle = @import("crypto/CertificateBundle.zig"); test { _ = aead.aegis.Aegis128L; @@ -266,6 +268,9 @@ test { _ = utils; _ = random; _ = errors; + _ = tls; + _ = Der; + _ = CertificateBundle; } test "CSPRNG" { diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig new file mode 100644 index 0000000000..50f43ba404 --- /dev/null +++ b/lib/std/crypto/CertificateBundle.zig @@ -0,0 +1,173 @@ +//! A set of certificates. Typically pre-installed on every operating system, +//! these are "Certificate Authorities" used to validate SSL certificates. +//! This data structure stores certificates in DER-encoded form, all of them +//! concatenated together in the `bytes` array. The `map` field contains an +//! index from the DER-encoded subject name to the index within `bytes`. + +map: std.HashMapUnmanaged(Key, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, +bytes: std.ArrayListUnmanaged(u8) = .{}, + +pub const Key = struct { + subject_start: u32, + subject_end: u32, +}; + +/// The returned bytes become invalid after calling any of the rescan functions +/// or add functions. +pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 { + const Adapter = struct { + cb: CertificateBundle, + + pub fn hash(ctx: @This(), k: []const u8) u64 { + _ = ctx; + return std.hash_map.hashString(k); + } + + pub fn eql(ctx: @This(), a: []const u8, b_key: Key) bool { + const b = ctx.cb.bytes.items[b_key.subject_start..b_key.subject_end]; + return mem.eql(u8, a, b); + } + }; + const index = cb.map.getAdapted(subject_name, Adapter{ .cb = cb }) orelse return null; + return cb.bytes.items[index..]; +} + +pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void { + cb.map.deinit(gpa); + cb.bytes.deinit(gpa); + cb.* = undefined; +} + +/// Empties the set of certificates and then scans the host operating system +/// file system standard locations for certificates. +pub fn rescan(cb: *CertificateBundle, gpa: Allocator) !void { + switch (builtin.os.tag) { + .linux => return rescanLinux(cb, gpa), + else => @compileError("it is unknown where the root CA certificates live on this OS"), + } +} + +pub fn rescanLinux(cb: *CertificateBundle, gpa: Allocator) !void { + var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { + error.FileNotFound => return, + else => |e| return e, + }; + defer dir.close(); + + cb.bytes.clearRetainingCapacity(); + cb.map.clearRetainingCapacity(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + switch (entry.kind) { + .File, .SymLink => {}, + else => continue, + } + + try addCertsFromFile(cb, gpa, dir.dir, entry.name); + } + + cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); +} + +pub fn addCertsFromFile( + cb: *CertificateBundle, + gpa: Allocator, + dir: fs.Dir, + sub_file_path: []const u8, +) !void { + var file = try dir.openFile(sub_file_path, .{}); + defer file.close(); + + const size = try file.getEndPos(); + + // We borrow `bytes` as a temporary buffer for the base64-encoded data. + // This is possible by computing the decoded length and reserving the space + // for the decoded bytes first. + const decoded_size_upper_bound = size / 4 * 3; + try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); + const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; + const buffer = cb.bytes.allocatedSlice()[end_reserved..]; + const end_index = try file.readAll(buffer); + const encoded_bytes = buffer[0..end_index]; + + const begin_marker = "-----BEGIN CERTIFICATE-----"; + const end_marker = "-----END CERTIFICATE-----"; + + var start_index: usize = 0; + while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { + const cert_start = begin_marker_start + begin_marker.len; + const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse + return error.MissingEndCertificateMarker; + start_index = cert_end + end_marker.len; + const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); + const decoded_start = @intCast(u32, cb.bytes.items.len); + const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; + cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); + const k = try key(cb, decoded_start); + try cb.map.putContext(gpa, k, decoded_start, .{ .cb = cb }); + } +} + +pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key { + const bytes = cb.bytes.items; + const certificate = try Der.parseElement(bytes, bytes_index); + const tbs_certificate = try Der.parseElement(bytes, certificate.start); + const version = try Der.parseElement(bytes, tbs_certificate.start); + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } + + const serial_number = try Der.parseElement(bytes, version.end); + + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const signature = try Der.parseElement(bytes, serial_number.end); + const issuer = try Der.parseElement(bytes, signature.end); + const validity = try Der.parseElement(bytes, issuer.end); + const subject = try Der.parseElement(bytes, validity.end); + //const subject_pub_key = try Der.parseElement(bytes, subject.end); + //const extensions = try Der.parseElement(bytes, subject_pub_key.end); + + return .{ + .subject_start = subject.start, + .subject_end = subject.end, + }; +} + +const builtin = @import("builtin"); +const std = @import("../std.zig"); +const fs = std.fs; +const mem = std.mem; +const Allocator = std.mem.Allocator; +const Der = std.crypto.Der; +const CertificateBundle = @This(); + +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); + +const MapContext = struct { + cb: *const CertificateBundle, + + pub fn hash(ctx: MapContext, k: Key) u64 { + return std.hash_map.hashString(ctx.cb.bytes.items[k.subject_start..k.subject_end]); + } + + pub fn eql(ctx: MapContext, a: Key, b: Key) bool { + const bytes = ctx.cb.bytes.items; + return mem.eql( + u8, + bytes[a.subject_start..a.subject_end], + bytes[b.subject_start..b.subject_end], + ); + } +}; + +test { + var bundle: CertificateBundle = .{}; + defer bundle.deinit(std.testing.allocator); + + try bundle.rescan(std.testing.allocator); +} diff --git a/lib/std/crypto/Der.zig b/lib/std/crypto/Der.zig new file mode 100644 index 0000000000..7b183d5c34 --- /dev/null +++ b/lib/std/crypto/Der.zig @@ -0,0 +1,153 @@ +pub const Class = enum(u2) { + universal, + application, + context_specific, + private, +}; + +pub const PC = enum(u1) { + primitive, + constructed, +}; + +pub const Identifier = packed struct(u8) { + tag: Tag, + pc: PC, + class: Class, +}; + +pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + null = 5, + object_identifier = 6, + sequence = 16, + sequence_of = 17, + _, +}; + +pub const Oid = enum { + rsadsi, + pkcs, + rsaEncryption, + md2WithRSAEncryption, + md5WithRSAEncryption, + sha1WithRSAEncryption, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + sha224WithRSAEncryption, + pbeWithMD2AndDES_CBC, + pbeWithMD5AndDES_CBC, + pkcs9_emailAddress, + md2, + md5, + rc4, + ecdsa_with_Recommended, + ecdsa_with_Specified, + ecdsa_with_SHA224, + ecdsa_with_SHA256, + ecdsa_with_SHA384, + ecdsa_with_SHA512, + X500, + X509, + commonName, + serialNumber, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + organizationIdentifier, + + pub const map = std.ComptimeStringMap(Oid, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D }, .rsadsi }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01 }, .pkcs }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x05, 0x01 }, .pbeWithMD2AndDES_CBC }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x05, 0x03 }, .pbeWithMD5AndDES_CBC }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x02, 0x02 }, .md2 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x02, 0x05 }, .md5 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x03, 0x04 }, .rc4 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x02 }, .ecdsa_with_Recommended }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03 }, .ecdsa_with_Specified }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, + .{ &[_]u8{0x55}, .X500 }, + .{ &[_]u8{ 0x55, 0x04 }, .X509 }, + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + }); +}; + +pub const Element = struct { + identifier: Identifier, + start: u32, + end: u32, +}; + +pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; + +pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { + var i = index; + const identifier = @bitCast(Identifier, bytes[i]); + i += 1; + const size_byte = bytes[i]; + i += 1; + if ((size_byte >> 7) == 0) { + return .{ + .identifier = identifier, + .start = i, + .end = i + size_byte, + }; + } + + const len_size = @truncate(u7, size_byte); + if (len_size > @sizeOf(u32)) { + return error.CertificateHasFieldWithInvalidLength; + } + + const end_i = i + len_size; + var long_form_size: u32 = 0; + while (i < end_i) : (i += 1) { + long_form_size = (long_form_size << 8) | bytes[i]; + } + + return .{ + .identifier = identifier, + .start = i, + .end = i + long_form_size, + }; +} + +pub const ParseObjectIdError = error{ + CertificateHasUnrecognizedObjectId, + CertificateFieldHasWrongDataType, +} || ParseElementError; + +pub fn parseObjectId(bytes: []const u8, element: Element) ParseObjectIdError!Oid { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Oid.map.get(bytes[element.start..element.end]) orelse + return error.CertificateHasUnrecognizedObjectId; +} + +const std = @import("../std.zig"); +const Der = @This(); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 33d6bbe906..09a22f9a23 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -349,112 +349,3 @@ pub inline fn int3(x: u24) [3]u8 { @truncate(u8, x), }; } - -pub const Der = struct { - pub const Class = enum(u2) { - universal, - application, - context_specific, - private, - }; - - pub const PC = enum(u1) { - primitive, - constructed, - }; - - pub const Identifier = packed struct(u8) { - tag: Tag, - pc: PC, - class: Class, - }; - - pub const Tag = enum(u5) { - boolean = 1, - integer = 2, - bitstring = 3, - null = 5, - object_identifier = 6, - sequence = 16, - _, - }; - - pub const Oid = enum { - commonName, - countryName, - localityName, - stateOrProvinceName, - organizationName, - organizationalUnitName, - sha256WithRSAEncryption, - sha384WithRSAEncryption, - sha512WithRSAEncryption, - sha224WithRSAEncryption, - - pub const map = std.ComptimeStringMap(Oid, .{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - }); - }; - - pub const Element = struct { - identifier: Identifier, - contents: []const u8, - }; - - pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; - - pub fn parseElement(bytes: []const u8, index: *usize) ParseElementError!Der.Element { - var i = index.*; - const identifier = @bitCast(Identifier, bytes[i]); - i += 1; - const size_byte = bytes[i]; - i += 1; - if ((size_byte >> 7) == 0) { - const contents = bytes[i..][0..size_byte]; - index.* = i + contents.len; - return .{ - .identifier = identifier, - .contents = contents, - }; - } - - const len_size = @truncate(u7, size_byte); - if (len_size > @sizeOf(usize)) { - return error.CertificateHasFieldWithInvalidLength; - } - - const end = i + len_size; - var long_form_size: usize = 0; - while (i < end) : (i += 1) { - long_form_size = (long_form_size << 8) | bytes[i]; - } - - const contents = bytes[i..][0..long_form_size]; - index.* = i + contents.len; - - return .{ - .identifier = identifier, - .contents = contents, - }; - } - - pub const ParseObjectIdError = error{ - CertificateHasUnrecognizedObjectId, - CertificateFieldHasWrongDataType, - } || ParseElementError; - - pub fn parseObjectId(bytes: []const u8, index: *usize) ParseObjectIdError!Oid { - const oid_element = try parseElement(bytes, index); - if (oid_element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; - return Oid.map.get(oid_element.contents) orelse return error.CertificateHasUnrecognizedObjectId; - } -}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c167d85134..45c96ed290 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,5 +1,6 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; +const Der = std.crypto.Der; const Client = @This(); const net = std.net; const mem = std.mem; @@ -28,7 +29,7 @@ partially_read_len: u15, eof: bool, /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, host: []const u8) !Client { +pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []const u8) !Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; @@ -392,7 +393,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { switch (cipher_params) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } - var hs_i: usize = 0; + var hs_i: u32 = 0; const cert_req_ctx_len = handshake[hs_i]; hs_i += 1; if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; @@ -404,75 +405,47 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { hs_i += 3; const end_cert = hs_i + cert_size; - const certificate = try tls.Der.parseElement(handshake, &hs_i); + const certificate = try Der.parseElement(handshake, hs_i); + const tbs_certificate = try Der.parseElement(handshake, certificate.start); + + const version = try Der.parseElement(handshake, tbs_certificate.start); + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, handshake[version.start..version.end], "\x02\x01\x02")) { - var cert_i: usize = 0; - const tbs_certificate = try tls.Der.parseElement(certificate.contents, &cert_i); - { - var tbs_i: usize = 0; - const version = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const serial_number = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const signature = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const issuer = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const validity = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const subject = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const subject_pub_key = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - const extensions = try tls.Der.parseElement(tbs_certificate.contents, &tbs_i); - - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - _ = signature; - - _ = issuer; - _ = validity; - - std.debug.print("version: {any} '{}'\n", .{ - version.identifier, std.fmt.fmtSliceHexLower(version.contents), - }); - - std.debug.print("serial_number: {any} {}\n", .{ - serial_number.identifier, - std.fmt.fmtSliceHexLower(serial_number.contents), - }); - - std.debug.print("subject: {any} {}\n", .{ - subject.identifier, - std.fmt.fmtSliceHexLower(subject.contents), - }); - - std.debug.print("subject pub key: {any} {}\n", .{ - subject_pub_key.identifier, - std.fmt.fmtSliceHexLower(subject_pub_key.contents), - }); - - std.debug.print("extensions: {any} {}\n", .{ - extensions.identifier, - std.fmt.fmtSliceHexLower(extensions.contents), - }); - } - const signature_algorithm = try tls.Der.parseElement(certificate.contents, &cert_i); - const signature_value = try tls.Der.parseElement(certificate.contents, &cert_i); - - { - var sa_i: usize = 0; - const algorithm = try tls.Der.parseObjectId(signature_algorithm.contents, &sa_i); - std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm}); - //const parameters = try tls.Der.parseElement(signature_algorithm.contents, &sa_i); - } - - std.debug.print("signature_value: {any} {d} bytes\n", .{ - signature_value.identifier, signature_value.contents.len, - }); + return error.UnsupportedCertificateVersion; } + const serial_number = try Der.parseElement(handshake, version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const signature = try Der.parseElement(handshake, serial_number.end); + const issuer = try Der.parseElement(handshake, signature.end); + const validity = try Der.parseElement(handshake, issuer.end); + const subject = try Der.parseElement(handshake, validity.end); + const subject_pub_key = try Der.parseElement(handshake, subject.end); + const extensions = try Der.parseElement(handshake, subject_pub_key.end); + _ = extensions; + + const signature_algorithm = try Der.parseElement(handshake, tbs_certificate.end); + const signature_value = try Der.parseElement(handshake, signature_algorithm.end); + _ = signature_value; + + const algorithm_elem = try Der.parseElement(handshake, signature_algorithm.start); + const algorithm = try Der.parseObjectId(handshake, algorithm_elem); + std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm}); + //const parameters = try Der.parseElement(signature_algorithm.contents, &sa_i); + hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; hs_i += total_ext_size; - std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions\n", .{ - cert_size, total_ext_size, + const issuer_bytes = handshake[issuer.start..issuer.end]; + const ca_cert = ca_bundle.find(issuer_bytes); + + std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions. ca_found={any}\n", .{ + cert_size, total_ext_size, ca_cert != null, }); } }, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index fcadf3669b..58686ed2e5 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -7,6 +7,7 @@ const Client = @This(); allocator: std.mem.Allocator, headers: std.ArrayListUnmanaged(u8) = .{}, active_requests: usize = 0, +ca_bundle: std.crypto.CertificateBundle = .{}, pub const Request = struct { client: *Client, @@ -102,7 +103,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { switch (options.protocol) { .http => {}, .https => { - req.tls_client = try std.crypto.tls.Client.init(req.stream, options.host); + req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, options.host); }, } From 504070e8fc50de0c354f6322115e08fe99503578 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 20 Dec 2022 00:47:04 -0700 Subject: [PATCH 23/59] std.crypto.CertificateBundle: ignore duplicate certificates --- lib/std/crypto/CertificateBundle.zig | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig index 50f43ba404..83560b6367 100644 --- a/lib/std/crypto/CertificateBundle.zig +++ b/lib/std/crypto/CertificateBundle.zig @@ -2,7 +2,8 @@ //! these are "Certificate Authorities" used to validate SSL certificates. //! This data structure stores certificates in DER-encoded form, all of them //! concatenated together in the `bytes` array. The `map` field contains an -//! index from the DER-encoded subject name to the index within `bytes`. +//! index from the DER-encoded subject name to the index of the containing +//! certificate within `bytes`. map: std.HashMapUnmanaged(Key, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, bytes: std.ArrayListUnmanaged(u8) = .{}, @@ -105,7 +106,12 @@ pub fn addCertsFromFile( const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); const k = try key(cb, decoded_start); - try cb.map.putContext(gpa, k, decoded_start, .{ .cb = cb }); + const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); + if (gop.found_existing) { + cb.bytes.items.len = decoded_start; + } else { + gop.value_ptr.* = decoded_start; + } } } From 244a97e8ada5349136ca642d89092dbaf6e52ae2 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 20 Dec 2022 19:26:23 -0700 Subject: [PATCH 24/59] std.crypto.tls: certificate signature validation --- lib/std/crypto/CertificateBundle.zig | 313 ++++++++++++++++++++++++++- lib/std/crypto/tls/Client.zig | 50 ++--- 2 files changed, 326 insertions(+), 37 deletions(-) diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig index 83560b6367..17c0286bed 100644 --- a/lib/std/crypto/CertificateBundle.zig +++ b/lib/std/crypto/CertificateBundle.zig @@ -15,7 +15,7 @@ pub const Key = struct { /// The returned bytes become invalid after calling any of the rescan functions /// or add functions. -pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 { +pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 { const Adapter = struct { cb: CertificateBundle, @@ -29,8 +29,7 @@ pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 { return mem.eql(u8, a, b); } }; - const index = cb.map.getAdapted(subject_name, Adapter{ .cb = cb }) orelse return null; - return cb.bytes.items[index..]; + return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); } pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void { @@ -105,7 +104,7 @@ pub fn addCertsFromFile( const decoded_start = @intCast(u32, cb.bytes.items.len); const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); - const k = try key(cb, decoded_start); + const k = try cb.key(decoded_start); const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); if (gop.found_existing) { cb.bytes.items.len = decoded_start; @@ -115,16 +114,12 @@ pub fn addCertsFromFile( } } -pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key { +pub fn key(cb: CertificateBundle, bytes_index: u32) !Key { const bytes = cb.bytes.items; const certificate = try Der.parseElement(bytes, bytes_index); const tbs_certificate = try Der.parseElement(bytes, certificate.start); const version = try Der.parseElement(bytes, tbs_certificate.start); - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; - } + try checkVersion(bytes, version); const serial_number = try Der.parseElement(bytes, version.end); @@ -144,10 +139,173 @@ pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key { }; } +pub const Certificate = struct { + buffer: []const u8, + index: u32, + + pub fn verify(subject: Certificate, issuer: Certificate) !void { + const subject_certificate = try Der.parseElement(subject.buffer, subject.index); + const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start); + const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start); + try checkVersion(subject.buffer, subject_version); + const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end); + const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end); + const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end); + //const subject_name = try Der.parseElement(subject.buffer, subject_validity.end); + + const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end); + const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start); + const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem); + const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end); + const subject_sig = try parseBitString(subject, subject_sig_elem); + + const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index); + const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start); + const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start); + try checkVersion(issuer.buffer, issuer_version); + const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end); + const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end); + const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end); + const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end); + const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end); + const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start); + const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start); + const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem); + const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end); + const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem); + + // Check that the subject's issuer name matches the issuer's subject + // name. + if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) { + return error.CertificateIssuerMismatch; + } + + // TODO check the time validity for the subject + _ = subject_validity; + // TODO check the time validity for the issuer + + const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end]; + //std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo }); + switch (subject_algo) { + // zig fmt: off + .sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + .sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), + // zig fmt: on + else => { + std.debug.print("unhandled algorithm: {any}\n", .{subject_algo}); + return error.UnsupportedCertificateSignatureAlgorithm; + }, + } + } + + pub fn contents(cert: Certificate, elem: Der.Element) []const u8 { + return cert.buffer[elem.start..elem.end]; + } + + pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 { + if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; + if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString; + return cert.buffer[elem.start + 1 .. elem.end]; + } + + fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void { + if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; + const pub_key_seq = try Der.parseElement(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const exponent = pub_key[exponent_elem.start..exponent_elem.end]; + if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; + if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const hash_der = switch (Hash) { + crypto.hash.Sha1 => [_]u8{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => [_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => [_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => [_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => [_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + + var msg_hashed: [Hash.digest_length]u8 = undefined; + Hash.hash(message, &msg_hashed, .{}); + + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; + const em: [modulus_len]u8 = + [2]u8{ 0, 1 } ++ + ([1]u8{0xff} ** ps_len) ++ + [1]u8{0} ++ + hash_der ++ + msg_hashed; + + const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); + + if (!mem.eql(u8, &em, &em_dec)) { + try std.testing.expectEqualSlices(u8, &em, &em_dec); + return error.CertificateSignatureInvalid; + } + }, + else => { + return error.CertificateSignatureUnsupportedBitCount; + }, + } + } +}; + +fn checkVersion(bytes: []const u8, version: Der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + const builtin = @import("builtin"); const std = @import("../std.zig"); const fs = std.fs; const mem = std.mem; +const crypto = std.crypto; const Allocator = std.mem.Allocator; const Der = std.crypto.Der; const CertificateBundle = @This(); @@ -177,3 +335,138 @@ test { try bundle.rescan(std.testing.allocator); } + +/// TODO: replace this with Frank's upcoming RSA implementation. the verify +/// function won't have the possibility of failure - it will either identify a +/// valid signature or an invalid signature. +/// This code is borrowed from https://github.com/shiguredo/tls13-zig +/// which is licensed under the Apache License Version 2.0, January 2004 +/// http://www.apache.org/licenses/ +/// The code has been modified. +const rsa = struct { + const BigInt = std.math.big.int.Managed; + + const PublicKey = struct { + n: BigInt, + e: BigInt, + + pub fn deinit(self: *PublicKey) void { + self.n.deinit(); + self.e.deinit(); + } + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { + var _n = try BigInt.init(allocator); + errdefer _n.deinit(); + try setBytes(&_n, modulus_bytes, allocator); + + var _e = try BigInt.init(allocator); + errdefer _e.deinit(); + try setBytes(&_e, pub_bytes, allocator); + + return .{ + .n = _n, + .e = _e, + }; + } + }; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { + var m = try BigInt.init(allocator); + defer m.deinit(); + + try setBytes(&m, &msg, allocator); + + if (m.order(public_key.n) != .lt) { + return error.MessageTooLong; + } + + var e = try BigInt.init(allocator); + defer e.deinit(); + + try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); + + var res: [modulus_len]u8 = undefined; + + try toBytes(&res, &e, allocator); + + return res; + } + + fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { + try r.set(0); + var tmp = try BigInt.init(allcator); + defer tmp.deinit(); + for (bytes) |b| { + try r.shiftLeft(r, 8); + try tmp.set(b); + try r.add(r, &tmp); + } + } + + fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var bin_raw: [512]u8 = undefined; + try toBytes(&bin_raw, x, allocator); + + var i: usize = 0; + while (bin_raw[i] == 0x00) : (i += 1) {} + const bin = bin_raw[i..]; + + try r.set(1); + var r1 = try BigInt.init(allocator); + defer r1.deinit(); + try BigInt.copy(&r1, a.toConst()); + i = 0; + while (i < bin.len * 8) : (i += 1) { + if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { + try BigInt.mul(&r1, r, &r1); + try mod(&r1, &r1, n, allocator); + try BigInt.sqr(r, r); + try mod(r, r, n, allocator); + } else { + try BigInt.mul(r, r, &r1); + try mod(r, r, n, allocator); + try BigInt.sqr(&r1, &r1); + try mod(&r1, &r1, n, allocator); + } + } + } + + fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { + const Error = error{ + BufferTooSmall, + }; + + var mask = try BigInt.initSet(allocator, 0xFF); + defer mask.deinit(); + var tmp = try BigInt.init(allocator); + defer tmp.deinit(); + + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a.toConst()); + + // Encoding into big-endian bytes + var i: usize = 0; + while (i < out.len) : (i += 1) { + try tmp.bitAnd(&a_copy, &mask); + const b = try tmp.to(u8); + out[out.len - i - 1] = b; + try a_copy.shiftRight(&a_copy, 8); + } + + if (!a_copy.eqZero()) { + return Error.BufferTooSmall; + } + } + + fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var q = try BigInt.init(allocator); + defer q.deinit(); + + try BigInt.divFloor(&q, rem, a, n); + } + + // TODO: flush the toilet + const poop = std.heap.page_allocator; +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 45c96ed290..8395be4551 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -53,15 +53,12 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con 0x02, // byte length of supported versions 0x03, 0x04, // TLS 1.3 }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, - .rsa_pkcs1_sha512, .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .ecdsa_secp521r1_sha512, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, .ed25519, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ .secp256r1, @@ -420,33 +417,32 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con // "This field MUST contain the same algorithm identifier as // the signatureAlgorithm field in the sequence Certificate." const signature = try Der.parseElement(handshake, serial_number.end); - const issuer = try Der.parseElement(handshake, signature.end); - const validity = try Der.parseElement(handshake, issuer.end); - const subject = try Der.parseElement(handshake, validity.end); - const subject_pub_key = try Der.parseElement(handshake, subject.end); - const extensions = try Der.parseElement(handshake, subject_pub_key.end); - _ = extensions; + const issuer_elem = try Der.parseElement(handshake, signature.end); - const signature_algorithm = try Der.parseElement(handshake, tbs_certificate.end); - const signature_value = try Der.parseElement(handshake, signature_algorithm.end); - _ = signature_value; - - const algorithm_elem = try Der.parseElement(handshake, signature_algorithm.start); - const algorithm = try Der.parseObjectId(handshake, algorithm_elem); - std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm}); - //const parameters = try Der.parseElement(signature_algorithm.contents, &sa_i); + const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end]; + if (ca_bundle.find(issuer_bytes)) |ca_cert_i| { + const Certificate = crypto.CertificateBundle.Certificate; + const subject: Certificate = .{ + .buffer = handshake, + .index = hs_i, + }; + const issuer: Certificate = .{ + .buffer = ca_bundle.bytes.items, + .index = ca_cert_i, + }; + if (subject.verify(issuer)) |_| { + std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); + } else |err| { + std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{ + @errorName(err), + }); + } + } hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; hs_i += total_ext_size; - - const issuer_bytes = handshake[issuer.start..issuer.end]; - const ca_cert = ca_bundle.find(issuer_bytes); - - std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions. ca_found={any}\n", .{ - cert_size, total_ext_size, ca_cert != null, - }); } }, @enumToInt(HandshakeType.certificate_verify) => { From 7ed7bd247ed301b2904379570ea86abd04c65618 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 20 Dec 2022 21:30:38 -0700 Subject: [PATCH 25/59] std.crypto.tls: verify the common name matches --- lib/std/crypto/CertificateBundle.zig | 261 +++++++++++++++++++-------- lib/std/crypto/tls/Client.zig | 67 ++++--- 2 files changed, 223 insertions(+), 105 deletions(-) diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig index 17c0286bed..bc82d54a79 100644 --- a/lib/std/crypto/CertificateBundle.zig +++ b/lib/std/crypto/CertificateBundle.zig @@ -13,6 +13,16 @@ pub const Key = struct { subject_end: u32, }; +pub fn verify(cb: CertificateBundle, subject: Certificate.Parsed) !void { + const bytes_index = cb.find(subject.issuer) orelse return error.IssuerNotFound; + const issuer_cert: Certificate = .{ + .buffer = cb.bytes.items, + .index = bytes_index, + }; + const issuer = try issuer_cert.parse(); + try subject.verify(issuer); +} + /// The returned bytes become invalid after calling any of the rescan functions /// or add functions. pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 { @@ -120,18 +130,11 @@ pub fn key(cb: CertificateBundle, bytes_index: u32) !Key { const tbs_certificate = try Der.parseElement(bytes, certificate.start); const version = try Der.parseElement(bytes, tbs_certificate.start); try checkVersion(bytes, version); - const serial_number = try Der.parseElement(bytes, version.end); - - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." const signature = try Der.parseElement(bytes, serial_number.end); const issuer = try Der.parseElement(bytes, signature.end); const validity = try Der.parseElement(bytes, issuer.end); const subject = try Der.parseElement(bytes, validity.end); - //const subject_pub_key = try Der.parseElement(bytes, subject.end); - //const extensions = try Der.parseElement(bytes, subject_pub_key.end); return .{ .subject_start = subject.start, @@ -143,70 +146,163 @@ pub const Certificate = struct { buffer: []const u8, index: u32, + pub const Algorithm = enum { + sha1WithRSAEncryption, + sha224WithRSAEncryption, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + + pub const map = std.ComptimeStringMap(Algorithm, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + }); + + pub fn Hash(comptime algorithm: Algorithm) type { + return switch (algorithm) { + .sha1WithRSAEncryption => crypto.hash.Sha1, + .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, + .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, + .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, + .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, + }; + } + }; + + pub const AlgorithmCategory = enum { + rsaEncryption, + X9_62_id_ecPublicKey, + + pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + }); + }; + + pub const Attribute = enum { + commonName, + serialNumber, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + organizationIdentifier, + + pub const map = std.ComptimeStringMap(Attribute, .{ + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + }); + }; + + pub const Parsed = struct { + certificate: Certificate, + issuer: []const u8, + subject: []const u8, + common_name: []const u8, + signature: []const u8, + signature_algorithm: Algorithm, + message: []const u8, + pub_key_algo: AlgorithmCategory, + pub_key: []const u8, + + pub fn verify(subject: Parsed, issuer: Parsed) !void { + // Check that the subject's issuer name matches the issuer's + // subject name. + if (!mem.eql(u8, subject.issuer, issuer.subject)) { + return error.CertificateIssuerMismatch; + } + + // TODO check the time validity for the subject + // TODO check the time validity for the issuer + + switch (subject.signature_algorithm) { + inline .sha1WithRSAEncryption, + .sha224WithRSAEncryption, + .sha256WithRSAEncryption, + .sha384WithRSAEncryption, + .sha512WithRSAEncryption, + => |algorithm| return verifyRsa( + algorithm.Hash(), + subject.message, + subject.signature, + issuer.pub_key_algo, + issuer.pub_key, + ), + } + } + }; + + pub fn parse(cert: Certificate) !Parsed { + const cert_bytes = cert.buffer; + const certificate = try Der.parseElement(cert_bytes, cert.index); + const tbs_certificate = try Der.parseElement(cert_bytes, certificate.start); + const version = try Der.parseElement(cert_bytes, tbs_certificate.start); + try checkVersion(cert_bytes, version); + const serial_number = try Der.parseElement(cert_bytes, version.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const tbs_signature = try Der.parseElement(cert_bytes, serial_number.end); + const issuer = try Der.parseElement(cert_bytes, tbs_signature.end); + const validity = try Der.parseElement(cert_bytes, issuer.end); + const subject = try Der.parseElement(cert_bytes, validity.end); + + const pub_key_info = try Der.parseElement(cert_bytes, subject.end); + const pub_key_signature_algorithm = try Der.parseElement(cert_bytes, pub_key_info.start); + const pub_key_algo_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.start); + const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); + const pub_key_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.end); + const pub_key = try parseBitString(cert, pub_key_elem); + + const rdn = try Der.parseElement(cert_bytes, subject.start); + const atav = try Der.parseElement(cert_bytes, rdn.start); + + var common_name: []const u8 = &.{}; + var atav_i = atav.start; + while (atav_i < atav.end) { + const ty_elem = try Der.parseElement(cert_bytes, atav_i); + const ty = try parseAttribute(cert_bytes, ty_elem); + const val = try Der.parseElement(cert_bytes, ty_elem.end); + switch (ty) { + .commonName => common_name = cert.contents(val), + else => {}, + } + atav_i = val.end; + } + + const sig_algo = try Der.parseElement(cert_bytes, tbs_certificate.end); + const algo_elem = try Der.parseElement(cert_bytes, sig_algo.start); + const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); + const sig_elem = try Der.parseElement(cert_bytes, sig_algo.end); + const signature = try parseBitString(cert, sig_elem); + + return .{ + .certificate = cert, + .common_name = common_name, + .issuer = cert.contents(issuer), + .subject = cert.contents(subject), + .signature = signature, + .signature_algorithm = signature_algorithm, + .message = cert_bytes[certificate.start..tbs_certificate.end], + .pub_key_algo = pub_key_algo, + .pub_key = pub_key, + }; + } + pub fn verify(subject: Certificate, issuer: Certificate) !void { - const subject_certificate = try Der.parseElement(subject.buffer, subject.index); - const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start); - const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start); - try checkVersion(subject.buffer, subject_version); - const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end); - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end); - const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end); - const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end); - //const subject_name = try Der.parseElement(subject.buffer, subject_validity.end); - - const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end); - const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start); - const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem); - const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end); - const subject_sig = try parseBitString(subject, subject_sig_elem); - - const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index); - const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start); - const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start); - try checkVersion(issuer.buffer, issuer_version); - const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end); - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end); - const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end); - const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end); - const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end); - const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end); - const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start); - const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start); - const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem); - const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end); - const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem); - - // Check that the subject's issuer name matches the issuer's subject - // name. - if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) { - return error.CertificateIssuerMismatch; - } - - // TODO check the time validity for the subject - _ = subject_validity; - // TODO check the time validity for the issuer - - const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end]; - //std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo }); - switch (subject_algo) { - // zig fmt: off - .sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), - .sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), - .sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), - .sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), - .sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key), - // zig fmt: on - else => { - std.debug.print("unhandled algorithm: {any}\n", .{subject_algo}); - return error.UnsupportedCertificateSignatureAlgorithm; - }, - } + const parsed_subject = try subject.parse(); + const parsed_issuer = try issuer.parse(); + return parsed_subject.verify(parsed_issuer); } pub fn contents(cert: Certificate, elem: Der.Element) []const u8 { @@ -219,7 +315,30 @@ pub const Certificate = struct { return cert.buffer[elem.start + 1 .. elem.end]; } - fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void { + pub fn parseAlgorithm(bytes: []const u8, element: Der.Element) !Algorithm { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Algorithm.map.get(bytes[element.start..element.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; + } + + pub fn parseAlgorithmCategory(bytes: []const u8, element: Der.Element) !AlgorithmCategory { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return AlgorithmCategory.map.get(bytes[element.start..element.end]) orelse { + std.debug.print("unrecognized algorithm category: {}\n", .{std.fmt.fmtSliceHexLower(bytes[element.start..element.end])}); + return error.CertificateHasUnrecognizedAlgorithmCategory; + }; + } + + pub fn parseAttribute(bytes: []const u8, element: Der.Element) !Attribute { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Attribute.map.get(bytes[element.start..element.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; + } + + fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; const pub_key_seq = try Der.parseElement(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 8395be4551..52ab1a55fa 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -18,6 +18,7 @@ const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; +const Certificate = crypto.CertificateBundle.Certificate; application_cipher: ApplicationCipher, read_seq: u64, @@ -298,6 +299,8 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con }; var read_seq: u64 = 0; + var validated_cert = false; + var is_subsequent_cert = false; while (true) { const end_hdr = i + 5; @@ -386,10 +389,11 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con hs_i = next_ext_i; } }, - @enumToInt(HandshakeType.certificate) => { + @enumToInt(HandshakeType.certificate) => cert: { switch (cipher_params) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } + if (validated_cert) break :cert; var hs_i: u32 = 0; const cert_req_ctx_len = handshake[hs_i]; hs_i += 1; @@ -402,41 +406,36 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con hs_i += 3; const end_cert = hs_i + cert_size; - const certificate = try Der.parseElement(handshake, hs_i); - const tbs_certificate = try Der.parseElement(handshake, certificate.start); - - const version = try Der.parseElement(handshake, tbs_certificate.start); - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, handshake[version.start..version.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; + const subject_cert: Certificate = .{ + .buffer = handshake, + .index = hs_i, + }; + const subject = try subject_cert.parse(); + if (!is_subsequent_cert) { + is_subsequent_cert = true; + if (mem.eql(u8, subject.common_name, host)) { + std.debug.print("exact host match\n", .{}); + } else if (mem.startsWith(u8, subject.common_name, "*.") and + mem.eql(u8, subject.common_name[2..], host)) + { + std.debug.print("wildcard host match\n", .{}); + } else { + std.debug.print("host does not match\n", .{}); + return error.TlsCertificateInvalidHost; + } } - const serial_number = try Der.parseElement(handshake, version.end); - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - const signature = try Der.parseElement(handshake, serial_number.end); - const issuer_elem = try Der.parseElement(handshake, signature.end); - - const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end]; - if (ca_bundle.find(issuer_bytes)) |ca_cert_i| { - const Certificate = crypto.CertificateBundle.Certificate; - const subject: Certificate = .{ - .buffer = handshake, - .index = hs_i, - }; - const issuer: Certificate = .{ - .buffer = ca_bundle.bytes.items, - .index = ca_cert_i, - }; - if (subject.verify(issuer)) |_| { - std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); - } else |err| { - std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{ - @errorName(err), - }); - } + if (ca_bundle.verify(subject)) |_| { + std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); + validated_cert = true; + break :cert; + } else |err| { + std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ + @errorName(err), + }); + // TODO handle a certificate + // signing chain that ends in a + // root-validated one. } hs_i = end_cert; From 22db1e166a6e3721df21546209fbfe9df7ddc0c0 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Dec 2022 16:16:31 -0700 Subject: [PATCH 26/59] std.crypto.CertificateBundle: disable test on WASI --- lib/std/crypto/CertificateBundle.zig | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig index bc82d54a79..6f9e77a4d7 100644 --- a/lib/std/crypto/CertificateBundle.zig +++ b/lib/std/crypto/CertificateBundle.zig @@ -448,7 +448,9 @@ const MapContext = struct { } }; -test { +test "scan for OS-provided certificates" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + var bundle: CertificateBundle = .{}; defer bundle.deinit(std.testing.allocator); From 4f9f4575bdf35fa69f09910d3d0ae349f9071c18 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Dec 2022 16:16:53 -0700 Subject: [PATCH 27/59] std.crypto.tls: rename HandshakeCipher --- lib/std/crypto/tls.zig | 14 +++++++------- lib/std/crypto/tls/Client.zig | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 09a22f9a23..12401bff35 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -227,7 +227,7 @@ pub const CertificateType = enum(u8) { _, }; -pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { +pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { return struct { pub const AEAD = AeadType; pub const Hash = HashType; @@ -246,12 +246,12 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { }; } -pub const CipherParams = union(enum) { - AES_128_GCM_SHA256: CipherParamsT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: CipherParamsT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: CipherParamsT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA384: CipherParamsT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), - AEGIS_128L_SHA256: CipherParamsT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), +pub const HandshakeCipher = union(enum) { + AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), + AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), + CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), + AEGIS_256_SHA384: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384), + AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), }; pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 52ab1a55fa..eb1b1b80bc 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -11,7 +11,7 @@ const ApplicationCipher = tls.ApplicationCipher; const CipherSuite = tls.CipherSuite; const ContentType = tls.ContentType; const HandshakeType = tls.HandshakeType; -const CipherParams = tls.CipherParams; +const HandshakeCipher = tls.HandshakeCipher; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; const int2 = tls.int2; @@ -117,7 +117,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con const client_hello_bytes1 = plaintext_header[5..]; - var cipher_params: CipherParams = undefined; + var handshake_cipher: HandshakeCipher = undefined; var handshake_buf: [8000]u8 = undefined; var len: usize = 0; @@ -243,8 +243,8 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con .AEGIS_256_SHA384, .AEGIS_128L_SHA256, => |tag| { - const P = std.meta.TagPayloadByName(CipherParams, @tagName(tag)); - cipher_params = @unionInit(CipherParams, @tagName(tag), .{ + const P = std.meta.TagPayloadByName(HandshakeCipher, @tagName(tag)); + handshake_cipher = @unionInit(HandshakeCipher, @tagName(tag), .{ .handshake_secret = undefined, .master_secret = undefined, .client_handshake_key = undefined, @@ -255,7 +255,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), }); - const p = &@field(cipher_params, @tagName(tag)); + const p = &@field(handshake_cipher, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(host); // Client Hello part 2 p.transcript_hash.update(frag); // Server Hello @@ -329,7 +329,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con }, .application_data => { var cleartext_buf: [8000]u8 = undefined; - const cleartext = switch (cipher_params) { + const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const ciphertext_len = record_size - P.AEAD.tag_length; @@ -366,7 +366,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { @enumToInt(HandshakeType.encrypted_extensions) => { - switch (cipher_params) { + switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } const total_ext_size = mem.readIntBig(u16, handshake[0..2]); @@ -390,7 +390,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con } }, @enumToInt(HandshakeType.certificate) => cert: { - switch (cipher_params) { + switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } if (validated_cert) break :cert; @@ -445,7 +445,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con } }, @enumToInt(HandshakeType.certificate_verify) => { - switch (cipher_params) { + switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } std.debug.print("ignoring certificate_verify\n", .{}); @@ -458,7 +458,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con 0x00, 0x01, // length 0x01, }; - const app_cipher = switch (cipher_params) { + const app_cipher = switch (handshake_cipher) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); const finished_digest = p.transcript_hash.peek(); From 29475b45185f90c2437d160567e67a4b141f5845 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Dec 2022 17:12:34 -0700 Subject: [PATCH 28/59] std.crypto.tls: validate previous certificate --- lib/std/crypto.zig | 8 +- lib/std/crypto/Certificate.zig | 446 +++++++++++++++++++ lib/std/crypto/Certificate/Bundle.zig | 174 ++++++++ lib/std/crypto/CertificateBundle.zig | 593 -------------------------- lib/std/crypto/{Der.zig => der.zig} | 26 +- lib/std/crypto/tls/Client.zig | 52 ++- lib/std/http/Client.zig | 2 +- 7 files changed, 679 insertions(+), 622 deletions(-) create mode 100644 lib/std/crypto/Certificate.zig create mode 100644 lib/std/crypto/Certificate/Bundle.zig delete mode 100644 lib/std/crypto/CertificateBundle.zig rename lib/std/crypto/{Der.zig => der.zig} (92%) diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 17c2e4afe9..6387eb48ae 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -177,8 +177,8 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); -pub const Der = @import("crypto/Der.zig"); -pub const CertificateBundle = @import("crypto/CertificateBundle.zig"); +pub const der = @import("crypto/der.zig"); +pub const Certificate = @import("crypto/Certificate.zig"); test { _ = aead.aegis.Aegis128L; @@ -269,8 +269,8 @@ test { _ = random; _ = errors; _ = tls; - _ = Der; - _ = CertificateBundle; + _ = der; + _ = Certificate; } test "CSPRNG" { diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig new file mode 100644 index 0000000000..3d50e43839 --- /dev/null +++ b/lib/std/crypto/Certificate.zig @@ -0,0 +1,446 @@ +buffer: []const u8, +index: u32, + +pub const Bundle = @import("Certificate/Bundle.zig"); + +pub const Algorithm = enum { + sha1WithRSAEncryption, + sha224WithRSAEncryption, + sha256WithRSAEncryption, + sha384WithRSAEncryption, + sha512WithRSAEncryption, + + pub const map = std.ComptimeStringMap(Algorithm, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + }); + + pub fn Hash(comptime algorithm: Algorithm) type { + return switch (algorithm) { + .sha1WithRSAEncryption => crypto.hash.Sha1, + .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, + .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, + .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, + .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, + }; + } +}; + +pub const AlgorithmCategory = enum { + rsaEncryption, + X9_62_id_ecPublicKey, + + pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + }); +}; + +pub const Attribute = enum { + commonName, + serialNumber, + countryName, + localityName, + stateOrProvinceName, + organizationName, + organizationalUnitName, + organizationIdentifier, + + pub const map = std.ComptimeStringMap(Attribute, .{ + .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + }); +}; + +pub const Parsed = struct { + certificate: Certificate, + issuer_slice: Slice, + subject_slice: Slice, + common_name_slice: Slice, + signature_slice: Slice, + signature_algorithm: Algorithm, + pub_key_algo: AlgorithmCategory, + pub_key_slice: Slice, + message_slice: Slice, + + pub const Slice = der.Element.Slice; + + pub fn slice(p: Parsed, s: Slice) []const u8 { + return p.certificate.buffer[s.start..s.end]; + } + + pub fn issuer(p: Parsed) []const u8 { + return p.slice(p.issuer_slice); + } + + pub fn subject(p: Parsed) []const u8 { + return p.slice(p.subject_slice); + } + + pub fn commonName(p: Parsed) []const u8 { + return p.slice(p.common_name_slice); + } + + pub fn signature(p: Parsed) []const u8 { + return p.slice(p.signature_slice); + } + + pub fn pubKey(p: Parsed) []const u8 { + return p.slice(p.pub_key_slice); + } + + pub fn message(p: Parsed) []const u8 { + return p.slice(p.message_slice); + } + + pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) !void { + // Check that the subject's issuer name matches the issuer's + // subject name. + if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) { + return error.CertificateIssuerMismatch; + } + + // TODO check the time validity for the subject + // TODO check the time validity for the issuer + + switch (parsed_subject.signature_algorithm) { + inline .sha1WithRSAEncryption, + .sha224WithRSAEncryption, + .sha256WithRSAEncryption, + .sha384WithRSAEncryption, + .sha512WithRSAEncryption, + => |algorithm| return verifyRsa( + algorithm.Hash(), + parsed_subject.message(), + parsed_subject.signature(), + parsed_issuer.pub_key_algo, + parsed_issuer.pubKey(), + ), + } + } +}; + +pub fn parse(cert: Certificate) !Parsed { + const cert_bytes = cert.buffer; + const certificate = try der.parseElement(cert_bytes, cert.index); + const tbs_certificate = try der.parseElement(cert_bytes, certificate.slice.start); + const version = try der.parseElement(cert_bytes, tbs_certificate.slice.start); + try checkVersion(cert_bytes, version); + const serial_number = try der.parseElement(cert_bytes, version.slice.end); + // RFC 5280, section 4.1.2.3: + // "This field MUST contain the same algorithm identifier as + // the signatureAlgorithm field in the sequence Certificate." + const tbs_signature = try der.parseElement(cert_bytes, serial_number.slice.end); + const issuer = try der.parseElement(cert_bytes, tbs_signature.slice.end); + const validity = try der.parseElement(cert_bytes, issuer.slice.end); + const subject = try der.parseElement(cert_bytes, validity.slice.end); + + const pub_key_info = try der.parseElement(cert_bytes, subject.slice.end); + const pub_key_signature_algorithm = try der.parseElement(cert_bytes, pub_key_info.slice.start); + const pub_key_algo_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.start); + const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); + const pub_key_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.end); + const pub_key = try parseBitString(cert, pub_key_elem); + + const rdn = try der.parseElement(cert_bytes, subject.slice.start); + const atav = try der.parseElement(cert_bytes, rdn.slice.start); + + var common_name = der.Element.Slice.empty; + var atav_i = atav.slice.start; + while (atav_i < atav.slice.end) { + const ty_elem = try der.parseElement(cert_bytes, atav_i); + const ty = try parseAttribute(cert_bytes, ty_elem); + const val = try der.parseElement(cert_bytes, ty_elem.slice.end); + switch (ty) { + .commonName => common_name = val.slice, + else => {}, + } + atav_i = val.slice.end; + } + + const sig_algo = try der.parseElement(cert_bytes, tbs_certificate.slice.end); + const algo_elem = try der.parseElement(cert_bytes, sig_algo.slice.start); + const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); + const sig_elem = try der.parseElement(cert_bytes, sig_algo.slice.end); + const signature = try parseBitString(cert, sig_elem); + + return .{ + .certificate = cert, + .common_name_slice = common_name, + .issuer_slice = issuer.slice, + .subject_slice = subject.slice, + .signature_slice = signature, + .signature_algorithm = signature_algorithm, + .message_slice = .{ .start = certificate.slice.start, .end = tbs_certificate.slice.end }, + .pub_key_algo = pub_key_algo, + .pub_key_slice = pub_key, + }; +} + +pub fn verify(subject: Certificate, issuer: Certificate) !void { + const parsed_subject = try subject.parse(); + const parsed_issuer = try issuer.parse(); + return parsed_subject.verify(parsed_issuer); +} + +pub fn contents(cert: Certificate, elem: der.Element) []const u8 { + return cert.buffer[elem.start..elem.end]; +} + +pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { + if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; + if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString; + return .{ .start = elem.slice.start + 1, .end = elem.slice.end }; +} + +pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Algorithm.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; +} + +pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return AlgorithmCategory.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithmCategory; +} + +pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { + if (element.identifier.tag != .object_identifier) + return error.CertificateFieldHasWrongDataType; + return Attribute.map.get(bytes[element.slice.start..element.slice.end]) orelse + return error.CertificateHasUnrecognizedAlgorithm; +} + +fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { + if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; + const pub_key_seq = try der.parseElement(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.parseElement(pub_key, pub_key_seq.slice.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try der.parseElement(pub_key, modulus_elem.slice.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; + if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const hash_der = switch (Hash) { + crypto.hash.Sha1 => [_]u8{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => [_]u8{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => [_]u8{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => [_]u8{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => [_]u8{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + + var msg_hashed: [Hash.digest_length]u8 = undefined; + Hash.hash(message, &msg_hashed, .{}); + + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; + const em: [modulus_len]u8 = + [2]u8{ 0, 1 } ++ + ([1]u8{0xff} ** ps_len) ++ + [1]u8{0} ++ + hash_der ++ + msg_hashed; + + const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); + + if (!mem.eql(u8, &em, &em_dec)) { + try std.testing.expectEqualSlices(u8, &em, &em_dec); + return error.CertificateSignatureInvalid; + } + }, + else => { + return error.CertificateSignatureUnsupportedBitCount; + }, + } +} + +pub fn checkVersion(bytes: []const u8, version: der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + +const std = @import("../std.zig"); +const crypto = std.crypto; +const mem = std.mem; +const der = std.crypto.der; +const Certificate = @This(); + +/// TODO: replace this with Frank's upcoming RSA implementation. the verify +/// function won't have the possibility of failure - it will either identify a +/// valid signature or an invalid signature. +/// This code is borrowed from https://github.com/shiguredo/tls13-zig +/// which is licensed under the Apache License Version 2.0, January 2004 +/// http://www.apache.org/licenses/ +/// The code has been modified. +const rsa = struct { + const BigInt = std.math.big.int.Managed; + + const PublicKey = struct { + n: BigInt, + e: BigInt, + + pub fn deinit(self: *PublicKey) void { + self.n.deinit(); + self.e.deinit(); + } + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { + var _n = try BigInt.init(allocator); + errdefer _n.deinit(); + try setBytes(&_n, modulus_bytes, allocator); + + var _e = try BigInt.init(allocator); + errdefer _e.deinit(); + try setBytes(&_e, pub_bytes, allocator); + + return .{ + .n = _n, + .e = _e, + }; + } + }; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { + var m = try BigInt.init(allocator); + defer m.deinit(); + + try setBytes(&m, &msg, allocator); + + if (m.order(public_key.n) != .lt) { + return error.MessageTooLong; + } + + var e = try BigInt.init(allocator); + defer e.deinit(); + + try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); + + var res: [modulus_len]u8 = undefined; + + try toBytes(&res, &e, allocator); + + return res; + } + + fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { + try r.set(0); + var tmp = try BigInt.init(allcator); + defer tmp.deinit(); + for (bytes) |b| { + try r.shiftLeft(r, 8); + try tmp.set(b); + try r.add(r, &tmp); + } + } + + fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var bin_raw: [512]u8 = undefined; + try toBytes(&bin_raw, x, allocator); + + var i: usize = 0; + while (bin_raw[i] == 0x00) : (i += 1) {} + const bin = bin_raw[i..]; + + try r.set(1); + var r1 = try BigInt.init(allocator); + defer r1.deinit(); + try BigInt.copy(&r1, a.toConst()); + i = 0; + while (i < bin.len * 8) : (i += 1) { + if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { + try BigInt.mul(&r1, r, &r1); + try mod(&r1, &r1, n, allocator); + try BigInt.sqr(r, r); + try mod(r, r, n, allocator); + } else { + try BigInt.mul(r, r, &r1); + try mod(r, r, n, allocator); + try BigInt.sqr(&r1, &r1); + try mod(&r1, &r1, n, allocator); + } + } + } + + fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { + const Error = error{ + BufferTooSmall, + }; + + var mask = try BigInt.initSet(allocator, 0xFF); + defer mask.deinit(); + var tmp = try BigInt.init(allocator); + defer tmp.deinit(); + + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a.toConst()); + + // Encoding into big-endian bytes + var i: usize = 0; + while (i < out.len) : (i += 1) { + try tmp.bitAnd(&a_copy, &mask); + const b = try tmp.to(u8); + out[out.len - i - 1] = b; + try a_copy.shiftRight(&a_copy, 8); + } + + if (!a_copy.eqZero()) { + return Error.BufferTooSmall; + } + } + + fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { + var q = try BigInt.init(allocator); + defer q.deinit(); + + try BigInt.divFloor(&q, rem, a, n); + } + + // TODO: flush the toilet + const poop = std.heap.page_allocator; +}; diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig new file mode 100644 index 0000000000..c2c18552a7 --- /dev/null +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -0,0 +1,174 @@ +//! A set of certificates. Typically pre-installed on every operating system, +//! these are "Certificate Authorities" used to validate SSL certificates. +//! This data structure stores certificates in DER-encoded form, all of them +//! concatenated together in the `bytes` array. The `map` field contains an +//! index from the DER-encoded subject name to the index of the containing +//! certificate within `bytes`. + +/// The key is the contents slice of the subject. +map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, +bytes: std.ArrayListUnmanaged(u8) = .{}, + +pub fn verify(cb: Bundle, subject: Certificate.Parsed) !void { + const bytes_index = cb.find(subject.issuer()) orelse return error.IssuerNotFound; + const issuer_cert: Certificate = .{ + .buffer = cb.bytes.items, + .index = bytes_index, + }; + const issuer = try issuer_cert.parse(); + try subject.verify(issuer); +} + +/// The returned bytes become invalid after calling any of the rescan functions +/// or add functions. +pub fn find(cb: Bundle, subject_name: []const u8) ?u32 { + const Adapter = struct { + cb: Bundle, + + pub fn hash(ctx: @This(), k: []const u8) u64 { + _ = ctx; + return std.hash_map.hashString(k); + } + + pub fn eql(ctx: @This(), a: []const u8, b_key: der.Element.Slice) bool { + const b = ctx.cb.bytes.items[b_key.start..b_key.end]; + return mem.eql(u8, a, b); + } + }; + return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); +} + +pub fn deinit(cb: *Bundle, gpa: Allocator) void { + cb.map.deinit(gpa); + cb.bytes.deinit(gpa); + cb.* = undefined; +} + +/// Empties the set of certificates and then scans the host operating system +/// file system standard locations for certificates. +pub fn rescan(cb: *Bundle, gpa: Allocator) !void { + switch (builtin.os.tag) { + .linux => return rescanLinux(cb, gpa), + else => @compileError("it is unknown where the root CA certificates live on this OS"), + } +} + +pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void { + var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { + error.FileNotFound => return, + else => |e| return e, + }; + defer dir.close(); + + cb.bytes.clearRetainingCapacity(); + cb.map.clearRetainingCapacity(); + + var it = dir.iterate(); + while (try it.next()) |entry| { + switch (entry.kind) { + .File, .SymLink => {}, + else => continue, + } + + try addCertsFromFile(cb, gpa, dir.dir, entry.name); + } + + cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); +} + +pub fn addCertsFromFile( + cb: *Bundle, + gpa: Allocator, + dir: fs.Dir, + sub_file_path: []const u8, +) !void { + var file = try dir.openFile(sub_file_path, .{}); + defer file.close(); + + const size = try file.getEndPos(); + + // We borrow `bytes` as a temporary buffer for the base64-encoded data. + // This is possible by computing the decoded length and reserving the space + // for the decoded bytes first. + const decoded_size_upper_bound = size / 4 * 3; + try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); + const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; + const buffer = cb.bytes.allocatedSlice()[end_reserved..]; + const end_index = try file.readAll(buffer); + const encoded_bytes = buffer[0..end_index]; + + const begin_marker = "-----BEGIN CERTIFICATE-----"; + const end_marker = "-----END CERTIFICATE-----"; + + var start_index: usize = 0; + while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { + const cert_start = begin_marker_start + begin_marker.len; + const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse + return error.MissingEndCertificateMarker; + start_index = cert_end + end_marker.len; + const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); + const decoded_start = @intCast(u32, cb.bytes.items.len); + const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; + cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); + const k = try cb.key(decoded_start); + const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); + if (gop.found_existing) { + cb.bytes.items.len = decoded_start; + } else { + gop.value_ptr.* = decoded_start; + } + } +} + +pub fn key(cb: Bundle, bytes_index: u32) !der.Element.Slice { + const bytes = cb.bytes.items; + const certificate = try der.parseElement(bytes, bytes_index); + const tbs_certificate = try der.parseElement(bytes, certificate.slice.start); + const version = try der.parseElement(bytes, tbs_certificate.slice.start); + try Certificate.checkVersion(bytes, version); + const serial_number = try der.parseElement(bytes, version.slice.end); + const signature = try der.parseElement(bytes, serial_number.slice.end); + const issuer = try der.parseElement(bytes, signature.slice.end); + const validity = try der.parseElement(bytes, issuer.slice.end); + const subject = try der.parseElement(bytes, validity.slice.end); + + return subject.slice; +} + +const builtin = @import("builtin"); +const std = @import("../../std.zig"); +const fs = std.fs; +const mem = std.mem; +const crypto = std.crypto; +const Allocator = std.mem.Allocator; +const der = std.crypto.der; +const Certificate = std.crypto.Certificate; +const Bundle = @This(); + +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); + +const MapContext = struct { + cb: *const Bundle, + + pub fn hash(ctx: MapContext, k: der.Element.Slice) u64 { + return std.hash_map.hashString(ctx.cb.bytes.items[k.start..k.end]); + } + + pub fn eql(ctx: MapContext, a: der.Element.Slice, b: der.Element.Slice) bool { + const bytes = ctx.cb.bytes.items; + return mem.eql( + u8, + bytes[a.start..a.end], + bytes[b.start..b.end], + ); + } +}; + +test "scan for OS-provided certificates" { + if (builtin.os.tag == .wasi) return error.SkipZigTest; + + var bundle: Bundle = .{}; + defer bundle.deinit(std.testing.allocator); + + try bundle.rescan(std.testing.allocator); +} diff --git a/lib/std/crypto/CertificateBundle.zig b/lib/std/crypto/CertificateBundle.zig deleted file mode 100644 index 6f9e77a4d7..0000000000 --- a/lib/std/crypto/CertificateBundle.zig +++ /dev/null @@ -1,593 +0,0 @@ -//! A set of certificates. Typically pre-installed on every operating system, -//! these are "Certificate Authorities" used to validate SSL certificates. -//! This data structure stores certificates in DER-encoded form, all of them -//! concatenated together in the `bytes` array. The `map` field contains an -//! index from the DER-encoded subject name to the index of the containing -//! certificate within `bytes`. - -map: std.HashMapUnmanaged(Key, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, -bytes: std.ArrayListUnmanaged(u8) = .{}, - -pub const Key = struct { - subject_start: u32, - subject_end: u32, -}; - -pub fn verify(cb: CertificateBundle, subject: Certificate.Parsed) !void { - const bytes_index = cb.find(subject.issuer) orelse return error.IssuerNotFound; - const issuer_cert: Certificate = .{ - .buffer = cb.bytes.items, - .index = bytes_index, - }; - const issuer = try issuer_cert.parse(); - try subject.verify(issuer); -} - -/// The returned bytes become invalid after calling any of the rescan functions -/// or add functions. -pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 { - const Adapter = struct { - cb: CertificateBundle, - - pub fn hash(ctx: @This(), k: []const u8) u64 { - _ = ctx; - return std.hash_map.hashString(k); - } - - pub fn eql(ctx: @This(), a: []const u8, b_key: Key) bool { - const b = ctx.cb.bytes.items[b_key.subject_start..b_key.subject_end]; - return mem.eql(u8, a, b); - } - }; - return cb.map.getAdapted(subject_name, Adapter{ .cb = cb }); -} - -pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void { - cb.map.deinit(gpa); - cb.bytes.deinit(gpa); - cb.* = undefined; -} - -/// Empties the set of certificates and then scans the host operating system -/// file system standard locations for certificates. -pub fn rescan(cb: *CertificateBundle, gpa: Allocator) !void { - switch (builtin.os.tag) { - .linux => return rescanLinux(cb, gpa), - else => @compileError("it is unknown where the root CA certificates live on this OS"), - } -} - -pub fn rescanLinux(cb: *CertificateBundle, gpa: Allocator) !void { - var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) { - error.FileNotFound => return, - else => |e| return e, - }; - defer dir.close(); - - cb.bytes.clearRetainingCapacity(); - cb.map.clearRetainingCapacity(); - - var it = dir.iterate(); - while (try it.next()) |entry| { - switch (entry.kind) { - .File, .SymLink => {}, - else => continue, - } - - try addCertsFromFile(cb, gpa, dir.dir, entry.name); - } - - cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); -} - -pub fn addCertsFromFile( - cb: *CertificateBundle, - gpa: Allocator, - dir: fs.Dir, - sub_file_path: []const u8, -) !void { - var file = try dir.openFile(sub_file_path, .{}); - defer file.close(); - - const size = try file.getEndPos(); - - // We borrow `bytes` as a temporary buffer for the base64-encoded data. - // This is possible by computing the decoded length and reserving the space - // for the decoded bytes first. - const decoded_size_upper_bound = size / 4 * 3; - try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); - const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; - const buffer = cb.bytes.allocatedSlice()[end_reserved..]; - const end_index = try file.readAll(buffer); - const encoded_bytes = buffer[0..end_index]; - - const begin_marker = "-----BEGIN CERTIFICATE-----"; - const end_marker = "-----END CERTIFICATE-----"; - - var start_index: usize = 0; - while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { - const cert_start = begin_marker_start + begin_marker.len; - const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse - return error.MissingEndCertificateMarker; - start_index = cert_end + end_marker.len; - const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n"); - const decoded_start = @intCast(u32, cb.bytes.items.len); - const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; - cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); - const k = try cb.key(decoded_start); - const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); - if (gop.found_existing) { - cb.bytes.items.len = decoded_start; - } else { - gop.value_ptr.* = decoded_start; - } - } -} - -pub fn key(cb: CertificateBundle, bytes_index: u32) !Key { - const bytes = cb.bytes.items; - const certificate = try Der.parseElement(bytes, bytes_index); - const tbs_certificate = try Der.parseElement(bytes, certificate.start); - const version = try Der.parseElement(bytes, tbs_certificate.start); - try checkVersion(bytes, version); - const serial_number = try Der.parseElement(bytes, version.end); - const signature = try Der.parseElement(bytes, serial_number.end); - const issuer = try Der.parseElement(bytes, signature.end); - const validity = try Der.parseElement(bytes, issuer.end); - const subject = try Der.parseElement(bytes, validity.end); - - return .{ - .subject_start = subject.start, - .subject_end = subject.end, - }; -} - -pub const Certificate = struct { - buffer: []const u8, - index: u32, - - pub const Algorithm = enum { - sha1WithRSAEncryption, - sha224WithRSAEncryption, - sha256WithRSAEncryption, - sha384WithRSAEncryption, - sha512WithRSAEncryption, - - pub const map = std.ComptimeStringMap(Algorithm, .{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - }); - - pub fn Hash(comptime algorithm: Algorithm) type { - return switch (algorithm) { - .sha1WithRSAEncryption => crypto.hash.Sha1, - .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, - .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, - .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, - .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, - }; - } - }; - - pub const AlgorithmCategory = enum { - rsaEncryption, - X9_62_id_ecPublicKey, - - pub const map = std.ComptimeStringMap(AlgorithmCategory, .{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, - }); - }; - - pub const Attribute = enum { - commonName, - serialNumber, - countryName, - localityName, - stateOrProvinceName, - organizationName, - organizationalUnitName, - organizationIdentifier, - - pub const map = std.ComptimeStringMap(Attribute, .{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - }); - }; - - pub const Parsed = struct { - certificate: Certificate, - issuer: []const u8, - subject: []const u8, - common_name: []const u8, - signature: []const u8, - signature_algorithm: Algorithm, - message: []const u8, - pub_key_algo: AlgorithmCategory, - pub_key: []const u8, - - pub fn verify(subject: Parsed, issuer: Parsed) !void { - // Check that the subject's issuer name matches the issuer's - // subject name. - if (!mem.eql(u8, subject.issuer, issuer.subject)) { - return error.CertificateIssuerMismatch; - } - - // TODO check the time validity for the subject - // TODO check the time validity for the issuer - - switch (subject.signature_algorithm) { - inline .sha1WithRSAEncryption, - .sha224WithRSAEncryption, - .sha256WithRSAEncryption, - .sha384WithRSAEncryption, - .sha512WithRSAEncryption, - => |algorithm| return verifyRsa( - algorithm.Hash(), - subject.message, - subject.signature, - issuer.pub_key_algo, - issuer.pub_key, - ), - } - } - }; - - pub fn parse(cert: Certificate) !Parsed { - const cert_bytes = cert.buffer; - const certificate = try Der.parseElement(cert_bytes, cert.index); - const tbs_certificate = try Der.parseElement(cert_bytes, certificate.start); - const version = try Der.parseElement(cert_bytes, tbs_certificate.start); - try checkVersion(cert_bytes, version); - const serial_number = try Der.parseElement(cert_bytes, version.end); - // RFC 5280, section 4.1.2.3: - // "This field MUST contain the same algorithm identifier as - // the signatureAlgorithm field in the sequence Certificate." - const tbs_signature = try Der.parseElement(cert_bytes, serial_number.end); - const issuer = try Der.parseElement(cert_bytes, tbs_signature.end); - const validity = try Der.parseElement(cert_bytes, issuer.end); - const subject = try Der.parseElement(cert_bytes, validity.end); - - const pub_key_info = try Der.parseElement(cert_bytes, subject.end); - const pub_key_signature_algorithm = try Der.parseElement(cert_bytes, pub_key_info.start); - const pub_key_algo_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.start); - const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); - const pub_key_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.end); - const pub_key = try parseBitString(cert, pub_key_elem); - - const rdn = try Der.parseElement(cert_bytes, subject.start); - const atav = try Der.parseElement(cert_bytes, rdn.start); - - var common_name: []const u8 = &.{}; - var atav_i = atav.start; - while (atav_i < atav.end) { - const ty_elem = try Der.parseElement(cert_bytes, atav_i); - const ty = try parseAttribute(cert_bytes, ty_elem); - const val = try Der.parseElement(cert_bytes, ty_elem.end); - switch (ty) { - .commonName => common_name = cert.contents(val), - else => {}, - } - atav_i = val.end; - } - - const sig_algo = try Der.parseElement(cert_bytes, tbs_certificate.end); - const algo_elem = try Der.parseElement(cert_bytes, sig_algo.start); - const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); - const sig_elem = try Der.parseElement(cert_bytes, sig_algo.end); - const signature = try parseBitString(cert, sig_elem); - - return .{ - .certificate = cert, - .common_name = common_name, - .issuer = cert.contents(issuer), - .subject = cert.contents(subject), - .signature = signature, - .signature_algorithm = signature_algorithm, - .message = cert_bytes[certificate.start..tbs_certificate.end], - .pub_key_algo = pub_key_algo, - .pub_key = pub_key, - }; - } - - pub fn verify(subject: Certificate, issuer: Certificate) !void { - const parsed_subject = try subject.parse(); - const parsed_issuer = try issuer.parse(); - return parsed_subject.verify(parsed_issuer); - } - - pub fn contents(cert: Certificate, elem: Der.Element) []const u8 { - return cert.buffer[elem.start..elem.end]; - } - - pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 { - if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType; - if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString; - return cert.buffer[elem.start + 1 .. elem.end]; - } - - pub fn parseAlgorithm(bytes: []const u8, element: Der.Element) !Algorithm { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return Algorithm.map.get(bytes[element.start..element.end]) orelse - return error.CertificateHasUnrecognizedAlgorithm; - } - - pub fn parseAlgorithmCategory(bytes: []const u8, element: Der.Element) !AlgorithmCategory { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return AlgorithmCategory.map.get(bytes[element.start..element.end]) orelse { - std.debug.print("unrecognized algorithm category: {}\n", .{std.fmt.fmtSliceHexLower(bytes[element.start..element.end])}); - return error.CertificateHasUnrecognizedAlgorithmCategory; - }; - } - - pub fn parseAttribute(bytes: []const u8, element: Der.Element) !Attribute { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return Attribute.map.get(bytes[element.start..element.end]) orelse - return error.CertificateHasUnrecognizedAlgorithm; - } - - fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { - if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pub_key_seq = try Der.parseElement(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start); - if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end); - if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end]; - const modulus_offset = for (modulus_raw) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - const modulus = modulus_raw[modulus_offset..]; - const exponent = pub_key[exponent_elem.start..exponent_elem.end]; - if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; - if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; - - const hash_der = switch (Hash) { - crypto.hash.Sha1 => [_]u8{ - 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, - 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, - }, - crypto.hash.sha2.Sha224 => [_]u8{ - 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, - 0x00, 0x04, 0x1c, - }, - crypto.hash.sha2.Sha256 => [_]u8{ - 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, - 0x00, 0x04, 0x20, - }, - crypto.hash.sha2.Sha384 => [_]u8{ - 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, - 0x00, 0x04, 0x30, - }, - crypto.hash.sha2.Sha512 => [_]u8{ - 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, - 0x00, 0x04, 0x40, - }, - else => @compileError("unreachable"), - }; - - var msg_hashed: [Hash.digest_length]u8 = undefined; - Hash.hash(message, &msg_hashed, .{}); - - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = - [2]u8{ 0, 1 } ++ - ([1]u8{0xff} ** ps_len) ++ - [1]u8{0} ++ - hash_der ++ - msg_hashed; - - const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); - const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); - - if (!mem.eql(u8, &em, &em_dec)) { - try std.testing.expectEqualSlices(u8, &em, &em_dec); - return error.CertificateSignatureInvalid; - } - }, - else => { - return error.CertificateSignatureUnsupportedBitCount; - }, - } - } -}; - -fn checkVersion(bytes: []const u8, version: Der.Element) !void { - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; - } -} - -const builtin = @import("builtin"); -const std = @import("../std.zig"); -const fs = std.fs; -const mem = std.mem; -const crypto = std.crypto; -const Allocator = std.mem.Allocator; -const Der = std.crypto.Der; -const CertificateBundle = @This(); - -const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); - -const MapContext = struct { - cb: *const CertificateBundle, - - pub fn hash(ctx: MapContext, k: Key) u64 { - return std.hash_map.hashString(ctx.cb.bytes.items[k.subject_start..k.subject_end]); - } - - pub fn eql(ctx: MapContext, a: Key, b: Key) bool { - const bytes = ctx.cb.bytes.items; - return mem.eql( - u8, - bytes[a.subject_start..a.subject_end], - bytes[b.subject_start..b.subject_end], - ); - } -}; - -test "scan for OS-provided certificates" { - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - var bundle: CertificateBundle = .{}; - defer bundle.deinit(std.testing.allocator); - - try bundle.rescan(std.testing.allocator); -} - -/// TODO: replace this with Frank's upcoming RSA implementation. the verify -/// function won't have the possibility of failure - it will either identify a -/// valid signature or an invalid signature. -/// This code is borrowed from https://github.com/shiguredo/tls13-zig -/// which is licensed under the Apache License Version 2.0, January 2004 -/// http://www.apache.org/licenses/ -/// The code has been modified. -const rsa = struct { - const BigInt = std.math.big.int.Managed; - - const PublicKey = struct { - n: BigInt, - e: BigInt, - - pub fn deinit(self: *PublicKey) void { - self.n.deinit(); - self.e.deinit(); - } - - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey { - var _n = try BigInt.init(allocator); - errdefer _n.deinit(); - try setBytes(&_n, modulus_bytes, allocator); - - var _e = try BigInt.init(allocator); - errdefer _e.deinit(); - try setBytes(&_e, pub_bytes, allocator); - - return .{ - .n = _n, - .e = _e, - }; - } - }; - - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { - var m = try BigInt.init(allocator); - defer m.deinit(); - - try setBytes(&m, &msg, allocator); - - if (m.order(public_key.n) != .lt) { - return error.MessageTooLong; - } - - var e = try BigInt.init(allocator); - defer e.deinit(); - - try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator); - - var res: [modulus_len]u8 = undefined; - - try toBytes(&res, &e, allocator); - - return res; - } - - fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void { - try r.set(0); - var tmp = try BigInt.init(allcator); - defer tmp.deinit(); - for (bytes) |b| { - try r.shiftLeft(r, 8); - try tmp.set(b); - try r.add(r, &tmp); - } - } - - fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { - var bin_raw: [512]u8 = undefined; - try toBytes(&bin_raw, x, allocator); - - var i: usize = 0; - while (bin_raw[i] == 0x00) : (i += 1) {} - const bin = bin_raw[i..]; - - try r.set(1); - var r1 = try BigInt.init(allocator); - defer r1.deinit(); - try BigInt.copy(&r1, a.toConst()); - i = 0; - while (i < bin.len * 8) : (i += 1) { - if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) { - try BigInt.mul(&r1, r, &r1); - try mod(&r1, &r1, n, allocator); - try BigInt.sqr(r, r); - try mod(r, r, n, allocator); - } else { - try BigInt.mul(r, r, &r1); - try mod(r, r, n, allocator); - try BigInt.sqr(&r1, &r1); - try mod(&r1, &r1, n, allocator); - } - } - } - - fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void { - const Error = error{ - BufferTooSmall, - }; - - var mask = try BigInt.initSet(allocator, 0xFF); - defer mask.deinit(); - var tmp = try BigInt.init(allocator); - defer tmp.deinit(); - - var a_copy = try BigInt.init(allocator); - defer a_copy.deinit(); - try a_copy.copy(a.toConst()); - - // Encoding into big-endian bytes - var i: usize = 0; - while (i < out.len) : (i += 1) { - try tmp.bitAnd(&a_copy, &mask); - const b = try tmp.to(u8); - out[out.len - i - 1] = b; - try a_copy.shiftRight(&a_copy, 8); - } - - if (!a_copy.eqZero()) { - return Error.BufferTooSmall; - } - } - - fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void { - var q = try BigInt.init(allocator); - defer q.deinit(); - - try BigInt.divFloor(&q, rem, a, n); - } - - // TODO: flush the toilet - const poop = std.heap.page_allocator; -}; diff --git a/lib/std/crypto/Der.zig b/lib/std/crypto/der.zig similarity index 92% rename from lib/std/crypto/Der.zig rename to lib/std/crypto/der.zig index 7b183d5c34..82f75421ea 100644 --- a/lib/std/crypto/Der.zig +++ b/lib/std/crypto/der.zig @@ -99,8 +99,14 @@ pub const Oid = enum { pub const Element = struct { identifier: Identifier, - start: u32, - end: u32, + slice: Slice, + + pub const Slice = struct { + start: u32, + end: u32, + + pub const empty: Slice = .{ .start = 0, .end = 0 }; + }; }; pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; @@ -114,8 +120,10 @@ pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { if ((size_byte >> 7) == 0) { return .{ .identifier = identifier, - .start = i, - .end = i + size_byte, + .slice = .{ + .start = i, + .end = i + size_byte, + }, }; } @@ -132,8 +140,10 @@ pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { return .{ .identifier = identifier, - .start = i, - .end = i + long_form_size, + .slice = .{ + .start = i, + .end = i + long_form_size, + }, }; } @@ -145,9 +155,9 @@ pub const ParseObjectIdError = error{ pub fn parseObjectId(bytes: []const u8, element: Element) ParseObjectIdError!Oid { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; - return Oid.map.get(bytes[element.start..element.end]) orelse + return Oid.map.get(bytes[element.slice.start..element.slice.end]) orelse return error.CertificateHasUnrecognizedObjectId; } const std = @import("../std.zig"); -const Der = @This(); +const der = @This(); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index eb1b1b80bc..c8fd41f83a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,6 +1,5 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -const Der = std.crypto.Der; const Client = @This(); const net = std.net; const mem = std.mem; @@ -18,7 +17,7 @@ const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; -const Certificate = crypto.CertificateBundle.Certificate; +const Certificate = crypto.Certificate; application_cipher: ApplicationCipher, read_seq: u64, @@ -30,7 +29,7 @@ partially_read_len: u15, eof: bool, /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []const u8) !Client { +pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; @@ -298,9 +297,19 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con break :i end; }; + // This is used for two purposes: + // * Detect whether a certificate is the first one presented, in which case + // we need to verify the host name. + // * Flip back and forth between the two cleartext buffers in order to keep + // the previous certificate in memory so that it can be verified by the + // next one. + var cert_index: usize = 0; var read_seq: u64 = 0; - var validated_cert = false; - var is_subsequent_cert = false; + var prev_cert: Certificate.Parsed = undefined; + // Set to true once a trust chain has been established from the first + // certificate to a root CA. + var cert_verification_done = false; + var cleartext_bufs: [2][8000]u8 = undefined; while (true) { const end_hdr = i + 5; @@ -328,7 +337,8 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - var cleartext_buf: [8000]u8 = undefined; + const cleartext_buf = &cleartext_bufs[cert_index % 2]; + const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); @@ -393,7 +403,7 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } - if (validated_cert) break :cert; + if (cert_verification_done) break :cert; var hs_i: u32 = 0; const cert_req_ctx_len = handshake[hs_i]; hs_i += 1; @@ -411,12 +421,22 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con .index = hs_i, }; const subject = try subject_cert.parse(); - if (!is_subsequent_cert) { - is_subsequent_cert = true; - if (mem.eql(u8, subject.common_name, host)) { + if (cert_index > 0) { + if (prev_cert.verify(subject)) |_| { + std.debug.print("previous certificate verified\n", .{}); + } else |err| { + std.debug.print("unable to validate previous cert: {s}\n", .{ + @errorName(err), + }); + } + } else { + // Verify the host on the first certificate. + const common_name = subject.commonName(); + if (mem.eql(u8, common_name, host)) { std.debug.print("exact host match\n", .{}); - } else if (mem.startsWith(u8, subject.common_name, "*.") and - mem.eql(u8, subject.common_name[2..], host)) + } else if (mem.startsWith(u8, common_name, "*.") and + (mem.endsWith(u8, host, common_name[1..]) or + mem.eql(u8, common_name[2..], host))) { std.debug.print("wildcard host match\n", .{}); } else { @@ -427,17 +447,17 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con if (ca_bundle.verify(subject)) |_| { std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); - validated_cert = true; + cert_verification_done = true; break :cert; } else |err| { std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ @errorName(err), }); - // TODO handle a certificate - // signing chain that ends in a - // root-validated one. } + prev_cert = subject; + cert_index += 1; + hs_i = end_cert; const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 58686ed2e5..1d10870312 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -7,7 +7,7 @@ const Client = @This(); allocator: std.mem.Allocator, headers: std.ArrayListUnmanaged(u8) = .{}, active_requests: usize = 0, -ca_bundle: std.crypto.CertificateBundle = .{}, +ca_bundle: std.crypto.Certificate.Bundle = .{}, pub const Request = struct { client: *Client, From 16f936b4202d352a6a8cf91a265fdd4bc64dde5d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Dec 2022 18:54:17 -0700 Subject: [PATCH 29/59] std.crypto.tls: handle the certificate_verify message --- lib/std/crypto/Certificate.zig | 28 +++++-- lib/std/crypto/tls/Client.zig | 138 ++++++++++++++++++++++++++------- 2 files changed, 131 insertions(+), 35 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 3d50e43839..0cdafa7ade 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -9,6 +9,10 @@ pub const Algorithm = enum { sha256WithRSAEncryption, sha384WithRSAEncryption, sha512WithRSAEncryption, + ecdsa_with_SHA224, + ecdsa_with_SHA256, + ecdsa_with_SHA384, + ecdsa_with_SHA512, pub const map = std.ComptimeStringMap(Algorithm, .{ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, @@ -16,15 +20,19 @@ pub const Algorithm = enum { .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, }); pub fn Hash(comptime algorithm: Algorithm) type { return switch (algorithm) { .sha1WithRSAEncryption => crypto.hash.Sha1, - .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, - .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, - .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, - .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, + .ecdsa_with_SHA224, .sha224WithRSAEncryption => crypto.hash.sha2.Sha224, + .ecdsa_with_SHA256, .sha256WithRSAEncryption => crypto.hash.sha2.Sha256, + .ecdsa_with_SHA384, .sha384WithRSAEncryption => crypto.hash.sha2.Sha384, + .ecdsa_with_SHA512, .sha512WithRSAEncryption => crypto.hash.sha2.Sha512, }; } }; @@ -125,6 +133,13 @@ pub const Parsed = struct { parsed_issuer.pub_key_algo, parsed_issuer.pubKey(), ), + .ecdsa_with_SHA224, + .ecdsa_with_SHA256, + .ecdsa_with_SHA384, + .ecdsa_with_SHA512, + => { + return error.CertificateSignatureAlgorithmUnsupported; + }, } } }; @@ -205,8 +220,11 @@ pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; - return Algorithm.map.get(bytes[element.slice.start..element.slice.end]) orelse + const oid_bytes = bytes[element.slice.start..element.slice.end]; + return Algorithm.map.get(oid_bytes) orelse { + //std.debug.print("oid bytes: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); return error.CertificateHasUnrecognizedAlgorithm; + }; } pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c8fd41f83a..bf6d0084f5 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -308,8 +308,23 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var prev_cert: Certificate.Parsed = undefined; // Set to true once a trust chain has been established from the first // certificate to a root CA. - var cert_verification_done = false; + const HandshakeState = enum { + /// In this state we expect only an encrypted_extensions message. + encrypted_extensions, + /// In this state we expect certificate messages. + certificate, + /// In this state we expect certificate or certificate_verify messages. + /// certificate messages are ignored since the trust chain is already + /// established. + trust_chain_established, + /// In this state, we expect only the finished message. + finished, + }; + var handshake_state: HandshakeState = .encrypted_extensions; var cleartext_bufs: [2][8000]u8 = undefined; + var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; + var main_cert_pub_key_buf: [128]u8 = undefined; + var main_cert_pub_key_len: u8 = undefined; while (true) { const end_hdr = i + 5; @@ -376,6 +391,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { @enumToInt(HandshakeType.encrypted_extensions) => { + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + handshake_state = .certificate; switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } @@ -403,7 +420,11 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } - if (cert_verification_done) break :cert; + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } var hs_i: u32 = 0; const cert_req_ctx_len = handshake[hs_i]; hs_i += 1; @@ -421,38 +442,41 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .index = hs_i, }; const subject = try subject_cert.parse(); - if (cert_index > 0) { - if (prev_cert.verify(subject)) |_| { - std.debug.print("previous certificate verified\n", .{}); - } else |err| { + if (cert_index == 0) { + // Verify the host on the first certificate. + if (!hostMatchesCommonName(host, subject.commonName())) { + return error.TlsCertificateHostMismatch; + } + + // Keep track of the public key for + // the certificate_verify message + // later. + main_cert_pub_key_algo = subject.pub_key_algo; + const pub_key = subject.pubKey(); + if (pub_key.len > main_cert_pub_key_buf.len) + return error.CertificatePublicKeyInvalid; + @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); + main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); + } else { + prev_cert.verify(subject) catch |err| { std.debug.print("unable to validate previous cert: {s}\n", .{ @errorName(err), }); - } - } else { - // Verify the host on the first certificate. - const common_name = subject.commonName(); - if (mem.eql(u8, common_name, host)) { - std.debug.print("exact host match\n", .{}); - } else if (mem.startsWith(u8, common_name, "*.") and - (mem.endsWith(u8, host, common_name[1..]) or - mem.eql(u8, common_name[2..], host))) - { - std.debug.print("wildcard host match\n", .{}); - } else { - std.debug.print("host does not match\n", .{}); - return error.TlsCertificateInvalidHost; - } + return err; + }; } if (ca_bundle.verify(subject)) |_| { - std.debug.print("found a root CA cert matching issuer. verification success!\n", .{}); - cert_verification_done = true; + handshake_state = .trust_chain_established; break :cert; - } else |err| { - std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ - @errorName(err), - }); + } else |err| switch (err) { + error.IssuerNotFound => {}, + else => |e| { + std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ + @errorName(e), + }); + return e; + }, } prev_cert = subject; @@ -465,12 +489,46 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) } }, @enumToInt(HandshakeType.certificate_verify) => { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), + switch (handshake_state) { + .trust_chain_established => handshake_state = .finished, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + + const algorithm = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); + const sig_len = mem.readIntBig(u16, handshake[2..4]); + if (4 + sig_len > handshake.len) return error.TlsBadLength; + const encoded_sig = handshake[4..][0..sig_len]; + const max_digest_len = 64; + var verify_buffer = + ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + ([1]u8{undefined} ** max_digest_len); + + const verify_bytes = switch (handshake_cipher) { + inline else => |*p| v: { + const transcript_digest = p.transcript_hash.peek(); + verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; + p.transcript_hash.update(wrapped_handshake); + break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; + }, + }; + const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; + + switch (algorithm) { + .ecdsa_secp256r1_sha256 => { + if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) + return error.TlsBadSignatureAlgorithm; + const P256 = std.crypto.sign.ecdsa.EcdsaP256Sha256; + const sig = try P256.Signature.fromDer(encoded_sig); + const key = try P256.PublicKey.fromSec1(main_cert_pub_key); + try sig.verify(verify_bytes, key); + }, + else => return error.TlsBadSignatureAlgorithm, } - std.debug.print("ignoring certificate_verify\n", .{}); }, @enumToInt(HandshakeType.finished) => { + if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. const client_change_cipher_spec_msg = [_]u8{ @enumToInt(ContentType.change_cipher_spec), @@ -762,6 +820,26 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { return out; } +fn hostMatchesCommonName(host: []const u8, common_name: []const u8) bool { + if (mem.eql(u8, common_name, host)) { + return true; // exact match + } + + if (mem.startsWith(u8, common_name, "*.")) { + // wildcard certificate, matches any subdomain + if (mem.endsWith(u8, host, common_name[1..])) { + // The host has a subdomain, but the important part matches. + return true; + } + if (mem.eql(u8, common_name[2..], host)) { + // The host has no subdomain and matches exactly. + return true; + } + } + + return false; +} + const builtin = @import("builtin"); const native_endian = builtin.cpu.arch.endian(); From 862ecf23442b7b399f07400c8997c6481f329853 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Dec 2022 19:01:21 -0700 Subject: [PATCH 30/59] std.crypto.tls.Client: handle extra data after handshake --- lib/std/crypto/tls/Client.zig | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index bf6d0084f5..f46b233fb3 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -583,15 +583,16 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }); }, }; - std.debug.print("remaining bytes: {d}\n", .{len - end}); - return .{ + var client: Client = .{ .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, .partially_read_buffer = undefined, - .partially_read_len = 0, + .partially_read_len = @intCast(u15, len - end), .eof = false, }; + mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]); + return client; }, else => { return error.TlsUnexpectedMessage; From 7cb535d4b54a4e5627edc6b558d1f31b41651328 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 22 Dec 2022 19:56:43 -0700 Subject: [PATCH 31/59] std.crypto.tls.Certificate: verify time validity When scanning the file system for root certificates, expired certificates are skipped and therefore not used for verification in TLS sessions. There is only this one check, however, so a long-running server will need to periodically rescan for a new Certificate.Bundle and strategically start using it for new sessions. In this commit I made the judgement call that applications would like to opt-in to root certificate rescanning at a point in time that makes sense for that application, as opposed to having the system clock potentially start causing connections to fail. Certificate verification checks the subject only, as opposed to both the subject and the issuer. The idea is that the trust chain analysis will always check the subject, leading to every certificate in the chain's validity being checked exactly once, with the root certificate's validity checked upon scanning. Furthermore, this commit adjusts the scanning logic to fully parse certificates, even though only the subject is technically needed. This allows relying on parsing to succeed later on. --- lib/std/crypto/Certificate.zig | 164 +++++++++++++++++++++++++- lib/std/crypto/Certificate/Bundle.zig | 45 ++++--- lib/std/crypto/der.zig | 2 + 3 files changed, 188 insertions(+), 23 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 0cdafa7ade..ba36958ae8 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -79,6 +79,12 @@ pub const Parsed = struct { pub_key_algo: AlgorithmCategory, pub_key_slice: Slice, message_slice: Slice, + validity: Validity, + + pub const Validity = struct { + not_before: u64, + not_after: u64, + }; pub const Slice = der.Element.Slice; @@ -110,6 +116,8 @@ pub const Parsed = struct { return p.slice(p.message_slice); } + /// This function checks the time validity for the subject only. Checking + /// the issuer's time validity is out of scope. pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) !void { // Check that the subject's issuer name matches the issuer's // subject name. @@ -117,8 +125,11 @@ pub const Parsed = struct { return error.CertificateIssuerMismatch; } - // TODO check the time validity for the subject - // TODO check the time validity for the issuer + const now_sec = std.time.timestamp(); + if (now_sec < parsed_subject.validity.not_before) + return error.CertificateNotYetValid; + if (now_sec > parsed_subject.validity.not_after) + return error.CertificateExpired; switch (parsed_subject.signature_algorithm) { inline .sha1WithRSAEncryption, @@ -157,6 +168,10 @@ pub fn parse(cert: Certificate) !Parsed { const tbs_signature = try der.parseElement(cert_bytes, serial_number.slice.end); const issuer = try der.parseElement(cert_bytes, tbs_signature.slice.end); const validity = try der.parseElement(cert_bytes, issuer.slice.end); + const not_before = try der.parseElement(cert_bytes, validity.slice.start); + const not_before_utc = try parseTime(cert, not_before); + const not_after = try der.parseElement(cert_bytes, not_before.slice.end); + const not_after_utc = try parseTime(cert, not_after); const subject = try der.parseElement(cert_bytes, validity.slice.end); const pub_key_info = try der.parseElement(cert_bytes, subject.slice.end); @@ -198,6 +213,10 @@ pub fn parse(cert: Certificate) !Parsed { .message_slice = .{ .start = certificate.slice.start, .end = tbs_certificate.slice.end }, .pub_key_algo = pub_key_algo, .pub_key_slice = pub_key, + .validity = .{ + .not_before = not_before_utc, + .not_after = not_after_utc, + }, }; } @@ -208,7 +227,7 @@ pub fn verify(subject: Certificate, issuer: Certificate) !void { } pub fn contents(cert: Certificate, elem: der.Element) []const u8 { - return cert.buffer[elem.start..elem.end]; + return cert.buffer[elem.slice.start..elem.slice.end]; } pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { @@ -217,6 +236,133 @@ pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice { return .{ .start = elem.slice.start + 1, .end = elem.slice.end }; } +/// Returns number of seconds since epoch. +pub fn parseTime(cert: Certificate, elem: der.Element) !u64 { + const bytes = cert.contents(elem); + switch (elem.identifier.tag) { + .utc_time => { + // Example: "YYMMDD000000Z" + if (bytes.len != 13) + return error.CertificateTimeInvalid; + if (bytes[12] != 'Z') + return error.CertificateTimeInvalid; + + return Date.toSeconds(.{ + .year = @as(u16, 2000) + try parseTimeDigits(bytes[0..2].*, 0, 99), + .month = try parseTimeDigits(bytes[2..4].*, 1, 12), + .day = try parseTimeDigits(bytes[4..6].*, 1, 31), + .hour = try parseTimeDigits(bytes[6..8].*, 0, 23), + .minute = try parseTimeDigits(bytes[8..10].*, 0, 59), + .second = try parseTimeDigits(bytes[10..12].*, 0, 59), + }); + }, + .generalized_time => { + // Examples: + // "19920521000000Z" + // "19920622123421Z" + // "19920722132100.3Z" + if (bytes.len < 15) + return error.CertificateTimeInvalid; + return Date.toSeconds(.{ + .year = try parseYear4(bytes[0..4]), + .month = try parseTimeDigits(bytes[4..6].*, 1, 12), + .day = try parseTimeDigits(bytes[6..8].*, 1, 31), + .hour = try parseTimeDigits(bytes[8..10].*, 0, 23), + .minute = try parseTimeDigits(bytes[10..12].*, 0, 59), + .second = try parseTimeDigits(bytes[12..14].*, 0, 59), + }); + }, + else => return error.CertificateFieldHasWrongDataType, + } +} + +const Date = struct { + /// example: 1999 + year: u16, + /// range: 1 to 12 + month: u8, + /// range: 1 to 31 + day: u8, + /// range: 0 to 59 + hour: u8, + /// range: 0 to 59 + minute: u8, + /// range: 0 to 59 + second: u8, + + /// Convert to number of seconds since epoch. + pub fn toSeconds(date: Date) u64 { + var sec: u64 = 0; + + { + var year: u16 = 1970; + while (year < date.year) : (year += 1) { + const days: u64 = std.time.epoch.getDaysInYear(year); + sec += days * std.time.epoch.secs_per_day; + } + } + + { + const is_leap = std.time.epoch.isLeapYear(date.year); + var month: u4 = 1; + while (month < date.month) : (month += 1) { + const days: u64 = std.time.epoch.getDaysInMonth( + @intToEnum(std.time.epoch.YearLeapKind, @boolToInt(is_leap)), + @intToEnum(std.time.epoch.Month, month), + ); + sec += days * std.time.epoch.secs_per_day; + } + } + + sec += (date.day - 1) * @as(u64, std.time.epoch.secs_per_day); + sec += date.hour * @as(u64, 60 * 60); + sec += date.minute * @as(u64, 60); + sec += date.second; + + return sec; + } +}; + +pub fn parseTimeDigits(nn: @Vector(2, u8), min: u8, max: u8) !u8 { + const zero: @Vector(2, u8) = .{ '0', '0' }; + const mm: @Vector(2, u8) = .{ 10, 1 }; + const result = @reduce(.Add, (nn -% zero) *% mm); + if (result < min) return error.CertificateTimeInvalid; + if (result > max) return error.CertificateTimeInvalid; + return result; +} + +test parseTimeDigits { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u8, 0), try parseTimeDigits("00".*, 0, 99)); + try expectEqual(@as(u8, 99), try parseTimeDigits("99".*, 0, 99)); + try expectEqual(@as(u8, 42), try parseTimeDigits("42".*, 0, 99)); + + const expectError = std.testing.expectError; + try expectError(error.CertificateTimeInvalid, parseTimeDigits("13".*, 1, 12)); + try expectError(error.CertificateTimeInvalid, parseTimeDigits("00".*, 1, 12)); +} + +pub fn parseYear4(text: *const [4]u8) !u16 { + const nnnn: @Vector(4, u16) = .{ text[0], text[1], text[2], text[3] }; + const zero: @Vector(4, u16) = .{ '0', '0', '0', '0' }; + const mmmm: @Vector(4, u16) = .{ 1000, 100, 10, 1 }; + const result = @reduce(.Add, (nnnn -% zero) *% mmmm); + if (result > 9999) return error.CertificateTimeInvalid; + return result; +} + +test parseYear4 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u16, 0), try parseYear4("0000")); + try expectEqual(@as(u16, 9999), try parseYear4("9999")); + try expectEqual(@as(u16, 1988), try parseYear4("1988")); + + const expectError = std.testing.expectError; + try expectError(error.CertificateTimeInvalid, parseYear4("999b")); + try expectError(error.CertificateTimeInvalid, parseYear4("crap")); +} + pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; @@ -241,7 +387,13 @@ pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { return error.CertificateHasUnrecognizedAlgorithm; } -fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void { +fn verifyRsa( + comptime Hash: type, + message: []const u8, + sig: []const u8, + pub_key_algo: AlgorithmCategory, + pub_key: []const u8, +) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; const pub_key_seq = try der.parseElement(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; @@ -328,6 +480,10 @@ const mem = std.mem; const der = std.crypto.der; const Certificate = @This(); +test { + _ = Bundle; +} + /// TODO: replace this with Frank's upcoming RSA implementation. the verify /// function won't have the possibility of failure - it will either identify a /// valid signature or an invalid signature. diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index c2c18552a7..68b2967d10 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -44,12 +44,20 @@ pub fn deinit(cb: *Bundle, gpa: Allocator) void { cb.* = undefined; } -/// Empties the set of certificates and then scans the host operating system +/// Clears the set of certificates and then scans the host operating system /// file system standard locations for certificates. +/// For operating systems that do not have standard CA installations to be +/// found, this function clears the set of certificates. pub fn rescan(cb: *Bundle, gpa: Allocator) !void { switch (builtin.os.tag) { .linux => return rescanLinux(cb, gpa), - else => @compileError("it is unknown where the root CA certificates live on this OS"), + .windows => { + // TODO + }, + .macos => { + // TODO + }, + else => {}, } } @@ -100,6 +108,8 @@ pub fn addCertsFromFile( const begin_marker = "-----BEGIN CERTIFICATE-----"; const end_marker = "-----END CERTIFICATE-----"; + const now_sec = std.time.timestamp(); + var start_index: usize = 0; while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| { const cert_start = begin_marker_start + begin_marker.len; @@ -110,8 +120,20 @@ pub fn addCertsFromFile( const decoded_start = @intCast(u32, cb.bytes.items.len); const dest_buf = cb.bytes.allocatedSlice()[decoded_start..]; cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert); - const k = try cb.key(decoded_start); - const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb }); + // Even though we could only partially parse the certificate to find + // the subject name, we pre-parse all of them to make sure and only + // include in the bundle ones that we know will parse. This way we can + // use `catch unreachable` later. + const parsed_cert = try Certificate.parse(.{ + .buffer = cb.bytes.items, + .index = decoded_start, + }); + if (now_sec > parsed_cert.validity.not_after) { + // Ignore expired cert. + cb.bytes.items.len = decoded_start; + continue; + } + const gop = try cb.map.getOrPutContext(gpa, parsed_cert.subject_slice, .{ .cb = cb }); if (gop.found_existing) { cb.bytes.items.len = decoded_start; } else { @@ -120,21 +142,6 @@ pub fn addCertsFromFile( } } -pub fn key(cb: Bundle, bytes_index: u32) !der.Element.Slice { - const bytes = cb.bytes.items; - const certificate = try der.parseElement(bytes, bytes_index); - const tbs_certificate = try der.parseElement(bytes, certificate.slice.start); - const version = try der.parseElement(bytes, tbs_certificate.slice.start); - try Certificate.checkVersion(bytes, version); - const serial_number = try der.parseElement(bytes, version.slice.end); - const signature = try der.parseElement(bytes, serial_number.slice.end); - const issuer = try der.parseElement(bytes, signature.slice.end); - const validity = try der.parseElement(bytes, issuer.slice.end); - const subject = try der.parseElement(bytes, validity.slice.end); - - return subject.slice; -} - const builtin = @import("builtin"); const std = @import("../../std.zig"); const fs = std.fs; diff --git a/lib/std/crypto/der.zig b/lib/std/crypto/der.zig index 82f75421ea..27c8049758 100644 --- a/lib/std/crypto/der.zig +++ b/lib/std/crypto/der.zig @@ -24,6 +24,8 @@ pub const Tag = enum(u5) { object_identifier = 6, sequence = 16, sequence_of = 17, + utc_time = 23, + generalized_time = 24, _, }; From 642a8b05c3687d5c084ed164c773bd4d0a4faaef Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 22 Dec 2022 20:19:25 -0700 Subject: [PATCH 32/59] std.crypto.tls.Certificate: explicit error set for verify --- lib/std/crypto/Certificate.zig | 30 +++++++++++++++++++++++---- lib/std/crypto/Certificate/Bundle.zig | 12 ++++++++--- lib/std/crypto/der.zig | 4 ++-- lib/std/crypto/tls/Client.zig | 2 +- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index ba36958ae8..330c380e5d 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -116,9 +116,23 @@ pub const Parsed = struct { return p.slice(p.message_slice); } + pub const VerifyError = error{ + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateExpired, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureAlgorithmMismatch, + CertificateFieldHasInvalidLength, + CertificateFieldHasWrongDataType, + CertificatePublicKeyInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureInvalid, + CertificateSignatureUnsupportedBitCount, + }; + /// This function checks the time validity for the subject only. Checking /// the issuer's time validity is out of scope. - pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) !void { + pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) VerifyError!void { // Check that the subject's issuer name matches the issuer's // subject name. if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) { @@ -452,11 +466,19 @@ fn verifyRsa( hash_der ++ msg_hashed; - const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); - const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop); + const public_key = rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop) catch |err| switch (err) { + error.OutOfMemory => @panic("TODO don't heap allocate"), + }; + const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop) catch |err| switch (err) { + error.OutOfMemory => @panic("TODO don't heap allocate"), + + error.MessageTooLong => unreachable, + error.NegativeIntoUnsigned => @panic("TODO make RSA not emit this error"), + error.TargetTooSmall => @panic("TODO make RSA not emit this error"), + error.BufferTooSmall => @panic("TODO make RSA not emit this error"), + }; if (!mem.eql(u8, &em, &em_dec)) { - try std.testing.expectEqualSlices(u8, &em, &em_dec); return error.CertificateSignatureInvalid; } }, diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index 68b2967d10..ea2831bcd9 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -9,13 +9,19 @@ map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{}, bytes: std.ArrayListUnmanaged(u8) = .{}, -pub fn verify(cb: Bundle, subject: Certificate.Parsed) !void { - const bytes_index = cb.find(subject.issuer()) orelse return error.IssuerNotFound; +pub const VerifyError = Certificate.Parsed.VerifyError || error{ + CertificateIssuerNotFound, +}; + +pub fn verify(cb: Bundle, subject: Certificate.Parsed) VerifyError!void { + const bytes_index = cb.find(subject.issuer()) orelse return error.CertificateIssuerNotFound; const issuer_cert: Certificate = .{ .buffer = cb.bytes.items, .index = bytes_index, }; - const issuer = try issuer_cert.parse(); + // Every certificate in the bundle is pre-parsed before adding it, ensuring + // that parsing will succeed here. + const issuer = issuer_cert.parse() catch unreachable; try subject.verify(issuer); } diff --git a/lib/std/crypto/der.zig b/lib/std/crypto/der.zig index 27c8049758..9f4065eeb7 100644 --- a/lib/std/crypto/der.zig +++ b/lib/std/crypto/der.zig @@ -111,7 +111,7 @@ pub const Element = struct { }; }; -pub const ParseElementError = error{CertificateHasFieldWithInvalidLength}; +pub const ParseElementError = error{CertificateFieldHasInvalidLength}; pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { var i = index; @@ -131,7 +131,7 @@ pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { const len_size = @truncate(u7, size_byte); if (len_size > @sizeOf(u32)) { - return error.CertificateHasFieldWithInvalidLength; + return error.CertificateFieldHasInvalidLength; } const end_i = i + len_size; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index f46b233fb3..8e46ce5053 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -470,7 +470,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) handshake_state = .trust_chain_established; break :cert; } else |err| switch (err) { - error.IssuerNotFound => {}, + error.CertificateIssuerNotFound => {}, else => |e| { std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ @errorName(e), From c71c562486c5b3e92a1ea936f3c7b848853b2d5c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 22 Dec 2022 20:23:50 -0700 Subject: [PATCH 33/59] remove std.crypto.der Only a little bit of generalized logic for DER encoding is needed and so it can live inside the Certificate namespace. This commit removes the generic "parse object id" function which is no longer used in favor of more specific, smaller sets of object ids used with ComptimeStringMap. --- lib/std/crypto.zig | 2 - lib/std/crypto/Certificate.zig | 84 ++++++++++++- lib/std/crypto/Certificate/Bundle.zig | 2 +- lib/std/crypto/der.zig | 165 -------------------------- 4 files changed, 84 insertions(+), 169 deletions(-) delete mode 100644 lib/std/crypto/der.zig diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 6387eb48ae..20522c175d 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -177,7 +177,6 @@ const std = @import("std.zig"); pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); -pub const der = @import("crypto/der.zig"); pub const Certificate = @import("crypto/Certificate.zig"); test { @@ -269,7 +268,6 @@ test { _ = random; _ = errors; _ = tls; - _ = der; _ = Certificate; } diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 330c380e5d..6c51ccb133 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -499,9 +499,91 @@ pub fn checkVersion(bytes: []const u8, version: der.Element) !void { const std = @import("../std.zig"); const crypto = std.crypto; const mem = std.mem; -const der = std.crypto.der; const Certificate = @This(); +pub const der = struct { + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + pub const PC = enum(u1) { + primitive, + constructed, + }; + + pub const Identifier = packed struct(u8) { + tag: Tag, + pc: PC, + class: Class, + }; + + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + null = 5, + object_identifier = 6, + sequence = 16, + sequence_of = 17, + utc_time = 23, + generalized_time = 24, + _, + }; + + pub const Element = struct { + identifier: Identifier, + slice: Slice, + + pub const Slice = struct { + start: u32, + end: u32, + + pub const empty: Slice = .{ .start = 0, .end = 0 }; + }; + }; + + pub const ParseElementError = error{CertificateFieldHasInvalidLength}; + + pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { + var i = index; + const identifier = @bitCast(Identifier, bytes[i]); + i += 1; + const size_byte = bytes[i]; + i += 1; + if ((size_byte >> 7) == 0) { + return .{ + .identifier = identifier, + .slice = .{ + .start = i, + .end = i + size_byte, + }, + }; + } + + const len_size = @truncate(u7, size_byte); + if (len_size > @sizeOf(u32)) { + return error.CertificateFieldHasInvalidLength; + } + + const end_i = i + len_size; + var long_form_size: u32 = 0; + while (i < end_i) : (i += 1) { + long_form_size = (long_form_size << 8) | bytes[i]; + } + + return .{ + .identifier = identifier, + .slice = .{ + .start = i, + .end = i + long_form_size, + }, + }; + } +}; + test { _ = Bundle; } diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index ea2831bcd9..8c1a63cd46 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -154,8 +154,8 @@ const fs = std.fs; const mem = std.mem; const crypto = std.crypto; const Allocator = std.mem.Allocator; -const der = std.crypto.der; const Certificate = std.crypto.Certificate; +const der = Certificate.der; const Bundle = @This(); const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); diff --git a/lib/std/crypto/der.zig b/lib/std/crypto/der.zig deleted file mode 100644 index 9f4065eeb7..0000000000 --- a/lib/std/crypto/der.zig +++ /dev/null @@ -1,165 +0,0 @@ -pub const Class = enum(u2) { - universal, - application, - context_specific, - private, -}; - -pub const PC = enum(u1) { - primitive, - constructed, -}; - -pub const Identifier = packed struct(u8) { - tag: Tag, - pc: PC, - class: Class, -}; - -pub const Tag = enum(u5) { - boolean = 1, - integer = 2, - bitstring = 3, - null = 5, - object_identifier = 6, - sequence = 16, - sequence_of = 17, - utc_time = 23, - generalized_time = 24, - _, -}; - -pub const Oid = enum { - rsadsi, - pkcs, - rsaEncryption, - md2WithRSAEncryption, - md5WithRSAEncryption, - sha1WithRSAEncryption, - sha256WithRSAEncryption, - sha384WithRSAEncryption, - sha512WithRSAEncryption, - sha224WithRSAEncryption, - pbeWithMD2AndDES_CBC, - pbeWithMD5AndDES_CBC, - pkcs9_emailAddress, - md2, - md5, - rc4, - ecdsa_with_Recommended, - ecdsa_with_Specified, - ecdsa_with_SHA224, - ecdsa_with_SHA256, - ecdsa_with_SHA384, - ecdsa_with_SHA512, - X500, - X509, - commonName, - serialNumber, - countryName, - localityName, - stateOrProvinceName, - organizationName, - organizationalUnitName, - organizationIdentifier, - - pub const map = std.ComptimeStringMap(Oid, .{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D }, .rsadsi }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01 }, .pkcs }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x05, 0x01 }, .pbeWithMD2AndDES_CBC }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x05, 0x03 }, .pbeWithMD5AndDES_CBC }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x02, 0x02 }, .md2 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x02, 0x05 }, .md5 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x03, 0x04 }, .rc4 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x02 }, .ecdsa_with_Recommended }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03 }, .ecdsa_with_Specified }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, - .{ &[_]u8{0x55}, .X500 }, - .{ &[_]u8{ 0x55, 0x04 }, .X509 }, - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - }); -}; - -pub const Element = struct { - identifier: Identifier, - slice: Slice, - - pub const Slice = struct { - start: u32, - end: u32, - - pub const empty: Slice = .{ .start = 0, .end = 0 }; - }; -}; - -pub const ParseElementError = error{CertificateFieldHasInvalidLength}; - -pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { - var i = index; - const identifier = @bitCast(Identifier, bytes[i]); - i += 1; - const size_byte = bytes[i]; - i += 1; - if ((size_byte >> 7) == 0) { - return .{ - .identifier = identifier, - .slice = .{ - .start = i, - .end = i + size_byte, - }, - }; - } - - const len_size = @truncate(u7, size_byte); - if (len_size > @sizeOf(u32)) { - return error.CertificateFieldHasInvalidLength; - } - - const end_i = i + len_size; - var long_form_size: u32 = 0; - while (i < end_i) : (i += 1) { - long_form_size = (long_form_size << 8) | bytes[i]; - } - - return .{ - .identifier = identifier, - .slice = .{ - .start = i, - .end = i + long_form_size, - }, - }; -} - -pub const ParseObjectIdError = error{ - CertificateHasUnrecognizedObjectId, - CertificateFieldHasWrongDataType, -} || ParseElementError; - -pub fn parseObjectId(bytes: []const u8, element: Element) ParseObjectIdError!Oid { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return Oid.map.get(bytes[element.slice.start..element.slice.end]) orelse - return error.CertificateHasUnrecognizedObjectId; -} - -const std = @import("../std.zig"); -const der = @This(); From 5b8b5f2505ca63dd62f487dcd0357112f959dde7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 25 Dec 2022 23:45:49 -0700 Subject: [PATCH 34/59] add url parsing to the std lib --- lib/std/Url.zig | 98 +++++++++++++++++++++++++++ lib/std/crypto/Certificate/Bundle.zig | 4 +- lib/std/http/Client.zig | 35 +++++----- lib/std/std.zig | 1 + 4 files changed, 120 insertions(+), 18 deletions(-) create mode 100644 lib/std/Url.zig diff --git a/lib/std/Url.zig b/lib/std/Url.zig new file mode 100644 index 0000000000..8887f5de92 --- /dev/null +++ b/lib/std/Url.zig @@ -0,0 +1,98 @@ +scheme: []const u8, +host: []const u8, +path: []const u8, +port: ?u16, + +/// TODO: redo this implementation according to RFC 1738. This code is only a +/// placeholder for now. +pub fn parse(s: []const u8) !Url { + var scheme_end: usize = 0; + var host_start: usize = 0; + var host_end: usize = 0; + var path_start: usize = 0; + var port_start: usize = 0; + var port_end: usize = 0; + var state: enum { + scheme, + scheme_slash1, + scheme_slash2, + host, + port, + path, + } = .scheme; + + for (s) |b, i| switch (state) { + .scheme => switch (b) { + ':' => { + state = .scheme_slash1; + scheme_end = i; + }, + else => {}, + }, + .scheme_slash1 => switch (b) { + '/' => { + state = .scheme_slash2; + }, + else => return error.InvalidUrl, + }, + .scheme_slash2 => switch (b) { + '/' => { + state = .host; + host_start = i + 1; + }, + else => return error.InvalidUrl, + }, + .host => switch (b) { + ':' => { + state = .port; + host_end = i; + port_start = i + 1; + }, + '/' => { + state = .path; + host_end = i; + path_start = i; + }, + else => {}, + }, + .port => switch (b) { + '/' => { + port_end = i; + state = .path; + path_start = i; + }, + else => {}, + }, + .path => {}, + }; + + const port_slice = s[port_start..port_end]; + const port = if (port_slice.len == 0) null else try std.fmt.parseInt(u16, port_slice, 10); + + return .{ + .scheme = s[0..scheme_end], + .host = s[host_start..host_end], + .path = s[path_start..], + .port = port, + }; +} + +const Url = @This(); +const std = @import("std.zig"); +const testing = std.testing; + +test "basic" { + const parsed = try parse("https://ziglang.org/download"); + try testing.expectEqualStrings("https", parsed.scheme); + try testing.expectEqualStrings("ziglang.org", parsed.host); + try testing.expectEqualStrings("/download", parsed.path); + try testing.expectEqual(@as(?u16, null), parsed.port); +} + +test "with port" { + const parsed = try parse("http://example:1337/"); + try testing.expectEqualStrings("http", parsed.scheme); + try testing.expectEqualStrings("example", parsed.host); + try testing.expectEqualStrings("/", parsed.path); + try testing.expectEqual(@as(?u16, 1337), parsed.port); +} diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index 8c1a63cd46..4177676d96 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -105,7 +105,9 @@ pub fn addCertsFromFile( // This is possible by computing the decoded length and reserving the space // for the decoded bytes first. const decoded_size_upper_bound = size / 4 * 3; - try cb.bytes.ensureUnusedCapacity(gpa, decoded_size_upper_bound + size); + const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse + return error.CertificateAuthorityBundleTooBig; + try cb.bytes.ensureUnusedCapacity(gpa, needed_capacity); const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; const buffer = cb.bytes.allocatedSlice()[end_reserved..]; const end_index = try file.readAll(buffer); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1d10870312..4e5bd3da0c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -3,6 +3,7 @@ const assert = std.debug.assert; const http = std.http; const net = std.net; const Client = @This(); +const Url = std.Url; allocator: std.mem.Allocator, headers: std.ArrayListUnmanaged(u8) = .{}, @@ -19,14 +20,7 @@ pub const Request = struct { pub const Protocol = enum { http, https }; pub const Options = struct { - family: Family = .any, - protocol: Protocol = .https, method: http.Method = .GET, - host: []const u8 = "localhost", - path: []const u8 = "/", - port: u16 = 0, - - pub const Family = enum { any, ip4, ip6 }; }; pub fn deinit(req: *Request) void { @@ -90,20 +84,27 @@ pub fn deinit(client: *Client) void { client.* = undefined; } -pub fn request(client: *Client, options: Request.Options) !Request { +pub fn request(client: *Client, url: Url, options: Request.Options) !Request { + const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse + return error.UnsupportedUrlScheme; + const port: u16 = url.port orelse switch (protocol) { + .http => 80, + .https => 443, + }; + var req: Request = .{ .client = client, - .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port), - .protocol = options.protocol, + .stream = try net.tcpConnectToHost(client.allocator, url.host, port), + .protocol = protocol, .tls_client = undefined, }; client.active_requests += 1; errdefer req.deinit(); - switch (options.protocol) { + switch (protocol) { .http => {}, .https => { - req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, options.host); + req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host); }, } @@ -111,19 +112,19 @@ pub fn request(client: *Client, options: Request.Options) !Request { client.allocator, @tagName(options.method).len + 1 + - options.path.len + + url.path.len + " HTTP/1.1\r\nHost: ".len + - options.host.len + + url.host.len + "\r\nUpgrade-Insecure-Requests: 1\r\n".len + client.headers.items.len + 2, // for the \r\n at the end of headers ); req.headers.appendSliceAssumeCapacity(@tagName(options.method)); req.headers.appendSliceAssumeCapacity(" "); - req.headers.appendSliceAssumeCapacity(options.path); + req.headers.appendSliceAssumeCapacity(url.path); req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); - req.headers.appendSliceAssumeCapacity(options.host); - switch (options.protocol) { + req.headers.appendSliceAssumeCapacity(url.host); + switch (protocol) { .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), .http => req.headers.appendSliceAssumeCapacity("\r\n"), } diff --git a/lib/std/std.zig b/lib/std/std.zig index 4bfb44d12f..1cbcd6bad7 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -42,6 +42,7 @@ pub const Target = @import("target.zig").Target; pub const Thread = @import("Thread.zig"); pub const Treap = @import("treap.zig").Treap; pub const Tz = tz.Tz; +pub const Url = @import("Url.zig"); pub const array_hash_map = @import("array_hash_map.zig"); pub const atomic = @import("atomic.zig"); From a1f6a08dcb91c74f31d9a2c75a73c7efb724bf92 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 26 Dec 2022 16:32:25 -0700 Subject: [PATCH 35/59] std.crypto.Certificate.Bundle: fix 32-bit build --- lib/std/crypto/Certificate/Bundle.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index 4177676d96..b30fa531ec 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -108,7 +108,7 @@ pub fn addCertsFromFile( const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse return error.CertificateAuthorityBundleTooBig; try cb.bytes.ensureUnusedCapacity(gpa, needed_capacity); - const end_reserved = cb.bytes.items.len + decoded_size_upper_bound; + const end_reserved = @intCast(u32, cb.bytes.items.len + decoded_size_upper_bound); const buffer = cb.bytes.allocatedSlice()[end_reserved..]; const end_index = try file.readAll(buffer); const encoded_bytes = buffer[0..end_index]; From b24f178029f20cacb559b14e5e5e095fabea4e62 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 17:36:04 -0700 Subject: [PATCH 36/59] std.crypto.tls.Certificate: fix parsing missing subsequent fields Instead of seeing all the attributed types and values, the code was only seeing the first one. --- lib/std/crypto/Certificate.zig | 54 +++++++++++++++++++++++++--------- lib/std/crypto/tls/Client.zig | 18 ++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 6c51ccb133..4f34412b1f 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -56,6 +56,8 @@ pub const Attribute = enum { organizationName, organizationalUnitName, organizationIdentifier, + subject_alt_name, + pkcs9_emailAddress, pub const map = std.ComptimeStringMap(Attribute, .{ .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, @@ -66,6 +68,8 @@ pub const Attribute = enum { .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, }); }; @@ -74,6 +78,7 @@ pub const Parsed = struct { issuer_slice: Slice, subject_slice: Slice, common_name_slice: Slice, + subject_alt_name_slice: Slice, signature_slice: Slice, signature_algorithm: Algorithm, pub_key_algo: AlgorithmCategory, @@ -104,6 +109,10 @@ pub const Parsed = struct { return p.slice(p.common_name_slice); } + pub fn subjectAltName(p: Parsed) []const u8 { + return p.slice(p.subject_alt_name_slice); + } + pub fn signature(p: Parsed) []const u8 { return p.slice(p.signature_slice); } @@ -195,20 +204,33 @@ pub fn parse(cert: Certificate) !Parsed { const pub_key_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.end); const pub_key = try parseBitString(cert, pub_key_elem); - const rdn = try der.parseElement(cert_bytes, subject.slice.start); - const atav = try der.parseElement(cert_bytes, rdn.slice.start); - var common_name = der.Element.Slice.empty; - var atav_i = atav.slice.start; - while (atav_i < atav.slice.end) { - const ty_elem = try der.parseElement(cert_bytes, atav_i); - const ty = try parseAttribute(cert_bytes, ty_elem); - const val = try der.parseElement(cert_bytes, ty_elem.slice.end); - switch (ty) { - .commonName => common_name = val.slice, - else => {}, + var subject_alt_name = der.Element.Slice.empty; + var name_i = subject.slice.start; + //std.debug.print("subject name:\n", .{}); + while (name_i < subject.slice.end) { + const rdn = try der.parseElement(cert_bytes, name_i); + var rdn_i = rdn.slice.start; + while (rdn_i < rdn.slice.end) { + const atav = try der.parseElement(cert_bytes, rdn_i); + var atav_i = atav.slice.start; + while (atav_i < atav.slice.end) { + const ty_elem = try der.parseElement(cert_bytes, atav_i); + const ty = try parseAttribute(cert_bytes, ty_elem); + const val = try der.parseElement(cert_bytes, ty_elem.slice.end); + //std.debug.print(" {s}: '{s}'\n", .{ + // @tagName(ty), cert_bytes[val.slice.start..val.slice.end], + //}); + switch (ty) { + .commonName => common_name = val.slice, + .subject_alt_name => subject_alt_name = val.slice, + else => {}, + } + atav_i = val.slice.end; + } + rdn_i = atav.slice.end; } - atav_i = val.slice.end; + name_i = rdn.slice.end; } const sig_algo = try der.parseElement(cert_bytes, tbs_certificate.slice.end); @@ -220,6 +242,7 @@ pub fn parse(cert: Certificate) !Parsed { return .{ .certificate = cert, .common_name_slice = common_name, + .subject_alt_name_slice = subject_alt_name, .issuer_slice = issuer.slice, .subject_slice = subject.slice, .signature_slice = signature, @@ -397,8 +420,11 @@ pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !Algorith pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; - return Attribute.map.get(bytes[element.slice.start..element.slice.end]) orelse - return error.CertificateHasUnrecognizedAlgorithm; + const oid_bytes = bytes[element.slice.start..element.slice.end]; + return Attribute.map.get(oid_bytes) orelse { + //std.debug.print("attr: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); + return error.CertificateHasUnrecognizedAttribute; + }; } fn verifyRsa( diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 8e46ce5053..aa9df520e6 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -323,8 +323,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var handshake_state: HandshakeState = .encrypted_extensions; var cleartext_bufs: [2][8000]u8 = undefined; var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [128]u8 = undefined; - var main_cert_pub_key_len: u8 = undefined; + var main_cert_pub_key_buf: [300]u8 = undefined; + var main_cert_pub_key_len: u16 = undefined; while (true) { const end_hdr = i + 5; @@ -503,7 +503,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var verify_buffer = ([1]u8{0x20} ** 64) ++ "TLS 1.3, server CertificateVerify\x00".* ++ - ([1]u8{undefined} ** max_digest_len); + @as([max_digest_len]u8, undefined); const verify_bytes = switch (handshake_cipher) { inline else => |*p| v: { @@ -524,7 +524,15 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const key = try P256.PublicKey.fromSec1(main_cert_pub_key); try sig.verify(verify_bytes, key); }, - else => return error.TlsBadSignatureAlgorithm, + .rsa_pss_rsae_sha256 => { + @panic("TODO signature algorithm: rsa_pss_rsae_sha256"); + }, + else => { + //std.debug.print("signature algorithm: {any}\n", .{ + // algorithm, + //}); + return error.TlsBadSignatureAlgorithm; + }, } }, @enumToInt(HandshakeType.finished) => { @@ -557,7 +565,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) @enumToInt(ContentType.application_data), 0x03, 0x03, // legacy protocol version 0, wrapped_len, // byte length of encrypted record - } ++ ([1]u8{undefined} ** wrapped_len); + } ++ @as([wrapped_len]u8, undefined); const ad = finished_msg[0..5]; const ciphertext = finished_msg[5..][0..out_cleartext.len]; From b1cbfa0ec640c3a998e4a59352b93da37f359ff7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 17:37:15 -0700 Subject: [PATCH 37/59] std.crypto.Certificate: remove subject_alt_name parsing I believe this is provided as an extension, not in this location. --- lib/std/crypto/Certificate.zig | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 4f34412b1f..1fa451a680 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -56,7 +56,6 @@ pub const Attribute = enum { organizationName, organizationalUnitName, organizationIdentifier, - subject_alt_name, pkcs9_emailAddress, pub const map = std.ComptimeStringMap(Attribute, .{ @@ -68,7 +67,6 @@ pub const Attribute = enum { .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, }); }; @@ -78,7 +76,6 @@ pub const Parsed = struct { issuer_slice: Slice, subject_slice: Slice, common_name_slice: Slice, - subject_alt_name_slice: Slice, signature_slice: Slice, signature_algorithm: Algorithm, pub_key_algo: AlgorithmCategory, @@ -109,10 +106,6 @@ pub const Parsed = struct { return p.slice(p.common_name_slice); } - pub fn subjectAltName(p: Parsed) []const u8 { - return p.slice(p.subject_alt_name_slice); - } - pub fn signature(p: Parsed) []const u8 { return p.slice(p.signature_slice); } @@ -205,7 +198,6 @@ pub fn parse(cert: Certificate) !Parsed { const pub_key = try parseBitString(cert, pub_key_elem); var common_name = der.Element.Slice.empty; - var subject_alt_name = der.Element.Slice.empty; var name_i = subject.slice.start; //std.debug.print("subject name:\n", .{}); while (name_i < subject.slice.end) { @@ -223,7 +215,6 @@ pub fn parse(cert: Certificate) !Parsed { //}); switch (ty) { .commonName => common_name = val.slice, - .subject_alt_name => subject_alt_name = val.slice, else => {}, } atav_i = val.slice.end; @@ -242,7 +233,6 @@ pub fn parse(cert: Certificate) !Parsed { return .{ .certificate = cert, .common_name_slice = common_name, - .subject_alt_name_slice = subject_alt_name, .issuer_slice = issuer.slice, .subject_slice = subject.slice, .signature_slice = signature, From 5bbedb63cf43297bcb6ff1dca12affafebc9c09e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 19:10:31 -0700 Subject: [PATCH 38/59] std.crypto.Certificate: support verifying secp384r1 pub keys --- lib/std/crypto/Certificate.zig | 237 ++++++++++++++++++++++----------- 1 file changed, 157 insertions(+), 80 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 1fa451a680..a8511d4d9e 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -71,6 +71,16 @@ pub const Attribute = enum { }); }; +pub const NamedCurve = enum { + secp384r1, + X9_62_prime256v1, + + pub const map = std.ComptimeStringMap(NamedCurve, .{ + .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, + .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, + }); +}; + pub const Parsed = struct { certificate: Certificate, issuer_slice: Slice, @@ -78,11 +88,16 @@ pub const Parsed = struct { common_name_slice: Slice, signature_slice: Slice, signature_algorithm: Algorithm, - pub_key_algo: AlgorithmCategory, + pub_key_algo: PubKeyAlgo, pub_key_slice: Slice, message_slice: Slice, validity: Validity, + pub const PubKeyAlgo = union(AlgorithmCategory) { + rsaEncryption: void, + X9_62_id_ecPublicKey: NamedCurve, + }; + pub const Validity = struct { not_before: u64, not_after: u64, @@ -114,6 +129,10 @@ pub const Parsed = struct { return p.slice(p.pub_key_slice); } + pub fn pubKeySigAlgo(p: Parsed) []const u8 { + return p.slice(p.pub_key_signature_algorithm_slice); + } + pub fn message(p: Parsed) []const u8 { return p.slice(p.message_slice); } @@ -130,6 +149,7 @@ pub const Parsed = struct { CertificateSignatureInvalidLength, CertificateSignatureInvalid, CertificateSignatureUnsupportedBitCount, + CertificateSignatureNamedCurveUnsupported, }; /// This function checks the time validity for the subject only. Checking @@ -160,56 +180,78 @@ pub const Parsed = struct { parsed_issuer.pub_key_algo, parsed_issuer.pubKey(), ), - .ecdsa_with_SHA224, + + inline .ecdsa_with_SHA224, .ecdsa_with_SHA256, .ecdsa_with_SHA384, .ecdsa_with_SHA512, - => { - return error.CertificateSignatureAlgorithmUnsupported; - }, + => |algorithm| return verify_ecdsa( + algorithm.Hash(), + parsed_subject.message(), + parsed_subject.signature(), + parsed_issuer.pub_key_algo, + parsed_issuer.pubKey(), + ), } } }; pub fn parse(cert: Certificate) !Parsed { const cert_bytes = cert.buffer; - const certificate = try der.parseElement(cert_bytes, cert.index); - const tbs_certificate = try der.parseElement(cert_bytes, certificate.slice.start); - const version = try der.parseElement(cert_bytes, tbs_certificate.slice.start); + const certificate = try der.Element.parse(cert_bytes, cert.index); + const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start); + const version = try der.Element.parse(cert_bytes, tbs_certificate.slice.start); try checkVersion(cert_bytes, version); - const serial_number = try der.parseElement(cert_bytes, version.slice.end); + const serial_number = try der.Element.parse(cert_bytes, version.slice.end); // RFC 5280, section 4.1.2.3: // "This field MUST contain the same algorithm identifier as // the signatureAlgorithm field in the sequence Certificate." - const tbs_signature = try der.parseElement(cert_bytes, serial_number.slice.end); - const issuer = try der.parseElement(cert_bytes, tbs_signature.slice.end); - const validity = try der.parseElement(cert_bytes, issuer.slice.end); - const not_before = try der.parseElement(cert_bytes, validity.slice.start); + const tbs_signature = try der.Element.parse(cert_bytes, serial_number.slice.end); + const issuer = try der.Element.parse(cert_bytes, tbs_signature.slice.end); + const validity = try der.Element.parse(cert_bytes, issuer.slice.end); + const not_before = try der.Element.parse(cert_bytes, validity.slice.start); const not_before_utc = try parseTime(cert, not_before); - const not_after = try der.parseElement(cert_bytes, not_before.slice.end); + const not_after = try der.Element.parse(cert_bytes, not_before.slice.end); const not_after_utc = try parseTime(cert, not_after); - const subject = try der.parseElement(cert_bytes, validity.slice.end); + const subject = try der.Element.parse(cert_bytes, validity.slice.end); - const pub_key_info = try der.parseElement(cert_bytes, subject.slice.end); - const pub_key_signature_algorithm = try der.parseElement(cert_bytes, pub_key_info.slice.start); - const pub_key_algo_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.start); - const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); - const pub_key_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.end); + const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end); + const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start); + const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start); + const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); + var pub_key_algo: Parsed.PubKeyAlgo = undefined; + switch (pub_key_algo_tag) { + .rsaEncryption => { + pub_key_algo = .{ .rsaEncryption = {} }; + }, + .X9_62_id_ecPublicKey => { + // RFC 5480 Section 2.1.1.1 Named Curve + // ECParameters ::= CHOICE { + // namedCurve OBJECT IDENTIFIER + // -- implicitCurve NULL + // -- specifiedCurve SpecifiedECDomain + // } + const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end); + const named_curve = try parseNamedCurve(cert_bytes, params_elem); + pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve }; + }, + } + const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end); const pub_key = try parseBitString(cert, pub_key_elem); var common_name = der.Element.Slice.empty; var name_i = subject.slice.start; //std.debug.print("subject name:\n", .{}); while (name_i < subject.slice.end) { - const rdn = try der.parseElement(cert_bytes, name_i); + const rdn = try der.Element.parse(cert_bytes, name_i); var rdn_i = rdn.slice.start; while (rdn_i < rdn.slice.end) { - const atav = try der.parseElement(cert_bytes, rdn_i); + const atav = try der.Element.parse(cert_bytes, rdn_i); var atav_i = atav.slice.start; while (atav_i < atav.slice.end) { - const ty_elem = try der.parseElement(cert_bytes, atav_i); + const ty_elem = try der.Element.parse(cert_bytes, atav_i); const ty = try parseAttribute(cert_bytes, ty_elem); - const val = try der.parseElement(cert_bytes, ty_elem.slice.end); + const val = try der.Element.parse(cert_bytes, ty_elem.slice.end); //std.debug.print(" {s}: '{s}'\n", .{ // @tagName(ty), cert_bytes[val.slice.start..val.slice.end], //}); @@ -224,10 +266,10 @@ pub fn parse(cert: Certificate) !Parsed { name_i = rdn.slice.end; } - const sig_algo = try der.parseElement(cert_bytes, tbs_certificate.slice.end); - const algo_elem = try der.parseElement(cert_bytes, sig_algo.slice.start); + const sig_algo = try der.Element.parse(cert_bytes, tbs_certificate.slice.end); + const algo_elem = try der.Element.parse(cert_bytes, sig_algo.slice.start); const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem); - const sig_elem = try der.parseElement(cert_bytes, sig_algo.slice.end); + const sig_elem = try der.Element.parse(cert_bytes, sig_algo.slice.end); const signature = try parseBitString(cert, sig_elem); return .{ @@ -391,45 +433,52 @@ test parseYear4 { } pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - const oid_bytes = bytes[element.slice.start..element.slice.end]; - return Algorithm.map.get(oid_bytes) orelse { - //std.debug.print("oid bytes: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); - return error.CertificateHasUnrecognizedAlgorithm; - }; + return parseEnum(Algorithm, bytes, element); } pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory { - if (element.identifier.tag != .object_identifier) - return error.CertificateFieldHasWrongDataType; - return AlgorithmCategory.map.get(bytes[element.slice.start..element.slice.end]) orelse - return error.CertificateHasUnrecognizedAlgorithmCategory; + return parseEnum(AlgorithmCategory, bytes, element); } pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute { + return parseEnum(Attribute, bytes, element); +} + +pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve { + return parseEnum(NamedCurve, bytes, element); +} + +fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; const oid_bytes = bytes[element.slice.start..element.slice.end]; - return Attribute.map.get(oid_bytes) orelse { - //std.debug.print("attr: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); - return error.CertificateHasUnrecognizedAttribute; + return E.map.get(oid_bytes) orelse { + //std.debug.print("tag: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); + return error.CertificateHasUnrecognizedObjectId; }; } +pub fn checkVersion(bytes: []const u8, version: der.Element) !void { + if (@bitCast(u8, version.identifier) != 0xa0 or + !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02")) + { + return error.UnsupportedCertificateVersion; + } +} + fn verifyRsa( comptime Hash: type, message: []const u8, sig: []const u8, - pub_key_algo: AlgorithmCategory, + pub_key_algo: Parsed.PubKeyAlgo, pub_key: []const u8, ) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pub_key_seq = try der.parseElement(pub_key, 0); + const pub_key_seq = try der.Element.parse(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try der.parseElement(pub_key, pub_key_seq.slice.start); + const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try der.parseElement(pub_key, modulus_elem.slice.end); + const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; // Skip over meaningless zeroes in the modulus. const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; @@ -504,11 +553,39 @@ fn verifyRsa( } } -pub fn checkVersion(bytes: []const u8, version: der.Element) !void { - if (@bitCast(u8, version.identifier) != 0xa0 or - !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02")) - { - return error.UnsupportedCertificateVersion; +fn verify_ecdsa( + comptime Hash: type, + message: []const u8, + encoded_sig: []const u8, + pub_key_algo: Parsed.PubKeyAlgo, + sec1_pub_key: []const u8, +) !void { + const sig_named_curve = switch (pub_key_algo) { + .X9_62_id_ecPublicKey => |named_curve| named_curve, + else => return error.CertificateSignatureAlgorithmMismatch, + }; + + switch (sig_named_curve) { + .secp384r1 => { + const P = crypto.ecc.P384; + const Ecdsa = crypto.sign.ecdsa.Ecdsa(P, Hash); + const sig = Ecdsa.Signature.fromDer(encoded_sig) catch |err| switch (err) { + error.InvalidEncoding => return error.CertificateSignatureInvalid, + }; + const pub_key = Ecdsa.PublicKey.fromSec1(sec1_pub_key) catch |err| switch (err) { + error.InvalidEncoding => return error.CertificateSignatureInvalid, + error.NonCanonical => return error.CertificateSignatureInvalid, + error.NotSquare => return error.CertificateSignatureInvalid, + }; + sig.verify(message, pub_key) catch |err| switch (err) { + error.IdentityElement => return error.CertificateSignatureInvalid, + error.NonCanonical => return error.CertificateSignatureInvalid, + error.SignatureVerificationFailed => return error.CertificateSignatureInvalid, + }; + }, + .X9_62_prime256v1 => { + return error.CertificateSignatureNamedCurveUnsupported; + }, } } @@ -559,45 +636,45 @@ pub const der = struct { pub const empty: Slice = .{ .start = 0, .end = 0 }; }; - }; - pub const ParseElementError = error{CertificateFieldHasInvalidLength}; + pub const ParseError = error{CertificateFieldHasInvalidLength}; + + pub fn parse(bytes: []const u8, index: u32) ParseError!Element { + var i = index; + const identifier = @bitCast(Identifier, bytes[i]); + i += 1; + const size_byte = bytes[i]; + i += 1; + if ((size_byte >> 7) == 0) { + return .{ + .identifier = identifier, + .slice = .{ + .start = i, + .end = i + size_byte, + }, + }; + } + + const len_size = @truncate(u7, size_byte); + if (len_size > @sizeOf(u32)) { + return error.CertificateFieldHasInvalidLength; + } + + const end_i = i + len_size; + var long_form_size: u32 = 0; + while (i < end_i) : (i += 1) { + long_form_size = (long_form_size << 8) | bytes[i]; + } - pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element { - var i = index; - const identifier = @bitCast(Identifier, bytes[i]); - i += 1; - const size_byte = bytes[i]; - i += 1; - if ((size_byte >> 7) == 0) { return .{ .identifier = identifier, .slice = .{ .start = i, - .end = i + size_byte, + .end = i + long_form_size, }, }; } - - const len_size = @truncate(u7, size_byte); - if (len_size > @sizeOf(u32)) { - return error.CertificateFieldHasInvalidLength; - } - - const end_i = i + len_size; - var long_form_size: u32 = 0; - while (i < end_i) : (i += 1) { - long_form_size = (long_form_size << 8) | bytes[i]; - } - - return .{ - .identifier = identifier, - .slice = .{ - .start = i, - .end = i + long_form_size, - }, - }; - } + }; }; test { From ceb211e65fe6bfe864b1150d08cd5e0383f6c2c2 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 22:59:19 -0700 Subject: [PATCH 39/59] std.crypto.tls.Client: handle key_update message --- lib/std/crypto/tls.zig | 8 +++++ lib/std/crypto/tls/Client.zig | 60 ++++++++++++++++++++++++++++++++--- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 12401bff35..acfa8558c1 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -227,6 +227,12 @@ pub const CertificateType = enum(u8) { _, }; +pub const KeyUpdateRequest = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { return struct { pub const AEAD = AeadType; @@ -261,6 +267,8 @@ pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + client_secret: [Hash.digest_length]u8, + server_secret: [Hash.digest_length]u8, client_key: [AEAD.key_length]u8, server_key: [AEAD.key_length]u8, client_iv: [AEAD.nonce_length]u8, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index aa9df520e6..9ab9197dc8 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -24,7 +24,7 @@ read_seq: u64, write_seq: u64, /// The size is enough to contain exactly one TLSCiphertext record. partially_read_buffer: [tls.max_ciphertext_record_len]u8, -/// The number of partially read bytes inside `partiall_read_buffer`. +/// The number of partially read bytes inside `partially_read_buffer`. partially_read_len: u15, eof: bool, @@ -584,6 +584,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) // std.fmt.fmtSliceHexLower(&server_secret), //}); break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + .client_secret = client_secret, + .server_secret = server_secret, .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), @@ -669,7 +671,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { ciphertext_end += auth_tag.len; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq)); - c.write_seq += 1; + c.write_seq += 1; // TODO send key_update on overflow const nonce = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ @@ -789,7 +791,8 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { }, }; - const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]); + const cleartext = buffer[out..][0..cleartext_len]; + const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { const level = @intToEnum(tls.AlertLevel, buffer[out]); @@ -802,7 +805,56 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsAlert; }, .handshake => { - std.debug.print("the server wants to keep shaking hands\n", .{}); + var ct_i: usize = 0; + while (true) { + const handshake_type = cleartext[ct_i]; + ct_i += 1; + const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len - 1) + return error.TlsBadLength; + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + @enumToInt(HandshakeType.new_session_ticket) => { + std.debug.print("server sent a new session ticket\n", .{}); + }, + @enumToInt(HandshakeType.key_update) => { + switch (c.application_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_secret = server_secret; + p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.read_seq = 0; + + switch (@intToEnum(tls.KeyUpdateRequest, handshake[0])) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_secret = client_secret; + p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return error.TlsIllegalParameter, + } + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len - 1) break; + } }, .application_data => { out += cleartext_len - 1; From 477864dca560b03bb3c3e8e7e1e50362e7ed681f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 23:30:43 -0700 Subject: [PATCH 40/59] std.crypto.tls.Client: fix truncation attack vulnerability --- lib/std/crypto/tls/Client.zig | 9 ++++++--- lib/std/http/Client.zig | 10 ++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 9ab9197dc8..260441295d 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -725,8 +725,11 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { /// Returns number of bytes that have been read, which are now populated inside /// `buffer`. A return value of zero bytes does not necessarily mean end of -/// stream. +/// stream. Instead, the `eof` flag is set upon end of stream. The `eof` flag +/// may be set after any call to `read`, including when greater than zero bytes +/// are returned, and this function asserts that `eof` is `false`. pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { + assert(!c.eof); const prev_len = c.partially_read_len; var in_buf: [max_ciphertext_len * 4]u8 = undefined; mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]); @@ -738,8 +741,8 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const actual_read_len = try stream.read(ask_slice); const frag = in_buf[0 .. prev_len + actual_read_len]; if (frag.len == 0) { - c.eof = true; - return 0; + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; } var in: usize = 0; var out: usize = 0; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 4e5bd3da0c..f1f61cae0c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -66,13 +66,11 @@ pub const Request = struct { var index: usize = 0; while (index < len) { const amt = try req.read(buffer[index..]); - if (amt == 0) { - switch (req.protocol) { - .http => break, - .https => if (req.tls_client.eof) break, - } - } index += amt; + switch (req.protocol) { + .http => if (amt == 0) break, + .https => if (req.tls_client.eof) break, + } } return index; } From 21ab99174eabc9ae8efa2b19890d9cab51773b35 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 27 Dec 2022 23:49:15 -0700 Subject: [PATCH 41/59] std.crypto.tls.Client: use enums more --- lib/std/crypto/tls.zig | 3 +++ lib/std/crypto/tls/Client.zig | 41 +++++++++++++++++------------------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index acfa8558c1..fc2523f02a 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -74,6 +74,7 @@ pub const HandshakeType = enum(u8) { finished = 20, key_update = 24, message_hash = 254, + _, }; pub const ExtensionType = enum(u16) { @@ -121,6 +122,8 @@ pub const ExtensionType = enum(u16) { signature_algorithms_cert = 50, /// RFC 8446 key_share = 51, + + _, }; pub const AlertLevel = enum(u8) { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 260441295d..fd22a503c1 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -9,7 +9,6 @@ const assert = std.debug.assert; const ApplicationCipher = tls.ApplicationCipher; const CipherSuite = tls.CipherSuite; const ContentType = tls.ContentType; -const HandshakeType = tls.HandshakeType; const HandshakeCipher = tls.HandshakeCipher; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; @@ -91,7 +90,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) extensions_header; const out_handshake = - [_]u8{@enumToInt(HandshakeType.client_hello)} ++ + [_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++ int3(@intCast(u24, client_hello.len + host_len)) ++ client_hello; @@ -142,7 +141,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) return error.TlsAlert; }, .handshake => { - if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) { return error.TlsUnexpectedMessage; } const length = mem.readIntBig(u24, frag[1..4]); @@ -175,27 +174,27 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var shared_key: [32]u8 = undefined; var have_shared_key = false; while (i < frag.len) { - const et = mem.readIntBig(u16, frag[i..][0..2]); + const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2])); i += 2; const ext_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; const next_i = i + ext_size; if (next_i > frag.len) return error.TlsBadLength; switch (et) { - @enumToInt(tls.ExtensionType.supported_versions) => { + .supported_versions => { if (supported_version != 0) return error.TlsIllegalParameter; supported_version = mem.readIntBig(u16, frag[i..][0..2]); }, - @enumToInt(tls.ExtensionType.key_share) => { + .key_share => { if (have_shared_key) return error.TlsIllegalParameter; have_shared_key = true; - const named_group = mem.readIntBig(u16, frag[i..][0..2]); + const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2])); i += 2; const key_size = mem.readIntBig(u16, frag[i..][0..2]); i += 2; switch (named_group) { - @enumToInt(tls.NamedGroup.x25519) => { + .x25519 => { if (key_size != 32) return error.TlsBadLength; const server_pub_key = frag[i..][0..32]; @@ -204,7 +203,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) server_pub_key.*, ) catch return error.TlsDecryptFailure; }, - @enumToInt(tls.NamedGroup.secp256r1) => { + .secp256r1 => { const server_pub_key = frag[i..][0..key_size]; const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; @@ -217,7 +216,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) shared_key = mul.affineCoordinates().x.toBytes(.Big); }, else => { - std.debug.print("named group: {x}\n", .{named_group}); + //std.debug.print("named group: {x}\n", .{named_group}); return error.TlsIllegalParameter; }, } @@ -380,7 +379,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .handshake => { var ct_i: usize = 0; while (true) { - const handshake_type = cleartext[ct_i]; + const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); ct_i += 1; const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); ct_i += 3; @@ -390,7 +389,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { - @enumToInt(HandshakeType.encrypted_extensions) => { + .encrypted_extensions => { if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; handshake_state = .certificate; switch (handshake_cipher) { @@ -400,13 +399,13 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var hs_i: usize = 2; const end_ext_i = 2 + total_ext_size; while (hs_i < end_ext_i) { - const et = mem.readIntBig(u16, handshake[hs_i..][0..2]); + const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2])); hs_i += 2; const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); hs_i += 2; const next_ext_i = hs_i + ext_size; switch (et) { - @enumToInt(tls.ExtensionType.server_name) => {}, + .server_name => {}, else => { std.debug.print("encrypted extension: {any}\n", .{ et, @@ -416,7 +415,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) hs_i = next_ext_i; } }, - @enumToInt(HandshakeType.certificate) => cert: { + .certificate => cert: { switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } @@ -488,7 +487,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) hs_i += total_ext_size; } }, - @enumToInt(HandshakeType.certificate_verify) => { + .certificate_verify => { switch (handshake_state) { .trust_chain_established => handshake_state = .finished, .certificate => return error.TlsCertificateNotVerified, @@ -535,7 +534,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, } }, - @enumToInt(HandshakeType.finished) => { + .finished => { if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. const client_change_cipher_spec_msg = [_]u8{ @@ -555,7 +554,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ - @enumToInt(HandshakeType.finished), + @enumToInt(tls.HandshakeType.finished), 0, 0, verify_data.len, // length } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; @@ -810,7 +809,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { .handshake => { var ct_i: usize = 0; while (true) { - const handshake_type = cleartext[ct_i]; + const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); ct_i += 1; const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); ct_i += 3; @@ -819,10 +818,10 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsBadLength; const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { - @enumToInt(HandshakeType.new_session_ticket) => { + .new_session_ticket => { std.debug.print("server sent a new session ticket\n", .{}); }, - @enumToInt(HandshakeType.key_update) => { + .key_update => { switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); From 940d368e7ea95d2bb8185e71af3d1ec0328917dc Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 28 Dec 2022 16:37:22 -0700 Subject: [PATCH 42/59] std.crypto.tls.Client: fix the read function The read function has been renamed to readAdvanced since it has slightly different semantics than typical read functions, specifically regarding the end-of-file. A higher level read function is implemented on top. Now, API users may pass small buffers to the read function and everything will work fine. This is done by re-decrypting the same ciphertext record with each call to read() until the record is finished being transmitted. If the buffer supplied to read() is large enough, then any given ciphertext record will only be decrypted once, since it decrypts directly to the read() buffer and therefore does not need any memcpy. On the other hand, if the buffer supplied to read() is small, then the ciphertext is decrypted into a stack buffer, a subset is copied to the read() buffer, and then the entire ciphertext record is saved for the next call to read(). --- lib/std/crypto/tls/Client.zig | 163 +++++++++++++++++++++++++++------- lib/std/http/Client.zig | 12 +-- lib/std/net.zig | 7 +- 3 files changed, 136 insertions(+), 46 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index fd22a503c1..8d37e82117 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -18,14 +18,20 @@ const array = tls.array; const enum_array = tls.enum_array; const Certificate = crypto.Certificate; -application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, -/// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partially_read_buffer`. partially_read_len: u15, +/// The number of cleartext bytes from decoding `partially_read_buffer` which +/// have already been transferred via read() calls. This implementation will +/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by +/// the read() API user is not large enough. +partial_cleartext_index: u15, +application_cipher: ApplicationCipher, eof: bool, +/// The size is enough to contain exactly one TLSCiphertext record. +/// Contains encrypted bytes. +partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// `host` is only borrowed during this function call. pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { @@ -596,6 +602,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, + .partial_cleartext_index = 0, .partially_read_buffer = undefined, .partially_read_len = @intCast(u15, len - end), .eof = false, @@ -722,27 +729,85 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } } -/// Returns number of bytes that have been read, which are now populated inside -/// `buffer`. A return value of zero bytes does not necessarily mean end of -/// stream. Instead, the `eof` flag is set upon end of stream. The `eof` flag -/// may be set after any call to `read`, including when greater than zero bytes -/// are returned, and this function asserts that `eof` is `false`. -pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the buffer has at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { + assert(len <= buffer.len); + if (c.eof) return 0; + var index: usize = 0; + while (index < len) { + index += try c.readAdvanced(stream, buffer[index..]); + if (c.eof) break; + } + return index; +} + +pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, 1); +} + +/// Returns the number of bytes read. If the number read is smaller than +/// `buffer.len`, it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { + return readAtLeast(c, stream, buffer, buffer.len); +} + +/// Returns number of bytes that have been read, populated inside `buffer`. A +/// return value of zero bytes does not mean end of stream. Instead, the `eof` +/// flag is set upon end of stream. The `eof` flag may be set after any call to +/// `read`, including when greater than zero bytes are returned, and this +/// function asserts that `eof` is `false`. +/// See `read` for a higher level function that has the same, familiar API +/// as other read functions, such as `std.fs.File.read`. +/// It is recommended to use a buffer size with length at least +/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same +/// encoded data. +pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { assert(!c.eof); const prev_len = c.partially_read_len; - var in_buf: [max_ciphertext_len * 4]u8 = undefined; - mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]); + // Ideally, this buffer would never be used. It is needed when `buffer` is too small + // to fit the cleartext, which may be as large as `max_ciphertext_len`. + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + // This buffer is typically used, except, as an optimization when a very large + // `buffer` is provided, we use half of it for buffering ciphertext and the + // other half for outputting cleartext. + var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; + const half_buffer_len = buffer.len / 2; + const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{ + buffer[0..half_buffer_len], + buffer[half_buffer_len..], + } else .{ + buffer, + &in_stack_buffer, + }; + const out_buf = out_in[0]; + const in_buf = out_in[1]; + mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]); // Capacity of output buffer, in records, rounded up. - const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); - const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)]; - const actual_read_len = try stream.read(ask_slice); - const frag = in_buf[0 .. prev_len + actual_read_len]; - if (frag.len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; - } + const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); + const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)]; + assert(ask_slice.len > 0); + const frag = frag: { + if (prev_len >= 5) { + const record_size = mem.readIntBig(u16, in_buf[3..][0..2]); + if (prev_len >= 5 + record_size) { + // We can use our buffered data without calling read(). + break :frag in_buf[0..prev_len]; + } + } + const actual_read_len = try stream.read(ask_slice); + if (actual_read_len == 0) { + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; + } + break :frag in_buf[0 .. prev_len + actual_read_len]; + }; var in: usize = 0; var out: usize = 0; @@ -750,6 +815,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { if (in + tls.ciphertext_record_header_len > frag.len) { return finishRead(c, frag, in, out); } + const record_start = in; const ct = @intToEnum(ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); @@ -767,7 +833,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { @panic("TODO handle an alert here"); }, .application_data => { - const cleartext_len = switch (c.application_cipher) { + const cleartext = switch (c.application_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); @@ -776,29 +842,29 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const cleartext = buffer[out..][0..ciphertext_len]; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + // Here we use read_seq and then intentionally don't + // increment it until later when it is certain the same + // ciphertext does not need to be decrypted again. const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); - c.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ - // c.read_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&p.server_key), - // std.fmt.fmtSliceHexLower(&p.server_iv), - //}); + const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len) + out_buf[out..] + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; - break :c cleartext.len; + break :c cleartext; }, }; - const cleartext = buffer[out..][0..cleartext_len]; const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { - const level = @intToEnum(tls.AlertLevel, buffer[out]); - const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]); + c.read_seq += 1; + const level = @intToEnum(tls.AlertLevel, out_buf[out]); + const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]); if (desc == .close_notify) { c.eof = true; return out; @@ -807,6 +873,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { return error.TlsAlert; }, .handshake => { + c.read_seq += 1; var ct_i: usize = 0; while (true) { const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); @@ -819,7 +886,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { - std.debug.print("server sent a new session ticket\n", .{}); + // This client implementation ignores new session tickets. }, .key_update => { switch (c.application_cipher) { @@ -859,7 +926,35 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize { } }, .application_data => { - out += cleartext_len - 1; + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (c.partial_cleartext_index == 0 and + out + cleartext.len <= out_buf.len) + { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + out += cleartext.len - 1; + c.read_seq += 1; + } else { + // Stack buffer was used, so we must copy to the output buffer. + const dest = out_buf[out..]; + const rest = cleartext[c.partial_cleartext_index..]; + const src = rest[0..@min(rest.len, dest.len)]; + mem.copy(u8, dest, src); + out += src.len; + c.partial_cleartext_index = @intCast( + @TypeOf(c.partial_cleartext_index), + c.partial_cleartext_index + src.len, + ); + if (c.partial_cleartext_index >= cleartext.len) { + c.partial_cleartext_index = 0; + c.read_seq += 1; + } else { + in = record_start; + return finishRead(c, frag, in, out); + } + } }, else => { std.debug.print("inner content type: {d}\n", .{inner_ct}); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index f1f61cae0c..d27d879663 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -63,16 +63,10 @@ pub const Request = struct { } pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - var index: usize = 0; - while (index < len) { - const amt = try req.read(buffer[index..]); - index += amt; - switch (req.protocol) { - .http => if (amt == 0) break, - .https => if (req.tls_client.eof) break, - } + switch (req.protocol) { + .http => return req.stream.readAtLeast(buffer, len), + .https => return req.tls_client.readAtLeast(req.stream, buffer, len), } - return index; } }; diff --git a/lib/std/net.zig b/lib/std/net.zig index a265fa69a9..0112d5be8c 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1680,11 +1680,12 @@ pub const Stream = struct { } /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until at least the buffer has at least - /// `len` bytes filled. If the number read is less than `len` it means the - /// stream reached the end. Reaching the end of the stream is not an error + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. Reaching the end of the stream is not an error /// condition. pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); var index: usize = 0; while (index < len) { const amt = try s.read(buffer[index..]); From 16af6286c8055a870128aa1f7c785273b23cad55 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 28 Dec 2022 17:11:10 -0700 Subject: [PATCH 43/59] std.crypto.tls.Client: support SignatureScheme.ecdsa_secp384r1_sha384 --- lib/std/crypto/tls/Client.zig | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 8d37e82117..d4a5b55023 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -500,7 +500,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) else => return error.TlsUnexpectedMessage, } - const algorithm = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); + const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); const sig_len = mem.readIntBig(u16, handshake[2..4]); if (4 + sig_len > handshake.len) return error.TlsBadLength; const encoded_sig = handshake[4..][0..sig_len]; @@ -520,23 +520,25 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }; const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - switch (algorithm) { - .ecdsa_secp256r1_sha256 => { + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureAlgorithm; - const P256 = std.crypto.sign.ecdsa.EcdsaP256Sha256; - const sig = try P256.Signature.fromDer(encoded_sig); - const key = try P256.PublicKey.fromSec1(main_cert_pub_key); + return error.TlsBadSignatureScheme; + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); try sig.verify(verify_bytes, key); }, .rsa_pss_rsae_sha256 => { - @panic("TODO signature algorithm: rsa_pss_rsae_sha256"); + @panic("TODO signature scheme: rsa_pss_rsae_sha256"); }, else => { - //std.debug.print("signature algorithm: {any}\n", .{ - // algorithm, + //std.debug.print("signature scheme: {any}\n", .{ + // scheme, //}); - return error.TlsBadSignatureAlgorithm; + return error.TlsBadSignatureScheme; }, } }, @@ -1008,6 +1010,15 @@ inline fn big(x: anytype) @TypeOf(x) { }; } +fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, + .ecdsa_secp521r1_sha512 => crypto.sign.ecdsa.EcdsaP512Sha512, + else => @compileError("bad scheme"), + }; +} + /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. From 1d20ada3665102ad09fec0ff486b22f4f0a56141 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 28 Dec 2022 19:54:17 -0700 Subject: [PATCH 44/59] std.crypto.tls.Client: refactor to reduce namespace bloat --- lib/std/crypto/tls/Client.zig | 40 ++++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index d4a5b55023..4fa9991bd4 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -5,18 +5,14 @@ const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; +const Certificate = std.crypto.Certificate; -const ApplicationCipher = tls.ApplicationCipher; -const CipherSuite = tls.CipherSuite; -const ContentType = tls.ContentType; -const HandshakeCipher = tls.HandshakeCipher; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; -const Certificate = crypto.Certificate; read_seq: u64, write_seq: u64, @@ -27,7 +23,7 @@ partially_read_len: u15, /// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by /// the read() API user is not large enough. partial_cleartext_index: u15, -application_cipher: ApplicationCipher, +application_cipher: tls.ApplicationCipher, eof: bool, /// The size is enough to contain exactly one TLSCiphertext record. /// Contains encrypted bytes. @@ -101,7 +97,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) client_hello; const plaintext_header = [_]u8{ - @enumToInt(ContentType.handshake), + @enumToInt(tls.ContentType.handshake), 0x03, 0x01, // legacy_record_version } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake; @@ -121,7 +117,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const client_hello_bytes1 = plaintext_header[5..]; - var handshake_cipher: HandshakeCipher = undefined; + var handshake_cipher: tls.HandshakeCipher = undefined; var handshake_buf: [8000]u8 = undefined; var len: usize = 0; @@ -129,7 +125,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const plaintext = handshake_buf[0..5]; len = try stream.readAtLeast(&handshake_buf, plaintext.len); if (len < plaintext.len) return error.EndOfStream; - const ct = @intToEnum(ContentType, plaintext[0]); + const ct = @intToEnum(tls.ContentType, plaintext[0]); const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); const end = plaintext.len + frag_len; if (end > handshake_buf.len) return error.TlsRecordOverflow; @@ -169,7 +165,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) i += 32; const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); i += 2; - const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int); + const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int); const legacy_compression_method = frag[i]; i += 1; _ = legacy_compression_method; @@ -247,8 +243,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .AEGIS_256_SHA384, .AEGIS_128L_SHA256, => |tag| { - const P = std.meta.TagPayloadByName(HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(HandshakeCipher, @tagName(tag), .{ + const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ .handshake_secret = undefined, .master_secret = undefined, .client_handshake_key = undefined, @@ -338,7 +334,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); if (end_hdr > len) return error.EndOfStream; } - const ct = @intToEnum(ContentType, handshake_buf[i]); + const ct = @intToEnum(tls.ContentType, handshake_buf[i]); i += 1; const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]); i += 2; @@ -380,7 +376,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, }; - const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); + const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .handshake => { var ct_i: usize = 0; @@ -546,7 +542,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. const client_change_cipher_spec_msg = [_]u8{ - @enumToInt(ContentType.change_cipher_spec), + @enumToInt(tls.ContentType.change_cipher_spec), 0x03, 0x03, // legacy protocol version 0x00, 0x01, // length 0x01, @@ -564,12 +560,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const out_cleartext = [_]u8{ @enumToInt(tls.HandshakeType.finished), 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)}; + } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; const wrapped_len = out_cleartext.len + P.AEAD.tag_length; var finished_msg = [_]u8{ - @enumToInt(ContentType.application_data), + @enumToInt(tls.ContentType.application_data), 0x03, 0x03, // legacy protocol version 0, wrapped_len, // byte length of encrypted record } ++ @as([wrapped_len]u8, undefined); @@ -590,7 +586,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) // std.fmt.fmtSliceHexLower(&client_secret), // std.fmt.fmtSliceHexLower(&server_secret), //}); - break :c @unionInit(ApplicationCipher, @tagName(tag), .{ + break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .client_secret = client_secret, .server_secret = server_secret, .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), @@ -661,7 +657,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { if (encrypted_content_len == 0) break :l overhead_len; mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data); + cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data); bytes_i += encrypted_content_len; const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; @@ -669,7 +665,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..5]; ad.* = - [_]u8{@enumToInt(ContentType.application_data)} ++ + [_]u8{@enumToInt(tls.ContentType.application_data)} ++ int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++ int2(ciphertext_len + P.AEAD.tag_length); ciphertext_end += ad.len; @@ -818,7 +814,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { return finishRead(c, frag, in, out); } const record_start = in; - const ct = @intToEnum(ContentType, frag[in]); + const ct = @intToEnum(tls.ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; @@ -861,7 +857,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { }, }; - const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]); + const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { c.read_seq += 1; From 7391df2be5143db8308ab7c5281842aea99cb1d7 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 15:43:22 -0700 Subject: [PATCH 45/59] std.crypto: make proper use of `undefined` --- lib/std/crypto/aegis.zig | 4 ++-- lib/std/crypto/aes_gcm.zig | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/std/crypto/aegis.zig b/lib/std/crypto/aegis.zig index 01dd5d547b..da09aca351 100644 --- a/lib/std/crypto/aegis.zig +++ b/lib/std/crypto/aegis.zig @@ -174,7 +174,7 @@ pub const Aegis128L = struct { acc |= (computed_tag[j] ^ tag[j]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } } @@ -343,7 +343,7 @@ pub const Aegis256 = struct { acc |= (computed_tag[j] ^ tag[j]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } } diff --git a/lib/std/crypto/aes_gcm.zig b/lib/std/crypto/aes_gcm.zig index 30fd37e6a0..6eadcdee2f 100644 --- a/lib/std/crypto/aes_gcm.zig +++ b/lib/std/crypto/aes_gcm.zig @@ -91,7 +91,7 @@ fn AesGcm(comptime Aes: anytype) type { acc |= (computed_tag[p] ^ tag[p]); } if (acc != 0) { - mem.set(u8, m, 0xaa); + @memset(m.ptr, undefined, m.len); return error.AuthenticationFailed; } From e4a9b19a1490d5c41a4d8c10f47ba5639de48404 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 15:45:51 -0700 Subject: [PATCH 46/59] std.crypto.tls.Client: rework the read function Here's what I landed on for the TLS client. It's 16896 bytes (max_ciphertext_record_len is 16640). I believe this is the theoretical minimum size, give or take a few bytes. These constraints are satisfied: * a call to the readvAdvanced() function makes at most one call to the underlying readv function * iovecs are provided by the API, and used by the implementation for underlying readv() calls to the socket * the theoretical minimum number of memcpy() calls are issued in all circumstances * decryption is only performed once for any given TLS record * large read buffers are fully exploited This is accomplished by using the partial read buffer to storing both cleartext and ciphertext. --- lib/std/crypto/tls/Client.zig | 455 +++++++++++++++++++++++++--------- lib/std/net.zig | 11 + 2 files changed, 347 insertions(+), 119 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 4fa9991bd4..0e23101ee3 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -16,17 +16,25 @@ const enum_array = tls.enum_array; read_seq: u64, write_seq: u64, -/// The number of partially read bytes inside `partially_read_buffer`. -partially_read_len: u15, -/// The number of cleartext bytes from decoding `partially_read_buffer` which -/// have already been transferred via read() calls. This implementation will -/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by -/// the read() API user is not large enough. -partial_cleartext_index: u15, +/// The starting index of cleartext bytes inside `partially_read_buffer`. +partial_cleartext_idx: u15, +/// The ending index of cleartext bytes inside `partially_read_buffer` as well +/// as the starting index of ciphertext bytes. +partial_ciphertext_idx: u15, +/// The ending index of ciphertext bytes inside `partially_read_buffer`. +partial_ciphertext_end: u15, +/// When this is true, the stream may still not be at the end because there +/// may be data in `partially_read_buffer`. +received_close_notify: bool, application_cipher: tls.ApplicationCipher, -eof: bool, /// The size is enough to contain exactly one TLSCiphertext record. -/// Contains encrypted bytes. +/// This buffer is segmented into four parts: +/// 0. unused +/// 1. cleartext +/// 2. ciphertext +/// 3. unused +/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and +/// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// `host` is only borrowed during this function call. @@ -597,13 +605,14 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, }; var client: Client = .{ - .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, - .partial_cleartext_index = 0, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(u15, len - end), + .received_close_notify = false, + .application_cipher = app_cipher, .partially_read_buffer = undefined, - .partially_read_len = @intCast(u15, len - end), - .eof = false, }; mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]); return client; @@ -727,19 +736,17 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } } +pub fn eof(c: Client) bool { + return c.received_close_notify and c.partial_ciphertext_end == 0; +} + /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the buffer has at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. /// Reaching the end of the stream is not an error condition. pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - assert(len <= buffer.len); - if (c.eof) return 0; - var index: usize = 0; - while (index < len) { - index += try c.readAdvanced(stream, buffer[index..]); - if (c.eof) break; - } - return index; + var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return readvAtLeast(c, stream, &iovecs, len); } pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { @@ -753,78 +760,180 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, buffer.len); } -/// Returns number of bytes that have been read, populated inside `buffer`. A -/// return value of zero bytes does not mean end of stream. Instead, the `eof` -/// flag is set upon end of stream. The `eof` flag may be set after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof` is `false`. -/// See `read` for a higher level function that has the same, familiar API -/// as other read functions, such as `std.fs.File.read`. -/// It is recommended to use a buffer size with length at least -/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same -/// encoded data. -pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { - assert(!c.eof); - const prev_len = c.partially_read_len; - // Ideally, this buffer would never be used. It is needed when `buffer` is too small - // to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // This buffer is typically used, except, as an optimization when a very large - // `buffer` is provided, we use half of it for buffering ciphertext and the - // other half for outputting cleartext. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - const half_buffer_len = buffer.len / 2; - const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{ - buffer[0..half_buffer_len], - buffer[half_buffer_len..], - } else .{ - buffer, - &in_stack_buffer, - }; - const out_buf = out_in[0]; - const in_buf = out_in[1]; - mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]); +/// Returns the number of bytes read. If the number read is less than the space +/// provided it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { + return readvAtLeast(c, stream, iovecs); +} - // Capacity of output buffer, in records, rounded up. - const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the iovecs have at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { + if (c.eof()) return 0; + + var off_i: usize = 0; + var vec_i: usize = 0; + while (true) { + var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); + off_i += amt; + if (c.eof() or off_i >= len) return off_i; + while (amt >= iovecs[vec_i].iov_len) { + amt -= iovecs[vec_i].iov_len; + vec_i += 1; + } + iovecs[vec_i].iov_base += amt; + iovecs[vec_i].iov_len -= amt; + } +} + +/// Returns number of bytes that have been read, populated inside `iovecs`. A +/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` +/// for the end of stream. The `eof()` may be true after any call to +/// `read`, including when greater than zero bytes are returned, and this +/// function asserts that `eof()` is `false`. +/// See `readv` for a higher level function that has the same, familiar API as +/// other read functions, such as `std.fs.File.read`. +pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + + // Give away the buffered cleartext we have, if any. + const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; + if (partial_cleartext.len > 0) { + const amt = @intCast(u15, vp.put(partial_cleartext)); + c.partial_cleartext_idx += amt; + if (amt < partial_cleartext.len) { + // We still have cleartext left so we cannot issue another read() call yet. + assert(vp.total == amt); + return amt; + } + if (c.received_close_notify) { + c.partial_ciphertext_end = 0; + assert(vp.total == amt); + return amt; + } + if (c.partial_ciphertext_end == c.partial_ciphertext_idx) { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = 0; + } + } + + assert(!c.received_close_notify); + + // Ideally, this buffer would never be used. It is needed when `iovecs` are + // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. + var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; + // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. + var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; + // How many bytes left in the user's buffer. + const free_size = vp.freeSize(); + // The amount of the user's buffer that we need to repurpose for storing + // ciphertext. The end of the buffer will be used for such purposes. + const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; + // The amount of the user's buffer that will be used to give cleartext. The + // beginning of the buffer will be used for such purposes. + const cleartext_buf_len = free_size - ciphertext_buf_len; + const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; + + var ask_iovecs_buf: [2]std.os.iovec = .{ + .{ + .iov_base = first_iov.ptr, + .iov_len = first_iov.len, + }, + .{ + .iov_base = &in_stack_buffer, + .iov_len = in_stack_buffer.len, + }, + }; + + // Cleartext capacity of output buffer, in records, rounded up. + const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); - const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)]; - assert(ask_slice.len > 0); - const frag = frag: { - if (prev_len >= 5) { - const record_size = mem.readIntBig(u16, in_buf[3..][0..2]); - if (prev_len >= 5 + record_size) { - // We can use our buffered data without calling read(). - break :frag in_buf[0..prev_len]; - } - } - const actual_read_len = try stream.read(ask_slice); - if (actual_read_len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; - } - break :frag in_buf[0 .. prev_len + actual_read_len]; - }; - var in: usize = 0; - var out: usize = 0; + const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); + const actual_read_len = try stream.readv(ask_iovecs); + if (actual_read_len == 0) { + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; + } + // There might be more bytes inside `in_stack_buffer` that need to be processed, + // but at least frag0 will have one complete ciphertext record. + const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)]; + var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len]; + // We need to decipher frag0 and frag1 but there may be a ciphertext record + // straddling the boundary. We can handle this with two memcpy() calls to + // assemble the straddling record in between handling the two sides. + var frag = frag0; + var in: usize = 0; while (true) { - if (in + tls.ciphertext_record_header_len > frag.len) { - return finishRead(c, frag, in, out); + if (in == frag.len) { + // Perfect split. + if (frag.ptr == frag1.ptr) { + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; + } + frag = frag1; + in = 0; + continue; + } + + if (in + tls.ciphertext_record_header_len > frag.len) { + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + const first = frag[in..]; + + if (frag1.len < tls.ciphertext_record_header_len) + return finishRead2(c, first, frag1, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); + const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); + const record_len = (record_len_byte_0 << 8) | record_len_byte_1; + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + + const second_len = record_len + tls.ciphertext_record_header_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag1 = frag1[second_len..]; + in = 0; + continue; } - const record_start = in; const ct = @intToEnum(tls.ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; _ = legacy_version; - const record_size = mem.readIntBig(u16, frag[in..][0..2]); + const record_len = mem.readIntBig(u16, frag[in..][0..2]); + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; - const end = in + record_size; + const end = in + record_len; if (end > frag.len) { - if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; - return finishRead(c, frag, in, out); + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const first = frag[in..]; + const second_len = record_len + tls.ciphertext_record_header_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag1 = frag1[second_len..]; + in = 0; + continue; } switch (ct) { .alert => { @@ -836,18 +945,16 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_size - P.AEAD.tag_length; + const ciphertext_len = record_len - P.AEAD.tag_length; const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - // Here we use read_seq and then intentionally don't - // increment it until later when it is certain the same - // ciphertext does not need to be decrypted again. const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len) - out_buf[out..] + const out_buf = vp.peek(); + const cleartext_buf = if (ciphertext.len <= out_buf.len) + out_buf else &cleartext_stack_buffer; const cleartext = cleartext_buf[0..ciphertext.len]; @@ -856,22 +963,22 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { break :c cleartext; }, }; + c.read_seq += 1; const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { - c.read_seq += 1; - const level = @intToEnum(tls.AlertLevel, out_buf[out]); - const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]); + const level = @intToEnum(tls.AlertLevel, cleartext[0]); + const desc = @intToEnum(tls.AlertDescription, cleartext[1]); if (desc == .close_notify) { - c.eof = true; - return out; + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; } std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); return error.TlsAlert; }, .handshake => { - c.read_seq += 1; var ct_i: usize = 0; while (true) { const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); @@ -926,42 +1033,37 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { .application_data => { // Determine whether the output buffer or a stack // buffer was used for storing the cleartext. - if (c.partial_cleartext_index == 0 and - out + cleartext.len <= out_buf.len) - { + if (cleartext.ptr == &cleartext_stack_buffer) { + // Stack buffer was used, so we must copy to the output buffer. + const msg = cleartext[0 .. cleartext.len - 1]; + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // We have already run out of room in iovecs. Continue + // appending to `partially_read_buffer`. + const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; + mem.copy(u8, dest, msg); + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); + } else { + const amt = vp.put(msg); + if (amt < msg.len) { + const rest = msg[amt..]; + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); + mem.copy(u8, &c.partially_read_buffer, rest); + } + } + } else { // Output buffer was used directly which means no // memory copying needs to occur, and we can move // on to the next ciphertext record. - out += cleartext.len - 1; - c.read_seq += 1; - } else { - // Stack buffer was used, so we must copy to the output buffer. - const dest = out_buf[out..]; - const rest = cleartext[c.partial_cleartext_index..]; - const src = rest[0..@min(rest.len, dest.len)]; - mem.copy(u8, dest, src); - out += src.len; - c.partial_cleartext_index = @intCast( - @TypeOf(c.partial_cleartext_index), - c.partial_cleartext_index + src.len, - ); - if (c.partial_cleartext_index >= cleartext.len) { - c.partial_cleartext_index = 0; - c.read_seq += 1; - } else { - in = record_start; - return finishRead(c, frag, in, out); - } + vp.next(cleartext.len - 1); } }, else => { - std.debug.print("inner content type: {d}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } }, else => { - std.debug.print("unexpected ct: {any}\n", .{ct}); return error.TlsUnexpectedMessage; }, } @@ -971,11 +1073,43 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; - mem.copy(u8, &c.partially_read_buffer, saved_buf); - c.partially_read_len = @intCast(u15, saved_buf.len); + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + saved_buf.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], saved_buf); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), saved_buf.len); + mem.copy(u8, &c.partially_read_buffer, saved_buf); + } return out; } +fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], first); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..], frag1); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len); + mem.copy(u8, &c.partially_read_buffer, first); + mem.copy(u8, c.partially_read_buffer[first.len..], frag1); + } + return out; +} + +fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { + if (index < s1.len) { + return s1[index]; + } else { + return s2[index - s1.len]; + } +} + fn hostMatchesCommonName(host: []const u8, common_name: []const u8) bool { if (mem.eql(u8, common_name, host)) { return true; // exact match @@ -1015,6 +1149,89 @@ fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { }; } +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +const VecPut = struct { + iovecs: []const std.os.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + fn put(vp: *VecPut, bytes: []const u8) usize { + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.iov_base[vp.off..v.iov_len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + mem.copy(u8, dest, src); + bytes_i += src.len; + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + vp.off += src.len; + if (vp.off >= v.iov_len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } + } + + /// Returns the next buffer that consecutive bytes can go into. + fn peek(vp: VecPut) []u8 { + if (vp.idx >= vp.iovecs.len) return &.{}; + const v = vp.iovecs[vp.idx]; + return v.iov_base[vp.off..v.iov_len]; + } + + // After writing to the result of peek(), one can call next() to + // advance the cursor. + fn next(vp: *VecPut, len: usize) void { + vp.total += len; + vp.off += len; + if (vp.off >= vp.iovecs[vp.idx].iov_len) { + vp.off = 0; + vp.idx += 1; + } + } + + fn freeSize(vp: VecPut) usize { + var total: usize = 0; + + total += vp.iovecs[vp.idx].iov_len - vp.off; + + if (vp.idx + 1 >= vp.iovecs.len) + return total; + + for (vp.iovecs[vp.idx + 1 ..]) |v| { + total += v.iov_len; + } + + return total; + } +}; + +/// Limit iovecs to a specific byte size. +fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { + var vec_i: usize = 0; + var bytes_left: usize = len; + while (true) { + if (bytes_left >= iovecs[vec_i].iov_len) { + bytes_left -= iovecs[vec_i].iov_len; + vec_i += 1; + if (vec_i == iovecs.len or bytes_left == 0) return iovecs[0..vec_i]; + continue; + } + iovecs[vec_i].iov_len = bytes_left; + return iovecs[0..vec_i]; + } +} + /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. diff --git a/lib/std/net.zig b/lib/std/net.zig index 0112d5be8c..aa51176184 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1672,6 +1672,17 @@ pub const Stream = struct { } } + pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + if (builtin.os.tag == .windows) { + // TODO improve this to use ReadFileScatter + if (iovecs.len == 0) return @as(usize, 0); + const first = iovecs[0]; + return os.windows.ReadFile(s.handle, first.iov_base[0..first.iov_len], null, io.default_mode); + } + + return os.readv(s.handle, iovecs); + } + /// Returns the number of bytes read. If the number read is smaller than /// `buffer.len`, it means the stream reached the end. Reaching the end of /// a stream is not an error condition. From 22e2aaa283646858502ac1075c9657383366005d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 17:56:46 -0700 Subject: [PATCH 47/59] crypto.tls: support rsa_pss_rsae_sha256 and fixes * fix eof logic * fix read logic * fix VecPut logic * add some debug prints to remove later --- lib/std/crypto/Certificate.zig | 198 ++++++++++++++++++++++++++++++--- lib/std/crypto/tls/Client.zig | 78 +++++++++---- 2 files changed, 239 insertions(+), 37 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index a8511d4d9e..cce0193cf0 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -474,19 +474,9 @@ fn verifyRsa( pub_key: []const u8, ) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pub_key_seq = try der.Element.parse(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); - if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); - if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; - const modulus_offset = for (modulus_raw) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - const modulus = modulus_raw[modulus_offset..]; - const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + const pk_components = try rsa.PublicKey.parseDer(pub_key); + const exponent = pk_components.exponent; + const modulus = pk_components.modulus; if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; @@ -688,10 +678,154 @@ test { /// which is licensed under the Apache License Version 2.0, January 2004 /// http://www.apache.org/licenses/ /// The code has been modified. -const rsa = struct { +pub const rsa = struct { const BigInt = std.math.big.int.Managed; - const PublicKey = struct { + pub const PSSSignature = struct { + pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { + var result = [1]u8{0} ** modulus_len; + std.mem.copy(u8, &result, msg); + return result; + } + + pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void { + const mod_bits = try countBits(public_key.n.toConst(), allocator); + const em_dec = try encrypt(modulus_len, sig, public_key, allocator); + + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator); + } + + fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void { + // TODO + // 1. If the length of M is greater than the input limitation for + // the hash function (2^61 - 1 octets for SHA-1), output + // "inconsistent" and stop. + + // emLen = \ceil(emBits/8) + const emLen = ((emBit - 1) / 8) + 1; + std.debug.assert(emLen == em.len); + + // 2. Let mHash = Hash(M), an octet string of length hLen. + var mHash: [Hash.digest_length]u8 = undefined; + Hash.hash(msg, &mHash, .{}); + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + if (emLen < Hash.digest_length + sLen + 2) { + return error.InvalidSignature; + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if (em[em.len - 1] != 0xbc) { + return error.InvalidSignature; + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, + // and let H be the next hLen octets. + const maskedDB = em[0..(emLen - Hash.digest_length - 1)]; + const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)]; + + // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + const zero_bits = emLen * 8 - emBit; + var mask: u8 = maskedDB[0]; + var i: usize = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask >> 1; + } + if (mask != 0) { + return error.InvalidSignature; + } + + // 7. Let dbMask = MGF(H, emLen - hLen - 1). + const mgf_len = emLen - Hash.digest_length - 1; + var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length); + defer allocator.free(mgf_out); + var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator); + + // 8. Let DB = maskedDB \xor dbMask. + i = 0; + while (i < dbMask.len) : (i += 1) { + dbMask[i] = maskedDB[i] ^ dbMask[i]; + } + + // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet + // in DB to zero. + i = 0; + mask = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask << 1; + mask += 1; + } + dbMask[0] = dbMask[0] & mask; + + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not + // zero or if the octet at position emLen - hLen - sLen - 1 (the + // leftmost position is "position 1") does not have hexadecimal + // value 0x01, output "inconsistent" and stop. + if (dbMask[mgf_len - sLen - 2] != 0x00) { + return error.InvalidSignature; + } + + if (dbMask[mgf_len - sLen - 1] != 0x01) { + return error.InvalidSignature; + } + + // 11. Let salt be the last sLen octets of DB. + const salt = dbMask[(mgf_len - sLen)..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen); + defer allocator.free(m_p); + std.mem.copy(u8, m_p, &([_]u8{0} ** 8)); + std.mem.copy(u8, m_p[8..], &mHash); + std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt); + + // 13. Let H' = Hash(M'), an octet string of length hLen. + var h_p: [Hash.digest_length]u8 = undefined; + Hash.hash(m_p, &h_p, .{}); + + // 14. If H = H', output "consistent". Otherwise, output + // "inconsistent". + if (!std.mem.eql(u8, h, &h_p)) { + return error.InvalidSignature; + } + } + + fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 { + var counter: usize = 0; + var idx: usize = 0; + var c: [4]u8 = undefined; + + var hash = try allocator.alloc(u8, seed.len + c.len); + defer allocator.free(hash); + std.mem.copy(u8, hash, seed); + var hashed: [Hash.digest_length]u8 = undefined; + + while (idx < len) { + c[0] = @intCast(u8, (counter >> 24) & 0xFF); + c[1] = @intCast(u8, (counter >> 16) & 0xFF); + c[2] = @intCast(u8, (counter >> 8) & 0xFF); + c[3] = @intCast(u8, counter & 0xFF); + + std.mem.copy(u8, hash[seed.len..], &c); + Hash.hash(hash, &hashed, .{}); + + std.mem.copy(u8, out[idx..], &hashed); + idx += hashed.len; + + counter += 1; + } + + return out[0..len]; + } + }; + + pub const PublicKey = struct { n: BigInt, e: BigInt, @@ -714,6 +848,24 @@ const rsa = struct { .e = _e, }; } + + pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + const pub_key_seq = try der.Element.parse(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + return .{ + .modulus = modulus_raw[modulus_offset..], + .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], + }; + } }; fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { @@ -812,6 +964,20 @@ const rsa = struct { try BigInt.divFloor(&q, rem, a, n); } + fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize { + var i: usize = 0; + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a); + + while (!a_copy.eqZero()) { + try a_copy.shiftRight(&a_copy, 1); + i += 1; + } + + return i; + } + // TODO: flush the toilet - const poop = std.heap.page_allocator; + pub const poop = std.heap.page_allocator; }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0e23101ee3..2eb5923187 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -536,7 +536,24 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) try sig.verify(verify_bytes, key); }, .rsa_pss_rsae_sha256 => { - @panic("TODO signature scheme: rsa_pss_rsae_sha256"); + if (main_cert_pub_key_algo != .rsaEncryption) + return error.TlsBadSignatureScheme; + + const Hash = crypto.hash.sha2.Sha256; + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop); + }, + else => { + return error.TlsBadRsaSignatureBitCount; + }, + } }, else => { //std.debug.print("signature scheme: {any}\n", .{ @@ -737,7 +754,7 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } pub fn eof(c: Client) bool { - return c.received_close_notify and c.partial_ciphertext_end == 0; + return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end; } /// Returns the number of bytes read, calling the underlying read function the @@ -822,6 +839,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = 0; + } else { + std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{ + c.partial_ciphertext_end - c.partial_ciphertext_idx, + }); } } @@ -866,8 +887,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // There might be more bytes inside `in_stack_buffer` that need to be processed, // but at least frag0 will have one complete ciphertext record. - const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)]; - var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len]; + const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); + const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; + var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; // We need to decipher frag0 and frag1 but there may be a ciphertext record // straddling the boundary. We can handle this with two memcpy() calls to // assemble the straddling record in between handling the two sides. @@ -900,12 +922,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const record_len = (record_len_byte_0 << 8) | record_len_byte_1; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - const second_len = record_len + tls.ciphertext_record_header_len - first.len; + const full_record_len = record_len + tls.ciphertext_record_header_len; + const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); mem.copy(u8, frag[0..in], first); mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; @@ -914,23 +938,35 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; - _ = legacy_version; + //_ = legacy_version; const record_len = mem.readIntBig(u16, frag[in..][0..2]); + std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{ + ct, legacy_version, record_len, + }); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; const end = in + record_len; if (end > frag.len) { + // We need the record header on the next iteration of the loop. + in -= tls.ciphertext_record_header_len; + if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const first = frag[in..]; - const second_len = record_len + tls.ciphertext_record_header_len - first.len; - if (frag1.len < second_len) + const full_record_len = record_len + tls.ciphertext_record_header_len; + const second_len = full_record_len - first.len; + if (frag1.len < second_len) { + std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{ + end, frag.len, + }); return finishRead2(c, first, frag1, vp.total); + } mem.copy(u8, frag[0..in], first); mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; @@ -991,9 +1027,11 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { + std.debug.print("new_session_ticket\n", .{}); // This client implementation ignores new session tickets. }, .key_update => { + std.debug.print("key_update\n", .{}); switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); @@ -1042,10 +1080,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; mem.copy(u8, dest, msg); c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); + std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len}); } else { const amt = vp.put(msg); + std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len}); if (amt < msg.len) { const rest = msg[amt..]; + std.debug.print(" {d} bytes to partial buffer\n", .{rest.len}); c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); mem.copy(u8, &c.partially_read_buffer, rest); @@ -1055,6 +1096,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // Output buffer was used directly which means no // memory copying needs to occur, and we can move // on to the next ciphertext record. + std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1}); vp.next(cleartext.len - 1); } }, @@ -1166,10 +1208,6 @@ const VecPut = struct { const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; mem.copy(u8, dest, src); bytes_i += src.len; - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } vp.off += src.len; if (vp.off >= v.iov_len) { vp.off = 0; @@ -1179,6 +1217,10 @@ const VecPut = struct { return bytes_i; } } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } } } @@ -1201,17 +1243,11 @@ const VecPut = struct { } fn freeSize(vp: VecPut) usize { + if (vp.idx >= vp.iovecs.len) return 0; var total: usize = 0; - total += vp.iovecs[vp.idx].iov_len - vp.off; - - if (vp.idx + 1 >= vp.iovecs.len) - return total; - - for (vp.iovecs[vp.idx + 1 ..]) |v| { - total += v.iov_len; - } - + if (vp.idx + 1 >= vp.iovecs.len) return total; + for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; return total; } }; From 05fee3b22b593c6b0829499b53f26f5750df3645 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 18:56:51 -0700 Subject: [PATCH 48/59] std.crypto.tls.Client: fix eof logic Before this, it incorrectly returned true when there was still cleartext to be read. --- lib/std/crypto/tls/Client.zig | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 2eb5923187..6260995685 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -754,7 +754,9 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } pub fn eof(c: Client) bool { - return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end; + return c.received_close_notify and + c.partial_cleartext_idx >= c.partial_ciphertext_idx and + c.partial_ciphertext_idx >= c.partial_ciphertext_end; } /// Returns the number of bytes read, calling the underlying read function the From 2d090f61be89d0738b9d2644236717e0fada7e4d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 18:57:53 -0700 Subject: [PATCH 49/59] add std.http.Headers This is a streaming HTTP header parser. All it currently does is detect the end of headers. This will be a non-allocating parser where one can bring supply their own buffer if they want to handle custom headers. This commit also improves std.http.Client to not return the HTTP headers with the read functions. --- lib/std/http.zig | 50 +++++++++++++++++++++++++++++++++++++++++ lib/std/http/Client.zig | 46 ++++++++++++++++++++++++++++++++----- 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/lib/std/http.zig b/lib/std/http.zig index cf92b462b8..944271df27 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -242,10 +242,60 @@ pub const Status = enum(u10) { } }; +pub const Headers = struct { + state: State = .start, + invalid_index: u32 = undefined, + + pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished }; + + /// Returns how many bytes are processed into headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn feed(h: *Headers, bytes: []const u8) usize { + for (bytes) |b, i| { + switch (h.state) { + .start => switch (b) { + '\r' => h.state = .nl_r, + '\n' => return invalid(h, i), + else => {}, + }, + .nl_r => switch (b) { + '\n' => h.state = .nl_n, + else => return invalid(h, i), + }, + .nl_n => switch (b) { + '\r' => h.state = .nl2_r, + else => h.state = .line, + }, + .nl2_r => switch (b) { + '\n' => h.state = .finished, + else => return invalid(h, i), + }, + .line => switch (b) { + '\r' => h.state = .nl_r, + '\n' => return invalid(h, i), + else => {}, + }, + .invalid => return i, + .finished => return i, + } + } + return bytes.len; + } + + fn invalid(h: *Headers, i: usize) usize { + h.invalid_index = @intCast(u32, i); + h.state = .invalid; + return i; + } +}; + const std = @import("std.zig"); test { _ = Client; _ = Method; _ = Status; + _ = Headers; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d27d879663..efae62680d 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -16,6 +16,7 @@ pub const Request = struct { headers: std.ArrayListUnmanaged(u8) = .{}, tls_client: std.crypto.tls.Client, protocol: Protocol, + response_headers: http.Headers = .{}, pub const Protocol = enum { http, https }; @@ -51,18 +52,53 @@ pub const Request = struct { } } + pub fn readAll(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, buffer.len); + } + pub fn read(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, 1); + } + + pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const headers_finished = req.response_headers.state == .finished; + const amt = try readAdvanced(req, buffer[index..]); + if (amt == 0 and headers_finished) break; + index += amt; + } + return index; + } + + /// This one can return 0 without meaning EOF. + /// TODO change to readvAdvanced + pub fn readAdvanced(req: *Request, buffer: []u8) !usize { + if (req.response_headers.state == .finished) return readRaw(req, buffer); + + const amt = try readRaw(req, buffer); + const data = buffer[0..amt]; + const i = req.response_headers.feed(data); + if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders; + if (i < data.len) { + const rest = data[i..]; + std.mem.copy(u8, buffer, rest); + return rest.len; + } + return 0; + } + + /// Only abstracts over http/https. + fn readRaw(req: *Request, buffer: []u8) !usize { switch (req.protocol) { .http => return req.stream.read(buffer), .https => return req.tls_client.read(req.stream, buffer), } } - pub fn readAll(req: *Request, buffer: []u8) !usize { - return readAtLeast(req, buffer, buffer.len); - } - - pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { + /// Only abstracts over http/https. + fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize { switch (req.protocol) { .http => return req.stream.readAtLeast(buffer, len), .https => return req.tls_client.readAtLeast(req.stream, buffer, len), From 79b41dbdbfd6c511b2e206397788b81bc720d266 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 20:49:56 -0700 Subject: [PATCH 50/59] std.crypto.tls: avoid heap allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The code we are borrowing from https://github.com/shiguredo/tls13-zig requires an Allocator for doing RSA certificate verification. As a stopgap measure, this commit uses a FixedBufferAllocator to avoid heap allocation for these functions. Thank you to @naoki9911 for providing this great resource which has been extremely helpful for me when working on this standard library TLS implementation. Until Zig has std.crypto.rsa officially, we will borrow this implementation of RSA. 🙏 --- lib/std/crypto/Certificate.zig | 15 ++++++++------- lib/std/crypto/tls/Client.zig | 7 +++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index cce0193cf0..c4fd66bbc9 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -511,6 +511,10 @@ fn verifyRsa( var msg_hashed: [Hash.digest_length]u8 = undefined; Hash.hash(message, &msg_hashed, .{}); + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; @@ -521,11 +525,11 @@ fn verifyRsa( hash_der ++ msg_hashed; - const public_key = rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop) catch |err| switch (err) { - error.OutOfMemory => @panic("TODO don't heap allocate"), + const public_key = rsa.PublicKey.fromBytes(exponent, modulus, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough }; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop) catch |err| switch (err) { - error.OutOfMemory => @panic("TODO don't heap allocate"), + const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, ally) catch |err| switch (err) { + error.OutOfMemory => unreachable, // rsa_mem_buf is big enough error.MessageTooLong => unreachable, error.NegativeIntoUnsigned => @panic("TODO make RSA not emit this error"), @@ -977,7 +981,4 @@ pub const rsa = struct { return i; } - - // TODO: flush the toilet - pub const poop = std.heap.page_allocator; }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6260995685..c4206862dd 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -544,11 +544,14 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const components = try rsa.PublicKey.parseDer(main_cert_pub_key); const exponent = components.exponent; const modulus = components.modulus; + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); }, else => { return error.TlsBadRsaSignatureBitCount; From 341e68ff8fb83a146f6000b2214c1eae668e9667 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 20:58:42 -0700 Subject: [PATCH 51/59] std.crypto.tls.Client: remove debug prints --- lib/std/crypto/Certificate.zig | 9 +--- lib/std/crypto/tls/Client.zig | 83 ++++------------------------------ 2 files changed, 10 insertions(+), 82 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index c4fd66bbc9..93383f3615 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -241,7 +241,6 @@ pub fn parse(cert: Certificate) !Parsed { var common_name = der.Element.Slice.empty; var name_i = subject.slice.start; - //std.debug.print("subject name:\n", .{}); while (name_i < subject.slice.end) { const rdn = try der.Element.parse(cert_bytes, name_i); var rdn_i = rdn.slice.start; @@ -252,9 +251,6 @@ pub fn parse(cert: Certificate) !Parsed { const ty_elem = try der.Element.parse(cert_bytes, atav_i); const ty = try parseAttribute(cert_bytes, ty_elem); const val = try der.Element.parse(cert_bytes, ty_elem.slice.end); - //std.debug.print(" {s}: '{s}'\n", .{ - // @tagName(ty), cert_bytes[val.slice.start..val.slice.end], - //}); switch (ty) { .commonName => common_name = val.slice, else => {}, @@ -452,10 +448,7 @@ fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; const oid_bytes = bytes[element.slice.start..element.slice.end]; - return E.map.get(oid_bytes) orelse { - //std.debug.print("tag: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)}); - return error.CertificateHasUnrecognizedObjectId; - }; + return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId; } pub fn checkVersion(bytes: []const u8, version: der.Element) !void { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c4206862dd..ec6f00ad8a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -147,7 +147,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) .alert => { const level = @intToEnum(tls.AlertLevel, frag[0]); const desc = @intToEnum(tls.AlertDescription, frag[1]); - std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + _ = level; + _ = desc; return error.TlsAlert; }, .handshake => { @@ -226,14 +227,11 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) shared_key = mul.affineCoordinates().x.toBytes(.Big); }, else => { - //std.debug.print("named group: {x}\n", .{named_group}); return error.TlsIllegalParameter; }, } }, - else => { - std.debug.print("unexpected extension: {x}\n", .{et}); - }, + else => {}, } i = next_i; } @@ -283,18 +281,6 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\nclient_handshake_iv: {}\nserver_handshake_iv: {}\n", .{ - // std.fmt.fmtSliceHexLower(&shared_key), - // std.fmt.fmtSliceHexLower(&hello_hash), - // std.fmt.fmtSliceHexLower(&early_secret), - // std.fmt.fmtSliceHexLower(&empty_hash), - // std.fmt.fmtSliceHexLower(&hs_derived_secret), - // std.fmt.fmtSliceHexLower(&p.handshake_secret), - // std.fmt.fmtSliceHexLower(&client_secret), - // std.fmt.fmtSliceHexLower(&server_secret), - // std.fmt.fmtSliceHexLower(&p.client_handshake_iv), - // std.fmt.fmtSliceHexLower(&p.server_handshake_iv), - //}); }, else => { return error.TlsIllegalParameter; @@ -416,11 +402,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const next_ext_i = hs_i + ext_size; switch (et) { .server_name => {}, - else => { - std.debug.print("encrypted extension: {any}\n", .{ - et, - }); - }, + else => {}, } hs_i = next_ext_i; } @@ -467,12 +449,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); } else { - prev_cert.verify(subject) catch |err| { - std.debug.print("unable to validate previous cert: {s}\n", .{ - @errorName(err), - }); - return err; - }; + try prev_cert.verify(subject); } if (ca_bundle.verify(subject)) |_| { @@ -480,12 +457,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) break :cert; } else |err| switch (err) { error.CertificateIssuerNotFound => {}, - else => |e| { - std.debug.print("unable to validate cert against system root CAs: {s}\n", .{ - @errorName(e), - }); - return e; - }, + else => |e| return e, } prev_cert = subject; @@ -559,9 +531,6 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) } }, else => { - //std.debug.print("signature scheme: {any}\n", .{ - // scheme, - //}); return error.TlsBadSignatureScheme; }, } @@ -609,11 +578,6 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ - // std.fmt.fmtSliceHexLower(&p.master_secret), - // std.fmt.fmtSliceHexLower(&client_secret), - // std.fmt.fmtSliceHexLower(&server_secret), - //}); break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .client_secret = client_secret, .server_secret = server_secret, @@ -646,13 +610,11 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) } }, else => { - std.debug.print("inner content type: {any}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } }, else => { - std.debug.print("content type: {s}\n", .{@tagName(ct)}); return error.TlsUnexpectedMessage; }, } @@ -707,16 +669,6 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { c.write_seq += 1; // TODO send key_update on overflow const nonce = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{ - // c.write_seq - 1, - // std.fmt.fmtSliceHexLower(&nonce), - // std.fmt.fmtSliceHexLower(&p.client_key), - // std.fmt.fmtSliceHexLower(&p.client_iv), - // std.fmt.fmtSliceHexLower(ad), - // std.fmt.fmtSliceHexLower(auth_tag), - // std.fmt.fmtSliceHexLower(&p.server_key), - // std.fmt.fmtSliceHexLower(&p.server_iv), - //}); const record = ciphertext_buf[record_start..ciphertext_end]; iovecs_buf[iovec_end] = .{ @@ -844,10 +796,6 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = 0; - } else { - std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{ - c.partial_ciphertext_end - c.partial_ciphertext_idx, - }); } } @@ -943,11 +891,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; - //_ = legacy_version; + _ = legacy_version; const record_len = mem.readIntBig(u16, frag[in..][0..2]); - std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{ - ct, legacy_version, record_len, - }); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; const end = in + record_len; @@ -962,12 +907,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const first = frag[in..]; const full_record_len = record_len + tls.ciphertext_record_header_len; const second_len = full_record_len - first.len; - if (frag1.len < second_len) { - std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{ - end, frag.len, - }); + if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); - } mem.copy(u8, frag[0..in], first); mem.copy(u8, frag[first.len..], frag1[0..second_len]); @@ -1016,7 +957,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove c.partial_ciphertext_end = c.partial_ciphertext_idx; return vp.total; } - std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + _ = level; return error.TlsAlert; }, .handshake => { @@ -1032,11 +973,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { - std.debug.print("new_session_ticket\n", .{}); // This client implementation ignores new session tickets. }, .key_update => { - std.debug.print("key_update\n", .{}); switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); @@ -1085,13 +1024,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; mem.copy(u8, dest, msg); c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); - std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len}); } else { const amt = vp.put(msg); - std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len}); if (amt < msg.len) { const rest = msg[amt..]; - std.debug.print(" {d} bytes to partial buffer\n", .{rest.len}); c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); mem.copy(u8, &c.partially_read_buffer, rest); @@ -1101,7 +1037,6 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // Output buffer was used directly which means no // memory copying needs to occur, and we can move // on to the next ciphertext record. - std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1}); vp.next(cleartext.len - 1); } }, From 0fb78b15aad74498d7f36785d5618edca7e83508 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Dec 2022 17:57:31 -0700 Subject: [PATCH 52/59] std.crypto.tls: use a Decoder abstraction This commit introduces tls.Decoder and then uses it in tls.Client. The purpose is to make it difficult to introduce vulnerabilities in the parsing code. With this abstraction in place, bugs in the TLS implementation will trip checks in the decoder, regardless of the actual length of packets sent by the other party, so that we can have confidence when using ReleaseFast builds. --- lib/std/crypto/tls.zig | 131 ++++++- lib/std/crypto/tls/Client.zig | 628 ++++++++++++++++------------------ 2 files changed, 423 insertions(+), 336 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index fc2523f02a..8ef4d9bfad 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -39,9 +39,9 @@ const assert = std.debug.assert; pub const Client = @import("tls/Client.zig"); -pub const ciphertext_record_header_len = 5; +pub const record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; -pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; +pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len; pub const hello_retry_request_sequence = [32]u8{ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, @@ -360,3 +360,130 @@ pub inline fn int3(x: u24) [3]u8 { @truncate(u8, x), }; } + +/// An abstraction to ensure that protocol-parsing code does not perform an +/// out-of-bounds read. +pub const Decoder = struct { + buf: []u8, + /// Points to the next byte in buffer that will be decoded. + idx: usize = 0, + /// Up to this point in `buf` we have already checked that `cap` is greater than it. + our_end: usize = 0, + /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we + /// requested with `readAtLeast`. + their_end: usize = 0, + /// Points to the end within buffer that has been filled. Beyond this point + /// in buf is undefined bytes. + cap: usize = 0, + /// Debug helper to prevent illegal calls to read functions. + disable_reads: bool = false, + + pub fn fromTheirSlice(buf: []u8) Decoder { + return .{ + .buf = buf, + .their_end = buf.len, + .cap = buf.len, + .disable_reads = true, + }; + } + + /// Use this function to increase `their_end`. + pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + assert(!d.disable_reads); + const existing_amt = d.cap - d.idx; + d.their_end = d.idx + their_amt; + if (their_amt <= existing_amt) return; + const request_amt = their_amt - existing_amt; + const dest = d.buf[d.cap..]; + if (request_amt > dest.len) return error.TlsRecordOverflow; + const actual_amt = try stream.readAtLeast(dest, request_amt); + if (actual_amt < request_amt) return error.TlsConnectionTruncated; + d.cap += actual_amt; + } + + /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. + /// Use when `our_amt` is calculated by us, not by them. + pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + assert(!d.disable_reads); + try readAtLeast(d, stream, our_amt); + d.our_end = d.idx + our_amt; + } + + /// Use this function to increase `our_end`. + /// This should always be called with an amount provided by us, not them. + pub fn ensure(d: *Decoder, amt: usize) !void { + d.our_end = @max(d.idx + amt, d.our_end); + if (d.our_end > d.their_end) return error.TlsDecodeError; + } + + /// Use this function to increase `idx`. + pub fn decode(d: *Decoder, comptime T: type) T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + skip(d, 1); + return d.buf[d.idx - 1]; + }, + 16 => { + skip(d, 2); + const b0: u16 = d.buf[d.idx - 2]; + const b1: u16 = d.buf[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + skip(d, 3); + const b0: u24 = d.buf[d.idx - 3]; + const b1: u24 = d.buf[d.idx - 2]; + const b2: u24 = d.buf[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @intToEnum(T, int); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + /// Use this function to increase `idx`. + pub fn array(d: *Decoder, comptime len: usize) *[len]u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn slice(d: *Decoder, len: usize) []u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn skip(d: *Decoder, amt: usize) void { + d.idx += amt; + assert(d.idx <= d.our_end); // insufficient ensured bytes + } + + pub fn eof(d: Decoder) bool { + assert(d.our_end <= d.their_end); + assert(d.idx <= d.our_end); + return d.idx == d.their_end; + } + + /// Provide the length they claim, and receive a sub-decoder specific to that slice. + /// The parent decoder is advanced to the end. + pub fn sub(d: *Decoder, their_len: usize) !Decoder { + const end = d.idx + their_len; + if (end > d.their_end) return error.TlsDecodeError; + const sub_buf = d.buf[d.idx..end]; + d.idx = end; + d.our_end = end; + return fromTheirSlice(sub_buf); + } + + pub fn rest(d: Decoder) []u8 { + return d.buf[d.idx..d.cap]; + } +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index ec6f00ad8a..bca05a3ffd 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -126,88 +126,73 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const client_hello_bytes1 = plaintext_header[5..]; var handshake_cipher: tls.HandshakeCipher = undefined; - - var handshake_buf: [8000]u8 = undefined; - var len: usize = 0; - var i: usize = i: { - const plaintext = handshake_buf[0..5]; - len = try stream.readAtLeast(&handshake_buf, plaintext.len); - if (len < plaintext.len) return error.EndOfStream; - const ct = @intToEnum(tls.ContentType, plaintext[0]); - const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); - const end = plaintext.len + frag_len; - if (end > handshake_buf.len) return error.TlsRecordOverflow; - if (end > len) { - len += try stream.readAtLeast(handshake_buf[len..], end - len); - if (end > len) return error.EndOfStream; - } - const frag = handshake_buf[plaintext.len..end]; - + var handshake_buffer: [8000]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; + { + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_record_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + const server_hello_fragment = d.buf[d.idx..][0..record_len]; + var ptd = try d.sub(record_len); switch (ct) { .alert => { - const level = @intToEnum(tls.AlertLevel, frag[0]); - const desc = @intToEnum(tls.AlertDescription, frag[1]); + try ptd.ensure(2); + const level = ptd.decode(tls.AlertLevel); + const desc = ptd.decode(tls.AlertDescription); _ = level; _ = desc; return error.TlsAlert; }, .handshake => { - if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) { + try ptd.ensure(4); + const handshake_type = ptd.decode(tls.HandshakeType); + if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; + const length = ptd.decode(u24); + var hsd = try ptd.sub(length); + try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); + const legacy_version = hsd.decode(u16); + const random = hsd.array(32); + if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. return error.TlsUnexpectedMessage; } - const length = mem.readIntBig(u24, frag[1..4]); - if (4 + length != frag.len) return error.TlsBadLength; - var i: usize = 4; - const legacy_version = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const random = frag[i..][0..32].*; - i += 32; - if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) { - @panic("TODO handle HelloRetryRequest"); - } - const legacy_session_id_echo_len = frag[i]; - i += 1; + const legacy_session_id_echo_len = hsd.decode(u8); if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = frag[i..][0..32]; + const legacy_session_id_echo = hsd.array(32); if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter; - i += 32; - const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int); - const legacy_compression_method = frag[i]; - i += 1; - _ = legacy_compression_method; - const extensions_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - if (i + extensions_size != frag.len) return error.TlsBadLength; + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); var supported_version: u16 = 0; var shared_key: [32]u8 = undefined; var have_shared_key = false; - while (i < frag.len) { - const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2])); - i += 2; - const ext_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const next_i = i + ext_size; - if (next_i > frag.len) return error.TlsBadLength; + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); switch (et) { .supported_versions => { if (supported_version != 0) return error.TlsIllegalParameter; - supported_version = mem.readIntBig(u16, frag[i..][0..2]); + try extd.ensure(2); + supported_version = extd.decode(u16); }, .key_share => { if (have_shared_key) return error.TlsIllegalParameter; have_shared_key = true; - const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2])); - i += 2; - const key_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); switch (named_group) { .x25519 => { - if (key_size != 32) return error.TlsBadLength; - const server_pub_key = frag[i..][0..32]; + if (key_size != 32) return error.TlsIllegalParameter; + const server_pub_key = extd.array(32); shared_key = crypto.dh.X25519.scalarmult( x25519_kp.secret_key, @@ -215,7 +200,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) ) catch return error.TlsDecryptFailure; }, .secp256r1 => { - const server_pub_key = frag[i..][0..key_size]; + const server_pub_key = extd.slice(key_size); const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; const pk = PublicKey.fromSec1(server_pub_key) catch { @@ -233,14 +218,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, else => {}, } - i = next_i; } if (!have_shared_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; - switch (tls_version) { - @enumToInt(tls.ProtocolVersion.tls_1_3) => {}, - else => return error.TlsIllegalParameter, - } + if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3)) + return error.TlsIllegalParameter; switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, @@ -264,7 +247,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const p = &@field(handshake_cipher, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(frag); // Server Hello + p.transcript_hash.update(server_hello_fragment); const hello_hash = p.transcript_hash.peek(); const zeroes = [1]u8{0} ** P.Hash.digest_length; const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); @@ -289,8 +272,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, else => return error.TlsUnexpectedMessage, } - break :i end; - }; + } // This is used for two purposes: // * Detect whether a certificate is the first one presented, in which case @@ -322,29 +304,17 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var main_cert_pub_key_len: u16 = undefined; while (true) { - const end_hdr = i + 5; - if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; - if (end_hdr > len) { - len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); - if (end_hdr > len) return error.EndOfStream; - } - const ct = @intToEnum(tls.ContentType, handshake_buf[i]); - i += 1; - const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]); - i += 2; - _ = legacy_version; - const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); - i += 2; - const end = i + record_size; - if (end > handshake_buf.len) return error.TlsRecordOverflow; - if (end > len) { - len += try stream.readAtLeast(handshake_buf[len..], end - len); - if (end > len) return error.EndOfStream; - } + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const record_header = d.buf[d.idx..][0..5]; + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + var record_decoder = try d.sub(record_len); switch (ct) { .change_cipher_spec => { - if (record_size != 1) return error.TlsUnexpectedMessage; - if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; + try record_decoder.ensure(1); + if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; }, .application_data => { const cleartext_buf = &cleartext_bufs[cert_index % 2]; @@ -352,276 +322,261 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); - const ciphertext_len = record_size - P.AEAD.tag_length; - const ciphertext = handshake_buf[i..][0..ciphertext_len]; - i += ciphertext.len; + const ciphertext_len = record_len - P.AEAD.tag_length; + try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); + const ciphertext = record_decoder.slice(ciphertext_len); if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*; + const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); read_seq += 1; const nonce = @as(V, p.server_handshake_iv) ^ operand; - const ad = handshake_buf[end_hdr - 5 ..][0..5]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; break :c cleartext; }, }; const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); - switch (inner_ct) { - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - const total_ext_size = mem.readIntBig(u16, handshake[0..2]); - var hs_i: usize = 2; - const end_ext_i = 2 + total_ext_size; - while (hs_i < end_ext_i) { - const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2])); - hs_i += 2; - const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); - hs_i += 2; - const next_ext_i = hs_i + ext_size; - switch (et) { - .server_name => {}, - else => {}, - } - hs_i = next_ext_i; + if (inner_ct != .handshake) return error.TlsUnexpectedMessage; + + var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); + while (true) { + try ctd.ensure(4); + const handshake_type = ctd.decode(tls.HandshakeType); + const handshake_len = ctd.decode(u24); + var hsd = try ctd.sub(handshake_len); + const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; + switch (handshake_type) { + .encrypted_extensions => { + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + handshake_state = .certificate; + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(2); + const total_ext_size = hsd.decode(u16); + var all_extd = try hsd.sub(total_ext_size); + while (!all_extd.eof()) { + try all_extd.ensure(4); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + _ = extd; + switch (et) { + .server_name => {}, + else => {}, + } + } + }, + .certificate => cert: { + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } + try hsd.ensure(1 + 4); + const cert_req_ctx_len = hsd.decode(u8); + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + const certs_size = hsd.decode(u24); + var certs_decoder = try hsd.sub(certs_size); + while (!certs_decoder.eof()) { + try certs_decoder.ensure(3); + const cert_size = certs_decoder.decode(u24); + var certd = try certs_decoder.sub(cert_size); + + const subject_cert: Certificate = .{ + .buffer = certd.buf, + .index = @intCast(u32, certd.idx), + }; + const subject = try subject_cert.parse(); + if (cert_index == 0) { + // Verify the host on the first certificate. + if (!hostMatchesCommonName(host, subject.commonName())) { + return error.TlsCertificateHostMismatch; } + + // Keep track of the public key for the + // certificate_verify message later. + main_cert_pub_key_algo = subject.pub_key_algo; + const pub_key = subject.pubKey(); + if (pub_key.len > main_cert_pub_key_buf.len) + return error.CertificatePublicKeyInvalid; + @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); + main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); + } else { + try prev_cert.verify(subject); + } + + if (ca_bundle.verify(subject)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, + } + + prev_cert = subject; + cert_index += 1; + + try certs_decoder.ensure(2); + const total_ext_size = certs_decoder.decode(u16); + var all_extd = try certs_decoder.sub(total_ext_size); + _ = all_extd; + } + }, + .certificate_verify => { + switch (handshake_state) { + .trust_chain_established => handshake_state = .finished, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + + try hsd.ensure(4); + const scheme = hsd.decode(tls.SignatureScheme); + const sig_len = hsd.decode(u16); + try hsd.ensure(sig_len); + const encoded_sig = hsd.slice(sig_len); + const max_digest_len = 64; + var verify_buffer = + ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + @as([max_digest_len]u8, undefined); + + const verify_bytes = switch (handshake_cipher) { + inline else => |*p| v: { + const transcript_digest = p.transcript_hash.peek(); + verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; + p.transcript_hash.update(wrapped_handshake); + break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - var hs_i: u32 = 0; - const cert_req_ctx_len = handshake[hs_i]; - hs_i += 1; - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); - hs_i += 3; - const end_certs = hs_i + certs_size; - while (hs_i < end_certs) { - const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); - hs_i += 3; - const end_cert = hs_i + cert_size; + }; + const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - const subject_cert: Certificate = .{ - .buffer = handshake, - .index = hs_i, - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - if (!hostMatchesCommonName(host, subject.commonName())) { - return error.TlsCertificateHostMismatch; - } - - // Keep track of the public key for - // the certificate_verify message - // later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); - main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); - } else { - try prev_cert.verify(subject); - } - - if (ca_bundle.verify(subject)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; - - hs_i = end_cert; - const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); - hs_i += 2; - hs_i += total_ext_size; - } + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) + return error.TlsBadSignatureScheme; + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); + try sig.verify(verify_bytes, key); }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } + .rsa_pss_rsae_sha256 => { + if (main_cert_pub_key_algo != .rsaEncryption) + return error.TlsBadSignatureScheme; - const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); - const sig_len = mem.readIntBig(u16, handshake[2..4]); - if (4 + sig_len > handshake.len) return error.TlsBadLength; - const encoded_sig = handshake[4..][0..sig_len]; - const max_digest_len = 64; - var verify_buffer = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - .rsa_pss_rsae_sha256 => { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = crypto.hash.sha2.Sha256; - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - var rsa_mem_buf: [512 * 32]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); - const ally = fba.allocator(); - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } + const Hash = crypto.hash.sha2.Sha256; + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); + const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); }, else => { - return error.TlsBadSignatureScheme; + return error.TlsBadRsaSignatureBitCount; }, } }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @enumToInt(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); - const finished_digest = p.transcript_hash.peek(); - p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @enumToInt(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; - - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - var finished_msg = [_]u8{ - @enumToInt(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); - - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - try stream.writeAll(&both_msgs); - - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - .client_secret = client_secret, - .server_secret = server_secret, - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - }; - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(u15, len - end), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]); - return client; - }, else => { - return error.TlsUnexpectedMessage; + return error.TlsBadSignatureScheme; }, } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - else => { - return error.TlsUnexpectedMessage; - }, + }, + .finished => { + if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(tls.ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (handshake_cipher) { + inline else => |*p, tag| c: { + const P = @TypeOf(p.*); + const finished_digest = p.transcript_hash.peek(); + p.transcript_hash.update(wrapped_handshake); + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, handshake)) + return error.TlsDecryptError; + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(tls.HandshakeType.finished), + 0, 0, verify_data.len, // length + } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; + + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + var finished_msg = [_]u8{ + @enumToInt(tls.ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ @as([wrapped_len]u8, undefined); + + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + try stream.writeAll(&both_msgs); + + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + .client_secret = client_secret, + .server_secret = server_secret, + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + }; + const leftover = d.rest(); + var client: Client = .{ + .read_seq = 0, + .write_seq = 0, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(u15, leftover.len), + .received_close_notify = false, + .application_cipher = app_cipher, + .partially_read_buffer = undefined, + }; + mem.copy(u8, &client.partially_read_buffer, leftover); + return client; + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + if (ctd.eof()) break; } }, else => { return error.TlsUnexpectedMessage; }, } - i = end; } - - return error.TlsHandshakeFailure; } pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { @@ -638,12 +593,12 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { inline else => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); - const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1; + const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; while (true) { const encrypted_content_len = @intCast(u16, @min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), ciphertext_buf.len - - tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, + tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1, )); if (encrypted_content_len == 0) break :l overhead_len; @@ -829,7 +784,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // Cleartext capacity of output buffer, in records, rounded up. const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); + const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); const actual_read_len = try stream.readv(ask_iovecs); @@ -860,13 +815,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove continue; } - if (in + tls.ciphertext_record_header_len > frag.len) { + if (in + tls.record_header_len > frag.len) { if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); const first = frag[in..]; - if (frag1.len < tls.ciphertext_record_header_len) + if (frag1.len < tls.record_header_len) return finishRead2(c, first, frag1, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. @@ -875,7 +830,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const record_len = (record_len_byte_0 << 8) | record_len_byte_1; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - const full_record_len = record_len + tls.ciphertext_record_header_len; + const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); @@ -898,14 +853,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const end = in + record_len; if (end > frag.len) { // We need the record header on the next iteration of the loop. - in -= tls.ciphertext_record_header_len; + in -= tls.record_header_len; if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const first = frag[in..]; - const full_record_len = record_len + tls.ciphertext_record_header_len; + const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); @@ -919,7 +874,12 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove } switch (ct) { .alert => { - @panic("TODO handle an alert here"); + if (in + 2 > frag.len) return error.TlsDecodeError; + const level = @intToEnum(tls.AlertLevel, frag[in]); + const desc = @intToEnum(tls.AlertDescription, frag[in + 1]); + _ = level; + _ = desc; + return error.TlsAlert; }, .application_data => { const cleartext = switch (c.application_cipher) { From 66b07fd67215f2cccf126f271defc5e028227d7e Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Dec 2022 20:05:17 -0700 Subject: [PATCH 53/59] std.crypto.Certificate: bump RSA needed memory --- lib/std/crypto/Certificate.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 93383f3615..1bd7446fb6 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -504,7 +504,7 @@ fn verifyRsa( var msg_hashed: [Hash.digest_length]u8 = undefined; Hash.hash(message, &msg_hashed, .{}); - var rsa_mem_buf: [512 * 32]u8 = undefined; + var rsa_mem_buf: [512 * 64]u8 = undefined; var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); const ally = fba.allocator(); From b3c8c383bbba05d9a9d28073e8e8ceba3f089ae8 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Dec 2022 20:06:13 -0700 Subject: [PATCH 54/59] std.os: add missing handling of ECONNRESET in readv --- lib/std/os.zig | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/std/os.zig b/lib/std/os.zig index a47e3d0068..143aa77bff 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -767,6 +767,7 @@ pub fn readv(fd: fd_t, iov: []const iovec) ReadError!usize { .ISDIR => return error.IsDir, .NOBUFS => return error.SystemResources, .NOMEM => return error.SystemResources, + .CONNRESET => return error.ConnectionResetByPeer, else => |err| return unexpectedErrno(err), } } From 611a1fdd6df81a95a74162a0ebdd5afba94d29d4 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Dec 2022 20:06:42 -0700 Subject: [PATCH 55/59] std.crypto.tls: add API for sending close_notify This commit adds `writeEnd` and `writeAllEnd` in order to send data and also notify the server that there will be no more data written. Unfortunately, it seems most TLS implementations in the wild get this wrong and immediately close the socket when they see a close_notify, rather than only ending the data stream on the application layer. --- lib/std/crypto/tls.zig | 5 + lib/std/crypto/tls/Client.zig | 194 +++++++++++++++++++++++++++------- lib/std/http/Client.zig | 2 +- 3 files changed, 160 insertions(+), 41 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 8ef4d9bfad..7d89da8929 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -47,6 +47,11 @@ pub const hello_retry_request_sequence = [32]u8{ 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, }; +pub const close_notify_alert = [_]u8{ + @enumToInt(AlertLevel.warning), + @enumToInt(AlertDescription.close_notify), +}; + pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, tls_1_3 = 0x0304, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index bca05a3ffd..df59932d4a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -37,8 +37,54 @@ application_cipher: tls.ApplicationCipher, /// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, +/// This is an example of the type that is needed by the read and write +/// functions. It can have any fields but it must at least have these +/// functions. +/// +/// Note that `std.net.Stream` conforms to this interface. +/// +/// This declaration serves as documentation only. +pub const StreamInterface = struct { + /// Can be any error set. + pub const ReadError = error{}; + + /// Returns the number of bytes read. The number read may be less than the + /// buffer space provided. End-of-stream is indicated by a return value of 0. + /// + /// The `iovecs` parameter is mutable because so that function may to + /// mutate the fields in order to handle partial reads from the underlying + /// stream layer. + pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Can be any error set. + pub const WriteError = error{}; + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided. A short read does not indicate end-of-stream. + pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// Returns the number of bytes read, which may be less than the buffer + /// space provided, indicating end-of-stream. + /// The `iovecs` parameter is mutable in case this function needs to mutate + /// the fields in order to handle partial writes from the underlying layer. + pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { + // This can be implemented in terms of writev, or specialized if desired. + _ = .{ this, iovecs }; + @panic("unimplemented"); + } +}; + +/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which +/// must conform to `StreamInterface`. +/// /// `host` is only borrowed during this function call. -pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client { +pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client { const host_len = @intCast(u16, host.len); var random_buffer: [128]u8 = undefined; @@ -579,31 +625,115 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) } } -pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { + return writeEnd(c, stream, bytes, false); +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(stream, bytes[index..]); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.writeEnd(stream, bytes[index..], end); + } +} + +/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. +/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// If `end` is true, then this function additionally sends a `close_notify` alert, +/// which is necessary for the server to distinguish between a properly finished +/// TLS session, or a truncation attack. +pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; + var iovecs_buf: [6]std.os.iovec_const = undefined; + var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); + if (end) { + prepared.iovec_end += prepareCiphertextRecord( + c, + iovecs_buf[prepared.iovec_end..], + ciphertext_buf[prepared.ciphertext_end..], + &tls.close_notify_alert, + .alert, + ).iovec_end; + } + + const iovec_end = prepared.iovec_end; + const overhead_len = prepared.overhead_len; + + // Ideally we would call writev exactly once here, however, we must ensure + // that we don't return with a record partially written. + var i: usize = 0; + var total_amt: usize = 0; + while (true) { + var amt = try stream.writev(iovecs_buf[i..iovec_end]); + while (amt >= iovecs_buf[i].iov_len) { + const encrypted_amt = iovecs_buf[i].iov_len; + total_amt += encrypted_amt - overhead_len; + amt -= encrypted_amt; + i += 1; + // Rely on the property that iovecs delineate records, meaning that + // if amt equals zero here, we have fortunately found ourselves + // with a short read that aligns at the record boundary. + if (i >= iovec_end) return total_amt; + // We also cannot return on a vector boundary if the final close_notify is + // not sent; otherwise the caller would not know to retry the call. + if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; + } + iovecs_buf[i].iov_base += amt; + iovecs_buf[i].iov_len -= amt; + } +} + +fn prepareCiphertextRecord( + c: *Client, + iovecs: []std.os.iovec_const, + ciphertext_buf: []u8, + bytes: []const u8, + inner_content_type: tls.ContentType, +) struct { + iovec_end: usize, + ciphertext_end: usize, + /// How many bytes are taken up by overhead per record. + overhead_len: usize, +} { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var iovecs_buf: [5]std.os.iovec_const = undefined; var ciphertext_end: usize = 0; var iovec_end: usize = 0; var bytes_i: usize = 0; - // How many bytes are taken up by overhead per record. - const overhead_len: usize = switch (c.application_cipher) { - inline else => |*p| l: { + switch (c.application_cipher) { + inline else => |*p| { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len = @intCast(u16, @min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), - ciphertext_buf.len - - tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1, + ciphertext_buf.len - close_notify_alert_reserved - + overhead_len - ciphertext_end, )); - if (encrypted_content_len == 0) break :l overhead_len; + if (encrypted_content_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data); + cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type); bytes_i += encrypted_content_len; const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; @@ -626,40 +756,13 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs_buf[iovec_end] = .{ + iovecs[iovec_end] = .{ .iov_base = record.ptr, .iov_len = record.len, }; iovec_end += 1; } }, - }; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].iov_len) { - const encrypted_amt = iovecs_buf[i].iov_len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end or amt == 0) return total_amt; - } - iovecs_buf[i].iov_base += amt; - iovecs_buf[i].iov_len -= amt; - } -} - -pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); } } @@ -669,6 +772,7 @@ pub fn eof(c: Client) bool { c.partial_ciphertext_idx >= c.partial_ciphertext_end; } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the buffer has at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. @@ -678,10 +782,12 @@ pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize return readvAtLeast(c, stream, &iovecs, len); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, 1); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read. If the number read is smaller than /// `buffer.len`, it means the stream reached the end. Reaching the end of the /// stream is not an error condition. @@ -689,6 +795,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, buffer.len); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read. If the number read is less than the space /// provided it means the stream reached the end. Reaching the end of the /// stream is not an error condition. @@ -698,6 +805,7 @@ pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { return readvAtLeast(c, stream, iovecs); } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the iovecs have at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. @@ -722,6 +830,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us } } +/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. /// Returns number of bytes that have been read, populated inside `iovecs`. A /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` /// for the end of stream. The `eof()` may be true after any call to @@ -729,7 +838,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us /// function asserts that `eof()` is `false`. /// See `readv` for a higher level function that has the same, familiar API as /// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize { +pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { var vp: VecPut = .{ .iovecs = iovecs }; // Give away the buffered cleartext we have, if any. @@ -905,7 +1014,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove break :c cleartext; }, }; - c.read_seq += 1; + + c.read_seq = try std.math.add(u64, c.read_seq, 1); const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { @@ -1196,3 +1306,7 @@ const cipher_suites = enum_array(tls.CipherSuite, &.{ .AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, }); + +test { + _ = StreamInterface; +} diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index efae62680d..33df40866a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -47,7 +47,7 @@ pub const Request = struct { try req.stream.writeAll(req.headers.items); }, .https => { - try req.tls_client.writeAll(req.stream, req.headers.items); + try req.tls_client.writeAllEnd(req.stream, req.headers.items, true); }, } } From 3127bd79fb8106c3fdf487486c686a0551c5efa8 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Dec 2022 20:26:49 -0700 Subject: [PATCH 56/59] std.http.Client: don't send TLS close_notify It appears to be implemented incorrectly in the wild and causes the read connection to be closed even though that is a direct violation of RFC 8446 Section 6.1. The writeEnd function variants are still there, ready to be used. --- lib/std/http/Client.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 33df40866a..efae62680d 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -47,7 +47,7 @@ pub const Request = struct { try req.stream.writeAll(req.headers.items); }, .https => { - try req.tls_client.writeAllEnd(req.stream, req.headers.items, true); + try req.tls_client.writeAll(req.stream, req.headers.items); }, } } From 97acdeeca86af4111972aeb57dd9c792e7f1f419 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 1 Jan 2023 17:52:28 -0700 Subject: [PATCH 57/59] std.crypto.tls: verify via Subject Alt Name Previously, the code only checked Common Name, leading to unable to validate valid certificates which relied on the subject_alt_name extension for host name verification. This commit also adds rsa_pss_rsae_* back to the signature algorithms list in the ClientHello. --- lib/std/crypto/Certificate.zig | 143 ++++++++++++++++++++++++++++++++- lib/std/crypto/tls/Client.zig | 27 +------ 2 files changed, 145 insertions(+), 25 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 1bd7446fb6..f81676a977 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -81,6 +81,43 @@ pub const NamedCurve = enum { }); }; +pub const ExtensionId = enum { + subject_key_identifier, + key_usage, + private_key_usage_period, + subject_alt_name, + issuer_alt_name, + basic_constraints, + crl_number, + certificate_policies, + authority_key_identifier, + + pub const map = std.ComptimeStringMap(ExtensionId, .{ + .{ &[_]u8{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, + .{ &[_]u8{ 0x55, 0x1D, 0x0F }, .key_usage }, + .{ &[_]u8{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, + .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, + .{ &[_]u8{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, + .{ &[_]u8{ 0x55, 0x1D, 0x13 }, .basic_constraints }, + .{ &[_]u8{ 0x55, 0x1D, 0x14 }, .crl_number }, + .{ &[_]u8{ 0x55, 0x1D, 0x20 }, .certificate_policies }, + .{ &[_]u8{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, + }); +}; + +pub const GeneralNameTag = enum(u5) { + otherName = 0, + rfc822Name = 1, + dNSName = 2, + x400Address = 3, + directoryName = 4, + ediPartyName = 5, + uniformResourceIdentifier = 6, + iPAddress = 7, + registeredID = 8, + _, +}; + pub const Parsed = struct { certificate: Certificate, issuer_slice: Slice, @@ -91,6 +128,7 @@ pub const Parsed = struct { pub_key_algo: PubKeyAlgo, pub_key_slice: Slice, message_slice: Slice, + subject_alt_name_slice: Slice, validity: Validity, pub const PubKeyAlgo = union(AlgorithmCategory) { @@ -137,6 +175,10 @@ pub const Parsed = struct { return p.slice(p.message_slice); } + pub fn subjectAltName(p: Parsed) []const u8 { + return p.slice(p.subject_alt_name_slice); + } + pub const VerifyError = error{ CertificateIssuerMismatch, CertificateNotYetValid, @@ -152,8 +194,10 @@ pub const Parsed = struct { CertificateSignatureNamedCurveUnsupported, }; - /// This function checks the time validity for the subject only. Checking - /// the issuer's time validity is out of scope. + /// This function verifies: + /// * That the subject's issuer is indeed the provided issuer. + /// * The time validity of the subject. + /// * The signature. pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) VerifyError!void { // Check that the subject's issuer name matches the issuer's // subject name. @@ -194,6 +238,62 @@ pub const Parsed = struct { ), } } + + pub const VerifyHostNameError = error{ + CertificateHostMismatch, + CertificateFieldHasInvalidLength, + }; + + pub fn verifyHostName(parsed_subject: Parsed, host_name: []const u8) VerifyHostNameError!void { + // If the Subject Alternative Names extension is present, this is + // what to check. Otherwise, only the common name is checked. + const subject_alt_name = parsed_subject.subjectAltName(); + if (subject_alt_name.len == 0) { + if (checkHostName(host_name, parsed_subject.commonName())) { + return; + } else { + return error.CertificateHostMismatch; + } + } + + const general_names = try der.Element.parse(subject_alt_name, 0); + var name_i = general_names.slice.start; + while (name_i < general_names.slice.end) { + const general_name = try der.Element.parse(subject_alt_name, name_i); + name_i = general_name.slice.end; + switch (@intToEnum(GeneralNameTag, @enumToInt(general_name.identifier.tag))) { + .dNSName => { + const dns_name = subject_alt_name[general_name.slice.start..general_name.slice.end]; + if (checkHostName(host_name, dns_name)) return; + }, + else => {}, + } + } + + return error.CertificateHostMismatch; + } + + fn checkHostName(host_name: []const u8, dns_name: []const u8) bool { + if (mem.eql(u8, dns_name, host_name)) { + return true; // exact match + } + + if (mem.startsWith(u8, dns_name, "*.")) { + // wildcard certificate, matches any subdomain + // TODO: I think wildcards are not supposed to match any prefix but + // only match exactly one subdomain. + if (mem.endsWith(u8, host_name, dns_name[1..])) { + // The host_name has a subdomain, but the important part matches. + return true; + } + if (mem.eql(u8, dns_name[2..], host_name)) { + // The host_name has no subdomain and matches exactly. + return true; + } + } + + return false; + } }; pub fn parse(cert: Certificate) !Parsed { @@ -268,6 +368,39 @@ pub fn parse(cert: Certificate) !Parsed { const sig_elem = try der.Element.parse(cert_bytes, sig_algo.slice.end); const signature = try parseBitString(cert, sig_elem); + // Extensions + var subject_alt_name_slice = der.Element.Slice.empty; + ext: { + if (pub_key_info.slice.end >= tbs_certificate.slice.end) + break :ext; + + const outer_extensions = try der.Element.parse(cert_bytes, pub_key_info.slice.end); + if (outer_extensions.identifier.tag != .bitstring) + break :ext; + + const extensions = try der.Element.parse(cert_bytes, outer_extensions.slice.start); + + var ext_i = extensions.slice.start; + while (ext_i < extensions.slice.end) { + const extension = try der.Element.parse(cert_bytes, ext_i); + ext_i = extension.slice.end; + const oid_elem = try der.Element.parse(cert_bytes, extension.slice.start); + const ext_id = parseExtensionId(cert_bytes, oid_elem) catch |err| switch (err) { + error.CertificateHasUnrecognizedObjectId => continue, + else => |e| return e, + }; + const critical_elem = try der.Element.parse(cert_bytes, oid_elem.slice.end); + const ext_bytes_elem = if (critical_elem.identifier.tag != .boolean) + critical_elem + else + try der.Element.parse(cert_bytes, critical_elem.slice.end); + switch (ext_id) { + .subject_alt_name => subject_alt_name_slice = ext_bytes_elem.slice, + else => continue, + } + } + } + return .{ .certificate = cert, .common_name_slice = common_name, @@ -282,6 +415,7 @@ pub fn parse(cert: Certificate) !Parsed { .not_before = not_before_utc, .not_after = not_after_utc, }, + .subject_alt_name_slice = subject_alt_name_slice, }; } @@ -444,6 +578,10 @@ pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve { return parseEnum(NamedCurve, bytes, element); } +pub fn parseExtensionId(bytes: []const u8, element: der.Element) !ExtensionId { + return parseEnum(ExtensionId, bytes, element); +} + fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E { if (element.identifier.tag != .object_identifier) return error.CertificateFieldHasWrongDataType; @@ -604,6 +742,7 @@ pub const der = struct { boolean = 1, integer = 2, bitstring = 3, + octetstring = 4, null = 5, object_identifier = 6, sequence = 16, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index df59932d4a..0251350dad 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -111,6 +111,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !C .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .ecdsa_secp521r1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, .rsa_pkcs1_sha256, .rsa_pkcs1_sha384, .rsa_pkcs1_sha512, @@ -444,9 +447,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !C const subject = try subject_cert.parse(); if (cert_index == 0) { // Verify the host on the first certificate. - if (!hostMatchesCommonName(host, subject.commonName())) { - return error.TlsCertificateHostMismatch; - } + try subject.verifyHostName(host); // Keep track of the public key for the // certificate_verify message later. @@ -1162,26 +1163,6 @@ fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { } } -fn hostMatchesCommonName(host: []const u8, common_name: []const u8) bool { - if (mem.eql(u8, common_name, host)) { - return true; // exact match - } - - if (mem.startsWith(u8, common_name, "*.")) { - // wildcard certificate, matches any subdomain - if (mem.endsWith(u8, host, common_name[1..])) { - // The host has a subdomain, but the important part matches. - return true; - } - if (mem.eql(u8, common_name[2..], host)) { - // The host has no subdomain and matches exactly. - return true; - } - } - - return false; -} - const builtin = @import("builtin"); const native_endian = builtin.cpu.arch.endian(); From 9ca6d673457723548d8fe721b44499817ead1d2d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 2 Jan 2023 13:18:56 -0700 Subject: [PATCH 58/59] std.crypto.tls.Certificate: make the current time a parameter --- lib/std/crypto/Certificate.zig | 7 +++---- lib/std/crypto/Certificate/Bundle.zig | 4 ++-- lib/std/crypto/tls/Client.zig | 5 +++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index f81676a977..fe211c6146 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -198,14 +198,13 @@ pub const Parsed = struct { /// * That the subject's issuer is indeed the provided issuer. /// * The time validity of the subject. /// * The signature. - pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed) VerifyError!void { + pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed, now_sec: i64) VerifyError!void { // Check that the subject's issuer name matches the issuer's // subject name. if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) { return error.CertificateIssuerMismatch; } - const now_sec = std.time.timestamp(); if (now_sec < parsed_subject.validity.not_before) return error.CertificateNotYetValid; if (now_sec > parsed_subject.validity.not_after) @@ -419,10 +418,10 @@ pub fn parse(cert: Certificate) !Parsed { }; } -pub fn verify(subject: Certificate, issuer: Certificate) !void { +pub fn verify(subject: Certificate, issuer: Certificate, now_sec: i64) !void { const parsed_subject = try subject.parse(); const parsed_issuer = try issuer.parse(); - return parsed_subject.verify(parsed_issuer); + return parsed_subject.verify(parsed_issuer, now_sec); } pub fn contents(cert: Certificate, elem: der.Element) []const u8 { diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig index b30fa531ec..a1684fda73 100644 --- a/lib/std/crypto/Certificate/Bundle.zig +++ b/lib/std/crypto/Certificate/Bundle.zig @@ -13,7 +13,7 @@ pub const VerifyError = Certificate.Parsed.VerifyError || error{ CertificateIssuerNotFound, }; -pub fn verify(cb: Bundle, subject: Certificate.Parsed) VerifyError!void { +pub fn verify(cb: Bundle, subject: Certificate.Parsed, now_sec: i64) VerifyError!void { const bytes_index = cb.find(subject.issuer()) orelse return error.CertificateIssuerNotFound; const issuer_cert: Certificate = .{ .buffer = cb.bytes.items, @@ -22,7 +22,7 @@ pub fn verify(cb: Bundle, subject: Certificate.Parsed) VerifyError!void { // Every certificate in the bundle is pre-parsed before adding it, ensuring // that parsing will succeed here. const issuer = issuer_cert.parse() catch unreachable; - try subject.verify(issuer); + try subject.verify(issuer, now_sec); } /// The returned bytes become invalid after calling any of the rescan functions diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0251350dad..877ad33bb4 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -351,6 +351,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !C var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; var main_cert_pub_key_buf: [300]u8 = undefined; var main_cert_pub_key_len: u16 = undefined; + const now_sec = std.time.timestamp(); while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); @@ -458,10 +459,10 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !C @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); } else { - try prev_cert.verify(subject); + try prev_cert.verify(subject, now_sec); } - if (ca_bundle.verify(subject)) |_| { + if (ca_bundle.verify(subject, now_sec)) |_| { handshake_state = .trust_chain_established; break :cert; } else |err| switch (err) { From 7178451d6258d9d04fdf03269478948643c39f02 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 2 Jan 2023 18:24:13 -0700 Subject: [PATCH 59/59] std.crypto.tls.Client: make close_notify optional Although RFC 8446 states: > Each party MUST send a "close_notify" alert before closing its write > side of the connection In practice many servers do not do this. Also in practice, truncation attacks are thwarted at the application layer by comparing the amount of bytes received with the amount expected via the HTTP headers. --- lib/std/crypto/tls/Client.zig | 18 ++++++++++++++++-- lib/std/http/Client.zig | 10 ++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 877ad33bb4..44891a1973 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -26,6 +26,14 @@ partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there /// may be data in `partially_read_buffer`. received_close_notify: bool, +/// By default, reaching the end-of-stream when reading from the server will +/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify +/// message has been received. By setting this flag to `true`, instead, the +/// end-of-stream will be forwarded to the application layer above TLS. +/// This makes the application vulnerable to truncation attacks unless the +/// application layer itself verifies that the amount of data received equals +/// the amount of data expected, such as HTTP with the Content-Length header. +allow_truncation_attacks: bool = false, application_cipher: tls.ApplicationCipher, /// The size is enough to contain exactly one TLSCiphertext record. /// This buffer is segmented into four parts: @@ -900,8 +908,14 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); const actual_read_len = try stream.readv(ask_iovecs); if (actual_read_len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + } else { + return error.TlsConnectionTruncated; + } } // There might be more bytes inside `in_stack_buffer` that need to be processed, diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index efae62680d..8a4a771416 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,3 +1,7 @@ +//! This API is a barely-touched, barely-functional http client, just the +//! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear +//! with me and I promise the API will become useful and streamlined. + const std = @import("../std.zig"); const assert = std.debug.assert; const http = std.http; @@ -10,6 +14,9 @@ headers: std.ArrayListUnmanaged(u8) = .{}, active_requests: usize = 0, ca_bundle: std.crypto.Certificate.Bundle = .{}, +/// TODO: emit error.UnexpectedEndOfStream or something like that when the read +/// data does not match the content length. This is necessary since HTTPS disables +/// close_notify protection on underlying TLS streams. pub const Request = struct { client: *Client, stream: net.Stream, @@ -133,6 +140,9 @@ pub fn request(client: *Client, url: Url, options: Request.Options) !Request { .http => {}, .https => { req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host); + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + req.tls_client.allow_truncation_attacks = true; }, }