std: update http.WebSocket to new API

This commit is contained in:
Andrew Kelley 2025-05-27 20:18:20 -07:00
parent da303bdaf1
commit 74c56376ee
9 changed files with 449 additions and 420 deletions

View File

@ -894,11 +894,7 @@ pub fn init(
pub fn reader(c: *Client) Reader {
return .{
.context = c,
.vtable = &.{
.read = read,
.readVec = readVec,
.discard = discard,
},
.vtable = &.{ .read = read },
};
}
@ -1225,24 +1221,6 @@ fn read(context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: Reader.Limit) R
}
}
fn readVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize {
var bw: std.io.BufferedWriter = undefined;
bw.initVec(data);
return read(context, &bw, .countVec(data)) catch |err| switch (err) {
error.WriteFailed => unreachable,
else => |e| return e,
};
}
fn discard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize {
var null_writer: Writer.Null = undefined;
var bw = null_writer.writer().unbuffered();
return read(context, &bw, limit) catch |err| switch (err) {
error.WriteFailed => unreachable,
else => |e| return e,
};
}
fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
c.read_err = err;
return error.ReadFailed;

View File

@ -899,6 +899,15 @@ pub fn writeFileAll(self: File, in_file: File, options: BufferedWriter.WriteFile
};
}
/// Memoizes key information about a file handle such as:
/// * The size from calling stat, or the error that occurred therein.
/// * The current seek position.
/// * The error that occurred when trying to seek.
/// * Whether reading should be done positionally or streaming.
/// * Whether reading should be done via fd-to-fd syscalls (e.g. `sendfile`)
/// versus plain variants (e.g. `read`).
///
/// Fulfills the `std.io.Reader` interface.
pub const Reader = struct {
file: File,
err: ?ReadError = null,
@ -951,14 +960,40 @@ pub const Reader = struct {
};
}
pub fn seekBy(r: *Reader, offset: i64) SeekError!void {
switch (r.mode) {
.positional, .positional_reading => {
r.pos += offset;
},
.streaming, .streaming_reading => {
const seek_err = r.seek_err orelse e: {
if (posix.lseek_CUR(r.file.handle, offset)) |_| {
r.pos += offset;
return;
} else |err| {
r.seek_err = err;
break :e err;
}
};
if (offset < 0) return seek_err;
var remaining = offset;
while (remaining > 0) {
const n = discard(r, .limited(remaining)) catch |err| switch (err) {};
r.pos += n;
remaining -= n;
}
},
}
}
pub fn seekTo(r: *Reader, offset: u64) SeekError!void {
// TODO if the offset is after the current offset, seek by discarding.
if (r.seek_err) |err| return err;
switch (r.mode) {
.positional, .positional_reading => {
r.pos = offset;
},
.streaming, .streaming_reading => {
if (offset >= r.pos) return Reader.seekBy(r, offset - r.pos);
if (r.seek_err) |err| return err;
posix.lseek_SET(r.file.handle, offset) catch |err| {
r.seek_err = err;
return err;

View File

@ -7,7 +7,6 @@ pub const Server = @import("http/Server.zig");
pub const HeadParser = @import("http/HeadParser.zig");
pub const ChunkParser = @import("http/ChunkParser.zig");
pub const HeaderIterator = @import("http/HeaderIterator.zig");
pub const WebSocket = @import("http/WebSocket.zig");
pub const Version = enum {
@"HTTP/1.0",
@ -508,7 +507,7 @@ pub const Reader = struct {
fn contentLengthRead(
ctx: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
limit: std.io.Limit,
) std.io.Reader.RwError!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
const remaining_content_length = &reader.state.body_remaining_content_length;
@ -535,7 +534,7 @@ pub const Reader = struct {
return n;
}
fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
fn contentLengthDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
const remaining_content_length = &reader.state.body_remaining_content_length;
const remaining = remaining_content_length.*;
@ -551,7 +550,7 @@ pub const Reader = struct {
fn chunkedRead(
ctx: ?*anyopaque,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
limit: std.io.Limit,
) std.io.Reader.RwError!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
const chunk_len_ptr = switch (reader.state) {
@ -576,7 +575,7 @@ pub const Reader = struct {
fn chunkedReadEndless(
reader: *Reader,
bw: *std.io.BufferedWriter,
limit: std.io.Reader.Limit,
limit: std.io.Limit,
chunk_len_ptr: *RemainingChunkLen,
) (BodyError || std.io.Reader.RwError)!usize {
const in = reader.in;
@ -712,7 +711,7 @@ pub const Reader = struct {
return amt_read;
}
fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
fn chunkedDiscard(ctx: ?*anyopaque, limit: std.io.Limit) std.io.Reader.Error!usize {
const reader: *Reader = @alignCast(@ptrCast(ctx));
const chunk_len_ptr = switch (reader.state) {
.ready => return error.EndOfStream,
@ -734,7 +733,7 @@ pub const Reader = struct {
fn chunkedDiscardEndless(
reader: *Reader,
limit: std.io.Reader.Limit,
limit: std.io.Limit,
chunk_len_ptr: *RemainingChunkLen,
) (BodyError || std.io.Reader.Error)!usize {
const in = reader.in;
@ -812,8 +811,8 @@ pub const Decompressor = struct {
buffered_reader: std.io.BufferedReader,
pub const Compression = union(enum) {
deflate: std.compress.zlib.Decompressor,
gzip: std.compress.gzip.Decompressor,
deflate: std.compress.flate.Decompressor,
gzip: std.compress.flate.Decompressor,
zstd: std.compress.zstd.Decompress,
none: void,
};
@ -1238,7 +1237,6 @@ test {
_ = Method;
_ = ChunkParser;
_ = HeadParser;
_ = WebSocket;
if (builtin.os.tag != .wasi) {
_ = Client;

View File

@ -57,6 +57,12 @@ pub const Request = struct {
/// Pointers in this struct are invalidated with the next call to
/// `receiveHead`.
head: Head,
respond_err: ?RespondError,
pub const RespondError = error{
/// The request contained an `expect` header with an unrecognized value.
HttpExpectationFailed,
};
pub const Head = struct {
method: http.Method,
@ -306,7 +312,7 @@ pub const Request = struct {
request: *Request,
content: []const u8,
options: RespondOptions,
) std.io.Writer.Error!void {
) ExpectContinueError!void {
try respondUnflushed(request, content, options);
try request.server.out.flush();
}
@ -315,7 +321,7 @@ pub const Request = struct {
request: *Request,
content: []const u8,
options: RespondOptions,
) std.io.Writer.Error!void {
) ExpectContinueError!void {
assert(options.status != .@"continue");
if (std.debug.runtime_safety) {
for (options.extra_headers) |header| {
@ -325,6 +331,7 @@ pub const Request = struct {
assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null);
}
}
try writeExpectContinue(request);
const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none;
const server_keep_alive = !transfer_encoding_none and options.keep_alive;
@ -333,17 +340,6 @@ pub const Request = struct {
const phrase = options.reason orelse options.status.phrase() orelse "";
const out = request.server.out;
if (request.head.expect != null) {
// reader() and hence discardBody() above sets expect to null if it
// is handled. So the fact that it is not null here means unhandled.
var vecs: [3][]const u8 = .{
"HTTP/1.1 417 Expectation Failed\r\n",
if (keep_alive) "" else "connection: close\r\n",
"content-length: 0\r\n\r\n",
};
try out.writeVecAll(&vecs);
return;
}
try out.print("{s} {d} {s}\r\n", .{
@tagName(options.version), @intFromEnum(options.status), phrase,
});
@ -402,6 +398,7 @@ pub const Request = struct {
///
/// Asserts status is not `continue`.
pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) std.io.Writer.Error!http.BodyWriter {
try writeExpectContinue(request);
const o = options.respond_options;
assert(o.status != .@"continue");
const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none;
@ -410,43 +407,34 @@ pub const Request = struct {
const phrase = o.reason orelse o.status.phrase() orelse "";
const out = request.server.out;
const elide_body = if (request.head.expect != null) eb: {
// reader() and hence discardBody() above sets expect to null if it
// is handled. So the fact that it is not null here means unhandled.
try out.writeAll("HTTP/1.1 417 Expectation Failed\r\n");
if (!keep_alive) try out.writeAll("connection: close\r\n");
try out.writeAll("content-length: 0\r\n\r\n");
break :eb true;
} else eb: {
try out.print("{s} {d} {s}\r\n", .{
@tagName(o.version), @intFromEnum(o.status), phrase,
});
try out.print("{s} {d} {s}\r\n", .{
@tagName(o.version), @intFromEnum(o.status), phrase,
});
switch (o.version) {
.@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
.@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
}
switch (o.version) {
.@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
.@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
}
if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
.chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
.none => {},
} else if (options.content_length) |len| {
try out.print("content-length: {d}\r\n", .{len});
} else {
try out.writeAll("transfer-encoding: chunked\r\n");
}
for (o.extra_headers) |header| {
assert(header.name.len != 0);
try out.writeAll(header.name);
try out.writeAll(": ");
try out.writeAll(header.value);
try out.writeAll("\r\n");
}
if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
.chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
.none => {},
} else if (options.content_length) |len| {
try out.print("content-length: {d}\r\n", .{len});
} else {
try out.writeAll("transfer-encoding: chunked\r\n");
}
for (o.extra_headers) |header| {
assert(header.name.len != 0);
try out.writeAll(header.name);
try out.writeAll(": ");
try out.writeAll(header.value);
try out.writeAll("\r\n");
break :eb request.head.method == .HEAD;
};
}
try out.writeAll("\r\n");
const elide_body = request.head.method == .HEAD;
return .{
.http_protocol_output = request.server.out,
@ -460,36 +448,126 @@ pub const Request = struct {
};
}
pub const ReaderError = error{
/// Failed to write "100-continue" to the stream.
WriteFailed,
/// Failed to write "100-continue" to the stream because it ended.
EndOfStream,
/// The client sent an expect HTTP header value other than
/// "100-continue".
HttpExpectationFailed,
pub const UpgradeRequest = union(enum) {
websocket: ?[]const u8,
other: []const u8,
none,
};
pub fn upgradeRequested(request: *const Request) UpgradeRequest {
switch (request.head.version) {
.@"HTTP/1.0" => return null,
.@"HTTP/1.1" => if (request.head.method != .GET) return null,
}
var sec_websocket_key: ?[]const u8 = null;
var upgrade_name: ?[]const u8 = null;
var it = request.iterateHeaders();
while (it.next()) |header| {
if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
sec_websocket_key = header.value;
} else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
upgrade_name = header.value;
}
}
const name = upgrade_name orelse return .none;
if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key };
return .{ .other = name };
}
pub const WebSocketOptions = struct {
/// The value from `UpgradeRequest.websocket` (sec-websocket-key header value).
key: []const u8,
reason: ?[]const u8 = null,
extra_headers: []const http.Header = &.{},
};
/// The header is not guaranteed to be sent until `WebSocket.flush` is
/// called on the returned struct.
pub fn respondWebSocket(request: *Request, options: WebSocketOptions) std.io.Writer.Error!WebSocket {
if (request.head.expect != null) return error.HttpExpectationFailed;
const out = request.server.out;
const version: http.Version = .@"HTTP/1.1";
const status: http.Status = .switching_protocols;
const phrase = options.reason orelse status.phrase() orelse "";
assert(request.head.version == version);
assert(request.head.method == .GET);
var sha1 = std.crypto.hash.Sha1.init(.{});
sha1.update(options.key);
sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
sha1.final(&digest);
try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase });
try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: ");
const base64_digest = try out.writableArray(28);
assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
out.advance(base64_digest.len);
try out.writeAll("\r\n");
for (options.extra_headers) |header| {
assert(header.name.len != 0);
try out.writeAll(header.name);
try out.writeAll(": ");
try out.writeAll(header.value);
try out.writeAll("\r\n");
}
try out.writeAll("\r\n");
return .{
.input = request.server.reader.in,
.output = request.server.out,
.key = options.key,
};
}
/// In the case that the request contains "expect: 100-continue", this
/// function writes the continuation header, which means it can fail with a
/// write error. After sending the continuation header, it sets the
/// request's expect field to `null`.
///
/// Asserts that this function is only called once.
pub fn reader(request: *Request) ReaderError!std.io.Reader {
///
/// See `readerExpectNone` for an infallible alternative that cannot write
/// to the server output stream.
pub fn readerExpectContinue(request: *Request) ExpectContinueError!std.io.Reader {
const flush = request.head.expect != null;
try writeExpectContinue(request);
if (flush) try request.server.out.flush();
return readerExpectNone(request);
}
/// Asserts the expect header is `null`. The caller must handle the
/// expectation manually and then set the value to `null` prior to calling
/// this function.
///
/// Asserts that this function is only called once.
pub fn readerExpectNone(request: *Request) std.io.Reader {
assert(request.server.reader.state == .received_head);
if (request.head.expect) |expect| {
if (mem.eql(u8, expect, "100-continue")) {
try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n");
request.head.expect = null;
} else {
return error.HttpExpectationFailed;
}
}
assert(request.head.expect == null);
if (!request.head.method.requestHasBody()) return .ending;
return request.server.reader.bodyReader(request.head.transfer_encoding, request.head.content_length);
}
pub const ExpectContinueError = error{
/// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream.
WriteFailed,
/// The client sent an expect HTTP header value other than
/// "100-continue".
HttpExpectationFailed,
};
pub fn writeExpectContinue(request: *Request) ExpectContinueError!void {
const expect = request.head.expect orelse return;
if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed;
try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n");
request.head.expect = null;
}
/// Returns whether the connection should remain persistent.
///
/// If it would fail, it instead sets the Server state to receiving body
@ -528,3 +606,150 @@ pub const Request = struct {
return false;
}
};
/// See https://tools.ietf.org/html/rfc6455
pub const WebSocket = struct {
key: []const u8,
input: *std.io.BufferedReader,
output: *std.io.BufferedWriter,
pub const Header0 = packed struct(u8) {
opcode: Opcode,
rsv3: u1 = 0,
rsv2: u1 = 0,
rsv1: u1 = 0,
fin: bool,
};
pub const Header1 = packed struct(u8) {
payload_len: enum(u7) {
len16 = 126,
len64 = 127,
_,
},
mask: bool,
};
pub const Opcode = enum(u4) {
continuation = 0,
text = 1,
binary = 2,
connection_close = 8,
ping = 9,
/// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
/// heartbeat. A response to an unsolicited Pong frame is not expected."
pong = 10,
_,
};
pub const ReadSmallTextMessageError = error{
ConnectionClose,
UnexpectedOpCode,
MessageTooBig,
MissingMaskBit,
};
pub const SmallMessage = struct {
/// Can be text, binary, or ping.
opcode: Opcode,
data: []u8,
};
/// Reads the next message from the WebSocket stream, failing if the
/// message does not fit into the input buffer. The returned memory points
/// into the input buffer and is invalidated on the next read.
pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
const in = ws.input;
while (true) {
const h0 = in.takeStruct(Header0);
const h1 = in.takeStruct(Header1);
switch (h0.opcode) {
.text, .binary, .pong, .ping => {},
.connection_close => return error.ConnectionClose,
.continuation => return error.UnexpectedOpCode,
_ => return error.UnexpectedOpCode,
}
if (!h0.fin) return error.MessageTooBig;
if (!h1.mask) return error.MissingMaskBit;
const len: usize = switch (h1.payload_len) {
.len16 => try in.takeInt(u16, .big),
.len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig,
else => @intFromEnum(h1.payload_len),
};
if (len > in.buffer.len) return error.MessageTooBig;
const mask: u32 = @bitCast((try in.takeArray(4)).*);
const payload = try in.take(len);
// Skip pongs.
if (h0.opcode == .pong) continue;
// The last item may contain a partial word of unused data.
const floored_len = (payload.len / 4) * 4;
const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]);
for (u32_payload) |*elem| elem.* ^= mask;
const mask_bytes: []const u8 = @ptrCast(&mask);
for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m|
leftover.* ^= m;
return .{
.opcode = h0.opcode,
.data = payload,
};
}
}
pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) std.io.Writer.Error!void {
try writeMessageVecUnflushed(ws, &.{data}, op);
try ws.output.flush();
}
pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) std.io.Writer.Error!void {
try writeMessageVecUnflushed(ws, &.{data}, op);
}
pub fn writeMessageVec(ws: *WebSocket, data: []const []const u8, op: Opcode) std.io.Writer.Error!void {
try writeMessageVecUnflushed(ws, data, op);
try ws.output.flush();
}
pub fn writeMessageVecUnflushed(ws: *WebSocket, data: []const []const u8, op: Opcode) std.io.Writer.Error!void {
const total_len = l: {
var total_len: u64 = 0;
for (data) |iovec| total_len += iovec.len;
break :l total_len;
};
const out = ws.output;
try out.writeStruct(@as(Header0, .{
.opcode = op,
.fin = true,
}));
switch (total_len) {
0...125 => try out.writeStruct(@as(Header1, .{
.payload_len = @enumFromInt(total_len),
.mask = false,
})),
126...0xffff => {
try out.writeStruct(@as(Header1, .{
.payload_len = .len16,
.mask = false,
}));
try out.writeInt(u16, @intCast(total_len), .big);
},
else => {
try out.writeStruct(@as(Header1, .{
.payload_len = .len64,
.mask = false,
}));
try out.writeInt(u64, total_len, .big);
},
}
try out.writeVecAll(data);
}
pub fn flush(ws: *WebSocket) std.io.Writer.Error!void {
try ws.output.flush();
}
};

View File

@ -1,243 +0,0 @@
//! See https://tools.ietf.org/html/rfc6455
const builtin = @import("builtin");
const std = @import("std");
const WebSocket = @This();
const assert = std.debug.assert;
const native_endian = builtin.cpu.arch.endian();
key: []const u8,
request: *std.http.Server.Request,
recv_fifo: std.fifo.LinearFifo(u8, .Slice),
reader: std.io.BufferedReader,
body_writer: std.http.BodyWriter,
/// Number of bytes that have been peeked but not discarded yet.
outstanding_len: usize,
pub const InitError = error{WebSocketUpgradeMissingKey} ||
std.http.Server.Request.ReaderError;
pub fn init(
ws: *WebSocket,
request: *std.http.Server.Request,
recv_buffer: []align(4) u8,
) InitError!bool {
switch (request.head.version) {
.@"HTTP/1.0" => return false,
.@"HTTP/1.1" => if (request.head.method != .GET) return false,
}
var sec_websocket_key: ?[]const u8 = null;
var upgrade_websocket: bool = false;
var it = request.iterateHeaders();
while (it.next()) |header| {
if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
sec_websocket_key = header.value;
} else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
if (!std.ascii.eqlIgnoreCase(header.value, "websocket"))
return false;
upgrade_websocket = true;
}
}
if (!upgrade_websocket)
return false;
const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey;
var sha1 = std.crypto.hash.Sha1.init(.{});
sha1.update(key);
sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
sha1.final(&digest);
var base64_digest: [28]u8 = undefined;
assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
request.head.content_length = std.math.maxInt(u64);
ws.* = .{
.key = key,
.recv_fifo = .init(recv_buffer),
.reader = (try request.reader()).unbuffered(),
.body_writer = try request.respondStreaming(.{
.respond_options = .{
.status = .switching_protocols,
.extra_headers = &.{
.{ .name = "upgrade", .value = "websocket" },
.{ .name = "connection", .value = "upgrade" },
.{ .name = "sec-websocket-accept", .value = &base64_digest },
},
.transfer_encoding = .none,
},
}),
.request = request,
.outstanding_len = 0,
};
return true;
}
pub const Header0 = packed struct(u8) {
opcode: Opcode,
rsv3: u1 = 0,
rsv2: u1 = 0,
rsv1: u1 = 0,
fin: bool,
};
pub const Header1 = packed struct(u8) {
payload_len: enum(u7) {
len16 = 126,
len64 = 127,
_,
},
mask: bool,
};
pub const Opcode = enum(u4) {
continuation = 0,
text = 1,
binary = 2,
connection_close = 8,
ping = 9,
/// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
/// heartbeat. A response to an unsolicited Pong frame is not expected."
pong = 10,
_,
};
pub const ReadSmallTextMessageError = error{
ConnectionClose,
UnexpectedOpCode,
MessageTooBig,
MissingMaskBit,
} || RecvError;
pub const SmallMessage = struct {
/// Can be text, binary, or ping.
opcode: Opcode,
data: []u8,
};
/// Reads the next message from the WebSocket stream, failing if the message does not fit
/// into `recv_buffer`.
pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
while (true) {
const header_bytes = (try recv(ws, 2))[0..2];
const h0: Header0 = @bitCast(header_bytes[0]);
const h1: Header1 = @bitCast(header_bytes[1]);
switch (h0.opcode) {
.text, .binary, .pong, .ping => {},
.connection_close => return error.ConnectionClose,
.continuation => return error.UnexpectedOpCode,
_ => return error.UnexpectedOpCode,
}
if (!h0.fin) return error.MessageTooBig;
if (!h1.mask) return error.MissingMaskBit;
const len: usize = switch (h1.payload_len) {
.len16 => try recvReadInt(ws, u16),
.len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
else => @intFromEnum(h1.payload_len),
};
if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
const payload = try recv(ws, len);
// Skip pongs.
if (h0.opcode == .pong) continue;
// The last item may contain a partial word of unused data.
const floored_len = (payload.len / 4) * 4;
const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len]));
for (u32_payload) |*elem| elem.* ^= mask;
const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len];
for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m;
return .{
.opcode = h0.opcode,
.data = payload,
};
}
}
const RecvError = std.http.Server.Request.ReadError || error{EndOfStream};
fn recv(ws: *WebSocket, len: usize) RecvError![]u8 {
ws.recv_fifo.discard(ws.outstanding_len);
assert(len <= ws.recv_fifo.buf.len);
if (len > ws.recv_fifo.count) {
const small_buf = ws.recv_fifo.writableSlice(0);
const needed = len - ws.recv_fifo.count;
const buf = if (small_buf.len >= needed) small_buf else b: {
ws.recv_fifo.realign();
break :b ws.recv_fifo.writableSlice(0);
};
const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed)));
if (n < needed) return error.EndOfStream;
ws.recv_fifo.update(n);
}
ws.outstanding_len = len;
// TODO: improve the std lib API so this cast isn't necessary.
return @constCast(ws.recv_fifo.readableSliceOfLen(len));
}
fn recvReadInt(ws: *WebSocket, comptime I: type) !I {
const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*);
return switch (native_endian) {
.little => @byteSwap(unswapped),
.big => unswapped,
};
}
pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) std.io.Writer.Error!void {
const iovecs: [1]std.posix.iovec_const = .{
.{ .base = message.ptr, .len = message.len },
};
return writeMessagev(ws, &iovecs, opcode);
}
pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) std.io.Writer.Error!void {
const total_len = l: {
var total_len: u64 = 0;
for (message) |iovec| total_len += iovec.len;
break :l total_len;
};
var header_buf: [2 + 8]u8 = undefined;
header_buf[0] = @bitCast(@as(Header0, .{
.opcode = opcode,
.fin = true,
}));
const header = switch (total_len) {
0...125 => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = @enumFromInt(total_len),
.mask = false,
}));
break :blk header_buf[0..2];
},
126...0xffff => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = .len16,
.mask = false,
}));
std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big);
break :blk header_buf[0..4];
},
else => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = .len64,
.mask = false,
}));
std.mem.writeInt(u64, header_buf[2..10], total_len, .big);
break :blk header_buf[0..10];
},
};
var bw = ws.body_writer.writer().unbuffered();
try bw.writeAll(header);
for (message) |iovec| try bw.writeAll(iovec.base[0..iovec.len]);
try bw.flush();
}

