mirror of
https://github.com/ziglang/zig.git
synced 2026-02-21 16:54:52 +00:00
std: WIP update more to new reader/writer
delete some bad readers/writers add limited reader update TLS about to do something drastic to compress
This commit is contained in:
parent
c7040171fb
commit
25ac70f973
@ -1,75 +1,19 @@
|
||||
//! Compression algorithms.
|
||||
|
||||
const std = @import("std.zig");
|
||||
|
||||
pub const flate = @import("compress/flate.zig");
|
||||
pub const gzip = @import("compress/gzip.zig");
|
||||
pub const zlib = @import("compress/zlib.zig");
|
||||
pub const lzma = @import("compress/lzma.zig");
|
||||
pub const lzma2 = @import("compress/lzma2.zig");
|
||||
pub const xz = @import("compress/xz.zig");
|
||||
pub const zlib = @import("compress/zlib.zig");
|
||||
pub const zstd = @import("compress/zstandard.zig");
|
||||
|
||||
pub fn HashedReader(ReaderType: type, HasherType: type) type {
|
||||
return struct {
|
||||
child_reader: ReaderType,
|
||||
hasher: HasherType,
|
||||
|
||||
pub const Error = ReaderType.Error;
|
||||
pub const Reader = std.io.Reader(*@This(), Error, read);
|
||||
|
||||
pub fn read(self: *@This(), buf: []u8) Error!usize {
|
||||
const amt = try self.child_reader.read(buf);
|
||||
self.hasher.update(buf[0..amt]);
|
||||
return amt;
|
||||
}
|
||||
|
||||
pub fn reader(self: *@This()) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn hashedReader(
|
||||
reader: anytype,
|
||||
hasher: anytype,
|
||||
) HashedReader(@TypeOf(reader), @TypeOf(hasher)) {
|
||||
return .{ .child_reader = reader, .hasher = hasher };
|
||||
}
|
||||
|
||||
pub fn HashedWriter(WriterType: type, HasherType: type) type {
|
||||
return struct {
|
||||
child_writer: WriterType,
|
||||
hasher: HasherType,
|
||||
|
||||
pub const Error = WriterType.Error;
|
||||
pub const Writer = std.io.Writer(*@This(), Error, write);
|
||||
|
||||
pub fn write(self: *@This(), buf: []const u8) Error!usize {
|
||||
const amt = try self.child_writer.write(buf);
|
||||
self.hasher.update(buf[0..amt]);
|
||||
return amt;
|
||||
}
|
||||
|
||||
pub fn writer(self: *@This()) Writer {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn hashedWriter(
|
||||
writer: anytype,
|
||||
hasher: anytype,
|
||||
) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) {
|
||||
return .{ .child_writer = writer, .hasher = hasher };
|
||||
}
|
||||
|
||||
test {
|
||||
_ = flate;
|
||||
_ = gzip;
|
||||
_ = lzma;
|
||||
_ = lzma2;
|
||||
_ = xz;
|
||||
_ = zstd;
|
||||
_ = flate;
|
||||
_ = gzip;
|
||||
_ = zlib;
|
||||
_ = zstd;
|
||||
}
|
||||
|
||||
@ -821,7 +821,7 @@ pub fn BitReader(comptime T: type) type {
|
||||
/// Skip zero terminated string.
|
||||
pub fn skipStringZ(self: *Self) !void {
|
||||
while (true) {
|
||||
if (try self.readF(u8, 0) == 0) break;
|
||||
if (try self.readF(u8, .{}) == 0) break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
const std = @import("std");
|
||||
const std = @import("../std.zig");
|
||||
const RingBuffer = std.RingBuffer;
|
||||
|
||||
const types = @import("zstandard/types.zig");
|
||||
|
||||
/// Recommended amount by the standard. Lower than this may result in inability
|
||||
/// to decompress common streams.
|
||||
pub const default_window_len = 8 * 1024 * 1024;
|
||||
|
||||
pub const frame = types.frame;
|
||||
pub const compressed_block = types.compressed_block;
|
||||
|
||||
@ -10,7 +15,8 @@ pub const decompress = @import("zstandard/decompress.zig");
|
||||
pub const Decompressor = struct {
|
||||
const table_size_max = types.compressed_block.table_size_max;
|
||||
|
||||
source: *std.io.BufferedReader,
|
||||
input: *std.io.BufferedReader,
|
||||
bytes_read: usize,
|
||||
state: enum { NewFrame, InFrame, LastBlock },
|
||||
decode_state: decompress.block.DecodeState,
|
||||
frame_context: decompress.FrameContext,
|
||||
@ -23,14 +29,12 @@ pub const Decompressor = struct {
|
||||
verify_checksum: bool,
|
||||
checksum: ?u32,
|
||||
current_frame_decompressed_size: usize,
|
||||
err: ?Error = null,
|
||||
|
||||
pub const Options = struct {
|
||||
verify_checksum: bool = true,
|
||||
/// See `default_window_len`.
|
||||
window_buffer: []u8,
|
||||
|
||||
/// Recommended amount by the standard. Lower than this may result
|
||||
/// in inability to decompress common streams.
|
||||
pub const default_window_buffer_len = 8 * 1024 * 1024;
|
||||
};
|
||||
|
||||
const WindowBuffer = struct {
|
||||
@ -45,11 +49,13 @@ pub const Decompressor = struct {
|
||||
MalformedBlock,
|
||||
MalformedFrame,
|
||||
OutOfMemory,
|
||||
EndOfStream,
|
||||
};
|
||||
|
||||
pub fn init(source: *std.io.BufferedReader, options: Options) Decompressor {
|
||||
pub fn init(input: *std.io.BufferedReader, options: Options) Decompressor {
|
||||
return .{
|
||||
.source = source,
|
||||
.input = input,
|
||||
.bytes_read = 0,
|
||||
.state = .NewFrame,
|
||||
.decode_state = undefined,
|
||||
.frame_context = undefined,
|
||||
@ -65,100 +71,128 @@ pub const Decompressor = struct {
|
||||
};
|
||||
}
|
||||
|
||||
fn frameInit(self: *Decompressor) !void {
|
||||
const source_reader = self.source;
|
||||
switch (try decompress.decodeFrameHeader(source_reader)) {
|
||||
fn frameInit(d: *Decompressor) !void {
|
||||
const in = d.input;
|
||||
switch (try decompress.decodeFrameHeader(in, &d.bytes_read)) {
|
||||
.skippable => |header| {
|
||||
try source_reader.skipBytes(header.frame_size, .{});
|
||||
self.state = .NewFrame;
|
||||
try in.discardAll(header.frame_size);
|
||||
d.bytes_read += header.frame_size;
|
||||
d.state = .NewFrame;
|
||||
},
|
||||
.zstandard => |header| {
|
||||
const frame_context = try decompress.FrameContext.init(
|
||||
header,
|
||||
self.buffer.data.len,
|
||||
self.verify_checksum,
|
||||
d.buffer.data.len,
|
||||
d.verify_checksum,
|
||||
);
|
||||
|
||||
const decode_state = decompress.block.DecodeState.init(
|
||||
&self.literal_fse_buffer,
|
||||
&self.match_fse_buffer,
|
||||
&self.offset_fse_buffer,
|
||||
&d.literal_fse_buffer,
|
||||
&d.match_fse_buffer,
|
||||
&d.offset_fse_buffer,
|
||||
);
|
||||
|
||||
self.decode_state = decode_state;
|
||||
self.frame_context = frame_context;
|
||||
d.decode_state = decode_state;
|
||||
d.frame_context = frame_context;
|
||||
|
||||
self.checksum = null;
|
||||
self.current_frame_decompressed_size = 0;
|
||||
d.checksum = null;
|
||||
d.current_frame_decompressed_size = 0;
|
||||
|
||||
self.state = .InFrame;
|
||||
d.state = .InFrame;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reader(self: *Decompressor) std.io.Reader {
|
||||
return .{ .context = self };
|
||||
return .{
|
||||
.context = self,
|
||||
.vtable = &.{
|
||||
.read = read,
|
||||
.readVec = readVec,
|
||||
.discard = discard,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
pub fn read(self: *Decompressor, buffer: []u8) Error!usize {
|
||||
if (buffer.len == 0) return 0;
|
||||
fn read(context: ?*anyopaque, bw: *std.io.BufferedWriter, limit: std.io.Reader.Limit) std.io.Reader.RwError!usize {
|
||||
const buf = limit.slice(try bw.writableSliceGreedy(1));
|
||||
const n = try readVec(context, &.{buf});
|
||||
bw.advance(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
var size: usize = 0;
|
||||
while (size == 0) {
|
||||
while (self.state == .NewFrame) {
|
||||
const initial_count = self.source.bytes_read;
|
||||
self.frameInit() catch |err| switch (err) {
|
||||
error.DictionaryIdFlagUnsupported => return error.DictionaryIdFlagUnsupported,
|
||||
error.EndOfStream => return if (self.source.bytes_read == initial_count)
|
||||
0
|
||||
else
|
||||
error.MalformedFrame,
|
||||
else => return error.MalformedFrame,
|
||||
};
|
||||
}
|
||||
size = try self.readInner(buffer);
|
||||
fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error!usize {
|
||||
var trash: [128]u8 = undefined;
|
||||
const buf = limit.slice(&trash);
|
||||
return readVec(context, &.{buf});
|
||||
}
|
||||
|
||||
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
const d: *Decompressor = @ptrCast(@alignCast(context));
|
||||
if (data.len == 0) return 0;
|
||||
const buffer = data[0];
|
||||
while (d.state == .NewFrame) {
|
||||
const initial_count = d.bytes_read;
|
||||
d.frameInit() catch |err| switch (err) {
|
||||
error.DictionaryIdFlagUnsupported => {
|
||||
d.err = error.DictionaryIdFlagUnsupported;
|
||||
return error.ReadFailed;
|
||||
},
|
||||
error.EndOfStream => {
|
||||
if (d.bytes_read == initial_count) return error.EndOfStream;
|
||||
d.err = error.MalformedFrame;
|
||||
return error.ReadFailed;
|
||||
},
|
||||
else => {
|
||||
d.err = error.MalformedFrame;
|
||||
return error.ReadFailed;
|
||||
},
|
||||
};
|
||||
}
|
||||
return size;
|
||||
return d.readInner(buffer) catch |err| {
|
||||
d.err = err;
|
||||
return error.ReadFailed;
|
||||
};
|
||||
}
|
||||
|
||||
fn readInner(self: *Decompressor, buffer: []u8) Error!usize {
|
||||
std.debug.assert(self.state != .NewFrame);
|
||||
fn readInner(d: *Decompressor, buffer: []u8) Error!usize {
|
||||
std.debug.assert(d.state != .NewFrame);
|
||||
|
||||
var ring_buffer = RingBuffer{
|
||||
.data = self.buffer.data,
|
||||
.read_index = self.buffer.read_index,
|
||||
.write_index = self.buffer.write_index,
|
||||
.data = d.buffer.data,
|
||||
.read_index = d.buffer.read_index,
|
||||
.write_index = d.buffer.write_index,
|
||||
};
|
||||
defer {
|
||||
self.buffer.read_index = ring_buffer.read_index;
|
||||
self.buffer.write_index = ring_buffer.write_index;
|
||||
d.buffer.read_index = ring_buffer.read_index;
|
||||
d.buffer.write_index = ring_buffer.write_index;
|
||||
}
|
||||
|
||||
const source_reader = self.source;
|
||||
while (ring_buffer.isEmpty() and self.state != .LastBlock) {
|
||||
const header_bytes = source_reader.readBytesNoEof(3) catch
|
||||
return error.MalformedFrame;
|
||||
const block_header = decompress.block.decodeBlockHeader(&header_bytes);
|
||||
const in = d.input;
|
||||
while (ring_buffer.isEmpty() and d.state != .LastBlock) {
|
||||
const header_bytes = try in.takeArray(3);
|
||||
d.bytes_read += header_bytes.len;
|
||||
const block_header = decompress.block.decodeBlockHeader(header_bytes);
|
||||
|
||||
decompress.block.decodeBlockReader(
|
||||
&ring_buffer,
|
||||
source_reader,
|
||||
in,
|
||||
&d.bytes_read,
|
||||
block_header,
|
||||
&self.decode_state,
|
||||
self.frame_context.block_size_max,
|
||||
&self.literals_buffer,
|
||||
&self.sequence_buffer,
|
||||
) catch
|
||||
return error.MalformedBlock;
|
||||
&d.decode_state,
|
||||
d.frame_context.block_size_max,
|
||||
&d.literals_buffer,
|
||||
&d.sequence_buffer,
|
||||
) catch return error.MalformedBlock;
|
||||
|
||||
if (self.frame_context.content_size) |size| {
|
||||
if (self.current_frame_decompressed_size > size) return error.MalformedFrame;
|
||||
if (d.frame_context.content_size) |size| {
|
||||
if (d.current_frame_decompressed_size > size) return error.MalformedFrame;
|
||||
}
|
||||
|
||||
const size = ring_buffer.len();
|
||||
self.current_frame_decompressed_size += size;
|
||||
d.current_frame_decompressed_size += size;
|
||||
|
||||
if (self.frame_context.hasher_opt) |*hasher| {
|
||||
if (d.frame_context.hasher_opt) |*hasher| {
|
||||
if (size > 0) {
|
||||
const written_slice = ring_buffer.sliceLast(size);
|
||||
hasher.update(written_slice.first);
|
||||
@ -166,19 +200,19 @@ pub const Decompressor = struct {
|
||||
}
|
||||
}
|
||||
if (block_header.last_block) {
|
||||
self.state = .LastBlock;
|
||||
if (self.frame_context.has_checksum) {
|
||||
const checksum = source_reader.readInt(u32, .little) catch
|
||||
return error.MalformedFrame;
|
||||
if (self.verify_checksum) {
|
||||
if (self.frame_context.hasher_opt) |*hasher| {
|
||||
d.state = .LastBlock;
|
||||
if (d.frame_context.has_checksum) {
|
||||
const checksum = in.readInt(u32, .little) catch return error.MalformedFrame;
|
||||
d.bytes_read += 4;
|
||||
if (d.verify_checksum) {
|
||||
if (d.frame_context.hasher_opt) |*hasher| {
|
||||
if (checksum != decompress.computeChecksum(hasher))
|
||||
return error.ChecksumFailure;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (self.frame_context.content_size) |content_size| {
|
||||
if (content_size != self.current_frame_decompressed_size) {
|
||||
if (d.frame_context.content_size) |content_size| {
|
||||
if (content_size != d.current_frame_decompressed_size) {
|
||||
return error.MalformedFrame;
|
||||
}
|
||||
}
|
||||
@ -189,8 +223,8 @@ pub const Decompressor = struct {
|
||||
if (size > 0) {
|
||||
ring_buffer.readFirstAssumeLength(buffer, size);
|
||||
}
|
||||
if (self.state == .LastBlock and ring_buffer.len() == 0) {
|
||||
self.state = .NewFrame;
|
||||
if (d.state == .LastBlock and ring_buffer.len() == 0) {
|
||||
d.state = .NewFrame;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
@ -807,7 +807,8 @@ pub fn decodeBlockRingBuffer(
|
||||
/// contain enough bytes.
|
||||
pub fn decodeBlockReader(
|
||||
dest: *RingBuffer,
|
||||
source: anytype,
|
||||
in: *std.io.BufferedReader,
|
||||
bytes_read: *usize,
|
||||
block_header: frame.Zstandard.Block.Header,
|
||||
decode_state: *DecodeState,
|
||||
block_size_max: usize,
|
||||
@ -815,26 +816,29 @@ pub fn decodeBlockReader(
|
||||
sequence_buffer: []u8,
|
||||
) !void {
|
||||
const block_size = block_header.block_size;
|
||||
var block_reader_limited = std.io.limitedReader(source, block_size);
|
||||
const block_reader = block_reader_limited.reader();
|
||||
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
|
||||
switch (block_header.block_type) {
|
||||
.raw => {
|
||||
if (block_size == 0) return;
|
||||
const slice = dest.sliceAt(dest.write_index, block_size);
|
||||
try source.readNoEof(slice.first);
|
||||
try source.readNoEof(slice.second);
|
||||
var vecs: [2][]u8 = &.{slice.first, slice.second };
|
||||
try in.readVecAll(&vecs);
|
||||
assert(slice.first.len + slice.second.len == block_size);
|
||||
bytes_read.* += block_size;
|
||||
dest.write_index = dest.mask2(dest.write_index + block_size);
|
||||
decode_state.written_count += block_size;
|
||||
},
|
||||
.rle => {
|
||||
const byte = try source.readByte();
|
||||
const byte = try in.takeByte();
|
||||
bytes_read.* += 1;
|
||||
for (0..block_size) |_| {
|
||||
dest.writeAssumeCapacity(byte);
|
||||
}
|
||||
decode_state.written_count += block_size;
|
||||
},
|
||||
.compressed => {
|
||||
var block_reader_limited = std.io.limitedReader(source, block_size);
|
||||
const block_reader = block_reader_limited.reader();
|
||||
const literals = try decodeLiteralsSection(block_reader, literals_buffer);
|
||||
const sequences_header = try decodeSequencesHeader(block_reader);
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ pub const FrameHeader = union(enum) {
|
||||
skippable: SkippableHeader,
|
||||
};
|
||||
|
||||
pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
|
||||
pub const HeaderError = error{ ReadFailed, BadMagic, EndOfStream, ReservedBitSet };
|
||||
|
||||
/// Returns the header of the frame at the beginning of `source`.
|
||||
///
|
||||
@ -61,16 +61,21 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
|
||||
/// - `error.EndOfStream` if `source` contains fewer than 4 bytes
|
||||
/// - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the
|
||||
/// reserved bits are set
|
||||
pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader {
|
||||
const magic = try source.readInt(u32, .little);
|
||||
pub fn decodeFrameHeader(br: *std.io.BufferedReader, bytes_read: *usize) HeaderError!FrameHeader {
|
||||
const magic = try br.readInt(u32, .little);
|
||||
bytes_read.* += 4;
|
||||
const frame_type = try frameType(magic);
|
||||
switch (frame_type) {
|
||||
.zstandard => return FrameHeader{ .zstandard = try decodeZstandardHeader(source) },
|
||||
.skippable => return FrameHeader{
|
||||
.skippable = .{
|
||||
.magic_number = magic,
|
||||
.frame_size = try source.readInt(u32, .little),
|
||||
},
|
||||
.zstandard => return .{ .zstandard = try decodeZstandardHeader(br, bytes_read) },
|
||||
.skippable => {
|
||||
const result: FrameHeader = .{
|
||||
.skippable = .{
|
||||
.magic_number = magic,
|
||||
.frame_size = try br.readInt(u32, .little),
|
||||
},
|
||||
};
|
||||
bytes_read.* += 4;
|
||||
return result;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{
|
||||
};
|
||||
|
||||
pub const close_notify_alert = [_]u8{
|
||||
@intFromEnum(AlertLevel.warning),
|
||||
@intFromEnum(AlertDescription.close_notify),
|
||||
@intFromEnum(Alert.Level.warning),
|
||||
@intFromEnum(Alert.Description.close_notify),
|
||||
};
|
||||
|
||||
pub const ProtocolVersion = enum(u16) {
|
||||
@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) {
|
||||
_,
|
||||
};
|
||||
|
||||
pub const AlertLevel = enum(u8) {
|
||||
warning = 1,
|
||||
fatal = 2,
|
||||
_,
|
||||
};
|
||||
pub const Alert = struct {
|
||||
level: Level,
|
||||
description: Description,
|
||||
|
||||
pub const AlertDescription = enum(u8) {
|
||||
pub const Error = error{
|
||||
TlsAlertUnexpectedMessage,
|
||||
TlsAlertBadRecordMac,
|
||||
TlsAlertRecordOverflow,
|
||||
TlsAlertHandshakeFailure,
|
||||
TlsAlertBadCertificate,
|
||||
TlsAlertUnsupportedCertificate,
|
||||
TlsAlertCertificateRevoked,
|
||||
TlsAlertCertificateExpired,
|
||||
TlsAlertCertificateUnknown,
|
||||
TlsAlertIllegalParameter,
|
||||
TlsAlertUnknownCa,
|
||||
TlsAlertAccessDenied,
|
||||
TlsAlertDecodeError,
|
||||
TlsAlertDecryptError,
|
||||
TlsAlertProtocolVersion,
|
||||
TlsAlertInsufficientSecurity,
|
||||
TlsAlertInternalError,
|
||||
TlsAlertInappropriateFallback,
|
||||
TlsAlertMissingExtension,
|
||||
TlsAlertUnsupportedExtension,
|
||||
TlsAlertUnrecognizedName,
|
||||
TlsAlertBadCertificateStatusResponse,
|
||||
TlsAlertUnknownPskIdentity,
|
||||
TlsAlertCertificateRequired,
|
||||
TlsAlertNoApplicationProtocol,
|
||||
TlsAlertUnknown,
|
||||
pub const Level = enum(u8) {
|
||||
warning = 1,
|
||||
fatal = 2,
|
||||
_,
|
||||
};
|
||||
|
||||
close_notify = 0,
|
||||
unexpected_message = 10,
|
||||
bad_record_mac = 20,
|
||||
record_overflow = 22,
|
||||
handshake_failure = 40,
|
||||
bad_certificate = 42,
|
||||
unsupported_certificate = 43,
|
||||
certificate_revoked = 44,
|
||||
certificate_expired = 45,
|
||||
certificate_unknown = 46,
|
||||
illegal_parameter = 47,
|
||||
unknown_ca = 48,
|
||||
access_denied = 49,
|
||||
decode_error = 50,
|
||||
decrypt_error = 51,
|
||||
protocol_version = 70,
|
||||
insufficient_security = 71,
|
||||
internal_error = 80,
|
||||
inappropriate_fallback = 86,
|
||||
user_canceled = 90,
|
||||
missing_extension = 109,
|
||||
unsupported_extension = 110,
|
||||
unrecognized_name = 112,
|
||||
bad_certificate_status_response = 113,
|
||||
unknown_psk_identity = 115,
|
||||
certificate_required = 116,
|
||||
no_application_protocol = 120,
|
||||
_,
|
||||
pub const Description = enum(u8) {
|
||||
pub const Error = error{
|
||||
TlsAlertUnexpectedMessage,
|
||||
TlsAlertBadRecordMac,
|
||||
TlsAlertRecordOverflow,
|
||||
TlsAlertHandshakeFailure,
|
||||
TlsAlertBadCertificate,
|
||||
TlsAlertUnsupportedCertificate,
|
||||
TlsAlertCertificateRevoked,
|
||||
TlsAlertCertificateExpired,
|
||||
TlsAlertCertificateUnknown,
|
||||
TlsAlertIllegalParameter,
|
||||
TlsAlertUnknownCa,
|
||||
TlsAlertAccessDenied,
|
||||
TlsAlertDecodeError,
|
||||
TlsAlertDecryptError,
|
||||
TlsAlertProtocolVersion,
|
||||
TlsAlertInsufficientSecurity,
|
||||
TlsAlertInternalError,
|
||||
TlsAlertInappropriateFallback,
|
||||
TlsAlertMissingExtension,
|
||||
TlsAlertUnsupportedExtension,
|
||||
TlsAlertUnrecognizedName,
|
||||
TlsAlertBadCertificateStatusResponse,
|
||||
TlsAlertUnknownPskIdentity,
|
||||
TlsAlertCertificateRequired,
|
||||
TlsAlertNoApplicationProtocol,
|
||||
TlsAlertUnknown,
|
||||
};
|
||||
|
||||
pub fn toError(alert: AlertDescription) Error!void {
|
||||
switch (alert) {
|
||||
.close_notify => {}, // not an error
|
||||
.unexpected_message => return error.TlsAlertUnexpectedMessage,
|
||||
.bad_record_mac => return error.TlsAlertBadRecordMac,
|
||||
.record_overflow => return error.TlsAlertRecordOverflow,
|
||||
.handshake_failure => return error.TlsAlertHandshakeFailure,
|
||||
.bad_certificate => return error.TlsAlertBadCertificate,
|
||||
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
|
||||
.certificate_revoked => return error.TlsAlertCertificateRevoked,
|
||||
.certificate_expired => return error.TlsAlertCertificateExpired,
|
||||
.certificate_unknown => return error.TlsAlertCertificateUnknown,
|
||||
.illegal_parameter => return error.TlsAlertIllegalParameter,
|
||||
.unknown_ca => return error.TlsAlertUnknownCa,
|
||||
.access_denied => return error.TlsAlertAccessDenied,
|
||||
.decode_error => return error.TlsAlertDecodeError,
|
||||
.decrypt_error => return error.TlsAlertDecryptError,
|
||||
.protocol_version => return error.TlsAlertProtocolVersion,
|
||||
.insufficient_security => return error.TlsAlertInsufficientSecurity,
|
||||
.internal_error => return error.TlsAlertInternalError,
|
||||
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
|
||||
.user_canceled => {}, // not an error
|
||||
.missing_extension => return error.TlsAlertMissingExtension,
|
||||
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
|
||||
.unrecognized_name => return error.TlsAlertUnrecognizedName,
|
||||
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
|
||||
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
|
||||
.certificate_required => return error.TlsAlertCertificateRequired,
|
||||
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
|
||||
_ => return error.TlsAlertUnknown,
|
||||
close_notify = 0,
|
||||
unexpected_message = 10,
|
||||
bad_record_mac = 20,
|
||||
record_overflow = 22,
|
||||
handshake_failure = 40,
|
||||
bad_certificate = 42,
|
||||
unsupported_certificate = 43,
|
||||
certificate_revoked = 44,
|
||||
certificate_expired = 45,
|
||||
certificate_unknown = 46,
|
||||
illegal_parameter = 47,
|
||||
unknown_ca = 48,
|
||||
access_denied = 49,
|
||||
decode_error = 50,
|
||||
decrypt_error = 51,
|
||||
protocol_version = 70,
|
||||
insufficient_security = 71,
|
||||
internal_error = 80,
|
||||
inappropriate_fallback = 86,
|
||||
user_canceled = 90,
|
||||
missing_extension = 109,
|
||||
unsupported_extension = 110,
|
||||
unrecognized_name = 112,
|
||||
bad_certificate_status_response = 113,
|
||||
unknown_psk_identity = 115,
|
||||
certificate_required = 116,
|
||||
no_application_protocol = 120,
|
||||
_,
|
||||
|
||||
pub fn toError(description: Description) Error!void {
|
||||
switch (description) {
|
||||
.close_notify => {}, // not an error
|
||||
.unexpected_message => return error.TlsAlertUnexpectedMessage,
|
||||
.bad_record_mac => return error.TlsAlertBadRecordMac,
|
||||
.record_overflow => return error.TlsAlertRecordOverflow,
|
||||
.handshake_failure => return error.TlsAlertHandshakeFailure,
|
||||
.bad_certificate => return error.TlsAlertBadCertificate,
|
||||
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
|
||||
.certificate_revoked => return error.TlsAlertCertificateRevoked,
|
||||
.certificate_expired => return error.TlsAlertCertificateExpired,
|
||||
.certificate_unknown => return error.TlsAlertCertificateUnknown,
|
||||
.illegal_parameter => return error.TlsAlertIllegalParameter,
|
||||
.unknown_ca => return error.TlsAlertUnknownCa,
|
||||
.access_denied => return error.TlsAlertAccessDenied,
|
||||
.decode_error => return error.TlsAlertDecodeError,
|
||||
.decrypt_error => return error.TlsAlertDecryptError,
|
||||
.protocol_version => return error.TlsAlertProtocolVersion,
|
||||
.insufficient_security => return error.TlsAlertInsufficientSecurity,
|
||||
.internal_error => return error.TlsAlertInternalError,
|
||||
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
|
||||
.user_canceled => {}, // not an error
|
||||
.missing_extension => return error.TlsAlertMissingExtension,
|
||||
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
|
||||
.unrecognized_name => return error.TlsAlertUnrecognizedName,
|
||||
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
|
||||
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
|
||||
.certificate_required => return error.TlsAlertCertificateRequired,
|
||||
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
|
||||
_ => return error.TlsAlertUnknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
pub const SignatureScheme = enum(u16) {
|
||||
|
||||
@ -39,8 +39,9 @@ output: *std.io.BufferedWriter,
|
||||
///
|
||||
/// Its buffer aliases the buffer of `input`.
|
||||
reader: std.io.BufferedReader,
|
||||
/// Populated under various error conditions.
|
||||
diagnostics: Diagnostics,
|
||||
/// Populated when `error.TlsAlert` is returned.
|
||||
alert: ?tls.Alert,
|
||||
read_err: ?ReadError,
|
||||
|
||||
tls_version: tls.ProtocolVersion,
|
||||
read_seq: u64,
|
||||
@ -69,15 +70,16 @@ application_cipher: tls.ApplicationCipher,
|
||||
/// this connection.
|
||||
ssl_key_log: ?*SslKeyLog,
|
||||
|
||||
pub const Diagnostics = union(enum) {
|
||||
/// Any `ReadFailure` and `WriteFailure` was due to `input` or `output`
|
||||
/// returning the error, respectively.
|
||||
transitive,
|
||||
/// Populated on `error.TlsAlert`.
|
||||
///
|
||||
/// If this isn't a error alert, then it's a closure alert, which makes
|
||||
/// no sense in a handshake.
|
||||
alert: tls.AlertDescription,
|
||||
pub const ReadError = error{
|
||||
/// The alert description will be stored in `alert`.
|
||||
TlsAlert,
|
||||
TlsBadLength,
|
||||
TlsBadRecordMac,
|
||||
TlsConnectionTruncated,
|
||||
TlsDecodeError,
|
||||
TlsRecordOverflow,
|
||||
TlsUnexpectedMessage,
|
||||
TlsIllegalParameter,
|
||||
};
|
||||
|
||||
pub const SslKeyLog = struct {
|
||||
@ -128,14 +130,13 @@ pub const Options = struct {
|
||||
};
|
||||
|
||||
const InitError = error{
|
||||
//OutOfMemory,
|
||||
WriteFailure,
|
||||
ReadFailure,
|
||||
WriteFailed,
|
||||
ReadFailed,
|
||||
InsufficientEntropy,
|
||||
DiskQuota,
|
||||
LockViolation,
|
||||
NotOpenForWriting,
|
||||
/// The alert description will be stored in `Options.Diagnostics.alert`.
|
||||
/// The alert description will be stored in `alert`.
|
||||
TlsAlert,
|
||||
TlsUnexpectedMessage,
|
||||
TlsIllegalParameter,
|
||||
@ -192,7 +193,7 @@ pub fn init(
|
||||
) InitError!void {
|
||||
assert(input.storage.buffer.len >= min_buffer_len);
|
||||
assert(output.buffer.len >= min_buffer_len);
|
||||
client.diagnostics = .transient;
|
||||
client.alert = null;
|
||||
const host = switch (options.host) {
|
||||
.no_verification => "",
|
||||
.explicit => |host| host,
|
||||
@ -417,10 +418,10 @@ pub fn init(
|
||||
switch (ct) {
|
||||
.alert => {
|
||||
ctd.ensure(2) catch continue :fragment;
|
||||
const level = ctd.decode(tls.AlertLevel);
|
||||
const desc = ctd.decode(tls.AlertDescription);
|
||||
_ = level;
|
||||
client.diagnostics = .{ .alert = desc };
|
||||
client.alert = .{
|
||||
.level = ctd.decode(tls.Alert.Level),
|
||||
.description = ctd.decode(tls.Alert.Description),
|
||||
};
|
||||
return error.TlsAlert;
|
||||
},
|
||||
.change_cipher_spec => {
|
||||
@ -924,7 +925,7 @@ pub fn writer(c: *Client) std.io.Writer {
|
||||
fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
|
||||
const c: *Client = @alignCast(@ptrCast(context));
|
||||
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
|
||||
const output = &c.output;
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
var total_clear: usize = 0;
|
||||
var ciphertext_end: usize = 0;
|
||||
@ -942,7 +943,7 @@ fn writeSplat(context: *anyopaque, data: []const []const u8, splat: usize) std.i
|
||||
/// distinguish between a properly finished TLS session, or a truncation
|
||||
/// attack.
|
||||
pub fn end(c: *Client) std.io.Writer.Error!void {
|
||||
const output = &c.output;
|
||||
const output = c.output;
|
||||
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
|
||||
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
|
||||
output.advance(prepared.cleartext_len);
|
||||
@ -1062,16 +1063,16 @@ fn read(
|
||||
context: ?*anyopaque,
|
||||
bw: *std.io.BufferedWriter,
|
||||
limit: std.io.Reader.Limit,
|
||||
) std.io.Reader.RwError!std.io.Reader.Status {
|
||||
) std.io.Reader.RwError!usize {
|
||||
const buf = limit.slice(try bw.writableSliceGreedy(1));
|
||||
const status = try readVec(context, &.{buf});
|
||||
bw.advance(status.len);
|
||||
return status;
|
||||
const n = try readVec(context, &.{buf});
|
||||
bw.advance(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
const c: *Client = @ptrCast(@alignCast(context));
|
||||
if (c.eof()) return .{ .end = true };
|
||||
if (c.eof()) return error.EndOfStream;
|
||||
|
||||
var vp: VecPut = .{ .iovecs = data };
|
||||
|
||||
@ -1093,11 +1094,11 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
if (c.received_close_notify) {
|
||||
c.partial_ciphertext_end = 0;
|
||||
assert(vp.total == amt);
|
||||
return .{ .len = amt, .end = c.eof() };
|
||||
return amt;
|
||||
} else if (amt > 0) {
|
||||
// We don't need more data, so don't call read.
|
||||
assert(vp.total == amt);
|
||||
return .{ .len = amt, .end = c.eof() };
|
||||
return amt;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1149,7 +1150,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
if (c.allow_truncation_attacks) {
|
||||
c.received_close_notify = true;
|
||||
} else {
|
||||
return error.TlsConnectionTruncated;
|
||||
return failRead(c, error.TlsConnectionTruncated);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1168,7 +1169,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
// Perfect split.
|
||||
if (frag.ptr == frag1.ptr) {
|
||||
c.partial_ciphertext_end = c.partial_ciphertext_idx;
|
||||
return .{ .len = vp.total, .end = c.eof() };
|
||||
return vp.total;
|
||||
}
|
||||
frag = frag1;
|
||||
in = 0;
|
||||
@ -1188,7 +1189,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
|
||||
const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
|
||||
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
|
||||
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
|
||||
if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
|
||||
|
||||
const full_record_len = record_len + tls.record_header_len;
|
||||
const second_len = full_record_len - first.len;
|
||||
@ -1208,7 +1209,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
in += 2;
|
||||
_ = legacy_version;
|
||||
const record_len = mem.readInt(u16, frag[in..][0..2], .big);
|
||||
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
|
||||
if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
|
||||
in += 2;
|
||||
const the_end = in + record_len;
|
||||
if (the_end > frag.len) {
|
||||
@ -1255,7 +1256,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
&cleartext_stack_buffer;
|
||||
const cleartext = cleartext_buf[0..ciphertext.len];
|
||||
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
|
||||
return error.TlsBadRecordMac;
|
||||
return failRead(c, error.TlsBadRecordMac);
|
||||
const msg = mem.trimEnd(u8, cleartext, "\x00");
|
||||
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
|
||||
},
|
||||
@ -1287,7 +1288,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
&cleartext_stack_buffer;
|
||||
const cleartext = cleartext_buf[0..ciphertext.len];
|
||||
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
|
||||
return error.TlsBadRecordMac;
|
||||
return failRead(c, error.TlsBadRecordMac);
|
||||
break :cleartext .{ cleartext, ct };
|
||||
},
|
||||
else => unreachable,
|
||||
@ -1296,23 +1297,24 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
c.read_seq = try std.math.add(u64, c.read_seq, 1);
|
||||
switch (inner_ct) {
|
||||
.alert => {
|
||||
if (cleartext.len != 2) return error.TlsDecodeError;
|
||||
const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
|
||||
_ = level;
|
||||
const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
|
||||
switch (desc) {
|
||||
if (cleartext.len != 2) return failRead(c, error.TlsDecodeError);
|
||||
const alert: tls.Alert = .{
|
||||
.level = @enumFromInt(cleartext[0]),
|
||||
.description = @enumFromInt(cleartext[1]),
|
||||
};
|
||||
switch (alert.description) {
|
||||
.close_notify => {
|
||||
c.received_close_notify = true;
|
||||
c.partial_ciphertext_end = c.partial_ciphertext_idx;
|
||||
return .{ .len = vp.total, .end = c.eof() };
|
||||
return vp.total;
|
||||
},
|
||||
.user_canceled => {
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
return failRead(c, error.TlsUnexpectedMessage);
|
||||
},
|
||||
else => {
|
||||
c.diagnostics = .{ .alert = desc };
|
||||
return error.TlsAlert;
|
||||
c.alert = alert;
|
||||
return failRead(c, error.TlsAlert);
|
||||
},
|
||||
}
|
||||
},
|
||||
@ -1324,8 +1326,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
|
||||
ct_i += 3;
|
||||
const next_handshake_i = ct_i + handshake_len;
|
||||
if (next_handshake_i > cleartext.len)
|
||||
return error.TlsBadLength;
|
||||
if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength);
|
||||
const handshake = cleartext[ct_i..next_handshake_i];
|
||||
switch (handshake_type) {
|
||||
.new_session_ticket => {
|
||||
@ -1371,12 +1372,10 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
c.write_seq = 0;
|
||||
},
|
||||
.update_not_requested => {},
|
||||
_ => return error.TlsIllegalParameter,
|
||||
_ => return failRead(c, error.TlsIllegalParameter),
|
||||
}
|
||||
},
|
||||
else => {
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
else => return failRead(c, error.TlsUnexpectedMessage),
|
||||
}
|
||||
ct_i = next_handshake_i;
|
||||
if (ct_i >= cleartext.len) break;
|
||||
@ -1411,7 +1410,7 @@ fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
vp.next(cleartext.len);
|
||||
}
|
||||
},
|
||||
else => return error.TlsUnexpectedMessage,
|
||||
else => return failRead(c, error.TlsUnexpectedMessage),
|
||||
}
|
||||
in = end;
|
||||
}
|
||||
@ -1423,6 +1422,11 @@ fn discard(context: ?*anyopaque, limit: std.io.Reader.Limit) std.io.Reader.Error
|
||||
@panic("TODO");
|
||||
}
|
||||
|
||||
fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
|
||||
c.read_err = err;
|
||||
return error.ReadFailed;
|
||||
}
|
||||
|
||||
fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void {
|
||||
const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
|
||||
defer if (locked) key_log_file.unlock();
|
||||
|
||||
@ -28,7 +28,7 @@ tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.cr
|
||||
/// If non-null, ssl secrets are logged to a stream. Creating such a stream
|
||||
/// allows other processes with access to that stream to decrypt all
|
||||
/// traffic over connections created with this `Client`.
|
||||
ssl_key_logger: ?*std.io.BufferedWriter = null,
|
||||
ssl_key_log: ?*std.io.BufferedWriter = null,
|
||||
|
||||
/// When this is `true`, the next time this client performs an HTTPS request,
|
||||
/// it will first rescan the system for root certificates.
|
||||
@ -342,7 +342,7 @@ pub const Connection = struct {
|
||||
tls.client.init(&tls.reader, &tls.writer, .{
|
||||
.host = .{ .explicit = remote_host },
|
||||
.ca = .{ .bundle = client.ca_bundle },
|
||||
.ssl_key_logger = client.ssl_key_logger,
|
||||
.ssl_key_log = client.ssl_key_log,
|
||||
}) catch return error.TlsInitializationFailed;
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
@ -1671,7 +1671,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
|
||||
const decompress_buffer: []u8 = switch (response.head.content_encoding) {
|
||||
.identity => &.{},
|
||||
.zstd => options.decompress_buffer orelse
|
||||
try client.allocator.alloc(u8, std.compress.zstd.Decompressor.Options.default_window_buffer_len * 2),
|
||||
try client.allocator.alloc(u8, std.compress.zstd.default_window_len * 2),
|
||||
else => options.decompress_buffer orelse try client.allocator.alloc(u8, 8 * 1024),
|
||||
};
|
||||
defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
const std = @import("std.zig");
|
||||
const builtin = @import("builtin");
|
||||
const root = @import("root");
|
||||
const c = std.c;
|
||||
const is_windows = builtin.os.tag == .windows;
|
||||
|
||||
const std = @import("std.zig");
|
||||
const windows = std.os.windows;
|
||||
const posix = std.posix;
|
||||
const math = std.math;
|
||||
const assert = std.debug.assert;
|
||||
const fs = std.fs;
|
||||
const mem = std.mem;
|
||||
const meta = std.meta;
|
||||
const File = std.fs.File;
|
||||
const Allocator = std.mem.Allocator;
|
||||
const Alignment = std.mem.Alignment;
|
||||
|
||||
@ -21,12 +16,6 @@ pub const BufferedReader = @import("io/BufferedReader.zig");
|
||||
pub const BufferedWriter = @import("io/BufferedWriter.zig");
|
||||
pub const AllocatingWriter = @import("io/AllocatingWriter.zig");
|
||||
|
||||
pub const CWriter = @import("io/c_writer.zig").CWriter;
|
||||
pub const cWriter = @import("io/c_writer.zig").cWriter;
|
||||
|
||||
pub const LimitedReader = @import("io/limited_reader.zig").LimitedReader;
|
||||
pub const limitedReader = @import("io/limited_reader.zig").limitedReader;
|
||||
|
||||
pub const MultiWriter = @import("io/multi_writer.zig").MultiWriter;
|
||||
pub const multiWriter = @import("io/multi_writer.zig").multiWriter;
|
||||
|
||||
@ -38,9 +27,6 @@ pub const bitWriter = @import("io/bit_writer.zig").bitWriter;
|
||||
pub const ChangeDetectionStream = @import("io/change_detection_stream.zig").ChangeDetectionStream;
|
||||
pub const changeDetectionStream = @import("io/change_detection_stream.zig").changeDetectionStream;
|
||||
|
||||
pub const FindByteWriter = @import("io/find_byte_writer.zig").FindByteWriter;
|
||||
pub const findByteWriter = @import("io/find_byte_writer.zig").findByteWriter;
|
||||
|
||||
pub const BufferedAtomicFile = @import("io/buffered_atomic_file.zig").BufferedAtomicFile;
|
||||
|
||||
pub const tty = @import("io/tty.zig");
|
||||
@ -63,7 +49,7 @@ pub fn poll(
|
||||
.windows = if (is_windows) .{
|
||||
.first_read_done = false,
|
||||
.overlapped = [1]windows.OVERLAPPED{
|
||||
mem.zeroes(windows.OVERLAPPED),
|
||||
std.mem.zeroes(windows.OVERLAPPED),
|
||||
} ** enum_fields.len,
|
||||
.small_bufs = undefined,
|
||||
.active = .{
|
||||
@ -436,10 +422,10 @@ pub fn PollFiles(comptime StreamEnum: type) type {
|
||||
for (&struct_fields, enum_fields) |*struct_field, enum_field| {
|
||||
struct_field.* = .{
|
||||
.name = enum_field.name,
|
||||
.type = fs.File,
|
||||
.type = std.fs.File,
|
||||
.default_value_ptr = null,
|
||||
.is_comptime = false,
|
||||
.alignment = @alignOf(fs.File),
|
||||
.alignment = @alignOf(std.fs.File),
|
||||
};
|
||||
}
|
||||
return @Type(.{ .@"struct" = .{
|
||||
@ -459,6 +445,5 @@ test {
|
||||
_ = @import("io/bit_reader.zig");
|
||||
_ = @import("io/bit_writer.zig");
|
||||
_ = @import("io/buffered_atomic_file.zig");
|
||||
_ = @import("io/c_writer.zig");
|
||||
_ = @import("io/test.zig");
|
||||
}
|
||||
|
||||
@ -6,6 +6,8 @@ const BufferedReader = std.io.BufferedReader;
|
||||
const Allocator = std.mem.Allocator;
|
||||
const ArrayList = std.ArrayListUnmanaged;
|
||||
|
||||
pub const Limited = @import("Reader/Limited.zig");
|
||||
|
||||
context: ?*anyopaque,
|
||||
vtable: *const VTable,
|
||||
|
||||
@ -252,6 +254,13 @@ pub fn buffered(r: Reader, buffer: []u8) BufferedReader {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn limited(r: Reader, limit: Limit) Limited {
|
||||
return .{
|
||||
.unlimited_reader = r,
|
||||
.remaining = limit,
|
||||
};
|
||||
}
|
||||
|
||||
fn endingRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Limit) RwError!usize {
|
||||
_ = context;
|
||||
_ = bw;
|
||||
|
||||
55
lib/std/io/Reader/Limited.zig
Normal file
55
lib/std/io/Reader/Limited.zig
Normal file
@ -0,0 +1,55 @@
|
||||
const Limited = @This();
|
||||
|
||||
const std = @import("../../std.zig");
|
||||
const Reader = std.io.Reader;
|
||||
const BufferedWriter = std.io.BufferedWriter;
|
||||
|
||||
unlimited_reader: Reader,
|
||||
remaining: Reader.Limit,
|
||||
|
||||
pub fn reader(l: *Limited) Reader {
|
||||
return .{
|
||||
.context = l,
|
||||
.vtable = &.{
|
||||
.read = passthruRead,
|
||||
.readVec = passthruReadVec,
|
||||
.discard = passthruDiscard,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
fn passthruRead(context: ?*anyopaque, bw: *BufferedWriter, limit: Reader.Limit) Reader.RwError!usize {
|
||||
const l: *Limited = @alignCast(@ptrCast(context));
|
||||
const combined_limit = limit.min(l.remaining);
|
||||
const n = try l.unlimited_reader.read(bw, combined_limit);
|
||||
l.remaining.subtract(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
fn passthruDiscard(context: ?*anyopaque, limit: Reader.Limit) Reader.Error!usize {
|
||||
const l: *Limited = @alignCast(@ptrCast(context));
|
||||
const combined_limit = limit.min(l.remaining);
|
||||
const n = try l.unlimited_reader.discard(combined_limit);
|
||||
l.remaining.subtract(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
fn passthruReadVec(context: ?*anyopaque, data: []const []u8) Reader.Error!usize {
|
||||
const l: *Limited = @alignCast(@ptrCast(context));
|
||||
if (data.len == 0) return 0;
|
||||
if (data[0].len >= @intFromEnum(l.limit)) {
|
||||
const n = try l.unlimited_reader.readVec(&.{l.limit.slice(data[0])});
|
||||
l.remaining.subtract(n);
|
||||
return n;
|
||||
}
|
||||
var total: usize = 0;
|
||||
for (data, 0..) |buf, i| {
|
||||
total += buf.len;
|
||||
if (total > @intFromEnum(l.limit)) {
|
||||
const n = try l.unlimited_reader.readVec(data[0..i]);
|
||||
l.remaining.subtract(n);
|
||||
return n;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -95,10 +95,12 @@ pub const Offset = enum(u64) {
|
||||
};
|
||||
|
||||
pub fn writeVec(w: Writer, data: []const []const u8) Error!usize {
|
||||
assert(data.len > 0);
|
||||
return w.vtable.writeSplat(w.context, data, 1);
|
||||
}
|
||||
|
||||
pub fn writeSplat(w: Writer, data: []const []const u8, splat: usize) Error!usize {
|
||||
assert(data.len > 0);
|
||||
return w.vtable.writeSplat(w.context, data, splat);
|
||||
}
|
||||
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
const std = @import("../std.zig");
|
||||
const builtin = @import("builtin");
|
||||
const io = std.io;
|
||||
const testing = std.testing;
|
||||
|
||||
pub const CWriter = io.Writer(*std.c.FILE, std.fs.File.WriteError, cWriterWrite);
|
||||
|
||||
pub fn cWriter(c_file: *std.c.FILE) CWriter {
|
||||
return .{ .context = c_file };
|
||||
}
|
||||
|
||||
fn cWriterWrite(c_file: *std.c.FILE, bytes: []const u8) std.fs.File.WriteError!usize {
|
||||
const amt_written = std.c.fwrite(bytes.ptr, 1, bytes.len, c_file);
|
||||
if (amt_written >= 0) return amt_written;
|
||||
switch (@as(std.c.E, @enumFromInt(std.c._errno().*))) {
|
||||
.SUCCESS => unreachable,
|
||||
.INVAL => unreachable,
|
||||
.FAULT => unreachable,
|
||||
.AGAIN => unreachable, // this is a blocking API
|
||||
.BADF => unreachable, // always a race condition
|
||||
.DESTADDRREQ => unreachable, // connect was never called
|
||||
.DQUOT => return error.DiskQuota,
|
||||
.FBIG => return error.FileTooBig,
|
||||
.IO => return error.InputOutput,
|
||||
.NOSPC => return error.NoSpaceLeft,
|
||||
.PERM => return error.PermissionDenied,
|
||||
.PIPE => return error.BrokenPipe,
|
||||
else => |err| return std.posix.unexpectedErrno(err),
|
||||
}
|
||||
}
|
||||
|
||||
test cWriter {
|
||||
if (!builtin.link_libc or builtin.os.tag == .wasi) return error.SkipZigTest;
|
||||
|
||||
const filename = "tmp_io_test_file.txt";
|
||||
const out_file = std.c.fopen(filename, "w") orelse return error.UnableToOpenTestFile;
|
||||
defer {
|
||||
_ = std.c.fclose(out_file);
|
||||
std.fs.cwd().deleteFileZ(filename) catch {};
|
||||
}
|
||||
|
||||
const writer = cWriter(out_file);
|
||||
try writer.print("hi: {}\n", .{@as(i32, 123)});
|
||||
}
|
||||
@ -1,45 +0,0 @@
|
||||
const std = @import("../std.zig");
|
||||
const io = std.io;
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
|
||||
pub fn LimitedReader(comptime ReaderType: type) type {
|
||||
return struct {
|
||||
inner_reader: ReaderType,
|
||||
bytes_left: u64,
|
||||
|
||||
pub const Error = ReaderType.Error;
|
||||
pub const Reader = io.Reader(*Self, Error, read);
|
||||
|
||||
const Self = @This();
|
||||
|
||||
pub fn read(self: *Self, dest: []u8) Error!usize {
|
||||
const max_read = @min(self.bytes_left, dest.len);
|
||||
const n = try self.inner_reader.read(dest[0..max_read]);
|
||||
self.bytes_left -= n;
|
||||
return n;
|
||||
}
|
||||
|
||||
pub fn reader(self: *Self) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Returns an initialised `LimitedReader`.
|
||||
/// `bytes_left` is a `u64` to be able to take 64 bit file offsets
|
||||
pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) {
|
||||
return .{ .inner_reader = inner_reader, .bytes_left = bytes_left };
|
||||
}
|
||||
|
||||
test "basic usage" {
|
||||
const data = "hello world";
|
||||
var fbs = std.io.fixedBufferStream(data);
|
||||
var early_stream = limitedReader(fbs.reader(), 3);
|
||||
|
||||
var buf: [5]u8 = undefined;
|
||||
try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf));
|
||||
try testing.expectEqualSlices(u8, data[0..3], buf[0..3]);
|
||||
try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf));
|
||||
try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{}));
|
||||
}
|
||||
@ -1919,9 +1919,9 @@ pub const Stream = struct {
|
||||
limit: std.io.Reader.Limit,
|
||||
) std.io.Reader.Error!usize {
|
||||
const buf = limit.slice(try bw.writableSliceGreedy(1));
|
||||
const status = try readVec(context, &.{buf});
|
||||
bw.advance(status.len);
|
||||
return status;
|
||||
const n = try readVec(context, &.{buf});
|
||||
bw.advance(n);
|
||||
return n;
|
||||
}
|
||||
|
||||
fn readVec(context: ?*anyopaque, data: []const []u8) std.io.Reader.Error!usize {
|
||||
|
||||
@ -154,35 +154,30 @@ pub fn findEndRecord(seekable_stream: anytype, stream_len: u64) !EndRecord {
|
||||
pub fn decompress(
|
||||
method: CompressionMethod,
|
||||
uncompressed_size: u64,
|
||||
reader: anytype,
|
||||
writer: anytype,
|
||||
reader: *std.io.BufferedReader,
|
||||
writer: *std.io.BufferedWriter,
|
||||
compressed_remaining: *u64,
|
||||
) !u32 {
|
||||
var hash = std.hash.Crc32.init();
|
||||
|
||||
var total_uncompressed: u64 = 0;
|
||||
switch (method) {
|
||||
.store => {
|
||||
var buf: [4096]u8 = undefined;
|
||||
while (true) {
|
||||
const len = try reader.read(&buf);
|
||||
if (len == 0) break;
|
||||
try writer.writeAll(buf[0..len]);
|
||||
hash.update(buf[0..len]);
|
||||
total_uncompressed += @intCast(len);
|
||||
}
|
||||
reader.writeAll(writer, .limited(compressed_remaining.*)) catch |err| switch (err) {
|
||||
error.EndOfStream => return error.ZipDecompressTruncated,
|
||||
else => |e| return e,
|
||||
};
|
||||
total_uncompressed += compressed_remaining.*;
|
||||
},
|
||||
.deflate => {
|
||||
var br = std.io.bufferedReader(reader);
|
||||
var decompressor = std.compress.flate.decompressor(br.reader());
|
||||
var decompressor: std.compress.flate.Decompressor = .init(reader);
|
||||
while (try decompressor.next()) |chunk| {
|
||||
try writer.writeAll(chunk);
|
||||
hash.update(chunk);
|
||||
total_uncompressed += @intCast(chunk.len);
|
||||
if (total_uncompressed > uncompressed_size)
|
||||
return error.ZipUncompressSizeTooSmall;
|
||||
compressed_remaining.* -= chunk.len;
|
||||
}
|
||||
if (br.end != br.start)
|
||||
return error.ZipDeflateTruncated;
|
||||
},
|
||||
_ => return error.UnsupportedCompressionMethod,
|
||||
}
|
||||
@ -552,15 +547,15 @@ pub fn Iterator(comptime SeekableStream: type) type {
|
||||
@as(u64, @sizeOf(LocalFileHeader)) +
|
||||
local_data_header_offset;
|
||||
try stream.seekTo(local_data_file_offset);
|
||||
var limited_reader = std.io.limitedReader(stream.context.reader(), self.compressed_size);
|
||||
var compressed_remaining: u64 = self.compressed_size;
|
||||
const crc = try decompress(
|
||||
self.compression_method,
|
||||
self.uncompressed_size,
|
||||
limited_reader.reader(),
|
||||
stream.context.reader(),
|
||||
out_file.writer(),
|
||||
&compressed_remaining,
|
||||
);
|
||||
if (limited_reader.bytes_left != 0)
|
||||
return error.ZipDecompressTruncated;
|
||||
if (compressed_remaining != 0) return error.ZipDecompressTruncated;
|
||||
return crc;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1199,7 +1199,7 @@ fn unpackResource(
|
||||
return try unpackTarball(f, tmp_directory.handle, dcp.reader());
|
||||
},
|
||||
.@"tar.zst" => {
|
||||
const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len;
|
||||
const window_size = std.compress.zstd.default_window_len;
|
||||
const window_buffer = try f.arena.allocator().create([window_size]u8);
|
||||
const reader = resource.reader();
|
||||
var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader);
|
||||
|
||||
@ -1490,7 +1490,7 @@ fn readObjectRaw(allocator: Allocator, reader: anytype, size: u64) ![]u8 {
|
||||
///
|
||||
/// The format of the delta data is documented in
|
||||
/// [pack-format](https://git-scm.com/docs/pack-format).
|
||||
fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !void {
|
||||
fn expandDelta(base_object: anytype, delta_reader: *std.io.BufferedReader, writer: *std.io.BufferedWriter) !void {
|
||||
while (true) {
|
||||
const inst: packed struct { value: u7, copy: bool } = @bitCast(delta_reader.readByte() catch |e| switch (e) {
|
||||
error.EndOfStream => return,
|
||||
@ -1521,13 +1521,11 @@ fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !vo
|
||||
var size: u24 = @bitCast(size_parts);
|
||||
if (size == 0) size = 0x10000;
|
||||
try base_object.seekTo(offset);
|
||||
var copy_reader = std.io.limitedReader(base_object.reader(), size);
|
||||
var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init();
|
||||
try fifo.pump(copy_reader.reader(), writer);
|
||||
|
||||
var base_object_br = base_object.reader();
|
||||
try base_object_br.readAll(writer, .limited(size));
|
||||
} else if (inst.value != 0) {
|
||||
var data_reader = std.io.limitedReader(delta_reader, inst.value);
|
||||
var fifo = std.fifo.LinearFifo(u8, .{ .Static = 4096 }).init();
|
||||
try fifo.pump(data_reader.reader(), writer);
|
||||
try delta_reader.readAll(writer, .limited(inst.value));
|
||||
} else {
|
||||
return error.InvalidDeltaInstruction;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user