From 598db831f3ea1267d469162db1a54c2d62ff3e87 Mon Sep 17 00:00:00 2001 From: Sreehari Sreedev Date: Sun, 18 Jul 2021 02:13:04 -0700 Subject: [PATCH] FileProtocol: add Reader, Writer, SeekableStream --- lib/std/os/uefi/protocols/file_protocol.zig | 73 ++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/lib/std/os/uefi/protocols/file_protocol.zig b/lib/std/os/uefi/protocols/file_protocol.zig index 1f80df0af2..ba4fad9c75 100644 --- a/lib/std/os/uefi/protocols/file_protocol.zig +++ b/lib/std/os/uefi/protocols/file_protocol.zig @@ -1,4 +1,6 @@ -const uefi = @import("std").os.uefi; +const std = @import("std"); +const uefi = std.os.uefi; +const io = std.io; const Guid = uefi.Guid; const Time = uefi.Time; const Status = uefi.Status; @@ -16,6 +18,27 @@ pub const FileProtocol = extern struct { _set_info: fn (*const FileProtocol, *align(8) const Guid, usize, [*]const u8) callconv(.C) Status, _flush: fn (*const FileProtocol) callconv(.C) Status, + pub const SeekError = error{SeekError}; + pub const GetSeekPosError = error{GetSeekPosError}; + pub const ReadError = error{ReadError}; + pub const WriteError = error{WriteError}; + + pub const SeekableStream = io.SeekableStream(*const FileProtocol, SeekError, GetSeekPosError, seekTo, seekBy, getPos, getEndPos); + pub const Reader = io.Reader(*const FileProtocol, ReadError, readFn); + pub const Writer = io.Writer(*const FileProtocol, WriteError, writeFn); + + pub fn seekableStream(self: *FileProtocol) SeekableStream { + return .{ .context = self }; + } + + pub fn reader(self: *FileProtocol) Reader { + return .{ .context = self }; + } + + pub fn writer(self: *FileProtocol) Writer { + return .{ .context = self }; + } + pub fn open(self: *const FileProtocol, new_handle: **const FileProtocol, file_name: [*:0]const u16, open_mode: u64, attributes: u64) Status { return self._open(self, new_handle, file_name, open_mode, attributes); } @@ -32,18 +55,66 @@ pub const FileProtocol = extern struct { return self._read(self, buffer_size, buffer); } + fn readFn(self: *const FileProtocol, buffer: []u8) ReadError!usize { + var size: usize = buffer.len; + if (.Success != self.read(&size, buffer.ptr)) return ReadError.ReadError; + return size; + } + pub fn write(self: *const FileProtocol, buffer_size: *usize, buffer: [*]const u8) Status { return self._write(self, buffer_size, buffer); } + fn writeFn(self: *const FileProtocol, bytes: []const u8) WriteError!usize { + var size: usize = bytes.len; + if (.Success != self.write(&size, bytes.ptr)) return WriteError.WriteError; + return size; + } + pub fn getPosition(self: *const FileProtocol, position: *u64) Status { return self._get_position(self, position); } + fn getPos(self: *const FileProtocol) GetSeekPosError!u64 { + var pos: u64 = undefined; + if (.Success != self.getPosition(&pos)) return GetSeekPosError.GetSeekPosError; + return pos; + } + + fn getEndPos(self: *const FileProtocol) GetSeekPosError!u64 { + // preserve the old file position + var pos: u64 = undefined; + if (.Success != self.getPosition(&pos)) return GetSeekPosError.GetSeekPosError; + // seek to end of file to get position = file size + if (.Success != self.setPosition(efi_file_position_end_of_file)) return GetSeekPosError.GetSeekPosError; + // restore the old position + if (.Success != self.setPosition(pos)) return GetSeekPosError.GetSeekPosError; + // return the file size = position + return pos; + } + pub fn setPosition(self: *const FileProtocol, position: u64) Status { return self._set_position(self, position); } + fn seekTo(self: *const FileProtocol, pos: u64) SeekError!void { + if (.Success != self.setPosition(pos)) return SeekError.SeekError; + } + + fn seekBy(self: *const FileProtocol, offset: i64) SeekError!void { + // save the old position and calculate the delta + var pos: u64 = undefined; + if (.Success != self.getPosition(&pos)) return SeekError.SeekError; + const seek_back = offset < 0; + const amt = std.math.absCast(offset); + if (seek_back) { + pos += amt; + } else { + pos -= amt; + } + if (.Success != self.setPosition(pos)) return SeekError.SeekError; + } + pub fn getInfo(self: *const FileProtocol, information_type: *align(8) const Guid, buffer_size: *usize, buffer: [*]u8) Status { return self._get_info(self, information_type, buffer_size, buffer); }