Merge pull request #24740 from ziglang/http-plus-fixes

fetch, tls, and http fixes
This commit is contained in:
Andrew Kelley 2025-08-08 12:33:53 -07:00 committed by GitHub
commit 1ba6838bc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 155 deletions

View File

@ -25,9 +25,7 @@ pub const VTable = struct {
///
/// Returns the number of bytes written, which will be at minimum `0` and
/// at most `limit`. The number returned, including zero, does not indicate
/// end of stream. `limit` is guaranteed to be at least as large as the
/// buffer capacity of `w`, a value whose minimum size is determined by the
/// stream implementation.
/// end of stream.
///
/// The reader's internal logical seek position moves forward in accordance
/// with the number of bytes returned from this function.

View File

@ -61,9 +61,6 @@ pub const ReadError = error{
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsSequenceOverflow,
/// The buffer provided to the read function was not at least
/// `min_buffer_len`.
OutputBufferUndersize,
};
pub const SslKeyLog = struct {
@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
};
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
return error.TlsBadRecordMac;
cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len;
// TODO use scalar, non-slice version
cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len;
},
}
read_seq += 1;
@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_fragment_buf[0..message_len];
const ad = std.mem.toBytes(big(read_seq)) ++
const ad = mem.toBytes(big(read_seq)) ++
record_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = record_decoder.array(P.record_iv_length).*;
const masked_read_seq = read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
&.{ "server finished", &p.transcript_hash.finalResult() },
P.verify_data_length,
),
.app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
.app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block),
} };
const pv = &p.version.tls_1_2;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
client_verify_cleartext.len ..][0..client_verify_cleartext.len],
client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
&client_verify_cleartext,
std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
nonce,
pv.app_cipher.client_write_key,
);
@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
.input = input,
.reader = .{
.buffer = options.read_buffer,
.vtable = &.{ .stream = stream },
.vtable = &.{
.stream = stream,
.readVec = readVec,
},
.seek = 0,
.end = 0,
},
@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord(
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.write_seq));
const operand: V = pad ++ mem.toBytes(big(c.write_seq));
break :nonce @as(V, pv.client_iv) ^ operand;
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord(
record_header.* = .{@intFromEnum(inner_content_type)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int(u16, P.record_iv_length + message_len + P.mac_length);
const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
ciphertext_end += P.record_iv_length;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool {
}
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
// This function writes exclusively to the buffer.
_ = w;
_ = limit;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
// This function writes exclusively to the buffer.
_ = data;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}
fn readIndirect(c: *Client) Reader.Error!usize {
const r = &c.reader;
if (c.eof()) return error.EndOfStream;
const input = c.input;
// If at least one full encrypted record is not buffered, read once.
@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
if (record_end > input.buffered().len) return 0;
}
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
if (r.seek == r.end) {
r.seek = 0;
r.end = 0;
}
const cleartext_buffer = r.buffer[r.end..];
const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
.tls_1_3 => {
const pv = &p.tls_1_3;
@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
break :nonce @as(V, pv.server_iv) ^ operand;
};
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
return failRead(c, error.TlsBadRecordMac);
// TODO use scalar, non-slice version
const msg = mem.trimRight(u8, cleartext, "\x00");
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) };
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
const ad = std.mem.toBytes(big(c.read_seq)) ++
const ad = mem.toBytes(big(c.read_seq)) ++
ad_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
const masked_read_seq = c.read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
};
const ciphertext = input.take(message_len) catch unreachable; // already peeked
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
return failRead(c, error.TlsBadRecordMac);
break :cleartext .{ cleartext, ct };
break :cleartext .{ cleartext.len, ct };
},
else => unreachable,
},
};
const cleartext = cleartext_buffer[0..cleartext_len];
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
switch (inner_ct) {
.alert => {
@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
return 0;
},
.application_data => {
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
try w.writeAll(cleartext);
return cleartext.len;
r.end += cleartext.len;
return 0;
},
else => return failRead(c, error.TlsUnexpectedMessage),
}

View File