View File

@ -6,8 +6,10 @@ const assert = std.debug.assert;
const testing = std.testing;
const BufferedWriter = std.io.BufferedWriter;
const Reader = std.io.Reader;
const Writer = std.io.Writer;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayListUnmanaged;
const Limit = std.io.Limit;
const BufferedReader = @This();
@ -63,12 +65,12 @@ pub fn readVec(br: *BufferedReader, data: []const []u8) Reader.Error!usize {
}
/// Equivalent semantics to `std.io.Reader.VTable.read`.
pub fn read(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize {
pub fn read(br: *BufferedReader, bw: *BufferedWriter, limit: Limit) Reader.StreamError!usize {
return passthruRead(br, bw, limit);
}
/// Equivalent semantics to `std.io.Reader.VTable.discard`.
pub fn discard(br: *BufferedReader, limit: Reader.Limit) Reader.Error!usize {
pub fn discard(br: *BufferedReader, limit: Limit) Reader.Error!usize {
return passthruDiscard(br, limit);
}
@ -90,7 +92,7 @@ pub fn readVecAll(br: *BufferedReader, data: [][]u8) Reader.Error!void {
}
/// "Pump" data from the reader to the writer.
pub fn readAll(br: *BufferedReader, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!void {
pub fn readAll(br: *BufferedReader, bw: *BufferedWriter, limit: Limit) Reader.StreamError!void {
var remaining = limit;
while (remaining.nonzero()) {
const n = try br.read(bw, remaining);
@ -113,8 +115,8 @@ pub fn readRemaining(br: *BufferedReader, bw: *BufferedWriter) Reader.RwRemainin
}
/// Equivalent to `readVec` but reads at most `limit` bytes.
pub fn readVecLimit(br: *BufferedReader, data: []const []u8, limit: Reader.Limit) Reader.Error!usize {
assert(@intFromEnum(Reader.Limit.unlimited) == std.math.maxInt(usize));
pub fn readVecLimit(br: *BufferedReader, data: []const []u8, limit: Limit) Reader.Error!usize {
assert(@intFromEnum(Limit.unlimited) == std.math.maxInt(usize));
var remaining = @intFromEnum(limit);
for (data, 0..) |buf, i| {
const buffered = br.buffer[br.seek..br.end];
@ -165,7 +167,7 @@ pub fn readVecLimit(br: *BufferedReader, data: []const []u8, limit: Reader.Limit
return @intFromEnum(limit) - remaining;
}
fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize {
fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) Reader.StreamError!usize {
const br: *BufferedReader = @alignCast(@ptrCast(context));
const buffer = limit.slice(br.buffer[br.seek..br.end]);
if (buffer.len > 0) {
@ -176,22 +178,19 @@ fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit)
return br.unbuffered_reader.read(bw, limit);
}
fn passthruDiscard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize {
fn passthruDiscard(context: ?*anyopaque, limit: Limit) Reader.Error!usize {
const br: *BufferedReader = @alignCast(@ptrCast(context));
const buffered_len = br.end - br.seek;
if (limit.toInt()) |n| {
const remaining: Limit = if (limit.toInt()) |n| l: {
if (buffered_len >= n) {
br.seek += n;
return n;
}
br.seek = 0;
br.end = 0;
const additional = try br.unbuffered_reader.discard(.limited(n - buffered_len));
return n + additional;
}
const n = try br.unbuffered_reader.discard(.unlimited);
break :l .limited(n - buffered_len);
} else .unlimited;
br.seek = 0;
br.end = 0;
const n = if (br.unbuffered_reader.discard) |f| try f(remaining) else try br.defaultDiscard(remaining);
return buffered_len + n;
}
@ -200,6 +199,72 @@ fn passthruReadVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize
return readVecLimit(br, data, .unlimited);
}
fn defaultDiscard(br: *BufferedReader, limit: Limit) Reader.Error!usize {
assert(br.seek == 0);
assert(br.end == 0);
var bw: BufferedWriter = .{
.unbuffered_writer = .{
.context = undefined,
.vtable = &.{
.writeSplat = defaultDiscardWriteSplat,
.writeFile = defaultDiscardWriteFile,
},
},
.buffer = br.buffer,
};
const n = br.read(&bw, limit) catch |err| switch (err) {
error.WriteFailed => unreachable,
error.ReadFailed => return error.ReadFailed,
error.EndOfStream => return error.EndOfStream,
};
if (n > @intFromEnum(limit)) {
const over_amt = n - @intFromEnum(limit);
assert(over_amt <= bw.buffer.end); // limit may be exceeded only by an amount within buffer capacity.
br.seek = bw.end - over_amt;
br.end = bw.end;
return @intFromEnum(limit);
}
return n;
}
fn defaultDiscardWriteSplat(context: ?*anyopaque, data: []const []const u8, splat: usize) Writer.Error!usize {
_ = context;
const headers = data[0 .. data.len - 1];
const pattern = data[headers.len..];
var written: usize = pattern.len * splat;
for (headers) |bytes| written += bytes.len;
return written;
}
fn defaultDiscardWriteFile(
context: ?*anyopaque,
file_reader: *std.fs.File.Reader,
limit: Limit,
headers_and_trailers: []const []const u8,
headers_len: usize,
) Writer.FileError!usize {
_ = context;
if (file_reader.getSize()) |size| {
const remaining = size - file_reader.pos;
const seek_amt = limit.minInt(remaining);
// Error is observable on `file_reader` instance, and is safe to ignore
// depending on the caller's needs. Caller can make that decision.
file_reader.seekForward(seek_amt) catch {};
var n: usize = seek_amt;
for (headers_and_trailers[0..headers_len]) |bytes| n += bytes.len;
if (seek_amt == remaining) {
// Since we made it all the way through the file, the trailers are
// also included.
for (headers_and_trailers[headers_len..]) |bytes| n += bytes.len;
}
return n;
} else |_| {
// Error is observable on `file_reader` instance, and it is better to
// treat the file as a pipe.
return error.Unimplemented;
}
}
/// Returns the next `len` bytes from `unbuffered_reader`, filling the buffer as
/// necessary.
///
@ -475,7 +540,7 @@ pub fn readSliceAlloc(br: *BufferedReader, allocator: Allocator, len: usize) Rea
///
/// See also:
/// * `readRemainingArrayList`
pub fn readRemainingAlloc(r: Reader, gpa: Allocator, limit: Reader.Limit) Reader.LimitedAllocError![]u8 {
pub fn readRemainingAlloc(r: Reader, gpa: Allocator, limit: Limit) Reader.LimitedAllocError![]u8 {
var buffer: ArrayList(u8) = .empty;
defer buffer.deinit(gpa);
try readRemainingArrayList(r, gpa, null, &buffer, limit);
@ -499,7 +564,7 @@ pub fn readRemainingArrayList(
gpa: Allocator,
comptime alignment: ?std.mem.Alignment,
list: *std.ArrayListAlignedUnmanaged(u8, alignment),
limit: Reader.Limit,
limit: Limit,
) Reader.LimitedAllocError!void {
const buffer = br.buffer;
const buffered = buffer[br.seek..br.end];
@ -680,7 +745,7 @@ pub fn peekDelimiterExclusive(br: *BufferedReader, delimiter: u8) DelimiterError
/// found. Does not write the delimiter itself.
///
/// Returns number of bytes streamed.
pub fn readDelimiter(br: *BufferedReader, bw: *BufferedWriter, delimiter: u8) Reader.RwError!usize {
pub fn readDelimiter(br: *BufferedReader, bw: *BufferedWriter, delimiter: u8) Reader.StreamError!usize {
const amount, const to = try br.readAny(bw, delimiter, .unlimited);
return switch (to) {
.delimiter => amount,
@ -722,7 +787,7 @@ pub fn readDelimiterLimit(
br: *BufferedReader,
bw: *BufferedWriter,
delimiter: u8,
limit: Reader.Limit,
limit: Limit,
) StreamDelimiterLimitedError!usize {
const amount, const to = try br.readAny(bw, delimiter, limit);
return switch (to) {
@ -736,7 +801,7 @@ fn readAny(
br: *BufferedReader,
bw: *BufferedWriter,
delimiter: ?u8,
limit: Reader.Limit,
limit: Limit,
) Reader.RwRemainingError!struct { usize, enum { delimiter, limit, end } } {
var amount: usize = 0;
var remaining = limit;

View File

@ -5,6 +5,7 @@ const native_endian = @import("builtin").target.cpu.arch.endian();
const Writer = std.io.Writer;
const Allocator = std.mem.Allocator;
const testing = std.testing;
const Limit = std.io.Limit;
/// Underlying stream to send bytes to.
///
@ -42,7 +43,7 @@ pub fn writer(bw: *BufferedWriter) Writer {
const fixed_vtable: Writer.VTable = .{
.writeSplat = fixedWriteSplat,
.writeFile = Writer.failingWriteFile,
.writeFile = Writer.unimplementedWriteFile,
};
/// Replaces the `BufferedWriter` with one that writes to `buffer` and returns
@ -82,7 +83,7 @@ pub fn flush(bw: *BufferedWriter) Writer.Error!void {
bw.end = 0;
}
pub fn flushLimit(bw: *BufferedWriter, limit: Writer.Limit) Writer.Error!void {
pub fn flushLimit(bw: *BufferedWriter, limit: Limit) Writer.Error!void {
const buffer = limit.slice(bw.buffer[0..bw.end]);
var index: usize = 0;
while (index < buffer.len) index += try bw.unbuffered_writer.writeVec(&.{buffer[index..]});
@ -228,7 +229,7 @@ pub fn writeSplatLimit(
bw: *BufferedWriter,
data: []const []const u8,
splat: usize,
limit: Writer.Limit,
limit: Limit,
) Writer.Error!usize {
_ = bw;
_ = data;
@ -544,7 +545,7 @@ pub fn writeFile(
bw: *BufferedWriter,
file: std.fs.File,
offset: Writer.Offset,
limit: Writer.Limit,
limit: Limit,
headers_and_trailers: []const []const u8,
headers_len: usize,
) Writer.FileError!usize {
@ -560,7 +561,7 @@ pub fn writeFileReading(
bw: *BufferedWriter,
file: std.fs.File,
offset: Writer.Offset,
limit: Writer.Limit,
limit: Limit,
) WriteFileReadingError!usize {
const dest = limit.slice(try bw.writableSliceGreedy(1));
const n = if (offset.toInt()) |pos| try file.pread(dest, pos) else try file.read(dest);
@ -572,7 +573,7 @@ fn passthruWriteFile(
context: ?*anyopaque,
file: std.fs.File,
offset: Writer.Offset,
limit: Writer.Limit,
limit: Limit,
headers_and_trailers: []const []const u8,
headers_len: usize,
) Writer.FileError!usize {
@ -653,7 +654,7 @@ pub const WriteFileOptions = struct {
offset: Writer.Offset = .none,
/// If the size of the source file is known, it is likely that passing the
/// size here will save one syscall.
limit: Writer.Limit = .unlimited,
limit: Limit = .unlimited,
/// Headers and trailers must be passed together so that in case `len` is
/// zero, they can be forwarded directly to `Writer.VTable.writeSplat`.
///
@ -749,7 +750,7 @@ pub fn writeFileReadingAll(
bw: *BufferedWriter,
file: std.fs.File,
offset: Writer.Offset,
limit: Writer.Limit,
limit: Limit,
) WriteFileReadingError!void {
if (offset.toInt()) |start_pos| {
var remaining = limit;

View File

@ -5,6 +5,7 @@ const BufferedWriter = std.io.BufferedWriter;
const BufferedReader = std.io.BufferedReader;
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayListUnmanaged;
const Limit = std.io.Limit;
pub const Limited = @import("Reader/Limited.zig");
@ -25,21 +26,7 @@ pub const VTable = struct {
/// Implementations are encouraged to utilize mandatory minimum buffer
/// sizes combined with short reads (returning a value less than `limit`)
/// in order to minimize complexity.
read: *const fn (context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize,
/// Writes bytes from the internally tracked stream position to `data`.
///
/// Returns the number of bytes written, which will be at minimum `0` and
/// at most the sum of each data slice length. The number of bytes read,
/// including zero, does not indicate end of stream.
///
/// The reader's internal logical seek position moves forward in accordance
/// with the number of bytes returned from this function.
///
/// Implementations are encouraged to utilize mandatory minimum buffer
/// sizes combined with short reads (returning a value less than the total
/// buffer capacity inside `data`) in order to minimize complexity.
readVec: *const fn (context: ?*anyopaque, data: []const []u8) Error!usize,
read: *const fn (context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) StreamError!usize,
/// Consumes bytes from the internally tracked stream position without
/// providing access to them.
@ -54,10 +41,15 @@ pub const VTable = struct {
/// Implementations are encouraged to utilize mandatory minimum buffer
/// sizes combined with short reads (returning a value less than `limit`)
/// in order to minimize complexity.
discard: *const fn (context: ?*anyopaque, limit: Limit) Error!usize,
///
/// If an implementation sets this to `null`, a default implementation is
/// provided which is based on calling `read`, borrowing
/// `BufferedReader.buffer` to construct a temporary `BufferedWriter` and
/// ignoring the written data.
discard: *const fn (context: ?*anyopaque, limit: Limit) DiscardError!usize = null,
};
pub const RwError = error{
pub const StreamError = error{
/// See the `Reader` implementation for detailed diagnostics.
ReadFailed,
/// See the `Writer` implementation for detailed diagnostics.
@ -67,7 +59,7 @@ pub const RwError = error{
EndOfStream,
};
pub const Error = error{
pub const DiscardError = error{
/// See the `Reader` implementation for detailed diagnostics.
ReadFailed,
EndOfStream,
@ -85,10 +77,7 @@ pub const ShortError = error{
ReadFailed,
};
/// TODO: no pub
pub const Limit = std.io.Limit;
pub fn read(r: Reader, bw: *BufferedWriter, limit: Limit) RwError!usize {
pub fn read(r: Reader, bw: *BufferedWriter, limit: Limit) StreamError!usize {
const before = bw.count;
const n = try r.vtable.read(r.context, bw, limit);
assert(n <= @intFromEnum(limit));
@ -96,11 +85,7 @@ pub fn read(r: Reader, bw: *BufferedWriter, limit: Limit) RwError!usize {
return n;
}
pub fn readVec(r: Reader, data: []const []u8) Error!usize {
return r.vtable.readVec(r.context, data);
}
pub fn discard(r: Reader, limit: Limit) Error!usize {
pub fn discard(r: Reader, limit: Limit) DiscardError!usize {
const n = try r.vtable.discard(r.context, limit);
assert(n <= @intFromEnum(limit));
return n;
@ -188,7 +173,6 @@ pub const failing: Reader = .{
.context = undefined,
.vtable = &.{
.read = failingRead,
.readVec = failingReadVec,
.discard = failingDiscard,
},
};
@ -197,7 +181,6 @@ pub const ending: Reader = .{
.context = undefined,
.vtable = &.{
.read = endingRead,
.readVec = endingReadVec,
.discard = endingDiscard,
},
};
@ -222,39 +205,27 @@ pub fn limited(r: Reader, limit: Limit) Limited {
};
}
fn endingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize {
fn endingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) StreamError!usize {
_ = context;
_ = bw;
_ = limit;
return error.EndOfStream;
}
fn endingReadVec(context: ?*anyopaque, data: []const []u8) Error!usize {
_ = context;
_ = data;
return error.EndOfStream;
}
fn endingDiscard(context: ?*anyopaque, limit: Limit) Error!usize {
fn endingDiscard(context: ?*anyopaque, limit: Limit) DiscardError!usize {
_ = context;
_ = limit;
return error.EndOfStream;
}
fn failingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize {
fn failingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) StreamError!usize {
_ = context;
_ = bw;
_ = limit;
return error.ReadFailed;
}
fn failingReadVec(context: ?*anyopaque, data: []const []u8) Error!usize {
_ = context;
_ = data;
return error.ReadFailed;
}
fn failingDiscard(context: ?*anyopaque, limit: Limit) Error!usize {
fn failingDiscard(context: ?*anyopaque, limit: Limit) DiscardError!usize {
_ = context;
_ = limit;
return error.ReadFailed;
@ -308,7 +279,6 @@ pub fn Hashed(comptime Hasher: type) type {
.context = this,
.vtable = &.{
.read = @This().read,
.readVec = @This().readVec,
.discard = @This().discard,
},
},
@ -318,7 +288,7 @@ pub fn Hashed(comptime Hasher: type) type {
};
}
fn read(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize {
fn read(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) StreamError!usize {
const this: *@This() = @alignCast(@ptrCast(context));
const slice = limit.slice(try bw.writableSliceGreedy(1));
const n = try this.in.readVec(&.{slice});
@ -327,7 +297,7 @@ pub fn Hashed(comptime Hasher: type) type {
return n;
}
fn discard(context: ?*anyopaque, limit: Limit) Error!usize {
fn discard(context: ?*anyopaque, limit: Limit) DiscardError!usize {
const this: *@This() = @alignCast(@ptrCast(context));
var bw = this.hasher.writable(&.{});
const n = this.in.read(&bw, limit) catch |err| switch (err) {
@ -337,7 +307,7 @@ pub fn Hashed(comptime Hasher: type) type {
return n;
}
fn readVec(context: ?*anyopaque, data: []const []u8) Error!usize {
fn readVec(context: ?*anyopaque, data: []const []u8) DiscardError!usize {
const this: *@This() = @alignCast(@ptrCast(context));
const n = try this.in.readVec(data);
var remaining: usize = n;

View File

@ -1,6 +1,7 @@
const std = @import("../std.zig");
const assert = std.debug.assert;
const Writer = @This();
const Limit = std.io.Limit;
pub const Null = @import("Writer/Null.zig");
@ -47,6 +48,8 @@ pub const VTable = struct {
offset: Offset,
/// Maximum amount of bytes to read from the file. Implementations may
/// assume that the file size does not exceed this amount.
///
/// `headers_and_trailers` do not count towards this limit.
limit: Limit,
/// Headers and trailers must be passed together so that in case `len` is
/// zero, they can be forwarded directly to `VTable.writeVec`.
@ -68,9 +71,6 @@ pub const FileError = std.fs.File.PReadError || error{
Unimplemented,
};
/// TODO: no pub
pub const Limit = std.io.Limit;
pub const Offset = enum(u64) {
zero = 0,
/// Indicates to read the file as a stream.