From 330d353d6e09ac1d48dedd1bfc127f81021b4b1f Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Wed, 15 May 2024 13:54:20 -0400 Subject: [PATCH 1/3] std.crypto: Add ASN1 module with OIDs and DER Add module for mapping ASN1 types to Zig types. See `asn1.Tag.fromZig` for the mapping. Add DER encoder and decoder. See `asn1/test.zig` for example usage of every ASN1 type. This implementation allows ASN1 tags to be overriden with `asn1_tag` and `asn1_tags`: ```zig const MyContainer = (enum | union | struct) { field: u32, pub const asn1_tag = asn1.Tag.init(...); // This specifies a tag's class, and if explicit, additional encoding // rules. pub const asn1_tags = .{ .field = asn1.FieldTag.explicit(0, .context_specific), }; }; ``` Despite having an enum tag type, ASN1 frequently uses OIDs as enum values. This is supported via an `pub const oids` field. ```zig const MyEnum = enum { a, pub const oids = asn1.Oid.StaticMap(MyEnum).initComptime(.{ .a = "1.2.3.4", }); }; ``` Futhermore, a container may choose to implement encoding and decoding however it deems fit. This allows for derived fields since Zig has a far more powerful type system than ASN1. ```zig // ASN1 has no standard way of tagging unions. const MyContainer = union(enum) { derived: PowerfulZigType, const WeakAsn1Type = ...; pub fn encodeDer(self: MyContainer, encoder: *der.Encoder) !void { try encoder.any(WeakAsn1Type{...}); } pub fn decodeDer(decoder: *der.Decoder) !MyContainer { const weak_asn1_type = try decoder.any(WeakAsn1Type); return .{ .derived = PowerfulZigType{...} }; } }; ``` An unfortunate side-effect is that decoding and encoding cannot have complete complete error sets unless we limit what errors users may return. Luckily, PKI ASN1 types are NOT recursive so the inferred error set should be sufficient. Finally, other encodings are possible, but this patch only implements a buffered DER encoder and decoder. In an effort to keep the changeset minimal this PR does not actually use the DER parser for stdlib PKI, but a tested example of how it may be used for Certificate is available [here.](https://github.com/clickingbuttons/asn1/blob/69c5709d/src/Certificate.zig) Closes #19775. --- lib/std/crypto.zig | 2 + lib/std/crypto/asn1.zig | 359 ++++++++++++++++++ lib/std/crypto/asn1/Oid.zig | 204 ++++++++++ lib/std/crypto/asn1/der.zig | 29 ++ lib/std/crypto/asn1/der/ArrayListReverse.zig | 97 +++++ lib/std/crypto/asn1/der/Decoder.zig | 165 ++++++++ lib/std/crypto/asn1/der/Encoder.zig | 162 ++++++++ .../crypto/asn1/der/testdata/all_types.der | Bin 0 -> 49 bytes .../crypto/asn1/der/testdata/id_ecc.pub.der | Bin 0 -> 91 bytes lib/std/crypto/asn1/test.zig | 79 ++++ 10 files changed, 1097 insertions(+) create mode 100644 lib/std/crypto/asn1.zig create mode 100644 lib/std/crypto/asn1/Oid.zig create mode 100644 lib/std/crypto/asn1/der.zig create mode 100644 lib/std/crypto/asn1/der/ArrayListReverse.zig create mode 100644 lib/std/crypto/asn1/der/Decoder.zig create mode 100644 lib/std/crypto/asn1/der/Encoder.zig create mode 100644 lib/std/crypto/asn1/der/testdata/all_types.der create mode 100644 lib/std/crypto/asn1/der/testdata/id_ecc.pub.der create mode 100644 lib/std/crypto/asn1/test.zig diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 3dc48ce146..2812c8c4d1 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -194,6 +194,7 @@ pub const errors = @import("crypto/errors.zig"); pub const tls = @import("crypto/tls.zig"); pub const Certificate = @import("crypto/Certificate.zig"); +pub const asn1 = @import("crypto/asn1.zig"); /// Side-channels mitigations. pub const SideChannelsMitigations = enum { @@ -307,6 +308,7 @@ test { _ = errors; _ = tls; _ = Certificate; + _ = asn1; } test "CSPRNG" { diff --git a/lib/std/crypto/asn1.zig b/lib/std/crypto/asn1.zig new file mode 100644 index 0000000000..b9c0a9e109 --- /dev/null +++ b/lib/std/crypto/asn1.zig @@ -0,0 +1,359 @@ +//! ASN.1 types for public consumption. +const std = @import("std"); +pub const der = @import("./asn1/der.zig"); +pub const Oid = @import("./asn1/Oid.zig"); + +pub const Index = u32; + +pub const Tag = struct { + number: Number, + /// Whether this ASN.1 type contains other ASN.1 types. + constructed: bool, + class: Class, + + /// These values apply to class == .universal. + pub const Number = enum(u16) { + // 0 is reserved by spec + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + oid = 6, + object_descriptor = 7, + real = 9, + enumerated = 10, + embedded = 11, + string_utf8 = 12, + oid_relative = 13, + time = 14, + // 15 is reserved to mean that the tag is >= 32 + sequence = 16, + /// Elements may appear in any order. + sequence_of = 17, + string_numeric = 18, + string_printable = 19, + string_teletex = 20, + string_videotex = 21, + string_ia5 = 22, + utc_time = 23, + generalized_time = 24, + string_graphic = 25, + string_visible = 26, + string_general = 27, + string_universal = 28, + string_char = 29, + string_bmp = 30, + date = 31, + time_of_day = 32, + date_time = 33, + duration = 34, + /// IRI = Internationalized Resource Identifier + oid_iri = 35, + oid_iri_relative = 36, + _, + }; + + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + pub fn init(number: Tag.Number, constructed: bool, class: Tag.Class) Tag { + return .{ .number = number, .constructed = constructed, .class = class }; + } + + pub fn universal(number: Tag.Number, constructed: bool) Tag { + return .{ .number = number, .constructed = constructed, .class = .universal }; + } + + pub fn decode(reader: anytype) !Tag { + const tag1: FirstTag = @bitCast(try reader.readByte()); + var number: u14 = tag1.number; + + if (tag1.number == 15) { + const tag2: NextTag = @bitCast(try reader.readByte()); + number = tag2.number; + if (tag2.continues) { + const tag3: NextTag = @bitCast(try reader.readByte()); + number = (number << 7) + tag3.number; + if (tag3.continues) return error.InvalidLength; + } + } + + return Tag{ + .number = @enumFromInt(number), + .constructed = tag1.constructed, + .class = tag1.class, + }; + } + + pub fn encode(self: Tag, writer: anytype) @TypeOf(writer).Error!void { + var tag1 = FirstTag{ + .number = undefined, + .constructed = self.constructed, + .class = self.class, + }; + + var buffer: [3]u8 = undefined; + var stream = std.io.fixedBufferStream(&buffer); + var writer2 = stream.writer(); + + switch (@intFromEnum(self.number)) { + 0...std.math.maxInt(u5) => |n| { + tag1.number = @intCast(n); + writer2.writeByte(@bitCast(tag1)) catch unreachable; + }, + std.math.maxInt(u5) + 1...std.math.maxInt(u7) => |n| { + tag1.number = 15; + const tag2 = NextTag{ .number = @intCast(n), .continues = false }; + writer2.writeByte(@bitCast(tag1)) catch unreachable; + writer2.writeByte(@bitCast(tag2)) catch unreachable; + }, + else => |n| { + tag1.number = 15; + const tag2 = NextTag{ .number = @intCast(n >> 7), .continues = true }; + const tag3 = NextTag{ .number = @truncate(n), .continues = false }; + writer2.writeByte(@bitCast(tag1)) catch unreachable; + writer2.writeByte(@bitCast(tag2)) catch unreachable; + writer2.writeByte(@bitCast(tag3)) catch unreachable; + }, + } + + _ = try writer.write(stream.getWritten()); + } + + const FirstTag = packed struct(u8) { number: u5, constructed: bool, class: Tag.Class }; + const NextTag = packed struct(u8) { number: u7, continues: bool }; + + pub fn toExpected(self: Tag) ExpectedTag { + return ExpectedTag{ + .number = self.number, + .constructed = self.constructed, + .class = self.class, + }; + } + + pub fn fromZig(comptime T: type) Tag { + switch (@typeInfo(T)) { + .Struct, .Enum, .Union => { + if (@hasDecl(T, "asn1_tag")) return T.asn1_tag; + }, + else => {}, + } + + switch (@typeInfo(T)) { + .Struct, .Union => return universal(.sequence, true), + .Bool => return universal(.boolean, false), + .Int => return universal(.integer, false), + .Enum => |e| { + if (@hasDecl(T, "oids")) return Oid.asn1_tag; + return universal(if (e.is_exhaustive) .enumerated else .integer, false); + }, + .Optional => |o| return fromZig(o.child), + .Null => return universal(.null, false), + else => @compileError("cannot map Zig type to asn1_tag " ++ @typeName(T)), + } + } +}; + +test Tag { + const buf = [_]u8{0xa3}; + var stream = std.io.fixedBufferStream(&buf); + const t = Tag.decode(stream.reader()); + try std.testing.expectEqual(Tag.init(@enumFromInt(3), true, .context_specific), t); +} + +/// A decoded view. +pub const Element = struct { + tag: Tag, + slice: Slice, + + pub const Slice = struct { + start: Index, + end: Index, + + pub fn len(self: Slice) Index { + return self.end - self.start; + } + + pub fn view(self: Slice, bytes: []const u8) []const u8 { + return bytes[self.start..self.end]; + } + }; + + pub const DecodeError = error{ InvalidLength, EndOfStream }; + + /// Safely decode a DER/BER/CER element at `index`: + /// - Ensures length uses shortest form + /// - Ensures length is within `bytes` + /// - Ensures length is less than `std.math.maxInt(Index)` + pub fn decode(bytes: []const u8, index: Index) DecodeError!Element { + var stream = std.io.fixedBufferStream(bytes[index..]); + var reader = stream.reader(); + + const tag = try Tag.decode(reader); + const size_or_len_size = try reader.readByte(); + + var start = index + 2; + var end = start + size_or_len_size; + // short form between 0-127 + if (size_or_len_size < 128) { + if (end > bytes.len) return error.InvalidLength; + } else { + // long form between 0 and std.math.maxInt(u1024) + const len_size: u7 = @truncate(size_or_len_size); + start += len_size; + if (len_size > @sizeOf(Index)) return error.InvalidLength; + + const len = try reader.readVarInt(Index, .big, len_size); + if (len < 128) return error.InvalidLength; // should have used short form + + end = std.math.add(Index, start, len) catch return error.InvalidLength; + if (end > bytes.len) return error.InvalidLength; + } + + return Element{ .tag = tag, .slice = Slice{ .start = start, .end = end } }; + } +}; + +test Element { + const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; + try std.testing.expectEqual(Element{ + .tag = Tag.universal(.sequence, true), + .slice = Element.Slice{ .start = 2, .end = short_form.len }, + }, Element.decode(&short_form, 0)); + + const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; + try std.testing.expectEqual(Element{ + .tag = Tag.universal(.sequence, true), + .slice = Element.Slice{ .start = 3, .end = long_form.len }, + }, Element.decode(&long_form, 0)); +} + +/// For decoding. +pub const ExpectedTag = struct { + number: ?Tag.Number = null, + constructed: ?bool = null, + class: ?Tag.Class = null, + + pub fn init(number: ?Tag.Number, constructed: ?bool, class: ?Tag.Class) ExpectedTag { + return .{ .number = number, .constructed = constructed, .class = class }; + } + + pub fn primitive(number: ?Tag.Number) ExpectedTag { + return .{ .number = number, .constructed = false, .class = .universal }; + } + + pub fn match(self: ExpectedTag, tag: Tag) bool { + if (self.number) |e| { + if (tag.number != e) return false; + } + if (self.constructed) |e| { + if (tag.constructed != e) return false; + } + if (self.class) |e| { + if (tag.class != e) return false; + } + return true; + } +}; + +pub const FieldTag = struct { + number: std.meta.Tag(Tag.Number), + class: Tag.Class, + explicit: bool = true, + + pub fn explicit(number: std.meta.Tag(Tag.Number), class: Tag.Class) FieldTag { + return FieldTag{ .number = number, .class = class, .explicit = true }; + } + + pub fn implicit(number: std.meta.Tag(Tag.Number), class: Tag.Class) FieldTag { + return FieldTag{ .number = number, .class = class, .explicit = false }; + } + + pub fn fromContainer(comptime Container: type, comptime field_name: []const u8) ?FieldTag { + if (@hasDecl(Container, "asn1_tags") and @hasField(@TypeOf(Container.asn1_tags), field_name)) { + return @field(Container.asn1_tags, field_name); + } + + return null; + } + + pub fn toTag(self: FieldTag) Tag { + return Tag.init(@enumFromInt(self.number), self.explicit, self.class); + } +}; + +pub const BitString = struct { + /// Number of bits in rightmost byte that are unused. + right_padding: u3 = 0, + bytes: []const u8, + + pub fn bitLen(self: BitString) usize { + return self.bytes.len * 8 - self.right_padding; + } + + const asn1_tag = Tag.universal(.bitstring, false); + + pub fn decodeDer(decoder: *der.Decoder) !BitString { + const ele = try decoder.element(asn1_tag.toExpected()); + const bytes = decoder.view(ele); + + if (bytes.len < 1) return error.InvalidBitString; + const padding = bytes[0]; + if (padding >= 8) return error.InvalidBitString; + const right_padding: u3 = @intCast(padding); + + // DER requires that unused bits be zero. + if (@ctz(bytes[bytes.len - 1]) < right_padding) return error.InvalidBitString; + + return BitString{ .bytes = bytes[1..], .right_padding = right_padding }; + } + + pub fn encodeDer(self: BitString, encoder: *der.Encoder) !void { + try encoder.writer().writeAll(self.bytes); + try encoder.writer().writeByte(self.right_padding); + try encoder.length(self.bytes.len + 1); + try encoder.tag(asn1_tag); + } +}; + +pub fn Opaque(comptime tag: Tag) type { + return struct { + bytes: []const u8, + + pub fn decodeDer(decoder: *der.Decoder) !@This() { + const ele = try decoder.element(tag.toExpected()); + if (tag.constructed) decoder.index = ele.slice.end; + return .{ .bytes = decoder.view(ele) }; + } + + pub fn encodeDer(self: @This(), encoder: *der.Encoder) !void { + try encoder.tagBytes(tag, self.bytes); + } + }; +} + +/// Use sparingly. +pub const Any = struct { + tag: Tag, + bytes: []const u8, + + pub fn decodeDer(decoder: *der.Decoder) !@This() { + const ele = try decoder.element(ExpectedTag{}); + return .{ .tag = ele.tag, .bytes = decoder.view(ele) }; + } + + pub fn encodeDer(self: @This(), encoder: *der.Encoder) !void { + try encoder.tagBytes(self.tag, self.bytes); + } +}; + +test { + _ = der; + _ = Oid; + _ = @import("asn1/test.zig"); +} diff --git a/lib/std/crypto/asn1/Oid.zig b/lib/std/crypto/asn1/Oid.zig new file mode 100644 index 0000000000..18d10cfda1 --- /dev/null +++ b/lib/std/crypto/asn1/Oid.zig @@ -0,0 +1,204 @@ +//! Globally unique hierarchical identifier made of a sequence of integers. +//! +//! Commonly used to identify standards, algorithms, certificate extensions, +//! organizations, or policy documents. +encoded: []const u8, + +pub const InitError = std.fmt.ParseIntError || error{MissingPrefix} || std.io.FixedBufferStream(u8).WriteError; + +pub fn fromDot(dot_notation: []const u8, out: []u8) InitError!Oid { + var split = std.mem.splitScalar(u8, dot_notation, '.'); + const first_str = split.next() orelse return error.MissingPrefix; + const second_str = split.next() orelse return error.MissingPrefix; + + const first = try std.fmt.parseInt(u8, first_str, 10); + const second = try std.fmt.parseInt(u8, second_str, 10); + + var stream = std.io.fixedBufferStream(out); + var writer = stream.writer(); + + try writer.writeByte(first * 40 + second); + + var i: usize = 1; + while (split.next()) |s| { + var parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + for (0..n_bytes) |j| { + const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); + const digit: u8 = @intCast(@divFloor(parsed, place)); + + try writer.writeByte(digit | 0x80); + parsed -= digit * place; + + i += 1; + } + try writer.writeByte(@intCast(parsed)); + i += 1; + } + + return .{ .encoded = stream.getWritten() }; +} + +test fromDot { + var buf: [256]u8 = undefined; + for (test_cases) |t| { + const actual = try fromDot(t.dot_notation, &buf); + try std.testing.expectEqualSlices(u8, t.encoded, actual.encoded); + } +} + +pub fn toDot(self: Oid, writer: anytype) @TypeOf(writer).Error!void { + const encoded = self.encoded; + const first = @divTrunc(encoded[0], 40); + const second = encoded[0] - first * 40; + try writer.print("{d}.{d}", .{ first, second }); + + var i: usize = 1; + while (i != encoded.len) { + const n_bytes: usize = brk: { + var res: usize = 1; + var j: usize = i; + while (encoded[j] & 0x80 != 0) { + res += 1; + j += 1; + } + break :brk res; + }; + + var n: usize = 0; + for (0..n_bytes) |j| { + const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); + n += place * (encoded[i] & 0b01111111); + i += 1; + } + try writer.print(".{d}", .{n}); + } +} + +test toDot { + var buf: [256]u8 = undefined; + + for (test_cases) |t| { + var stream = std.io.fixedBufferStream(&buf); + try toDot(Oid{ .encoded = t.encoded }, stream.writer()); + try std.testing.expectEqualStrings(t.dot_notation, stream.getWritten()); + } +} + +const TestCase = struct { + encoded: []const u8, + dot_notation: []const u8, + + pub fn init(comptime hex: []const u8, dot_notation: []const u8) TestCase { + return .{ .encoded = &hexToBytes(hex), .dot_notation = dot_notation }; + } +}; + +const test_cases = [_]TestCase{ + // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier + TestCase.init("2b0601040182371514", "1.3.6.1.4.1.311.21.20"), + // https://luca.ntop.org/Teaching/Appunti/asn1.html + TestCase.init("2a864886f70d", "1.2.840.113549"), + // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx + TestCase.init("2a868d20", "1.2.100000"), + TestCase.init("2a864886f70d01010b", "1.2.840.113549.1.1.11"), + TestCase.init("2b6570", "1.3.101.112"), +}; + +pub const asn1_tag = asn1.Tag.init(.oid, false, .universal); + +pub fn decodeDer(decoder: *der.Decoder) !Oid { + const ele = try decoder.element(asn1_tag.toExpected()); + return Oid{ .encoded = decoder.view(ele) }; +} + +pub fn encodeDer(self: Oid, encoder: *der.Encoder) !void { + try encoder.tagBytes(asn1_tag, self.encoded); +} + +fn encodedLen(dot_notation: []const u8) usize { + var buf: [256]u8 = undefined; + const oid = fromDot(dot_notation, &buf) catch unreachable; + return oid.encoded.len; +} + +pub fn encodeComptime(comptime dot_notation: []const u8) [encodedLen(dot_notation)]u8 { + @setEvalBranchQuota(4000); + comptime var buf: [256]u8 = undefined; + const oid = comptime fromDot(dot_notation, &buf) catch unreachable; + return oid.encoded[0..oid.encoded.len].*; +} + +test encodeComptime { + try std.testing.expectEqual( + hexToBytes("2b0601040182371514"), + comptime encodeComptime("1.3.6.1.4.1.311.21.20"), + ); +} + +/// Maps of: +/// - Oid -> enum +/// - Enum -> oid +pub fn StaticMap(comptime Enum: type) type { + const enum_info = @typeInfo(Enum).Enum; + const EnumToOid = std.EnumArray(Enum, []const u8); + const ReturnType = struct { + oid_to_enum: std.StaticStringMap(Enum), + enum_to_oid: EnumToOid, + + pub fn oidToEnum(self: @This(), encoded: []const u8) ?Enum { + return self.oid_to_enum.get(encoded); + } + + pub fn enumToOid(self: @This(), value: Enum) Oid { + const bytes = self.enum_to_oid.get(value); + return .{ .encoded = bytes }; + } + }; + + return struct { + pub fn initComptime(comptime key_pairs: anytype) ReturnType { + const struct_info = @typeInfo(@TypeOf(key_pairs)).Struct; + const error_msg = "Each field of '" ++ @typeName(Enum) ++ "' must map to exactly one OID"; + if (!enum_info.is_exhaustive or enum_info.fields.len != struct_info.fields.len) { + @compileError(error_msg); + } + + comptime var enum_to_oid = EnumToOid.initUndefined(); + + const KeyPair = struct { []const u8, Enum }; + comptime var static_key_pairs: [enum_info.fields.len]KeyPair = undefined; + + comptime for (enum_info.fields, 0..) |f, i| { + if (!@hasField(@TypeOf(key_pairs), f.name)) { + @compileError("Field '" ++ f.name ++ "' missing Oid.StaticMap entry"); + } + const encoded = &encodeComptime(@field(key_pairs, f.name)); + const tag: Enum = @enumFromInt(f.value); + static_key_pairs[i] = .{ encoded, tag }; + enum_to_oid.set(tag, encoded); + }; + + const oid_to_enum = std.StaticStringMap(Enum).initComptime(static_key_pairs); + if (oid_to_enum.values().len != enum_info.fields.len) @compileError(error_msg); + + return ReturnType{ .oid_to_enum = oid_to_enum, .enum_to_oid = enum_to_oid }; + } + }; +} + +/// Strictly for testing. +fn hexToBytes(comptime hex: []const u8) [hex.len / 2]u8 { + var res: [hex.len / 2]u8 = undefined; + _ = std.fmt.hexToBytes(&res, hex) catch unreachable; + return res; +} + +const std = @import("std"); +const Oid = @This(); +const Arc = u32; +const encoding_base = 128; +const Allocator = std.mem.Allocator; +const der = @import("der.zig"); +const asn1 = @import("../asn1.zig"); diff --git a/lib/std/crypto/asn1/der.zig b/lib/std/crypto/asn1/der.zig new file mode 100644 index 0000000000..28e5f77988 --- /dev/null +++ b/lib/std/crypto/asn1/der.zig @@ -0,0 +1,29 @@ +//! Distinguised Encoding Rules as defined in X.690 and X.691. +//! +//! Subset of Basic Encoding Rules (BER) which eliminates flexibility in +//! an effort to acheive normality. Used in PKI. +const std = @import("std"); +const asn1 = @import("../asn1.zig"); + +pub const Decoder = @import("der/Decoder.zig"); +pub const Encoder = @import("der/Encoder.zig"); + +pub fn decode(comptime T: type, encoded: []const u8) !T { + var decoder = Decoder{ .bytes = encoded }; + const res = try decoder.any(T); + std.debug.assert(decoder.index == encoded.len); + return res; +} + +/// Caller owns returned memory. +pub fn encode(allocator: std.mem.Allocator, value: anytype) ![]u8 { + var encoder = Encoder.init(allocator); + defer encoder.deinit(); + try encoder.any(value); + return try encoder.buffer.toOwnedSlice(); +} + +test { + _ = Decoder; + _ = Encoder; +} diff --git a/lib/std/crypto/asn1/der/ArrayListReverse.zig b/lib/std/crypto/asn1/der/ArrayListReverse.zig new file mode 100644 index 0000000000..f580c54546 --- /dev/null +++ b/lib/std/crypto/asn1/der/ArrayListReverse.zig @@ -0,0 +1,97 @@ +//! An ArrayList that grows backwards. Counts nested prefix length fields +//! in O(n) instead of O(n^depth) at the cost of extra buffering. +//! +//! Laid out in memory like: +//! capacity |--------------------------| +//! data |-------------| +data: []u8, +capacity: usize, +allocator: Allocator, + +const ArrayListReverse = @This(); +const Error = Allocator.Error; + +pub fn init(allocator: Allocator) ArrayListReverse { + return .{ .data = &.{}, .capacity = 0, .allocator = allocator }; +} + +pub fn deinit(self: *ArrayListReverse) void { + self.allocator.free(self.allocatedSlice()); +} + +pub fn ensureCapacity(self: *ArrayListReverse, new_capacity: usize) Error!void { + if (self.capacity >= new_capacity) return; + + const old_memory = self.allocatedSlice(); + // Just make a new allocation to not worry about aliasing. + const new_memory = try self.allocator.alloc(u8, new_capacity); + @memcpy(new_memory[new_capacity - self.data.len ..], self.data); + self.allocator.free(old_memory); + self.data.ptr = new_memory.ptr + new_capacity - self.data.len; + self.capacity = new_memory.len; +} + +pub fn prependSlice(self: *ArrayListReverse, data: []const u8) Error!void { + try self.ensureCapacity(self.data.len + data.len); + const old_len = self.data.len; + const new_len = old_len + data.len; + assert(new_len <= self.capacity); + self.data.len = new_len; + + const end = self.data.ptr; + const begin = end - data.len; + const slice = begin[0..data.len]; + @memcpy(slice, data); + self.data.ptr = begin; +} + +pub const Writer = std.io.Writer(*ArrayListReverse, Error, prependSliceSize); +/// Warning: This writer writes backwards. `fn print` will NOT work as expected. +pub fn writer(self: *ArrayListReverse) Writer { + return .{ .context = self }; +} + +fn prependSliceSize(self: *ArrayListReverse, data: []const u8) Error!usize { + try self.prependSlice(data); + return data.len; +} + +fn allocatedSlice(self: *ArrayListReverse) []u8 { + return (self.data.ptr + self.data.len - self.capacity)[0..self.capacity]; +} + +/// Invalidates all element pointers. +pub fn clearAndFree(self: *ArrayListReverse) void { + self.allocator.free(self.allocatedSlice()); + self.data.len = 0; + self.capacity = 0; +} + +/// The caller owns the returned memory. +/// Capacity is cleared, making deinit() safe but unnecessary to call. +pub fn toOwnedSlice(self: *ArrayListReverse) Error![]u8 { + const new_memory = try self.allocator.alloc(u8, self.data.len); + @memcpy(new_memory, self.data); + @memset(self.data, undefined); + self.clearAndFree(); + return new_memory; +} + +const std = @import("std"); +const Allocator = std.mem.Allocator; +const assert = std.debug.assert; +const testing = std.testing; + +test ArrayListReverse { + var b = ArrayListReverse.init(testing.allocator); + defer b.deinit(); + const data: []const u8 = &.{ 4, 5, 6 }; + try b.prependSlice(data); + try testing.expectEqual(data.len, b.data.len); + try testing.expectEqualSlices(u8, data, b.data); + + const data2: []const u8 = &.{ 1, 2, 3 }; + try b.prependSlice(data2); + try testing.expectEqual(data.len + data2.len, b.data.len); + try testing.expectEqualSlices(u8, data2 ++ data, b.data); +} diff --git a/lib/std/crypto/asn1/der/Decoder.zig b/lib/std/crypto/asn1/der/Decoder.zig new file mode 100644 index 0000000000..dc8894c94f --- /dev/null +++ b/lib/std/crypto/asn1/der/Decoder.zig @@ -0,0 +1,165 @@ +//! A secure DER parser that: +//! - Prefers calling `fn decodeDer(self: @This(), decoder: *der.Decoder)` +//! - Does NOT allocate. If you wish to parse lists you can do so lazily +//! with an opaque type. +//! - Does NOT read memory outside `bytes`. +//! - Does NOT return elements with slices outside `bytes`. +//! - Errors on values that do NOT follow DER rules: +//! - Lengths that could be represented in a shorter form. +//! - Booleans that are not 0xff or 0x00. +bytes: []const u8, +index: Index = 0, +/// The field tag of the most recently visited field. +/// This is needed because we might visit an implicitly tagged container with a `fn decodeDer`. +field_tag: ?FieldTag = null, + +pub fn any(self: *Decoder, comptime T: type) !T { + if (std.meta.hasFn(T, "decodeDer")) return try T.decodeDer(self); + + const tag = Tag.fromZig(T).toExpected(); + switch (@typeInfo(T)) { + .Struct => { + const ele = try self.element(tag); + defer self.index = ele.slice.end; // don't force parsing all fields + + var res: T = undefined; + + inline for (std.meta.fields(T)) |f| { + self.field_tag = FieldTag.fromContainer(T, f.name); + + if (self.field_tag) |ft| { + if (ft.explicit) { + const seq = try self.element(ft.toTag().toExpected()); + self.index = seq.slice.start; + self.field_tag = null; + } + } + + @field(res, f.name) = self.any(f.type) catch |err| brk: { + if (f.default_value) |d| { + break :brk @as(*const f.type, @alignCast(@ptrCast(d))).*; + } + return err; + }; + // DER encodes null values by skipping them. + if (@typeInfo(f.type) == .Optional and @field(res, f.name) == null) { + if (f.default_value) |d| { + @field(res, f.name) = @as(*const f.type, @alignCast(@ptrCast(d))).*; + } + } + } + + return res; + }, + .Bool => { + const ele = try self.element(tag); + const bytes = self.view(ele); + if (bytes.len != 1) return error.InvalidBool; + + return switch (bytes[0]) { + 0x00 => false, + 0xff => true, + else => error.InvalidBool, + }; + }, + .Int => { + const ele = try self.element(tag); + const bytes = self.view(ele); + return try int(T, bytes); + }, + .Enum => |e| { + const ele = try self.element(tag); + const bytes = self.view(ele); + if (@hasDecl(T, "oids")) { + return T.oids.oidToEnum(bytes) orelse return error.UnknownOid; + } + return @enumFromInt(try int(e.tag_type, bytes)); + }, + .Optional => |o| return self.any(o.child) catch return null, + else => @compileError("cannot decode type " ++ @typeName(T)), + } +} + +pub fn sequence(self: *Decoder) !Element { + return try self.element(ExpectedTag.init(.sequence, true, .universal)); +} + +pub fn element(self: *Decoder, expected: ExpectedTag) (error{ EndOfStream, UnexpectedElement } || Element.DecodeError)!Element { + if (self.index >= self.bytes.len) return error.EndOfStream; + + const res = try Element.decode(self.bytes, self.index); + var e = expected; + if (self.field_tag) |ft| { + e.number = @enumFromInt(ft.number); + e.class = ft.class; + } + if (!e.match(res.tag)) { + return error.UnexpectedElement; + } + + self.index = if (res.tag.constructed) res.slice.start else res.slice.end; + return res; +} + +pub fn view(self: Decoder, elem: Element) []const u8 { + return elem.slice.view(self.bytes); +} + +fn int(comptime T: type, value: []const u8) error{ NonCanonical, LargeValue }!T { + if (@typeInfo(T).Int.bits % 8 != 0) @compileError("T must be byte aligned"); + + var bytes = value; + if (bytes.len >= 2) { + if (bytes[0] == 0) { + if (@clz(bytes[1]) > 0) return error.NonCanonical; + bytes.ptr += 1; + } + if (bytes[0] == 0xff and @clz(bytes[1]) == 0) return error.NonCanonical; + } + + if (bytes.len > @sizeOf(T)) return error.LargeValue; + if (@sizeOf(T) == 1) return @bitCast(bytes[0]); + + return std.mem.readVarInt(T, bytes, .big); +} + +test int { + try expectEqual(@as(u8, 1), try int(u8, &[_]u8{1})); + try expectError(error.NonCanonical, int(u8, &[_]u8{ 0, 1 })); + try expectError(error.NonCanonical, int(u8, &[_]u8{ 0xff, 0xff })); + + const big = [_]u8{ 0xef, 0xff }; + try expectError(error.LargeValue, int(u8, &big)); + try expectEqual(0xefff, int(u16, &big)); +} + +test Decoder { + var parser = Decoder{ .bytes = @embedFile("./testdata/id_ecc.pub.der") }; + const seq = try parser.sequence(); + + { + const seq2 = try parser.sequence(); + _ = try parser.element(ExpectedTag.init(.oid, false, .universal)); + _ = try parser.element(ExpectedTag.init(.oid, false, .universal)); + + try std.testing.expectEqual(parser.index, seq2.slice.end); + } + _ = try parser.element(ExpectedTag.init(.bitstring, false, .universal)); + + try std.testing.expectEqual(parser.index, seq.slice.end); + try std.testing.expectEqual(parser.index, parser.bytes.len); +} + +const std = @import("std"); +const builtin = @import("builtin"); +const asn1 = @import("../../asn1.zig"); +const Oid = @import("../Oid.zig"); + +const expectEqual = std.testing.expectEqual; +const expectError = std.testing.expectError; +const Decoder = @This(); +const Index = asn1.Index; +const Tag = asn1.Tag; +const FieldTag = asn1.FieldTag; +const ExpectedTag = asn1.ExpectedTag; +const Element = asn1.Element; diff --git a/lib/std/crypto/asn1/der/Encoder.zig b/lib/std/crypto/asn1/der/Encoder.zig new file mode 100644 index 0000000000..709fbe75c6 --- /dev/null +++ b/lib/std/crypto/asn1/der/Encoder.zig @@ -0,0 +1,162 @@ +//! A buffered DER encoder. +//! +//! Prefers calling container's `fn encodeDer(self: @This(), encoder: *der.Encoder)`. +//! That function should encode values, lengths, then tags. +buffer: ArrayListReverse, +/// The field tag set by a parent container. +/// This is needed because we might visit an implicitly tagged container with a `fn encodeDer`. +field_tag: ?FieldTag = null, + +pub fn init(allocator: std.mem.Allocator) Encoder { + return Encoder{ .buffer = ArrayListReverse.init(allocator) }; +} + +pub fn deinit(self: *Encoder) void { + self.buffer.deinit(); +} + +pub fn any(self: *Encoder, val: anytype) !void { + const T = @TypeOf(val); + try self.anyTag(Tag.fromZig(T), val); +} + +fn anyTag(self: *Encoder, tag_: Tag, val: anytype) !void { + const T = @TypeOf(val); + if (std.meta.hasFn(T, "encodeDer")) return try val.encodeDer(self); + const start = self.buffer.data.len; + const merged_tag = self.mergedTag(tag_); + + switch (@typeInfo(T)) { + .Struct => |info| { + inline for (0..info.fields.len) |i| { + const f = info.fields[info.fields.len - i - 1]; + const field_val = @field(val, f.name); + const field_tag = FieldTag.fromContainer(T, f.name); + + // > The encoding of a set value or sequence value shall not include an encoding for any + // > component value which is equal to its default value. + const is_default = if (f.is_comptime) false else if (f.default_value) |v| brk: { + const default_val: *const f.type = @alignCast(@ptrCast(v)); + break :brk std.mem.eql(u8, std.mem.asBytes(default_val), std.mem.asBytes(&field_val)); + } else false; + + if (!is_default) { + const start2 = self.buffer.data.len; + self.field_tag = field_tag; + // will merge with self.field_tag. + // may mutate self.field_tag. + try self.anyTag(Tag.fromZig(f.type), field_val); + if (field_tag) |ft| { + if (ft.explicit) { + try self.length(self.buffer.data.len - start2); + try self.tag(ft.toTag()); + self.field_tag = null; + } + } + } + } + }, + .Bool => try self.buffer.prependSlice(&[_]u8{if (val) 0xff else 0}), + .Int => try self.int(T, val), + .Enum => |e| { + if (@hasDecl(T, "oids")) { + return self.any(T.oids.enumToOid(val)); + } else { + try self.int(e.tag_type, @intFromEnum(val)); + } + }, + .Optional => if (val) |v| return try self.anyTag(tag_, v), + .Null => {}, + else => @compileError("cannot encode type " ++ @typeName(T)), + } + + try self.length(self.buffer.data.len - start); + try self.tag(merged_tag); +} + +pub fn tag(self: *Encoder, tag_: Tag) !void { + const t = self.mergedTag(tag_); + try t.encode(self.writer()); +} + +pub fn tagBytes(self: *Encoder, tag_: Tag, bytes: []const u8) !void { + try self.buffer.prependSlice(bytes); + try self.length(bytes.len); + try self.tag(tag_); +} + +fn mergedTag(self: *Encoder, tag_: Tag) Tag { + var res = tag_; + if (self.field_tag) |ft| { + if (!ft.explicit) { + res.number = @enumFromInt(ft.number); + res.class = ft.class; + } + } + return res; +} + +pub fn length(self: *Encoder, len: usize) !void { + const writer_ = self.writer(); + if (len < 128) { + try writer_.writeInt(u8, @intCast(len), .big); + return; + } + inline for ([_]type{ u8, u16, u32 }) |T| { + if (len < std.math.maxInt(T)) { + try writer_.writeInt(T, @intCast(len), .big); + try writer_.writeInt(u8, @sizeOf(T) | 0x80, .big); + return; + } + } + return error.InvalidLength; +} + +/// Warning: This writer writes backwards. `fn print` will NOT work as expected. +pub fn writer(self: *Encoder) ArrayListReverse.Writer { + return self.buffer.writer(); +} + +fn int(self: *Encoder, comptime T: type, value: T) !void { + const big = std.mem.nativeTo(T, value, .big); + const big_bytes = std.mem.asBytes(&big); + + const bits_needed = @bitSizeOf(T) - @clz(value); + const needs_padding: u1 = if (value == 0) + 1 + else if (bits_needed > 8) brk: { + const RightShift = std.meta.Int(.unsigned, @bitSizeOf(@TypeOf(bits_needed)) - 1); + const right_shift: RightShift = @intCast(bits_needed - 9); + break :brk if (value >> right_shift == 0x1ff) 1 else 0; + } else 0; + const bytes_needed = try std.math.divCeil(usize, bits_needed, 8) + needs_padding; + + const writer_ = self.writer(); + for (0..bytes_needed - needs_padding) |i| try writer_.writeByte(big_bytes[big_bytes.len - i - 1]); + if (needs_padding == 1) try writer_.writeByte(0); +} + +test int { + const allocator = std.testing.allocator; + var encoder = Encoder.init(allocator); + defer encoder.deinit(); + + try encoder.int(u8, 0); + try std.testing.expectEqualSlices(u8, &[_]u8{0}, encoder.buffer.data); + + encoder.buffer.clearAndFree(); + try encoder.int(u16, 0x00ff); + try std.testing.expectEqualSlices(u8, &[_]u8{0xff}, encoder.buffer.data); + + encoder.buffer.clearAndFree(); + try encoder.int(u32, 0xffff); + try std.testing.expectEqualSlices(u8, &[_]u8{ 0, 0xff, 0xff }, encoder.buffer.data); +} + +const std = @import("std"); +const Oid = @import("../Oid.zig"); +const asn1 = @import("../../asn1.zig"); +const ArrayListReverse = @import("./ArrayListReverse.zig"); +const Tag = asn1.Tag; +const FieldTag = asn1.FieldTag; +const Encoder = @This(); diff --git a/lib/std/crypto/asn1/der/testdata/all_types.der b/lib/std/crypto/asn1/der/testdata/all_types.der new file mode 100644 index 0000000000000000000000000000000000000000..a4f784938b54d2f3aef4d4049a02c35342f106ea GIT binary patch literal 49 zcmXreU%<@7$h44^nVErQK@+nUGYbz(VsT0u3rku`aUv5FBjaK=CPo$}MpiMfH~`uM B2-pAs literal 0 HcmV?d00001 diff --git a/lib/std/crypto/asn1/der/testdata/id_ecc.pub.der b/lib/std/crypto/asn1/der/testdata/id_ecc.pub.der new file mode 100644 index 0000000000000000000000000000000000000000..3964a8f0fcaa5710ea9b8f198b1dfb7ee9ced25e GIT binary patch literal 91 zcmXqrG!SNE*J|@PXUoLM#sOw9GqN)~F|aHUHClXN0z>eI*+6v0WsWU! u`F>`87Cq{M`@R?*ohi~YMa0(sOODQVJG(5$>+e3 Date: Thu, 16 May 2024 13:11:58 -0400 Subject: [PATCH 2/3] std.crypto.asn1: add short comments and der tests --- lib/std/crypto/asn1/Oid.zig | 8 +++++++- lib/std/crypto/asn1/der.zig | 26 ++++++++++++++++++++++++++ lib/std/crypto/asn1/der/Decoder.zig | 9 ++++++++- lib/std/crypto/asn1/der/Encoder.zig | 16 ++++++++++------ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/lib/std/crypto/asn1/Oid.zig b/lib/std/crypto/asn1/Oid.zig index 18d10cfda1..897b4050bf 100644 --- a/lib/std/crypto/asn1/Oid.zig +++ b/lib/std/crypto/asn1/Oid.zig @@ -123,7 +123,8 @@ fn encodedLen(dot_notation: []const u8) usize { return oid.encoded.len; } -pub fn encodeComptime(comptime dot_notation: []const u8) [encodedLen(dot_notation)]u8 { +/// Returns encoded bytes of OID. +fn encodeComptime(comptime dot_notation: []const u8) [encodedLen(dot_notation)]u8 { @setEvalBranchQuota(4000); comptime var buf: [256]u8 = undefined; const oid = comptime fromDot(dot_notation, &buf) catch unreachable; @@ -137,6 +138,11 @@ test encodeComptime { ); } +pub fn fromDotComptime(comptime dot_notation: []const u8) Oid { + const tmp = comptime encodeComptime(dot_notation); + return Oid{ .encoded = &tmp }; +} + /// Maps of: /// - Oid -> enum /// - Enum -> oid diff --git a/lib/std/crypto/asn1/der.zig b/lib/std/crypto/asn1/der.zig index 28e5f77988..4395f9f3b6 100644 --- a/lib/std/crypto/asn1/der.zig +++ b/lib/std/crypto/asn1/der.zig @@ -23,6 +23,32 @@ pub fn encode(allocator: std.mem.Allocator, value: anytype) ![]u8 { return try encoder.buffer.toOwnedSlice(); } +test encode { + // https://lapo.it/asn1js/#MAgGAyoDBAIBBA + const Value = struct { a: asn1.Oid, b: i32 }; + const test_case = .{ + .value = Value{ .a = asn1.Oid.fromDotComptime("1.2.3.4"), .b = 4 }, + .encoded = &[_]u8{ 0x30, 0x08, 0x06, 0x03, 0x2A, 0x03, 0x04, 0x02, 0x01, 0x04 }, + }; + const allocator = std.testing.allocator; + const actual = try encode(allocator, test_case.value); + defer allocator.free(actual); + + try std.testing.expectEqualSlices(u8, test_case.encoded, actual); +} + +test decode { + // https://lapo.it/asn1js/#MAgGAyoDBAIBBA + const Value = struct { a: asn1.Oid, b: i32 }; + const test_case = .{ + .value = Value{ .a = asn1.Oid.fromDotComptime("1.2.3.4"), .b = 4 }, + .encoded = &[_]u8{ 0x30, 0x08, 0x06, 0x03, 0x2A, 0x03, 0x04, 0x02, 0x01, 0x04 }, + }; + const decoded = try decode(Value, test_case.encoded); + + try std.testing.expectEqualDeep(test_case.value, decoded); +} + test { _ = Decoder; _ = Encoder; diff --git a/lib/std/crypto/asn1/der/Decoder.zig b/lib/std/crypto/asn1/der/Decoder.zig index dc8894c94f..2eedbee957 100644 --- a/lib/std/crypto/asn1/der/Decoder.zig +++ b/lib/std/crypto/asn1/der/Decoder.zig @@ -13,6 +13,7 @@ index: Index = 0, /// This is needed because we might visit an implicitly tagged container with a `fn decodeDer`. field_tag: ?FieldTag = null, +/// Expect a value. pub fn any(self: *Decoder, comptime T: type) !T { if (std.meta.hasFn(T, "decodeDer")) return try T.decodeDer(self); @@ -80,11 +81,16 @@ pub fn any(self: *Decoder, comptime T: type) !T { } } +//// Expect a sequence. pub fn sequence(self: *Decoder) !Element { return try self.element(ExpectedTag.init(.sequence, true, .universal)); } -pub fn element(self: *Decoder, expected: ExpectedTag) (error{ EndOfStream, UnexpectedElement } || Element.DecodeError)!Element { +//// Expect an element. +pub fn element( + self: *Decoder, + expected: ExpectedTag, +) (error{ EndOfStream, UnexpectedElement } || Element.DecodeError)!Element { if (self.index >= self.bytes.len) return error.EndOfStream; const res = try Element.decode(self.bytes, self.index); @@ -101,6 +107,7 @@ pub fn element(self: *Decoder, expected: ExpectedTag) (error{ EndOfStream, Unexp return res; } +/// View of element bytes. pub fn view(self: Decoder, elem: Element) []const u8 { return elem.slice.view(self.bytes); } diff --git a/lib/std/crypto/asn1/der/Encoder.zig b/lib/std/crypto/asn1/der/Encoder.zig index 709fbe75c6..939a1a5aa7 100644 --- a/lib/std/crypto/asn1/der/Encoder.zig +++ b/lib/std/crypto/asn1/der/Encoder.zig @@ -15,6 +15,7 @@ pub fn deinit(self: *Encoder) void { self.buffer.deinit(); } +/// Encode any value. pub fn any(self: *Encoder, val: anytype) !void { const T = @TypeOf(val); try self.anyTag(Tag.fromZig(T), val); @@ -74,17 +75,12 @@ fn anyTag(self: *Encoder, tag_: Tag, val: anytype) !void { try self.tag(merged_tag); } +/// Encode a tag. pub fn tag(self: *Encoder, tag_: Tag) !void { const t = self.mergedTag(tag_); try t.encode(self.writer()); } -pub fn tagBytes(self: *Encoder, tag_: Tag, bytes: []const u8) !void { - try self.buffer.prependSlice(bytes); - try self.length(bytes.len); - try self.tag(tag_); -} - fn mergedTag(self: *Encoder, tag_: Tag) Tag { var res = tag_; if (self.field_tag) |ft| { @@ -96,6 +92,7 @@ fn mergedTag(self: *Encoder, tag_: Tag) Tag { return res; } +/// Encode a length. pub fn length(self: *Encoder, len: usize) !void { const writer_ = self.writer(); if (len < 128) { @@ -112,6 +109,13 @@ pub fn length(self: *Encoder, len: usize) !void { return error.InvalidLength; } +/// Encode a tag and length-prefixed bytes. +pub fn tagBytes(self: *Encoder, tag_: Tag, bytes: []const u8) !void { + try self.buffer.prependSlice(bytes); + try self.length(bytes.len); + try self.tag(tag_); +} + /// Warning: This writer writes backwards. `fn print` will NOT work as expected. pub fn writer(self: *Encoder) ArrayListReverse.Writer { return self.buffer.writer(); From a51bc1d1d11717e9aab72121918ddb09fb3333f7 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Thu, 16 May 2024 13:13:42 -0400 Subject: [PATCH 3/3] std.crypto.asn1: add lapo.it url for all_types.der --- lib/std/crypto/asn1/test.zig | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/std/crypto/asn1/test.zig b/lib/std/crypto/asn1/test.zig index b616491760..261f5a4310 100644 --- a/lib/std/crypto/asn1/test.zig +++ b/lib/std/crypto/asn1/test.zig @@ -61,6 +61,7 @@ test AllTypes { .g = .{ .inner = .{ .a = 4, .b = 5 }, .sum = 9 }, .h = .{ .tag = Tag.init(.string_ia5, false, .universal), .bytes = "asdf" }, }; + // https://lapo.it/asn1js/#MC-gAwIBAqEFAwMABKCCAyoDBAwEYXNkZgQEZmRzYQICAQGjBgIBBAIBBRYEYXNkZg const path = "./der/testdata/all_types.der"; const encoded = @embedFile(path); const actual = try asn1.der.decode(AllTypes, encoded);