@ -292,6 +292,14 @@ pub const ContentEncoding = enum {
});
return map.get(s);
}
pub fn minBufferCapacity(ce: ContentEncoding) usize {
return switch (ce) {
.zstd => std.compress.zstd.default_window_len,
.gzip, .deflate => std.compress.flate.max_window_len,
.compress, .identity => 0,
};
}
};
pub const Connection = enum {
@ -412,7 +420,7 @@ pub const Reader = struct {
/// * `interfaceDecompressing`
pub fn bodyReader(
reader: *Reader,
buffer: []u8,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
) *std.Io.Reader {
@ -421,7 +429,7 @@ pub const Reader = struct {
.chunked => {
reader.state = .{ .body_remaining_chunk_len = .head };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
@ -435,7 +443,7 @@ pub const Reader = struct {
if (content_length) |len| {
reader.state = .{ .body_remaining_content_length = len };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
@ -460,11 +468,12 @@ pub const Reader = struct {
/// * `interface`
pub fn bodyReaderDecompressing(
reader: *Reader,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
content_encoding: ContentEncoding,
decompressor: *Decompressor,
decompression_buffer: []u8,
decompress: *Decompress,
decompress_buffer: []u8,
) *std.Io.Reader {
if (transfer_encoding == .none and content_length == null) {
assert(reader.state == .received_head);
@ -474,22 +483,22 @@ pub const Reader = struct {
return reader.in;
},
.deflate => {
decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .zlib, decompress_buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .gzip, decompress_buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(reader.in, decompress_buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
}
const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length);
return decompressor.init(transfer_reader, decompression_buffer, content_encoding);
const transfer_reader = bodyReader(reader, transfer_buffer, transfer_encoding, content_length);
return decompress.init(transfer_reader, decompress_buffer, content_encoding);
}
fn contentLengthStream(
@ -691,33 +700,33 @@ pub const Reader = struct {
}
};
pub const Decompressor = union(enum) {
pub const Decompress = union(enum) {
flate: std.compress.flate.Decompress,
zstd: std.compress.zstd.Decompress,
none: *std.Io.Reader,
pub fn init(
decompressor: *Decompressor,
decompress: *Decompress,
transfer_reader: *std.Io.Reader,
buffer: []u8,
content_encoding: ContentEncoding,
) *std.Io.Reader {
switch (content_encoding) {
.identity => {
decompressor.* = .{ .none = transfer_reader };
decompress.* = .{ .none = transfer_reader };
return transfer_reader;
},
.deflate => {
decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
@ -794,7 +803,7 @@ pub const BodyWriter = struct {
}
/// When using content-length, asserts that the amount of data sent matches
/// the value sent in the header, then flushes.
/// the value sent in the header, then flushes `http_protocol_output`.
///
/// When using transfer-encoding: chunked, writes the end-of-stream message
/// with empty trailers, then flushes the stream to the system. Asserts any
@ -818,10 +827,13 @@ pub const BodyWriter = struct {
///
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// Does not flush `http_protocol_output`, but does flush `writer`.
///
/// See also:
/// * `end`
/// * `endChunked`
pub fn endUnflushed(w: *BodyWriter) Error!void {
try w.writer.flush();
switch (w.state) {
.end => unreachable,
.content_length => |len| {

View File

@ -13,8 +13,8 @@ const net = std.net;
const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;
const Writer = std.io.Writer;
const Reader = std.io.Reader;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
const Client = @This();
@ -704,12 +704,12 @@ pub const Response = struct {
///
/// See also:
/// * `readerDecompressing`
pub fn reader(response: *Response, buffer: []u8) *Reader {
pub fn reader(response: *Response, transfer_buffer: []u8) *Reader {
response.head.invalidateStrings();
const req = response.request;
if (!req.method.responseHasBody()) return .ending;
const head = &response.head;
return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length);
return req.reader.bodyReader(transfer_buffer, head.transfer_encoding, head.content_length);
}
/// If compressed body has been negotiated this will return decompressed bytes.
@ -723,17 +723,19 @@ pub const Response = struct {
/// * `reader`
pub fn readerDecompressing(
response: *Response,
decompressor: *http.Decompressor,
decompression_buffer: []u8,
transfer_buffer: []u8,
decompress: *http.Decompress,
decompress_buffer: []u8,
) *Reader {
response.head.invalidateStrings();
const head = &response.head;
return response.request.reader.bodyReaderDecompressing(
transfer_buffer,
head.transfer_encoding,
head.content_length,
head.content_encoding,
decompressor,
decompression_buffer,
decompress,
decompress_buffer,
);
}
@ -1322,7 +1324,7 @@ pub const basic_authorization = struct {
const user: Uri.Component = uri.user orelse .empty;
const password: Uri.Component = uri.password orelse .empty;
var dw: std.io.Writer.Discarding = .init(&.{});
var dw: Writer.Discarding = .init(&.{});
user.formatUser(&dw.writer) catch unreachable; // discarding
const user_len = dw.count + dw.writer.end;
@ -1696,8 +1698,8 @@ pub const FetchOptions = struct {
/// `null` means it will be heap-allocated.
decompress_buffer: ?[]u8 = null,
redirect_behavior: ?Request.RedirectBehavior = null,
/// If the server sends a body, it will be stored here.
response_storage: ?ResponseStorage = null,
/// If the server sends a body, it will be written here.
response_writer: ?*Writer = null,
location: Location,
method: ?http.Method = null,
@ -1725,7 +1727,7 @@ pub const FetchOptions = struct {
list: *std.ArrayListUnmanaged(u8),
/// If null then only the existing capacity will be used.
allocator: ?Allocator = null,
append_limit: std.io.Limit = .unlimited,
append_limit: std.Io.Limit = .unlimited,
};
};
@ -1778,7 +1780,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
var response = try req.receiveHead(redirect_buffer);
const storage = options.response_storage orelse {
const response_writer = options.response_writer orelse {
const reader = response.reader(&.{});
_ = reader.discardRemaining() catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
@ -1794,21 +1796,14 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
};
defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);
var decompressor: http.Decompressor = undefined;
const reader = response.readerDecompressing(&decompressor, decompress_buffer);
const list = storage.list;
var transfer_buffer: [64]u8 = undefined;
var decompress: http.Decompress = undefined;
const reader = response.readerDecompressing(&transfer_buffer, &decompress, decompress_buffer);
if (storage.allocator) |allocator| {
reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
else => |e| return e,
};
} else {
const buf = storage.append_limit.slice(list.unusedCapacitySlice());
list.items.len += reader.readSliceShort(buf) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
};
}
_ = reader.streamRemaining(response_writer) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
else => |e| return e,
};
return .{ .status = response.head.status };
}

View File

@ -1006,8 +1006,9 @@ fn echoTests(client: *http.Client, port: u16) !void {
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
defer gpa.free(location);
var body: std.ArrayListUnmanaged(u8) = .empty;
defer body.deinit(gpa);
var body: std.Io.Writer.Allocating = .init(gpa);
defer body.deinit();
try body.ensureUnusedCapacity(64);
const res = try client.fetch(.{
.location = .{ .url = location },
@ -1016,10 +1017,10 @@ fn echoTests(client: *http.Client, port: u16) !void {
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
},
.response_storage = .{ .allocator = gpa, .list = &body },
.response_writer = &body.writer,
});
try expectEqual(.ok, res.status);
try expectEqualStrings("Hello, World!\n", body.items);
try expectEqualStrings("Hello, World!\n", body.getWritten());
}
{ // expect: 100-continue

View File

@ -883,7 +883,9 @@ const Resource = union(enum) {
const HttpRequest = struct {
request: std.http.Client.Request,
response: std.http.Client.Response,
buffer: []u8,
transfer_buffer: []u8,
decompress: std.http.Decompress,
decompress_buffer: []u8,
};
fn deinit(resource: *Resource) void {
@ -892,7 +894,6 @@ const Resource = union(enum) {
.http_request => |*http_request| http_request.request.deinit(),
.git => |*git_resource| {
git_resource.fetch_stream.deinit();
git_resource.session.deinit();
},
.dir => |*dir| dir.close(),
}
@ -902,7 +903,11 @@ const Resource = union(enum) {
fn reader(resource: *Resource) *std.Io.Reader {
return switch (resource.*) {
.file => |*file_reader| return &file_reader.interface,
.http_request => |*http_request| return http_request.response.reader(http_request.buffer),
.http_request => |*http_request| return http_request.response.readerDecompressing(
http_request.transfer_buffer,
&http_request.decompress,
http_request.decompress_buffer,
),
.git => |*g| return &g.fetch_stream.reader,
.dir => unreachable,
};
@ -971,7 +976,6 @@ const FileType = enum {
const init_resource_buffer_size = git.Packet.max_data_length;
fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u8) RunError!void {
const gpa = f.arena.child_allocator;
const arena = f.arena.allocator();
const eb = &f.error_bundle;
@ -995,7 +999,9 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
.request = http_client.request(.GET, uri, .{}) catch |err|
return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})),
.response = undefined,
.buffer = reader_buffer,
.transfer_buffer = reader_buffer,
.decompress_buffer = &.{},
.decompress = undefined,
} };
const request = &resource.http_request.request;
errdefer request.deinit();
@ -1019,6 +1025,7 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
.{ response.head.status, response.head.status.phrase() orelse "" },
));
resource.http_request.decompress_buffer = try arena.alloc(u8, response.head.content_encoding.minBufferCapacity());
return;
}
@ -1027,13 +1034,12 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
{
var transport_uri = uri;
transport_uri.scheme = uri.scheme["git+".len..];
var session = git.Session.init(gpa, http_client, transport_uri, reader_buffer) catch |err| {
return f.fail(f.location_tok, try eb.printString(
"unable to discover remote git server capabilities: {s}",
.{@errorName(err)},
));
var session = git.Session.init(arena, http_client, transport_uri, reader_buffer) catch |err| {
return f.fail(
f.location_tok,
try eb.printString("unable to discover remote git server capabilities: {t}", .{err}),
);
};
errdefer session.deinit();
const want_oid = want_oid: {
const want_ref =
@ -1086,17 +1092,17 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
var want_oid_buf: [git.Oid.max_formatted_length]u8 = undefined;
_ = std.fmt.bufPrint(&want_oid_buf, "{f}", .{want_oid}) catch unreachable;
var fetch_stream: git.Session.FetchStream = undefined;
session.fetch(&fetch_stream, &.{&want_oid_buf}, reader_buffer) catch |err| {
return f.fail(f.location_tok, try eb.printString("unable to create fetch stream: {t}", .{err}));
};
errdefer fetch_stream.deinit();
resource.* = .{ .git = .{
.session = session,
.fetch_stream = fetch_stream,
.fetch_stream = undefined,
.want_oid = want_oid,
} };
const fetch_stream = &resource.git.fetch_stream;
session.fetch(fetch_stream, &.{&want_oid_buf}, reader_buffer) catch |err| {
return f.fail(f.location_tok, try eb.printString("unable to create fetch stream: {t}", .{err}));
};
errdefer fetch_stream.deinit(fetch_stream);
return;
}

View File

@ -644,7 +644,7 @@ pub const Session = struct {
supports_agent: bool,
supports_shallow: bool,
object_format: Oid.Format,
allocator: Allocator,
arena: Allocator,
const agent = "zig/" ++ @import("builtin").zig_version_string;
const agent_capability = std.fmt.comptimePrint("agent={s}\n", .{agent});
@ -652,7 +652,7 @@ pub const Session = struct {
/// Initializes a client session and discovers the capabilities of the
/// server for optimal transport.
pub fn init(
allocator: Allocator,
arena: Allocator,
transport: *std.http.Client,
uri: std.Uri,
/// Asserted to be at least `Packet.max_data_length`
@ -661,13 +661,12 @@ pub const Session = struct {
assert(response_buffer.len >= Packet.max_data_length);
var session: Session = .{
.transport = transport,
.location = try .init(allocator, uri),
.location = try .init(arena, uri),
.supports_agent = false,
.supports_shallow = false,
.object_format = .sha1,
.allocator = allocator,
.arena = arena,
};
errdefer session.deinit();
var capability_iterator: CapabilityIterator = undefined;
try session.getCapabilities(&capability_iterator, response_buffer);
defer capability_iterator.deinit();
@ -690,34 +689,24 @@ pub const Session = struct {
return session;
}
pub fn deinit(session: *Session) void {
session.location.deinit(session.allocator);
session.* = undefined;
}
/// An owned `std.Uri` representing the location of the server (base URI).
const Location = struct {
uri: std.Uri,
fn init(allocator: Allocator, uri: std.Uri) !Location {
const scheme = try allocator.dupe(u8, uri.scheme);
errdefer allocator.free(scheme);
const user = if (uri.user) |user| try std.fmt.allocPrint(allocator, "{f}", .{
fn init(arena: Allocator, uri: std.Uri) !Location {
const scheme = try arena.dupe(u8, uri.scheme);
const user = if (uri.user) |user| try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(user, .formatUser),
}) else null;
errdefer if (user) |s| allocator.free(s);
const password = if (uri.password) |password| try std.fmt.allocPrint(allocator, "{f}", .{
const password = if (uri.password) |password| try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(password, .formatPassword),
}) else null;
errdefer if (password) |s| allocator.free(s);
const host = if (uri.host) |host| try std.fmt.allocPrint(allocator, "{f}", .{
const host = if (uri.host) |host| try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(host, .formatHost),
}) else null;
errdefer if (host) |s| allocator.free(s);
const path = try std.fmt.allocPrint(allocator, "{f}", .{
const path = try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(uri.path, .formatPath),
});
errdefer allocator.free(path);
// The query and fragment are not used as part of the base server URI.
return .{
.uri = .{
@ -730,14 +719,6 @@ pub const Session = struct {
},
};
}
fn deinit(loc: *Location, allocator: Allocator) void {
allocator.free(loc.uri.scheme);
if (loc.uri.user) |user| allocator.free(user.percent_encoded);
if (loc.uri.password) |password| allocator.free(password.percent_encoded);
if (loc.uri.host) |host| allocator.free(host.percent_encoded);
allocator.free(loc.uri.path.percent_encoded);
}
};
/// Returns an iterator over capabilities supported by the server.
@ -745,16 +726,17 @@ pub const Session = struct {
/// The `session.location` is updated if the server returns a redirect, so
/// that subsequent session functions do not need to handle redirects.
fn getCapabilities(session: *Session, it: *CapabilityIterator, response_buffer: []u8) !void {
const arena = session.arena;
assert(response_buffer.len >= Packet.max_data_length);
var info_refs_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
const session_uri_path = try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(session.location.uri.path, .formatPath),
});
defer session.allocator.free(session_uri_path);
info_refs_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(session.allocator, &.{ "/", session_uri_path, "info/refs" }) };
info_refs_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(arena, &.{
"/", session_uri_path, "info/refs",
}) };
}
defer session.allocator.free(info_refs_uri.path.percent_encoded);
info_refs_uri.query = .{ .percent_encoded = "service=git-upload-pack" };
info_refs_uri.fragment = null;
@ -767,6 +749,7 @@ pub const Session = struct {
},
}),
.reader = undefined,
.decompress = undefined,
};
errdefer it.deinit();
const request = &it.request;
@ -777,19 +760,17 @@ pub const Session = struct {
if (response.head.status != .ok) return error.ProtocolError;
const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects;
if (any_redirects_occurred) {
const request_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
const request_uri_path = try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(request.uri.path, .formatPath),
});
defer session.allocator.free(request_uri_path);
if (!mem.endsWith(u8, request_uri_path, "/info/refs")) return error.UnparseableRedirect;
var new_uri = request.uri;
new_uri.path = .{ .percent_encoded = request_uri_path[0 .. request_uri_path.len - "/info/refs".len] };
const new_location: Location = try .init(session.allocator, new_uri);
session.location.deinit(session.allocator);
session.location = new_location;
session.location = try .init(arena, new_uri);
}
it.reader = response.reader(response_buffer);
const decompress_buffer = try arena.alloc(u8, response.head.content_encoding.minBufferCapacity());
it.reader = response.readerDecompressing(response_buffer, &it.decompress, decompress_buffer);
var state: enum { response_start, response_content } = .response_start;
while (true) {
// Some Git servers (at least GitHub) include an additional
@ -821,6 +802,7 @@ pub const Session = struct {
const CapabilityIterator = struct {
request: std.http.Client.Request,
reader: *std.Io.Reader,
decompress: std.http.Decompress,
const Capability = struct {
key: []const u8,
@ -864,16 +846,15 @@ pub const Session = struct {
/// Returns an iterator over refs known to the server.
pub fn listRefs(session: Session, it: *RefIterator, options: ListRefsOptions) !void {
const arena = session.arena;
assert(options.buffer.len >= Packet.max_data_length);
var upload_pack_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
const session_uri_path = try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(session.location.uri.path, .formatPath),
});
defer session.allocator.free(session_uri_path);
upload_pack_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(session.allocator, &.{ "/", session_uri_path, "git-upload-pack" }) };
upload_pack_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(arena, &.{ "/", session_uri_path, "git-upload-pack" }) };
}
defer session.allocator.free(upload_pack_uri.path.percent_encoded);
upload_pack_uri.query = null;
upload_pack_uri.fragment = null;
@ -883,16 +864,14 @@ pub const Session = struct {
try Packet.write(.{ .data = agent_capability }, &body);
}
{
const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={t}\n", .{
const object_format_packet = try std.fmt.allocPrint(arena, "object-format={t}\n", .{
session.object_format,
});
defer session.allocator.free(object_format_packet);
try Packet.write(.{ .data = object_format_packet }, &body);
}
try Packet.write(.delimiter, &body);
for (options.ref_prefixes) |ref_prefix| {
const ref_prefix_packet = try std.fmt.allocPrint(session.allocator, "ref-prefix {s}\n", .{ref_prefix});
defer session.allocator.free(ref_prefix_packet);
const ref_prefix_packet = try std.fmt.allocPrint(arena, "ref-prefix {s}\n", .{ref_prefix});
try Packet.write(.{ .data = ref_prefix_packet }, &body);
}
if (options.include_symrefs) {
@ -913,6 +892,7 @@ pub const Session = struct {
}),
.reader = undefined,
.format = session.object_format,
.decompress = undefined,
};
const request = &it.request;
errdefer request.deinit();
@ -920,13 +900,15 @@ pub const Session = struct {
var response = try request.receiveHead(options.buffer);
if (response.head.status != .ok) return error.ProtocolError;
it.reader = response.reader(options.buffer);
const decompress_buffer = try arena.alloc(u8, response.head.content_encoding.minBufferCapacity());
it.reader = response.readerDecompressing(options.buffer, &it.decompress, decompress_buffer);
}
pub const RefIterator = struct {
format: Oid.Format,
request: std.http.Client.Request,
reader: *std.Io.Reader,
decompress: std.http.Decompress,
pub const Ref = struct {
oid: Oid,
@ -981,16 +963,15 @@ pub const Session = struct {
/// Asserted to be at least `Packet.max_data_length`.
response_buffer: []u8,
) !void {
const arena = session.arena;
assert(response_buffer.len >= Packet.max_data_length);
var upload_pack_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
const session_uri_path = try std.fmt.allocPrint(arena, "{f}", .{
std.fmt.alt(session.location.uri.path, .formatPath),
});
defer session.allocator.free(session_uri_path);
upload_pack_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(session.allocator, &.{ "/", session_uri_path, "git-upload-pack" }) };
upload_pack_uri.path = .{ .percent_encoded = try std.fs.path.resolvePosix(arena, &.{ "/", session_uri_path, "git-upload-pack" }) };
}
defer session.allocator.free(upload_pack_uri.path.percent_encoded);
upload_pack_uri.query = null;
upload_pack_uri.fragment = null;
@ -1000,8 +981,7 @@ pub const Session = struct {
try Packet.write(.{ .data = agent_capability }, &body);
}
{
const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)});
defer session.allocator.free(object_format_packet);
const object_format_packet = try std.fmt.allocPrint(arena, "object-format={s}\n", .{@tagName(session.object_format)});
try Packet.write(.{ .data = object_format_packet }, &body);
}
try Packet.write(.delimiter, &body);
@ -1031,6 +1011,7 @@ pub const Session = struct {
.input = undefined,
.reader = undefined,
.remaining_len = undefined,
.decompress = undefined,
};
const request = &fs.request;
errdefer request.deinit();
@ -1040,7 +1021,8 @@ pub const Session = struct {
var response = try request.receiveHead(&.{});
if (response.head.status != .ok) return error.ProtocolError;
const reader = response.reader(response_buffer);
const decompress_buffer = try arena.alloc(u8, response.head.content_encoding.minBufferCapacity());
const reader = response.readerDecompressing(response_buffer, &fs.decompress, decompress_buffer);
// We are not interested in any of the sections of the returned fetch
// data other than the packfile section, since we aren't doing anything
// complex like ref negotiation (this is a fresh clone).
@ -1079,6 +1061,7 @@ pub const Session = struct {
reader: std.Io.Reader,
err: ?Error = null,
remaining_len: usize,
decompress: std.http.Decompress,
pub fn deinit(fs: *FetchStream) void {
fs.request.deinit();
@ -1131,8 +1114,8 @@ pub const Session = struct {
}
const buf = limit.slice(try w.writableSliceGreedy(1));
const n = @min(buf.len, fs.remaining_len);
@memcpy(buf[0..n], input.buffered()[0..n]);
input.toss(n);
try input.readSliceAll(buf[0..n]);
w.advance(n);
fs.remaining_len -= n;
return n;
}