From 88693a56fc6a7129d4967957f387b2e404e32f4d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 2 Oct 2025 20:45:16 -0700 Subject: [PATCH] std.Io.net.HostName: implement DNS name expansion --- lib/std/Io/net/HostName.zig | 96 +++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 52 deletions(-) diff --git a/lib/std/Io/net/HostName.zig b/lib/std/Io/net/HostName.zig index edc8d45965..c8924e1fa6 100644 --- a/lib/std/Io/net/HostName.zig +++ b/lib/std/Io/net/HostName.zig @@ -51,6 +51,7 @@ pub const LookupError = error{ ResolvConfParseFailed, InvalidDnsARecord, InvalidDnsAAAARecord, + InvalidDnsCnameRecord, NameServerFailure, } || Io.Timestamp.Error || IpAddress.BindError || Io.File.OpenError || Io.File.Reader.Error || Io.Cancelable; @@ -381,16 +382,8 @@ fn lookupDns(io: Io, lookup_canon_name: []const u8, rc: *const ResolvConf, optio addresses_len += 1; }, std.posix.RR.CNAME => { - _ = &canonical_name; - @panic("TODO"); - //var tmp: [256]u8 = undefined; - //// Returns len of compressed name. strlen to get canon name. - //_ = try posix.dn_expand(packet, record.data, &tmp); - //const canon_name = mem.sliceTo(&tmp, 0); - //if (isValidHostName(canon_name)) { - // ctx.canon.items.len = 0; - // try ctx.canon.appendSlice(gpa, canon_name); - //} + _, canonical_name = expand(record.packet, record.data_off, options.canonical_name_buffer) catch + return error.InvalidDnsCnameRecord; }, else => continue, }; @@ -525,51 +518,50 @@ fn writeResolutionQuery(q: *[280]u8, op: u4, dname: []const u8, class: u8, ty: u return n; } -pub const ExpandDomainNameError = error{InvalidDnsPacket}; +pub const ExpandError = error{InvalidDnsPacket} || InitError; -pub fn expandDomainName( - msg: []const u8, - comp_dn: []const u8, - exp_dn: []u8, -) ExpandDomainNameError!usize { - // This implementation is ported from musl libc. - // A more idiomatic "ziggy" implementation would be welcome. - var p = comp_dn.ptr; - var len: usize = std.math.maxInt(usize); - const end = msg.ptr + msg.len; - if (p == end or exp_dn.len == 0) return error.InvalidDnsPacket; - var dest = exp_dn.ptr; - const dend = dest + @min(exp_dn.len, 254); - // detect reference loop using an iteration counter - var i: usize = 0; - while (i < msg.len) : (i += 2) { - // loop invariants: p= msg.len) return error.InvalidDnsPacket; - p = msg.ptr + j; - } else if (p[0] != 0) { - if (dest != exp_dn.ptr) { - dest[0] = '.'; - dest += 1; - } - var j = p[0]; - p += 1; - if (j >= @intFromPtr(end) - @intFromPtr(p) or j >= @intFromPtr(dend) - @intFromPtr(dest)) { - return error.InvalidDnsPacket; - } - while (j != 0) { - j -= 1; - dest[0] = p[0]; - dest += 1; - p += 1; +/// Decompresses a DNS name. +/// +/// Returns number of bytes consumed from `packet` starting at `i`, +/// along with the expanded `HostName`. +/// +/// Asserts `buffer` is has length at least `max_len`. +pub fn expand(noalias packet: []const u8, start_i: usize, noalias dest_buffer: []u8) ExpandError!struct { usize, HostName } { + const dest = dest_buffer[0..max_len]; + + var i = start_i; + var dest_i: usize = 0; + var len: ?usize = null; + + // Detect reference loop using an iteration counter. + for (0..packet.len / 2) |_| { + if (i >= packet.len) return error.InvalidDnsPacket; + + const c = packet[i]; + if ((c & 0xc0) != 0) { + if (i + 1 >= packet.len) return error.InvalidDnsPacket; + const j: usize = (@as(usize, c & 0x3F) << 8) | packet[i + 1]; + if (j >= packet.len) return error.InvalidDnsPacket; + if (len == null) len = (i + 2) - start_i; + i = j; + } else if (c != 0) { + if (dest_i != 0) { + dest[dest_i] = '.'; + dest_i += 1; } + const label_len: usize = c; + if (i + 1 + label_len > packet.len) return error.InvalidDnsPacket; + if (dest_i + label_len + 1 > dest.len) return error.InvalidDnsPacket; + @memcpy(dest[dest_i..][0..label_len], packet[i + 1 ..][0..label_len]); + dest_i += label_len; + i += 1 + label_len; } else { - dest[0] = 0; - if (len == std.math.maxInt(usize)) len = @intFromPtr(p) + 1 - @intFromPtr(comp_dn.ptr); - return len; + dest[dest_i] = 0; + dest_i += 1; + return .{ + len orelse i - start_i + 1, + try .init(dest[0..dest_i]), + }; } } return error.InvalidDnsPacket;